/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.extensions.sql.impl.rel;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.extensions.sql.BeamSqlSeekableTable;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamCostModel;
import org.apache.beam.sdk.extensions.sql.impl.planner.NodeStats;
import org.apache.beam.sdk.extensions.sql.impl.rel.BeamIOSourceRel;
import org.apache.beam.sdk.extensions.sql.impl.rel.BeamRelNode;
import org.apache.beam.sdk.extensions.sql.impl.rel.BeamSqlRelUtils;
import org.apache.beam.sdk.extensions.sql.impl.transform.BeamJoinTransforms;
import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
import org.apache.beam.sdk.extensions.sql.meta.BeamSqlTable;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaCoder;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionList;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.base.Optional;
import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelOptCluster;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelOptPlanner;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelTraitSet;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.volcano.RelSubset;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.RelNode;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.core.CorrelationId;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.core.Join;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.core.JoinRelType;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexCall;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexFieldAccess;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexInputRef;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexLiteral;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexNode;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.util.Pair;

public abstract class BeamJoinRel
extends Join
implements BeamRelNode {
    protected BeamJoinRel(RelOptCluster cluster, RelTraitSet traits, RelNode left, RelNode right, RexNode condition, Set<CorrelationId> variablesSet, JoinRelType joinType) {
        super(cluster, traits, left, right, condition, variablesSet, joinType);
    }

    @Override
    public List<RelNode> getPCollectionInputs() {
        if (this.isSideInputLookupJoin()) {
            return ImmutableList.of((Object)BeamSqlRelUtils.getBeamRelInput((RelNode)this.getInputs().get((Integer)this.nonSeekableInputIndex().get())));
        }
        return BeamRelNode.super.getPCollectionInputs();
    }

    protected boolean isSideInputLookupJoin() {
        return this.seekableInputIndex().isPresent() && this.nonSeekableInputIndex().isPresent();
    }

    protected Optional<Integer> seekableInputIndex() {
        BeamRelNode leftRelNode = BeamSqlRelUtils.getBeamRelInput(this.left);
        BeamRelNode rightRelNode = BeamSqlRelUtils.getBeamRelInput(this.right);
        return BeamJoinRel.seekable(leftRelNode) ? Optional.of((Object)0) : (BeamJoinRel.seekable(rightRelNode) ? Optional.of((Object)1) : Optional.absent());
    }

    protected Optional<Integer> nonSeekableInputIndex() {
        BeamRelNode leftRelNode = BeamSqlRelUtils.getBeamRelInput(this.left);
        BeamRelNode rightRelNode = BeamSqlRelUtils.getBeamRelInput(this.right);
        return !BeamJoinRel.seekable(leftRelNode) ? Optional.of((Object)0) : (!BeamJoinRel.seekable(rightRelNode) ? Optional.of((Object)1) : Optional.absent());
    }

    public static boolean seekable(BeamRelNode relNode) {
        BeamIOSourceRel srcRel;
        BeamSqlTable sourceTable;
        return relNode instanceof BeamIOSourceRel && (sourceTable = (srcRel = (BeamIOSourceRel)relNode).getBeamSqlTable()) instanceof BeamSqlSeekableTable;
    }

    @Override
    public BeamCostModel beamComputeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) {
        NodeStats leftEstimates = BeamSqlRelUtils.getNodeStats(this.left, mq);
        NodeStats rightEstimates = BeamSqlRelUtils.getNodeStats(this.right, mq);
        NodeStats selfEstimates = BeamSqlRelUtils.getNodeStats(this, mq);
        NodeStats summation = selfEstimates.plus(leftEstimates).plus(rightEstimates);
        return BeamCostModel.FACTORY.makeCost(summation.getRowCount(), summation.getRate());
    }

    @Override
    public NodeStats estimateNodeStats(RelMetadataQuery mq) {
        double selectivity = mq.getSelectivity((RelNode)this, this.getCondition());
        NodeStats leftEstimates = BeamSqlRelUtils.getNodeStats(this.left, mq);
        NodeStats rightEstimates = BeamSqlRelUtils.getNodeStats(this.right, mq);
        if (leftEstimates.isUnknown() || rightEstimates.isUnknown()) {
            return NodeStats.UNKNOWN;
        }
        return NodeStats.create(leftEstimates.getRowCount() * rightEstimates.getRowCount() * selectivity, (leftEstimates.getRate() * rightEstimates.getWindow() + rightEstimates.getRate() * leftEstimates.getWindow()) * selectivity, leftEstimates.getWindow() * rightEstimates.getWindow() * selectivity);
    }

    public static boolean isJoinLegal(Join join) {
        try {
            BeamJoinRel.extractJoinRexNodes(join.getCondition());
        }
        catch (UnsupportedOperationException e) {
            return false;
        }
        return true;
    }

    protected Schema buildNullSchema(Schema schema) {
        Schema.Builder builder = Schema.builder();
        builder.addFields(schema.getFields().stream().map(f -> f.withNullable(true)).collect(Collectors.toList()));
        return builder.build();
    }

    protected static <K, V> PCollection<KV<K, V>> setValueCoder(PCollection<KV<K, V>> kvs, Coder<V> valueCoder) {
        KvCoder coder = (KvCoder)kvs.getCoder();
        return kvs.setCoder((Coder)KvCoder.of((Coder)coder.getKeyCoder(), valueCoder));
    }

    private static Schema.Field getFieldBasedOnRexNode(Schema schema, RexNode rexNode, int leftRowColumnCount) {
        if (rexNode instanceof RexInputRef) {
            return schema.getField(((RexInputRef)rexNode).getIndex() - leftRowColumnCount);
        }
        if (rexNode instanceof RexFieldAccess) {
            return BeamJoinRel.getFieldBasedOnRexFieldAccess(schema, (RexFieldAccess)rexNode, leftRowColumnCount);
        }
        throw new UnsupportedOperationException("Does not support " + rexNode.getType() + " in JOIN.");
    }

    private static Schema.Field getFieldBasedOnRexFieldAccess(Schema schema, RexFieldAccess rexFieldAccess, int leftRowColumnCount) {
        ArrayDeque<RexFieldAccess> fieldAccessStack = new ArrayDeque<RexFieldAccess>();
        fieldAccessStack.push(rexFieldAccess);
        RexFieldAccess curr = rexFieldAccess;
        while (curr.getReferenceExpr() instanceof RexFieldAccess) {
            curr = (RexFieldAccess)curr.getReferenceExpr();
            fieldAccessStack.push(curr);
        }
        if (!(curr.getReferenceExpr() instanceof RexInputRef)) {
            throw new UnsupportedOperationException("Does not support " + curr.getReferenceExpr().getType() + " in JOIN.");
        }
        RexInputRef inputRef = (RexInputRef)curr.getReferenceExpr();
        Schema.Field curField = schema.getField(inputRef.getIndex() - leftRowColumnCount);
        while (fieldAccessStack.size() > 0) {
            curr = (RexFieldAccess)fieldAccessStack.pop();
            curField = curField.getType().getRowSchema().getField(curr.getField().getIndex());
        }
        return curField;
    }

    static List<Pair<RexNode, RexNode>> extractJoinRexNodes(RexNode condition) {
        if (condition instanceof RexLiteral && ((Boolean)((RexLiteral)condition).getValue()).booleanValue()) {
            throw new UnsupportedOperationException("CROSS JOIN is not supported!");
        }
        RexCall call = (RexCall)condition;
        ArrayList<Pair<RexNode, RexNode>> pairs = new ArrayList<Pair<RexNode, RexNode>>();
        if ("AND".equals(call.getOperator().getName())) {
            List operands = call.getOperands();
            for (RexNode rexNode : operands) {
                Pair<RexNode, RexNode> pair = BeamJoinRel.extractJoinPairOfRexNodes((RexCall)rexNode);
                pairs.add(pair);
            }
        } else if ("=".equals(call.getOperator().getName())) {
            pairs.add(BeamJoinRel.extractJoinPairOfRexNodes(call));
        } else {
            throw new UnsupportedOperationException("Operator " + call.getOperator().getName() + " is not supported in join condition");
        }
        return pairs;
    }

    private static Pair<RexNode, RexNode> extractJoinPairOfRexNodes(RexCall rexCall) {
        int rightIndex;
        if (!rexCall.getOperator().getName().equals("=")) {
            throw new UnsupportedOperationException("Non equi-join is not supported");
        }
        if (BeamJoinRel.isIllegalJoinConjunctionClause(rexCall)) {
            throw new UnsupportedOperationException("Only support column reference or struct field access in conjunction clause");
        }
        int leftIndex = BeamJoinRel.getColumnIndex((RexNode)rexCall.getOperands().get(0));
        if (leftIndex < (rightIndex = BeamJoinRel.getColumnIndex((RexNode)rexCall.getOperands().get(1)))) {
            return new Pair((Object)((RexNode)rexCall.getOperands().get(0)), (Object)((RexNode)rexCall.getOperands().get(1)));
        }
        return new Pair((Object)((RexNode)rexCall.getOperands().get(1)), (Object)((RexNode)rexCall.getOperands().get(0)));
    }

    private static boolean isIllegalJoinConjunctionClause(RexCall rexCall) {
        return !(rexCall.getOperands().get(0) instanceof RexInputRef) && !(rexCall.getOperands().get(0) instanceof RexFieldAccess) || !(rexCall.getOperands().get(1) instanceof RexInputRef) && !(rexCall.getOperands().get(1) instanceof RexFieldAccess);
    }

    private static int getColumnIndex(RexNode rexNode) {
        if (rexNode instanceof RexInputRef) {
            return ((RexInputRef)rexNode).getIndex();
        }
        if (rexNode instanceof RexFieldAccess) {
            return BeamJoinRel.getColumnIndex(((RexFieldAccess)rexNode).getReferenceExpr());
        }
        throw new UnsupportedOperationException("Cannot get column index from " + rexNode.getType());
    }

    public static PCollection.IsBounded getBoundednessOfRelNode(RelNode relNode) {
        if (relNode instanceof BeamRelNode) {
            return ((BeamRelNode)relNode).isBounded();
        }
        ArrayList<PCollection.IsBounded> boundednessOfInputs = new ArrayList<PCollection.IsBounded>();
        for (RelNode inputRel : relNode.getInputs()) {
            if (inputRel instanceof RelSubset) {
                RelNode rel = ((RelSubset)inputRel).getBest();
                if (rel == null) {
                    rel = (RelNode)((RelSubset)inputRel).getRelList().get(0);
                }
                boundednessOfInputs.add(BeamJoinRel.getBoundednessOfRelNode(rel));
                continue;
            }
            boundednessOfInputs.add(BeamJoinRel.getBoundednessOfRelNode(inputRel));
        }
        return boundednessOfInputs.contains(PCollection.IsBounded.UNBOUNDED) ? PCollection.IsBounded.UNBOUNDED : PCollection.IsBounded.BOUNDED;
    }

    public static boolean containsSeekableInput(RelNode relNode) {
        for (RelNode relInput : relNode.getInputs()) {
            if (relInput instanceof RelSubset) {
                relInput = ((RelSubset)relInput).getBest();
            }
            if (relInput == null || !(relInput instanceof BeamRelNode) || !BeamJoinRel.seekable((BeamRelNode)relInput)) continue;
            return true;
        }
        return false;
    }

    protected class ExtractJoinKeys
    extends PTransform<PCollectionList<Row>, PCollectionList<KV<Row, Row>>> {
        protected ExtractJoinKeys() {
        }

        public PCollectionList<KV<Row, Row>> expand(PCollectionList<Row> pinput) {
            BeamRelNode leftRelNode = BeamSqlRelUtils.getBeamRelInput(BeamJoinRel.this.left);
            Schema leftSchema = CalciteUtils.toSchema(BeamJoinRel.this.left.getRowType());
            Schema rightSchema = CalciteUtils.toSchema(BeamJoinRel.this.right.getRowType());
            assert (pinput.size() == 2);
            PCollection leftRows = pinput.get(0);
            PCollection rightRows = pinput.get(1);
            int leftRowColumnCount = leftRelNode.getRowType().getFieldCount();
            List<Pair<RexNode, RexNode>> pairs = BeamJoinRel.extractJoinRexNodes(BeamJoinRel.this.condition);
            Schema extractKeySchemaLeft = (Schema)pairs.stream().map(pair -> BeamJoinRel.getFieldBasedOnRexNode(leftSchema, (RexNode)pair.getKey(), 0)).collect(Schema.toSchema());
            Schema extractKeySchemaRight = (Schema)pairs.stream().map(pair -> BeamJoinRel.getFieldBasedOnRexNode(rightSchema, (RexNode)pair.getValue(), leftRowColumnCount)).collect(Schema.toSchema());
            SchemaCoder extractKeyRowCoder = SchemaCoder.of((Schema)extractKeySchemaLeft);
            PCollection extractedLeftRows = ((PCollection)((PCollection)leftRows.apply("left_TimestampCombiner", (PTransform)Window.configure().withTimestampCombiner(TimestampCombiner.EARLIEST))).apply("left_ExtractJoinFields", (PTransform)MapElements.via((SimpleFunction)new BeamJoinTransforms.ExtractJoinFields(true, pairs, extractKeySchemaLeft, 0)))).setCoder((Coder)KvCoder.of((Coder)extractKeyRowCoder, (Coder)leftRows.getCoder()));
            PCollection extractedRightRows = ((PCollection)((PCollection)rightRows.apply("right_TimestampCombiner", (PTransform)Window.configure().withTimestampCombiner(TimestampCombiner.EARLIEST))).apply("right_ExtractJoinFields", (PTransform)MapElements.via((SimpleFunction)new BeamJoinTransforms.ExtractJoinFields(false, pairs, extractKeySchemaRight, leftRowColumnCount)))).setCoder((Coder)KvCoder.of((Coder)extractKeyRowCoder, (Coder)rightRows.getCoder()));
            return PCollectionList.of((PCollection)extractedLeftRows).and(extractedRightRows);
        }
    }
}

