/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.rules;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.TreeMap;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.plan.RelOptPredicateList;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.plan.RelOptRule;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.plan.RelOptRuleCall;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.RelNode;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.core.Aggregate;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.core.AggregateCall;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.core.RelFactories;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.logical.LogicalProject;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.type.RelDataType;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rex.RexBuilder;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rex.RexInputRef;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rex.RexNode;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.tools.RelBuilder;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.tools.RelBuilderFactory;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.util.ImmutableBitSet;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.util.Pair;

public class AggregateProjectPullUpConstantsRule
extends RelOptRule {
    public static final AggregateProjectPullUpConstantsRule INSTANCE = new AggregateProjectPullUpConstantsRule(LogicalAggregate.class, LogicalProject.class, RelFactories.LOGICAL_BUILDER, "AggregateProjectPullUpConstantsRule");
    public static final AggregateProjectPullUpConstantsRule INSTANCE2 = new AggregateProjectPullUpConstantsRule(LogicalAggregate.class, RelNode.class, RelFactories.LOGICAL_BUILDER, "AggregatePullUpConstantsRule");

    public AggregateProjectPullUpConstantsRule(Class<? extends Aggregate> aggregateClass, Class<? extends RelNode> inputClass, RelBuilderFactory relBuilderFactory, String description) {
        super(AggregateProjectPullUpConstantsRule.operandJ(aggregateClass, null, Aggregate::isSimple, AggregateProjectPullUpConstantsRule.operand(inputClass, AggregateProjectPullUpConstantsRule.any()), new RelOptRuleOperand[0]), relBuilderFactory, description);
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        Aggregate aggregate = (Aggregate)call.rel(0);
        Object input = call.rel(1);
        assert (!aggregate.indicator) : "predicate ensured no grouping sets";
        int groupCount = aggregate.getGroupCount();
        if (groupCount == 1) {
            return;
        }
        RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        RelMetadataQuery mq = call.getMetadataQuery();
        RelOptPredicateList predicates = mq.getPulledUpPredicates(aggregate.getInput());
        if (predicates == null) {
            return;
        }
        TreeMap<Integer, RexNode> map = new TreeMap<Integer, RexNode>();
        for (int key : aggregate.getGroupSet()) {
            RexInputRef ref = rexBuilder.makeInputRef(aggregate.getInput(), key);
            if (!predicates.constantMap.containsKey(ref)) continue;
            map.put(key, predicates.constantMap.get(ref));
        }
        if (map.isEmpty()) {
            return;
        }
        if (groupCount == map.size()) {
            map.remove(map.navigableKeySet().first());
        }
        ImmutableBitSet newGroupSet = aggregate.getGroupSet();
        Iterator key = map.keySet().iterator();
        while (key.hasNext()) {
            int key2 = (Integer)key.next();
            newGroupSet = newGroupSet.clear(key2);
        }
        int newGroupCount = newGroupSet.cardinality();
        RelBuilder relBuilder = call.builder();
        relBuilder.push((RelNode)input);
        ArrayList<AggregateCall> newAggCalls = new ArrayList<AggregateCall>();
        for (AggregateCall aggCall : aggregate.getAggCallList()) {
            newAggCalls.add(aggCall.adaptTo((RelNode)input, aggCall.getArgList(), aggCall.filterArg, groupCount, newGroupCount));
        }
        relBuilder.aggregate(relBuilder.groupKey(newGroupSet), (List<AggregateCall>)newAggCalls);
        ArrayList<Pair<RexInputRef, String>> projects = new ArrayList<Pair<RexInputRef, String>>();
        int source = 0;
        for (RelDataTypeField field : aggregate.getRowType().getFieldList()) {
            RexNode expr;
            int i = field.getIndex();
            if (i >= groupCount) {
                expr = relBuilder.field(i - map.size());
            } else {
                int pos = aggregate.getGroupSet().nth(i);
                if (map.containsKey(pos)) {
                    RelDataType originalType = aggregate.getRowType().getFieldList().get(projects.size()).getType();
                    expr = !originalType.equals(((RexNode)map.get(pos)).getType()) ? rexBuilder.makeCast(originalType, (RexNode)map.get(pos), true) : (RexNode)map.get(pos);
                } else {
                    expr = relBuilder.field(source);
                    ++source;
                }
            }
            projects.add(Pair.of(expr, field.getName()));
        }
        relBuilder.project(Pair.left(projects), Pair.right(projects));
        call.transformTo(relBuilder.build());
    }
}

