/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.extensions.sql.impl.rel;

import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.plan.RelOptCluster;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.plan.RelOptPlanner;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.plan.RelTraitSet;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.RelNode;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.Calc;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.logical.LogicalCalc;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.type.RelDataType;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexCall;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexDynamicParam;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexFieldAccess;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexInputRef;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexLiteral;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexLocalRef;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexNode;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexProgram;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexShuttle;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexUtil;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexVisitor;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexVisitorImpl;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.tools.RelBuilder;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.util.Litmus;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.util.Util;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.util.graph.DefaultDirectedGraph;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.util.graph.DirectedGraph;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.util.graph.TopologicalOrderIterator;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.primitives.Ints;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.slf4j.Logger;

public class CalcRelSplitter {
    private static final Logger RULE_LOGGER = RelOptPlanner.LOGGER;
    protected final RexProgram program;
    private final RelDataTypeFactory typeFactory;
    private final RelType[] relTypes;
    private final RelOptCluster cluster;
    private final RelTraitSet traits;
    private final RelNode child;
    protected final RelBuilder relBuilder;

    public CalcRelSplitter(Calc calc, RelBuilder relBuilder, RelType[] relTypes) {
        this.relBuilder = relBuilder;
        for (int i = 0; i < relTypes.length; ++i) {
            assert (relTypes[i] != null);
            for (int j = 0; j < i; ++j) {
                assert (relTypes[i] != relTypes[j]) : "Rel types must be distinct";
            }
        }
        this.program = calc.getProgram();
        this.cluster = calc.getCluster();
        this.traits = calc.getTraitSet();
        this.typeFactory = calc.getCluster().getTypeFactory();
        this.child = calc.getInput();
        this.relTypes = relTypes;
    }

    public RelNode execute() {
        assert (this.program.isValid(Litmus.THROW, null));
        List exprList = this.program.getExprList();
        RexNode[] exprs = exprList.toArray(new RexNode[0]);
        assert (!RexUtil.containComplexExprs((List)exprList));
        int[] exprLevels = new int[exprs.length];
        int[] levelTypeOrdinals = new int[exprs.length];
        int levelCount = this.chooseLevels(exprs, -1, exprLevels, levelTypeOrdinals);
        int[] exprMaxUsingLevelOrdinals = new HighestUsageFinder(exprs, exprLevels).getMaxUsingLevelOrdinals();
        List projectRefList = this.program.getProjectList();
        RexLocalRef conditionRef = this.program.getCondition();
        for (RexLocalRef projectRef : projectRefList) {
            exprMaxUsingLevelOrdinals[projectRef.getIndex()] = levelCount;
        }
        if (conditionRef != null) {
            exprMaxUsingLevelOrdinals[conditionRef.getIndex()] = levelCount;
        }
        if (RULE_LOGGER.isTraceEnabled()) {
            this.traceLevelExpressions(exprs, exprLevels, levelTypeOrdinals, levelCount);
        }
        RelNode rel = this.child;
        int inputFieldCount = this.program.getInputRowType().getFieldCount();
        int[] inputExprOrdinals = CalcRelSplitter.identityArray(inputFieldCount);
        boolean doneCondition = false;
        for (int level = 0; level < levelCount; ++level) {
            int[] projectExprOrdinals;
            RelDataType outputRowType;
            if (level == levelCount - 1) {
                outputRowType = this.program.getOutputRowType();
                projectExprOrdinals = new int[projectRefList.size()];
                for (int i = 0; i < projectExprOrdinals.length; ++i) {
                    projectExprOrdinals[i] = ((RexLocalRef)projectRefList.get(i)).getIndex();
                }
            } else {
                outputRowType = null;
                ArrayList<Integer> projectExprOrdinalList = new ArrayList<Integer>();
                for (int i = 0; i < exprs.length; ++i) {
                    RexNode expr = exprs[i];
                    if (expr instanceof RexLiteral) {
                        exprLevels[i] = -1;
                        continue;
                    }
                    if (exprLevels[i] > level || exprMaxUsingLevelOrdinals[i] <= level) continue;
                    projectExprOrdinalList.add(i);
                }
                projectExprOrdinals = Ints.toArray(projectExprOrdinalList);
            }
            RelType relType = this.relTypes[levelTypeOrdinals[level]];
            int conditionExprOrdinal = -1;
            if (conditionRef != null && !doneCondition) {
                conditionExprOrdinal = conditionRef.getIndex();
                if (exprLevels[conditionExprOrdinal] > level || !relType.supportsCondition()) {
                    conditionExprOrdinal = -1;
                } else {
                    doneCondition = true;
                }
            }
            RexProgram program1 = this.createProgramForLevel(level, levelCount, rel.getRowType(), exprs, exprLevels, inputExprOrdinals, projectExprOrdinals, conditionExprOrdinal, outputRowType);
            rel = relType.makeRel(this.cluster, this.traits, this.relBuilder, rel, program1);
            rel = this.handle(rel);
            inputExprOrdinals = projectExprOrdinals;
        }
        Preconditions.checkArgument((doneCondition || conditionRef == null ? 1 : 0) != 0, (Object)"unhandled condition");
        return rel;
    }

