/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.extensions.sql.impl.transform;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.extensions.sql.BeamSqlSeekableTable;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.RowType;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.rel.core.JoinRelType;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.rex.RexCall;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.rex.RexInputRef;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.rex.RexNode;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.util.Pair;

public class BeamJoinTransforms {
    private static Row combineTwoRowsIntoOne(Row leftRow, Row rightRow, boolean swap) {
        if (swap) {
            return BeamJoinTransforms.combineTwoRowsIntoOneHelper(rightRow, leftRow);
        }
        return BeamJoinTransforms.combineTwoRowsIntoOneHelper(leftRow, rightRow);
    }

    private static Row combineTwoRowsIntoOneHelper(Row leftRow, Row rightRow) {
        ArrayList names = new ArrayList(leftRow.getFieldCount() + rightRow.getFieldCount());
        names.addAll(leftRow.getRowType().getFieldNames());
        names.addAll(rightRow.getRowType().getFieldNames());
        ArrayList types = new ArrayList(leftRow.getFieldCount() + rightRow.getFieldCount());
        types.addAll(leftRow.getRowType().getRowCoder().getCoders());
        types.addAll(rightRow.getRowType().getRowCoder().getCoders());
        RowType type = RowType.fromNamesAndCoders(names, types);
        return Row.withRowType((RowType)type).addValues(leftRow.getValues()).addValues(rightRow.getValues()).build();
    }

    public static class JoinAsLookup
    extends PTransform<PCollection<Row>, PCollection<Row>> {
        BeamSqlSeekableTable seekableTable;
        RowType lkpRowType;
        RowType joinSubsetType;
        List<Integer> factJoinIdx;

        public JoinAsLookup(RexNode joinCondition, BeamSqlSeekableTable seekableTable, RowType lkpRowType, int factTableColSize) {
            this.seekableTable = seekableTable;
            this.lkpRowType = lkpRowType;
            this.joinFieldsMapping(joinCondition, factTableColSize);
        }

        private void joinFieldsMapping(RexNode joinCondition, int factTableColSize) {
            this.factJoinIdx = new ArrayList<Integer>();
            ArrayList<String> lkpJoinFieldsName = new ArrayList<String>();
            ArrayList<Coder> lkpJoinFieldsType = new ArrayList<Coder>();
            RexCall call = (RexCall)joinCondition;
            if ("AND".equals(call.getOperator().getName())) {
                List<RexNode> operands = call.getOperands();
                for (RexNode rexNode : operands) {
                    this.factJoinIdx.add(((RexInputRef)((RexCall)rexNode).getOperands().get(0)).getIndex());
                    int lkpJoinIdx = ((RexInputRef)((RexCall)rexNode).getOperands().get(1)).getIndex() - factTableColSize;
                    lkpJoinFieldsName.add(this.lkpRowType.getFieldName(lkpJoinIdx));
                    lkpJoinFieldsType.add(this.lkpRowType.getFieldCoder(lkpJoinIdx));
                }
            } else if ("=".equals(call.getOperator().getName())) {
                this.factJoinIdx.add(((RexInputRef)call.getOperands().get(0)).getIndex());
                int lkpJoinIdx = ((RexInputRef)call.getOperands().get(1)).getIndex() - factTableColSize;
                lkpJoinFieldsName.add(this.lkpRowType.getFieldName(lkpJoinIdx));
                lkpJoinFieldsType.add(this.lkpRowType.getFieldCoder(lkpJoinIdx));
            } else {
                throw new UnsupportedOperationException("Operator " + call.getOperator().getName() + " is not supported in join condition");
            }
            this.joinSubsetType = RowType.fromNamesAndCoders(lkpJoinFieldsName, lkpJoinFieldsType);
        }

        public PCollection<Row> expand(PCollection<Row> input) {
            return (PCollection)input.apply("join_as_lookup", (PTransform)ParDo.of((DoFn)new DoFn<Row, Row>(){

                @DoFn.Setup
                public void setup() {
                    seekableTable.setUp();
                }

                @DoFn.ProcessElement
                public void processElement(DoFn.ProcessContext context) {
                    Row factRow = (Row)context.element();
                    Row joinSubRow = this.extractJoinSubRow(factRow);
                    List<Row> lookupRows = seekableTable.seekRow(joinSubRow);
                    for (Row lr : lookupRows) {
                        context.output((Object)BeamJoinTransforms.combineTwoRowsIntoOneHelper(factRow, lr));
                    }
                }

                @DoFn.Teardown
                public void teardown() {
                    seekableTable.tearDown();
                }

                private Row extractJoinSubRow(Row factRow) {
                    List joinSubsetValues = factJoinIdx.stream().map(arg_0 -> ((Row)factRow).getValue(arg_0)).collect(Collectors.toList());
                    return Row.withRowType((RowType)joinSubsetType).addValues(joinSubsetValues).build();
                }
            }));
        }
    }

