/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.rules;

import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.com.google.common.collect.ImmutableList;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.com.google.common.collect.ImmutableSet;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.plan.RelOptRule;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.plan.RelOptRuleCall;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.plan.RelOptUtil;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.RelNode;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.core.CorrelationId;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.core.Filter;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.core.Join;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.core.JoinRelType;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.core.Project;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.core.RelFactories;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rex.LogicVisitor;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rex.RexInputRef;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rex.RexNode;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rex.RexShuttle;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rex.RexSubQuery;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rex.RexUtil;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.sql.SqlKind;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.sql.SqlOperator;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.sql.fun.SqlQuantifyOperator;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.tools.RelBuilder;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.tools.RelBuilderFactory;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.util.ImmutableBitSet;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.util.Pair;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.util.Util;

public abstract class SubQueryRemoveRule
extends RelOptRule {
    public static final SubQueryRemoveRule PROJECT = new SubQueryProjectRemoveRule(RelFactories.LOGICAL_BUILDER);
    public static final SubQueryRemoveRule FILTER = new SubQueryFilterRemoveRule(RelFactories.LOGICAL_BUILDER);
    public static final SubQueryRemoveRule JOIN = new SubQueryJoinRemoveRule(RelFactories.LOGICAL_BUILDER);

    public SubQueryRemoveRule(RelOptRuleOperand operand, RelBuilderFactory relBuilderFactory, String description) {
        super(operand, relBuilderFactory, description);
    }

    protected RexNode apply(RexSubQuery e, Set<CorrelationId> variablesSet, RelOptUtil.Logic logic, RelBuilder builder, int inputCount, int offset) {
        switch (e.getKind()) {
            case SCALAR_QUERY: {
                return this.rewriteScalarQuery(e, variablesSet, builder, inputCount, offset);
            }
            case SOME: {
                return this.rewriteSome(e, builder);
            }
            case IN: {
                return this.rewriteIn(e, variablesSet, logic, builder, offset);
            }
            case EXISTS: {
                return this.rewriteExists(e, variablesSet, logic, builder);
            }
        }
        throw new AssertionError((Object)e.getKind());
    }

    private RexNode rewriteScalarQuery(RexSubQuery e, Set<CorrelationId> variablesSet, RelBuilder builder, int inputCount, int offset) {
        builder.push(e.rel);
        RelMetadataQuery mq = e.rel.getCluster().getMetadataQuery();
        Boolean unique = mq.areColumnsUnique(builder.peek(), ImmutableBitSet.of());
        if (unique == null || !unique.booleanValue()) {
            builder.aggregate(builder.groupKey(), builder.aggregateCall(SqlStdOperatorTable.SINGLE_VALUE, false, false, null, null, builder.field(0)));
        }
        builder.join(JoinRelType.LEFT, builder.literal(true), variablesSet);
        return this.field(builder, inputCount, offset);
    }

    private RexNode rewriteSome(RexSubQuery e, RelBuilder builder) {
        SqlQuantifyOperator op = (SqlQuantifyOperator)e.op;
        builder.push(e.rel).aggregate(builder.groupKey(), op.comparisonKind == SqlKind.GREATER_THAN || op.comparisonKind == SqlKind.GREATER_THAN_OR_EQUAL ? builder.min("m", builder.field(0)) : builder.max("m", builder.field(0)), builder.count(false, "c", new RexNode[0]), builder.count(false, "d", builder.field(0))).as("q").join(JoinRelType.INNER, new String[0]);
        return builder.call((SqlOperator)SqlStdOperatorTable.CASE, builder.call((SqlOperator)SqlStdOperatorTable.EQUALS, builder.field("q", "c"), builder.literal(0)), builder.literal(false), builder.call((SqlOperator)SqlStdOperatorTable.IS_TRUE, builder.call(RelOptUtil.op(op.comparisonKind, null), (RexNode)e.operands.get(0), builder.field("q", "m"))), builder.literal(true), builder.call((SqlOperator)SqlStdOperatorTable.GREATER_THAN, builder.field("q", "c"), builder.field("q", "d")), builder.literal(null), builder.call(RelOptUtil.op(op.comparisonKind, null), (RexNode)e.operands.get(0), builder.field("q", "m")));
    }

    private RexNode rewriteExists(RexSubQuery e, Set<CorrelationId> variablesSet, RelOptUtil.Logic logic, RelBuilder builder) {
        builder.push(e.rel);
        builder.project(builder.alias(builder.literal(true), "i"));
        switch (logic) {
            case TRUE: {
                builder.aggregate(builder.groupKey(0), new RelBuilder.AggCall[0]);
                builder.as("dt");
                builder.join(JoinRelType.INNER, builder.literal(true), variablesSet);
                return builder.literal(true);
            }
        }
        builder.distinct();
        builder.as("dt");
        builder.join(JoinRelType.LEFT, builder.literal(true), variablesSet);
        return builder.isNotNull(Util.last(builder.fields()));
    }

    private RexNode rewriteIn(RexSubQuery e, Set<CorrelationId> variablesSet, RelOptUtil.Logic logic, RelBuilder builder, int offset) {
        builder.push(e.rel);
        ArrayList<RexNode> fields = new ArrayList<RexNode>(builder.fields());
        boolean allLiterals = RexUtil.allLiterals(e.getOperands());
        ArrayList<RexNode> expressionOperands = new ArrayList<RexNode>(e.getOperands());
        List keyIsNulls = e.getOperands().stream().filter(operand -> operand.getType().isNullable()).map(builder::isNull).collect(Collectors.toList());
        if (allLiterals) {
            List conditions = Pair.zip(expressionOperands, fields).stream().map(pair -> builder.equals((RexNode)pair.left, (RexNode)pair.right)).collect(Collectors.toList());
            switch (logic) {
                case TRUE: 
                case TRUE_FALSE: {
                    builder.filter(conditions);
                    builder.project(builder.alias(builder.literal(true), "cs"));
                    builder.distinct();
                    break;
                }
                default: {
                    List isNullOpperands = fields.stream().map(builder::isNull).collect(Collectors.toList());
                    isNullOpperands.addAll(keyIsNulls);
                    builder.filter(builder.or(builder.and(conditions), builder.or(isNullOpperands)));
                    RexNode project = builder.and(fields.stream().map(builder::isNotNull).collect(Collectors.toList()));
                    builder.project(builder.alias(project, "cs"));
                    if (variablesSet.isEmpty()) {
                        builder.aggregate(builder.groupKey(builder.field("cs")), builder.count(false, "c", new RexNode[0]));
                        builder.sortLimit(0, 1, ImmutableList.of(builder.call((SqlOperator)SqlStdOperatorTable.DESC, builder.field("cs"))));
                        break;
                    }
                    builder.distinct();
                }
            }
            expressionOperands.clear();
            fields.clear();
        } else {
            switch (logic) {
                case TRUE: {
                    builder.aggregate(builder.groupKey(fields), new RelBuilder.AggCall[0]);
                    break;
                }
                case TRUE_FALSE_UNKNOWN: 
                case UNKNOWN_AS_TRUE: {
                    builder.aggregate(builder.groupKey(), builder.count(false, "c", new RexNode[0]), builder.aggregateCall(SqlStdOperatorTable.COUNT, false, false, null, "ck", builder.fields()));
                    builder.as("ct");
                    if (!variablesSet.isEmpty()) {
                        builder.join(JoinRelType.LEFT, builder.literal(true), variablesSet);
                    } else {
                        builder.join(JoinRelType.INNER, builder.literal(true), variablesSet);
                    }
                    offset += 2;
                    builder.push(e.rel);
                }
                default: {
                    fields.add(builder.alias(builder.literal(true), "i"));
                    builder.project(fields);
                    builder.distinct();
                }
            }
        }
        builder.as("dt");
        int refOffset = offset;
        List conditions = Pair.zip(expressionOperands, builder.fields()).stream().map(pair -> builder.equals((RexNode)pair.left, RexUtil.shift((RexNode)pair.right, refOffset))).collect(Collectors.toList());
        switch (logic) {
            case TRUE: {
                builder.join(JoinRelType.INNER, builder.and(conditions), variablesSet);
                return builder.literal(true);
            }
        }
        builder.join(JoinRelType.LEFT, builder.and(conditions), variablesSet);
        ImmutableList.Builder operands = ImmutableList.builder();
        Boolean b = true;
        switch (logic) {
            case TRUE_FALSE_UNKNOWN: {
                b = null;
            }
            case UNKNOWN_AS_TRUE: {
                if (allLiterals) {
                    if (variablesSet.isEmpty()) {
                        operands.add(new RexNode[]{builder.isNull(builder.field("c")), builder.literal(false)});
                    }
                    operands.add(new RexNode[]{builder.equals(builder.field("cs"), builder.literal(false)), builder.literal(b)});
                    break;
                }
                operands.add(new RexNode[]{builder.equals(builder.field("ct", "c"), builder.literal(0)), builder.literal(false)});
            }
        }
        if (!keyIsNulls.isEmpty()) {
            operands.add(new RexNode[]{builder.or(keyIsNulls), builder.literal(null)});
        }
        if (allLiterals) {
            operands.add(new RexNode[]{builder.isNotNull(builder.field("cs")), builder.literal(true)});
        } else {
            operands.add(new RexNode[]{builder.isNotNull(Util.last(builder.fields())), builder.literal(true)});
        }
        if (!allLiterals) {
            switch (logic) {
                case TRUE_FALSE_UNKNOWN: 
                case UNKNOWN_AS_TRUE: {
                    operands.add(new RexNode[]{builder.call((SqlOperator)SqlStdOperatorTable.LESS_THAN, builder.field("ct", "ck"), builder.field("ct", "c")), builder.literal(b)});
                }
            }
        }
        operands.add(builder.literal(false));
        return builder.call((SqlOperator)SqlStdOperatorTable.CASE, operands.build());
    }

    private RexInputRef field(RelBuilder builder, int inputCount, int offset) {
        int inputOrdinal = 0;
        RelNode r;
        while (offset >= (r = builder.peek(inputCount, inputOrdinal)).getRowType().getFieldCount()) {
            ++inputOrdinal;
            offset -= r.getRowType().getFieldCount();
        }
        return builder.field(inputCount, inputOrdinal, offset);
    }

    private static List<RexNode> fields(RelBuilder builder, int fieldCount) {
        ArrayList<RexNode> projects = new ArrayList<RexNode>();
        for (int i = 0; i < fieldCount; ++i) {
            projects.add(builder.field(i));
        }
        return projects;
    }

    private static class ReplaceSubQueryShuttle
    extends RexShuttle {
        private final RexSubQuery subQuery;
        private final RexNode replacement;

        ReplaceSubQueryShuttle(RexSubQuery subQuery, RexNode replacement) {
            this.subQuery = subQuery;
            this.replacement = replacement;
        }

        @Override
        public RexNode visitSubQuery(RexSubQuery subQuery) {
            return RexUtil.eq(subQuery, this.subQuery) ? this.replacement : subQuery;
        }
    }

    public static class SubQueryJoinRemoveRule
    extends SubQueryRemoveRule {
        public SubQueryJoinRemoveRule(RelBuilderFactory relBuilderFactory) {
            super(SubQueryJoinRemoveRule.operandJ(Join.class, null, RexUtil.SubQueryFinder::containsSubQuery, SubQueryJoinRemoveRule.any()), relBuilderFactory, "SubQueryRemoveRule:Join");
        }

        @Override
        public void onMatch(RelOptRuleCall call) {
            Join join = (Join)call.rel(0);
            RelBuilder builder = call.builder();
            RexSubQuery e = RexUtil.SubQueryFinder.find(join.getCondition());
            assert (e != null);
            RelOptUtil.Logic logic = LogicVisitor.find(RelOptUtil.Logic.TRUE, ImmutableList.of(join.getCondition()), e);
            builder.push(join.getLeft());
            builder.push(join.getRight());
            int fieldCount = join.getRowType().getFieldCount();
            RexNode target = this.apply(e, ImmutableSet.of(), logic, builder, 2, fieldCount);
            ReplaceSubQueryShuttle shuttle = new ReplaceSubQueryShuttle(e, target);
            builder.join(join.getJoinType(), shuttle.apply(join.getCondition()));
            builder.project(SubQueryRemoveRule.fields(builder, join.getRowType().getFieldCount()));
            call.transformTo(builder.build());
        }
    }

    public static class SubQueryFilterRemoveRule
    extends SubQueryRemoveRule {
        public SubQueryFilterRemoveRule(RelBuilderFactory relBuilderFactory) {
            super(SubQueryFilterRemoveRule.operandJ(Filter.class, null, RexUtil.SubQueryFinder::containsSubQuery, SubQueryFilterRemoveRule.any()), relBuilderFactory, "SubQueryRemoveRule:Filter");
        }

        @Override
        public void onMatch(RelOptRuleCall call) {
            Filter filter = (Filter)call.rel(0);
            RelBuilder builder = call.builder();
            builder.push(filter.getInput());
            int count = 0;
            RexNode c = filter.getCondition();
            while (true) {
                RexSubQuery e;
                if ((e = RexUtil.SubQueryFinder.find(c)) == null) {
                    assert (count > 0);
                    break;
                }
                ++count;
                RelOptUtil.Logic logic = LogicVisitor.find(RelOptUtil.Logic.TRUE, ImmutableList.of(c), e);
                Set<CorrelationId> variablesSet = RelOptUtil.getVariablesUsed(e.rel);
                RexNode target = this.apply(e, variablesSet, logic, builder, 1, builder.peek().getRowType().getFieldCount());
                ReplaceSubQueryShuttle shuttle = new ReplaceSubQueryShuttle(e, target);
                c = c.accept(shuttle);
            }
            builder.filter(c);
            builder.project(SubQueryRemoveRule.fields(builder, filter.getRowType().getFieldCount()));
            call.transformTo(builder.build());
        }
    }

    public static class SubQueryProjectRemoveRule
    extends SubQueryRemoveRule {
        public SubQueryProjectRemoveRule(RelBuilderFactory relBuilderFactory) {
            super(SubQueryProjectRemoveRule.operandJ(Project.class, null, RexUtil.SubQueryFinder::containsSubQuery, SubQueryProjectRemoveRule.any()), relBuilderFactory, "SubQueryRemoveRule:Project");
        }

        @Override
        public void onMatch(RelOptRuleCall call) {
            Project project = (Project)call.rel(0);
            RelBuilder builder = call.builder();
            RexSubQuery e = RexUtil.SubQueryFinder.find(project.getProjects());
            assert (e != null);
            RelOptUtil.Logic logic = LogicVisitor.find(RelOptUtil.Logic.TRUE_FALSE_UNKNOWN, project.getProjects(), e);
            builder.push(project.getInput());
            int fieldCount = builder.peek().getRowType().getFieldCount();
            RexNode target = this.apply(e, ImmutableSet.of(), logic, builder, 1, fieldCount);
            ReplaceSubQueryShuttle shuttle = new ReplaceSubQueryShuttle(e, target);
            builder.project(shuttle.apply(project.getProjects()), project.getRowType().getFieldNames());
            call.transformTo(builder.build());
        }
    }
}