    protected RelNode handle(RelNode rel) {
        return rel;
    }

    private int chooseLevels(RexNode[] exprs, int conditionOrdinal, int[] exprLevels, int[] levelTypeOrdinals) {
        int inputFieldCount = this.program.getInputRowType().getFieldCount();
        int levelCount = 0;
        MaxInputFinder maxInputFinder = new MaxInputFinder(exprLevels);
        boolean[] relTypesPossibleForTopLevel = new boolean[this.relTypes.length];
        Arrays.fill(relTypesPossibleForTopLevel, true);
        List<Set<Integer>> cohorts = this.getCohorts();
        List<Integer> permutation = CalcRelSplitter.computeTopologicalOrdering(exprs, cohorts);
        block0: for (int i : permutation) {
            boolean condition;
            RexNode expr = exprs[i];
            boolean bl = condition = i == conditionOrdinal;
            if (i < inputFieldCount) {
                assert (expr instanceof RexInputRef);
                exprLevels[i] = -1;
                continue;
            }
            int level = maxInputFinder.maxInputFor(expr);
            Set<Integer> cohort = CalcRelSplitter.findCohort(cohorts, i);
            if (cohort != null) {
                for (Integer exprOrdinal : cohort) {
                    RexNode cohortExpr;
                    int cohortLevel;
                    if (exprOrdinal == i || (cohortLevel = maxInputFinder.maxInputFor(cohortExpr = exprs[exprOrdinal])) <= level) continue;
                    level = cohortLevel;
                }
            }
            while (true) {
                if (level >= levelCount) {
                    for (int relTypeOrdinal = 0; relTypeOrdinal < this.relTypes.length; ++relTypeOrdinal) {
                        int j;
                        if (!relTypesPossibleForTopLevel[relTypeOrdinal] || !this.relTypes[relTypeOrdinal].canImplement(expr, condition)) continue;
                        exprLevels[i] = level;
                        levelTypeOrdinals[level] = relTypeOrdinal;
                        assert (level == 0 || levelTypeOrdinals[level - 1] != levelTypeOrdinals[level]) : "successive levels of same type";
                        for (j = 0; j < relTypeOrdinal; ++j) {
                            relTypesPossibleForTopLevel[j] = false;
                        }
                        for (j = relTypeOrdinal + 1; j < this.relTypes.length; ++j) {
                            if (!relTypesPossibleForTopLevel[j]) continue;
                            relTypesPossibleForTopLevel[j] = this.relTypes[j].canImplement(expr, condition);
                        }
                        levelTypeOrdinals[levelCount] = CalcRelSplitter.firstSet(relTypesPossibleForTopLevel);
                        ++levelCount;
                        Arrays.fill(relTypesPossibleForTopLevel, true);
                        continue block0;
                    }
                    if (CalcRelSplitter.count(relTypesPossibleForTopLevel) >= this.relTypes.length) {
                        throw new AssertionError((Object)("cannot implement " + expr));
                    }
                    levelTypeOrdinals[levelCount] = CalcRelSplitter.firstSet(relTypesPossibleForTopLevel);
                    ++levelCount;
                    Arrays.fill(relTypesPossibleForTopLevel, true);
                } else {
                    int levelTypeOrdinal = levelTypeOrdinals[level];
                    if (this.relTypes[levelTypeOrdinal].canImplement(expr, condition)) {
                        exprLevels[i] = level;
                        continue block0;
                    }
                }
                ++level;
            }
        }
        if (levelCount == 0) {
            levelCount = 1;
        }
        return levelCount;
    }