    public static class JoinParts2WholeRow
    extends SimpleFunction<KV<Row, KV<Row, Row>>, Row> {
        public Row apply(KV<Row, KV<Row, Row>> input) {
            KV parts = (KV)input.getValue();
            Row leftRow = (Row)parts.getKey();
            Row rightRow = (Row)parts.getValue();
            return BeamJoinTransforms.combineTwoRowsIntoOne(leftRow, rightRow, false);
        }
    }

    public static class SideInputJoinDoFn
    extends DoFn<KV<Row, Row>, Row> {
        private final PCollectionView<Map<Row, Iterable<Row>>> sideInputView;
        private final JoinRelType joinType;
        private final Row rightNullRow;
        private final boolean swap;

        public SideInputJoinDoFn(JoinRelType joinType, Row rightNullRow, PCollectionView<Map<Row, Iterable<Row>>> sideInputView, boolean swap) {
            this.joinType = joinType;
            this.rightNullRow = rightNullRow;
            this.sideInputView = sideInputView;
            this.swap = swap;
        }

        @DoFn.ProcessElement
        public void processElement(DoFn.ProcessContext context) {
            Row key = (Row)((KV)context.element()).getKey();
            Row leftRow = (Row)((KV)context.element()).getValue();
            Map key2Rows = (Map)context.sideInput(this.sideInputView);
            Iterable rightRowsIterable = (Iterable)key2Rows.get(key);
            if (rightRowsIterable != null && rightRowsIterable.iterator().hasNext()) {
                for (Row aRightRowsIterable : rightRowsIterable) {
                    context.output((Object)BeamJoinTransforms.combineTwoRowsIntoOne(leftRow, aRightRowsIterable, this.swap));
                }
            } else if (this.joinType == JoinRelType.LEFT) {
                context.output((Object)BeamJoinTransforms.combineTwoRowsIntoOne(leftRow, this.rightNullRow, this.swap));
            }
        }
    }

    public static class ExtractJoinFields
    extends SimpleFunction<Row, KV<Row, Row>> {
        private final List<Integer> joinColumns;

        public ExtractJoinFields(boolean isLeft, List<Pair<Integer, Integer>> joinColumns) {
            this.joinColumns = joinColumns.stream().map(pair -> isLeft ? (Integer)pair.left : (Integer)pair.right).collect(Collectors.toList());
        }

        public KV<Row, Row> apply(Row input) {
            RowType rowType = (RowType)this.joinColumns.stream().map(fieldIndex -> this.toField(input.getRowType(), (Integer)fieldIndex)).collect(RowType.toRowType());
            Row row = (Row)this.joinColumns.stream().map(arg_0 -> ((Row)input).getValue(arg_0)).collect(Row.toRow((RowType)rowType));
            return KV.of((Object)row, (Object)input);
        }

        private RowType.Field toField(RowType rowType, Integer fieldIndex) {
            return RowType.newField((String)("c" + fieldIndex), (Coder)rowType.getFieldCoder(fieldIndex.intValue()));
        }
    }
}

