/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.extensions.sql.impl.transform;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.beam.sdk.extensions.sql.BeamSqlSeekableTable;
import org.apache.beam.sdk.extensions.sql.impl.utils.SerializableRexFieldAccess;
import org.apache.beam.sdk.extensions.sql.impl.utils.SerializableRexInputRef;
import org.apache.beam.sdk.extensions.sql.impl.utils.SerializableRexNode;
import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexCall;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexInputRef;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexNode;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.util.Pair;

public class BeamJoinTransforms {
    public static FieldAccessDescriptor getJoinColumns(boolean isLeft, List<Pair<RexNode, RexNode>> joinColumns, int leftRowColumnCount, Schema schema) {
        List joinColumnsBuilt = joinColumns.stream().map(pair -> SerializableRexNode.builder(isLeft ? (RexNode)pair.left : (RexNode)pair.right).build()).collect(Collectors.toList());
        return FieldAccessDescriptor.union((Iterable)joinColumnsBuilt.stream().map(v -> BeamJoinTransforms.getJoinColumn(v, leftRowColumnCount).resolve(schema)).collect(Collectors.toList()));
    }

    private static FieldAccessDescriptor getJoinColumn(SerializableRexNode serializableRexNode, int leftRowColumnCount) {
        if (serializableRexNode instanceof SerializableRexInputRef) {
            SerializableRexInputRef inputRef = (SerializableRexInputRef)serializableRexNode;
            return FieldAccessDescriptor.withFieldIds((Integer[])new Integer[]{inputRef.getIndex() - leftRowColumnCount});
        }
        List<Integer> indexes = ((SerializableRexFieldAccess)serializableRexNode).getIndexes();
        FieldAccessDescriptor fieldAccessDescriptor = FieldAccessDescriptor.withFieldIds((Integer[])new Integer[]{indexes.get(0) - leftRowColumnCount});
        for (int i = 1; i < indexes.size(); ++i) {
            fieldAccessDescriptor = FieldAccessDescriptor.withFieldIds((FieldAccessDescriptor)fieldAccessDescriptor, (Integer[])new Integer[]{indexes.get(i)});
        }
        return fieldAccessDescriptor;
    }

    private static Row combineTwoRowsIntoOne(Row leftRow, Row rightRow, boolean swap, Schema outputSchema) {
        if (swap) {
            return BeamJoinTransforms.combineTwoRowsIntoOneHelper(rightRow, leftRow, outputSchema);
        }
        return BeamJoinTransforms.combineTwoRowsIntoOneHelper(leftRow, rightRow, outputSchema);
    }

    private static Row combineTwoRowsIntoOneHelper(Row leftRow, Row rightRow, Schema ouputSchema) {
        return Row.withSchema((Schema)ouputSchema).addValues(leftRow.getBaseValues()).addValues(rightRow.getBaseValues()).build();
    }

    public static class JoinAsLookup
    extends PTransform<PCollection<Row>, PCollection<Row>> {
        private final BeamSqlSeekableTable seekableTable;
        private final Schema lkpSchema;
        private final int factColOffset;
        private Schema joinSubsetType;
        private final Schema outputSchema;
        private List<Integer> factJoinIdx;

        public JoinAsLookup(RexNode joinCondition, BeamSqlSeekableTable seekableTable, Schema lkpSchema, Schema outputSchema, int factColOffset, int lkpColOffset) {
            this.seekableTable = seekableTable;
            this.lkpSchema = lkpSchema;
            this.outputSchema = outputSchema;
            this.factColOffset = factColOffset;
            this.joinFieldsMapping(joinCondition, factColOffset, lkpColOffset);
        }

        private void joinFieldsMapping(RexNode joinCondition, int factColOffset, int lkpColOffset) {
            this.factJoinIdx = new ArrayList<Integer>();
            ArrayList<Schema.Field> lkpJoinFields = new ArrayList<Schema.Field>();
            RexCall call = (RexCall)joinCondition;
            if ("AND".equals(call.getOperator().getName())) {
                List operands = call.getOperands();
                for (RexNode rexNode : operands) {
                    this.factJoinIdx.add(((RexInputRef)((RexCall)rexNode).getOperands().get(0)).getIndex() - factColOffset);
                    int lkpJoinIdx = ((RexInputRef)((RexCall)rexNode).getOperands().get(1)).getIndex() - lkpColOffset;
                    lkpJoinFields.add(this.lkpSchema.getField(lkpJoinIdx));
                }
            } else if ("=".equals(call.getOperator().getName())) {
                this.factJoinIdx.add(((RexInputRef)call.getOperands().get(0)).getIndex() - factColOffset);
                int lkpJoinIdx = ((RexInputRef)call.getOperands().get(1)).getIndex() - lkpColOffset;
                lkpJoinFields.add(this.lkpSchema.getField(lkpJoinIdx));
            } else {
                throw new UnsupportedOperationException("Operator " + call.getOperator().getName() + " is not supported in join condition");
            }
            this.joinSubsetType = Schema.builder().addFields(lkpJoinFields).build();
        }

        public PCollection<Row> expand(PCollection<Row> input) {
            return ((PCollection)input.apply("join_as_lookup", (PTransform)ParDo.of((DoFn)new DoFn<Row, Row>(){

                @DoFn.Setup
                public void setup() {
                    seekableTable.setUp();
                }

                @DoFn.ProcessElement
                public void processElement(DoFn.ProcessContext context) {
                    Row factRow = (Row)context.element();
                    Row joinSubRow = this.extractJoinSubRow(factRow);
                    List<Row> lookupRows = seekableTable.seekRow(joinSubRow);
                    for (Row lr : lookupRows) {
                        context.output((Object)BeamJoinTransforms.combineTwoRowsIntoOne(factRow, lr, factColOffset != 0, outputSchema));
                    }
                }

                @DoFn.Teardown
                public void teardown() {
                    seekableTable.tearDown();
                }

                private Row extractJoinSubRow(Row factRow) {
                    List joinSubsetValues = factJoinIdx.stream().map(i -> factRow.getBaseValue(i.intValue(), Object.class)).collect(Collectors.toList());
                    return Row.withSchema((Schema)joinSubsetType).addValues(joinSubsetValues).build();
                }
            }))).setRowSchema(this.joinSubsetType);
        }
    }
}