    private static List<Integer> computeTopologicalOrdering(RexNode[] exprs, List<Set<Integer>> cohorts) {
        int i;
        DefaultDirectedGraph graph = DefaultDirectedGraph.create();
        for (i = 0; i < exprs.length; ++i) {
            graph.addVertex((Object)i);
        }
        for (i = 0; i < exprs.length; ++i) {
            RexNode expr = exprs[i];
            Set<Integer> cohort = CalcRelSplitter.findCohort(cohorts, i);
            final Set<Integer> targets = cohort == null ? Collections.singleton(i) : cohort;
            expr.accept((RexVisitor)new RexVisitorImpl<Void>(true, (DirectedGraph)graph){
                final /* synthetic */ DirectedGraph val$graph;
                {
                    this.val$graph = directedGraph;
                    super(deep);
                }

                public Void visitLocalRef(RexLocalRef localRef) {
                    for (Integer target : targets) {
                        this.val$graph.addEdge((Object)localRef.getIndex(), (Object)target);
                    }
                    return null;
                }
            });
        }
        TopologicalOrderIterator iter = new TopologicalOrderIterator((DirectedGraph)graph);
        ArrayList<Integer> permutation = new ArrayList<Integer>();
        while (iter.hasNext()) {
            permutation.add((Integer)iter.next());
        }
        return permutation;
    }

    private static @Nullable Set<Integer> findCohort(List<Set<Integer>> cohorts, int ordinal) {
        for (Set<Integer> cohort : cohorts) {
            if (!cohort.contains(ordinal)) continue;
            return cohort;
        }
        return null;
    }

    private static int[] identityArray(int length) {
        int[] ints = new int[length];
        for (int i = 0; i < ints.length; ++i) {
            ints[i] = i;
        }
        return ints;
    }

    private RexProgram createProgramForLevel(int level, int levelCount, RelDataType inputRowType, RexNode[] allExprs, int[] exprLevels, int[] inputExprOrdinals, int[] projectExprOrdinals, int conditionExprOrdinal, @Nullable RelDataType outputRowType) {
        RexLocalRef conditionRef;
        ArrayList<Object> exprs = new ArrayList<Object>();
        int[] exprInverseOrdinals = new int[allExprs.length];
        Arrays.fill(exprInverseOrdinals, -1);
        int j = 0;
        for (int i = 0; i < inputExprOrdinals.length; ++i) {
            int inputExprOrdinal = inputExprOrdinals[i];
            exprs.add(new RexInputRef(i, allExprs[inputExprOrdinal].getType()));
            exprInverseOrdinals[inputExprOrdinal] = j++;
        }
        InputToCommonExprConverter shuttle = new InputToCommonExprConverter(exprInverseOrdinals, exprLevels, level, inputExprOrdinals, allExprs);
        for (int i = 0; i < allExprs.length; ++i) {
            if (exprLevels[i] != level && (exprLevels[i] != -1 || level != levelCount - 1 || !(allExprs[i] instanceof RexLiteral))) continue;
            RexNode expr = allExprs[i];
            RexNode translatedExpr = (RexNode)expr.accept((RexVisitor)shuttle);
            exprs.add(translatedExpr);
            assert (exprInverseOrdinals[i] == -1);
            exprInverseOrdinals[i] = j++;
        }
        ArrayList<RexLocalRef> projectRefs = new ArrayList<RexLocalRef>(projectExprOrdinals.length);
        ArrayList<String> fieldNames = new ArrayList<String>(projectExprOrdinals.length);
        for (int i = 0; i < projectExprOrdinals.length; ++i) {
            int projectExprOrdinal = projectExprOrdinals[i];
            int index = exprInverseOrdinals[projectExprOrdinal];
            assert (index >= 0);
            RexNode expr = allExprs[projectExprOrdinal];
            projectRefs.add(new RexLocalRef(index, expr.getType()));
            fieldNames.add(this.deriveFieldName(expr, i));
        }
        if (conditionExprOrdinal >= 0) {
            int index = exprInverseOrdinals[conditionExprOrdinal];
            conditionRef = new RexLocalRef(index, allExprs[conditionExprOrdinal].getType());
        } else {
            conditionRef = null;
        }
        if (outputRowType == null) {
            outputRowType = RexUtil.createStructType((RelDataTypeFactory)this.typeFactory, projectRefs, fieldNames, null);
        }
        RexProgram program = new RexProgram(inputRowType, exprs, projectRefs, conditionRef, outputRowType);
        return program;
    }

