/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.rules;

import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.com.google.common.collect.ImmutableList;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.plan.RelOptRule;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.plan.RelOptRuleCall;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.prepare.CalcitePrepareImpl;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.RelNode;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.core.JoinRelType;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.core.RelFactories;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.metadata.RelMdUtil;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.rules.LoptMultiJoin;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.rules.MultiJoin;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rex.RexBuilder;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rex.RexNode;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rex.RexPermuteInputsShuttle;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rex.RexUtil;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.tools.RelBuilder;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.tools.RelBuilderFactory;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.util.ImmutableBitSet;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.util.Pair;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.util.Util;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.util.mapping.Mappings;

public class MultiJoinOptimizeBushyRule
extends RelOptRule {
    public static final MultiJoinOptimizeBushyRule INSTANCE = new MultiJoinOptimizeBushyRule(RelFactories.LOGICAL_BUILDER);
    private final PrintWriter pw = CalcitePrepareImpl.DEBUG ? Util.printWriter(System.out) : null;

    public MultiJoinOptimizeBushyRule(RelBuilderFactory relBuilderFactory) {
        super(MultiJoinOptimizeBushyRule.operand(MultiJoin.class, MultiJoinOptimizeBushyRule.any()), relBuilderFactory, null);
    }

    @Deprecated
    public MultiJoinOptimizeBushyRule(RelFactories.JoinFactory joinFactory, RelFactories.ProjectFactory projectFactory) {
        this(RelBuilder.proto(joinFactory, projectFactory));
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        Object factors;
        MultiJoin multiJoinRel = (MultiJoin)call.rel(0);
        RexBuilder rexBuilder = multiJoinRel.getCluster().getRexBuilder();
        RelBuilder relBuilder = call.builder();
        RelMetadataQuery mq = call.getMetadataQuery();
        LoptMultiJoin multiJoin = new LoptMultiJoin(multiJoinRel);
        final ArrayList<Vertex> vertexes = new ArrayList<Vertex>();
        int x = 0;
        for (int i = 0; i < multiJoin.getNumJoinFactors(); ++i) {
            RelNode rel = multiJoin.getJoinFactor(i);
            double cost = mq.getRowCount(rel);
            vertexes.add(new LeafVertex(i, rel, cost, x));
            x += rel.getRowType().getFieldCount();
        }
        assert (x == multiJoin.getNumTotalFields());
        ArrayList<LoptMultiJoin.Edge> unusedEdges = new ArrayList<LoptMultiJoin.Edge>();
        for (RexNode node : multiJoin.getJoinFilters()) {
            unusedEdges.add(multiJoin.createEdge(node));
        }
        Comparator<LoptMultiJoin.Edge> edgeComparator = new Comparator<LoptMultiJoin.Edge>(){

            @Override
            public int compare(LoptMultiJoin.Edge e0, LoptMultiJoin.Edge e1) {
                return Double.compare(this.rowCountDiff(e0), this.rowCountDiff(e1));
            }

            private double rowCountDiff(LoptMultiJoin.Edge edge) {
                assert (edge.factors.cardinality() == 2) : edge.factors;
                int factor0 = edge.factors.nextSetBit(0);
                int factor1 = edge.factors.nextSetBit(factor0 + 1);
                return Math.abs(((Vertex)vertexes.get((int)factor0)).cost - ((Vertex)vertexes.get((int)factor1)).cost);
            }
        };
        ArrayList<LoptMultiJoin.Edge> usedEdges = new ArrayList<LoptMultiJoin.Edge>();
        block2: while (true) {
            int minorFactor;
            int majorFactor;
            int edgeOrdinal = this.chooseBestEdge(unusedEdges, edgeComparator);
            if (this.pw != null) {
                this.trace(vertexes, unusedEdges, usedEdges, edgeOrdinal, this.pw);
            }
            if (edgeOrdinal == -1) {
                Vertex lastVertex = (Vertex)Util.last(vertexes);
                int z = lastVertex.factors.previousClearBit(lastVertex.id - 1);
                if (z < 0) break;
                factors = new int[]{z, lastVertex.id};
            } else {
                LoptMultiJoin.Edge bestEdge = (LoptMultiJoin.Edge)unusedEdges.get(edgeOrdinal);
                assert (bestEdge.factors.cardinality() == 2);
                factors = bestEdge.factors.toArray();
            }
            if (((Vertex)vertexes.get((int)factors[0])).cost <= ((Vertex)vertexes.get((int)factors[1])).cost) {
                majorFactor = factors[0];
                minorFactor = factors[1];
            } else {
                majorFactor = factors[1];
                minorFactor = factors[0];
            }
            Vertex majorVertex = (Vertex)vertexes.get(majorFactor);
            Vertex minorVertex = (Vertex)vertexes.get(minorFactor);
            int v = vertexes.size();
            ImmutableBitSet newFactors = majorVertex.factors.rebuild().addAll(minorVertex.factors).set(v).build();
            ArrayList<RexNode> conditions = new ArrayList<RexNode>();
            Iterator edgeIterator = unusedEdges.iterator();
            while (edgeIterator.hasNext()) {
                LoptMultiJoin.Edge edge = (LoptMultiJoin.Edge)edgeIterator.next();
                if (!newFactors.contains(edge.factors)) continue;
                conditions.add(edge.condition);
                edgeIterator.remove();
                usedEdges.add(edge);
            }
            double cost = majorVertex.cost * minorVertex.cost * RelMdUtil.guessSelectivity(RexUtil.composeConjunction(rexBuilder, conditions));
            JoinVertex newVertex = new JoinVertex(v, majorFactor, minorFactor, newFactors, cost, ImmutableList.copyOf(conditions));
            vertexes.add(newVertex);
            ImmutableBitSet merged = ImmutableBitSet.of(minorFactor, majorFactor);
            int i = 0;
            while (true) {
                if (i >= unusedEdges.size()) continue block2;
                LoptMultiJoin.Edge edge = (LoptMultiJoin.Edge)unusedEdges.get(i);
                if (edge.factors.intersects(merged)) {
                    ImmutableBitSet newEdgeFactors = edge.factors.rebuild().removeAll(newFactors).set(v).build();
                    assert (newEdgeFactors.cardinality() == 2);
                    LoptMultiJoin.Edge newEdge = new LoptMultiJoin.Edge(edge.condition, newEdgeFactors, edge.columns);
                    unusedEdges.set(i, newEdge);
                }
                ++i;
            }
            break;
        }
        ArrayList<Pair<RelNode, Mappings.TargetMapping>> relNodes = new ArrayList<Pair<RelNode, Mappings.TargetMapping>>();
        factors = vertexes.iterator();
        while (factors.hasNext()) {
            Vertex vertex = (Vertex)factors.next();
            if (vertex instanceof LeafVertex) {
                LeafVertex leafVertex = (LeafVertex)vertex;
                Mappings.TargetMapping mapping = Mappings.offsetSource(Mappings.createIdentity(leafVertex.rel.getRowType().getFieldCount()), leafVertex.fieldOffset, multiJoin.getNumTotalFields());
                relNodes.add(Pair.of(leafVertex.rel, mapping));
            } else {
                JoinVertex joinVertex = (JoinVertex)vertex;
                Pair leftPair = (Pair)relNodes.get(joinVertex.leftFactor);
                RelNode left = (RelNode)leftPair.left;
                Mappings.TargetMapping leftMapping = (Mappings.TargetMapping)leftPair.right;
                Pair rightPair = (Pair)relNodes.get(joinVertex.rightFactor);
                RelNode right = (RelNode)rightPair.left;
                Mappings.TargetMapping rightMapping = (Mappings.TargetMapping)rightPair.right;
                Mappings.TargetMapping mapping = Mappings.merge(leftMapping, Mappings.offsetTarget(rightMapping, left.getRowType().getFieldCount()));
                if (this.pw != null) {
                    this.pw.println("left: " + leftMapping);
                    this.pw.println("right: " + rightMapping);
                    this.pw.println("combined: " + mapping);
                    this.pw.println();
                }
                RexPermuteInputsShuttle shuttle = new RexPermuteInputsShuttle(mapping, left, right);
                RexNode condition = RexUtil.composeConjunction(rexBuilder, joinVertex.conditions);
                RelNode join = relBuilder.push(left).push(right).join(JoinRelType.INNER, condition.accept(shuttle)).build();
                relNodes.add(Pair.of(join, mapping));
            }
            if (this.pw == null) continue;
            this.pw.println(Util.last(relNodes));
        }
        Pair top = (Pair)Util.last(relNodes);
        relBuilder.push((RelNode)top.left).project(relBuilder.fields((Mappings.TargetMapping)top.right));
        call.transformTo(relBuilder.build());
    }

    private void trace(List<Vertex> vertexes, List<LoptMultiJoin.Edge> unusedEdges, List<LoptMultiJoin.Edge> usedEdges, int edgeOrdinal, PrintWriter pw) {
        pw.println("bestEdge: " + edgeOrdinal);
        pw.println("vertexes:");
        for (Vertex vertex : vertexes) {
            pw.println(vertex);
        }
        pw.println("unused edges:");
        for (LoptMultiJoin.Edge edge : unusedEdges) {
            pw.println(edge);
        }
        pw.println("edges:");
        for (LoptMultiJoin.Edge edge : usedEdges) {
            pw.println(edge);
        }
        pw.println();
        pw.flush();
    }

    int chooseBestEdge(List<LoptMultiJoin.Edge> edges, Comparator<LoptMultiJoin.Edge> comparator) {
        return MultiJoinOptimizeBushyRule.minPos(edges, comparator);
    }

    static <E> int minPos(List<E> list, Comparator<E> fn) {
        if (list.isEmpty()) {
            return -1;
        }
        E eBest = list.get(0);
        int iBest = 0;
        for (int i = 1; i < list.size(); ++i) {
            E e = list.get(i);
            if (fn.compare(e, eBest) >= 0) continue;
            eBest = e;
            iBest = i;
        }
        return iBest;
    }

    static class JoinVertex
    extends Vertex {
        private final int leftFactor;
        private final int rightFactor;
        final ImmutableList<RexNode> conditions;

        JoinVertex(int id, int leftFactor, int rightFactor, ImmutableBitSet factors, double cost, ImmutableList<RexNode> conditions) {
            super(id, factors, cost);
            this.leftFactor = leftFactor;
            this.rightFactor = rightFactor;
            this.conditions = Objects.requireNonNull(conditions);
        }

        public String toString() {
            return "JoinVertex(id: " + this.id + ", cost: " + Util.human(this.cost) + ", factors: " + this.factors + ", leftFactor: " + this.leftFactor + ", rightFactor: " + this.rightFactor + ")";
        }
    }

    static class LeafVertex
    extends Vertex {
        private final RelNode rel;
        final int fieldOffset;

        LeafVertex(int id, RelNode rel, double cost, int fieldOffset) {
            super(id, ImmutableBitSet.of(id), cost);
            this.rel = rel;
            this.fieldOffset = fieldOffset;
        }

        public String toString() {
            return "LeafVertex(id: " + this.id + ", cost: " + Util.human(this.cost) + ", factors: " + this.factors + ", fieldOffset: " + this.fieldOffset + ")";
        }
    }

    static abstract class Vertex {
        final int id;
        protected final ImmutableBitSet factors;
        final double cost;

        Vertex(int id, ImmutableBitSet factors, double cost) {
            this.id = id;
            this.factors = factors;
            this.cost = cost;
        }
    }
}