    private String deriveFieldName(RexNode expr, int ordinal) {
        if (expr instanceof RexInputRef) {
            int inputIndex = ((RexInputRef)expr).getIndex();
            String fieldName = ((RelDataTypeField)this.child.getRowType().getFieldList().get(inputIndex)).getName();
            if (!fieldName.startsWith("$") || fieldName.startsWith("$EXPR")) {
                return fieldName;
            }
        }
        return "$" + ordinal;
    }

    private void traceLevelExpressions(RexNode[] exprs, int[] exprLevels, int[] levelTypeOrdinals, int levelCount) {
        StringWriter traceMsg = new StringWriter();
        PrintWriter traceWriter = new PrintWriter(traceMsg);
        traceWriter.println("FarragoAutoCalcRule result expressions for: ");
        traceWriter.println(this.program.toString());
        for (int level = 0; level < levelCount; ++level) {
            traceWriter.println("Rel Level " + level + ", type " + this.relTypes[levelTypeOrdinals[level]]);
            for (int i = 0; i < exprs.length; ++i) {
                RexNode expr = exprs[i];
                assert (exprLevels[i] >= -1 && exprLevels[i] < levelCount) : "expression's level is out of range";
                if (exprLevels[i] != level) continue;
                traceWriter.println("\t" + i + ": " + expr);
            }
            traceWriter.println();
        }
        String msg = traceMsg.toString();
        RULE_LOGGER.trace(msg);
    }

    private static int count(boolean[] booleans) {
        int count = 0;
        for (boolean b : booleans) {
            if (!b) continue;
            ++count;
        }
        return count;
    }

    private static int firstSet(boolean[] booleans) {
        for (int i = 0; i < booleans.length; ++i) {
            if (!booleans[i]) continue;
            return i;
        }
        return -1;
    }

    private static int indexOf(int value, int[] map) {
        for (int i = 0; i < map.length; ++i) {
            if (value != map[i]) continue;
            return i;
        }
        return -1;
    }

    protected boolean canImplement(LogicalCalc rel, String relTypeName) {
        for (RelType relType : this.relTypes) {
            if (!relType.name.equals(relTypeName)) continue;
            return relType.canImplement(rel.getProgram());
        }
        throw new AssertionError((Object)("unknown type " + relTypeName));
    }

    protected List<Set<Integer>> getCohorts() {
        return Collections.emptyList();
    }

    private static class HighestUsageFinder
    extends RexVisitorImpl<Void> {
        private final int[] maxUsingLevelOrdinals;
        private int currentLevel;

        HighestUsageFinder(RexNode[] exprs, int[] exprLevels) {
            super(true);
            this.maxUsingLevelOrdinals = new int[exprs.length];
            Arrays.fill(this.maxUsingLevelOrdinals, -1);
            for (int i = 0; i < exprs.length; ++i) {
                if (exprs[i] instanceof RexLiteral) {
                    this.maxUsingLevelOrdinals[i] = -1;
                    continue;
                }
                this.currentLevel = exprLevels[i];
                Void void_ = (Void)exprs[i].accept((RexVisitor)this);
            }
        }

        public int[] getMaxUsingLevelOrdinals() {
            return this.maxUsingLevelOrdinals;
        }

        public Void visitLocalRef(RexLocalRef ref) {
            int index = ref.getIndex();
            this.maxUsingLevelOrdinals[index] = Math.max(this.maxUsingLevelOrdinals[index], this.currentLevel);
            return null;
        }
    }

    private static class MaxInputFinder
    extends RexVisitorImpl<Void> {
        int level;
        private final int[] exprLevels;

        MaxInputFinder(int[] exprLevels) {
            super(true);
            this.exprLevels = exprLevels;
        }

        public Void visitLocalRef(RexLocalRef localRef) {
            int inputLevel = this.exprLevels[localRef.getIndex()];
            this.level = Math.max(this.level, inputLevel);
            return null;
        }

        public int maxInputFor(RexNode expr) {
            this.level = 0;
            expr.accept((RexVisitor)this);
            return this.level;
        }
    }

    private static class InputToCommonExprConverter
    extends RexShuttle {
        private final int[] exprInverseOrdinals;
        private final int[] exprLevels;
        private final int level;
        private final int[] inputExprOrdinals;
        private final RexNode[] allExprs;

        InputToCommonExprConverter(int[] exprInverseOrdinals, int[] exprLevels, int level, int[] inputExprOrdinals, RexNode[] allExprs) {
            this.exprInverseOrdinals = exprInverseOrdinals;
            this.exprLevels = exprLevels;
            this.level = level;
            this.inputExprOrdinals = inputExprOrdinals;
            this.allExprs = allExprs;
        }

        public RexNode visitInputRef(RexInputRef input) {
            int index = this.exprInverseOrdinals[input.getIndex()];
            assert (index >= 0);
            return new RexLocalRef(index, input.getType());
        }

        public RexNode visitLocalRef(RexLocalRef local) {
            int localIndex = local.getIndex();
            int exprLevel = this.exprLevels[localIndex];
            if (exprLevel < this.level) {
                if (this.allExprs[localIndex] instanceof RexLiteral) {
                    return this.allExprs[localIndex];
                }
                int inputIndex = CalcRelSplitter.indexOf(localIndex, this.inputExprOrdinals);
                assert (inputIndex >= 0);
                return new RexLocalRef(inputIndex, local.getType());
            }
            int exprIndex = this.exprInverseOrdinals[localIndex];
            return new RexLocalRef(exprIndex, local.getType());
        }
    }

    private static class CannotImplement
    extends RuntimeException {
        static final CannotImplement INSTANCE = new CannotImplement();

        private CannotImplement() {
        }
    }

    private static class ImplementTester
    extends RexVisitorImpl<Void> {
        private final RelType relType;

        ImplementTester(RelType relType) {
            super(false);
            this.relType = relType;
        }

        public Void visitCall(RexCall call) {
            if (!this.relType.canImplement(call)) {
                throw CannotImplement.INSTANCE;
            }
            return null;
        }

        public Void visitDynamicParam(RexDynamicParam dynamicParam) {
            if (!this.relType.canImplement(dynamicParam)) {
                throw CannotImplement.INSTANCE;
            }
            return null;
        }

        public Void visitFieldAccess(RexFieldAccess fieldAccess) {
            if (!this.relType.canImplement(fieldAccess)) {
                throw CannotImplement.INSTANCE;
            }
            return null;
        }

        public Void visitLiteral(RexLiteral literal) {
            if (!this.relType.canImplement(literal)) {
                throw CannotImplement.INSTANCE;
            }
            return null;
        }
    }

    public static abstract class RelType {
        private final String name;

        protected RelType(String name) {
            this.name = name;
        }

        public String toString() {
            return this.name;
        }

        protected abstract boolean canImplement(RexFieldAccess var1);

        protected abstract boolean canImplement(RexDynamicParam var1);

        protected abstract boolean canImplement(RexLiteral var1);

        protected abstract boolean canImplement(RexCall var1);

        protected boolean supportsCondition() {
            return true;
        }

        protected RelNode makeRel(RelOptCluster cluster, RelTraitSet traitSet, RelBuilder relBuilder, RelNode input, RexProgram program) {
            return LogicalCalc.create((RelNode)input, (RexProgram)program);
        }

        public boolean canImplement(RexNode expr, boolean condition) {
            if (condition && !this.supportsCondition()) {
                return false;
            }
            try {
                expr.accept((RexVisitor)new ImplementTester(this));
                return true;
            }
            catch (CannotImplement e) {
                Util.swallow((Throwable)e, null);
                return false;
            }
        }

        public boolean canImplement(RexProgram program) {
            if (program.getCondition() != null && !this.canImplement((RexNode)program.getCondition(), true)) {
                return false;
            }
            for (RexNode expr : program.getExprList()) {
                if (this.canImplement(expr, false)) continue;
                return false;
            }
            return true;
        }
    }
}

