/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.rules;

import java.util.ArrayList;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.com.google.common.collect.ImmutableList;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.linq4j.Ord;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.plan.RelOptRule;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.plan.RelOptRuleCall;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.RelNode;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.core.Aggregate;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.core.AggregateCall;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.core.RelFactories;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.core.Union;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.logical.LogicalUnion;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.metadata.RelMdUtil;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.type.RelDataType;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.sql.SqlAggFunction;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.sql.fun.SqlAnyValueAggFunction;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.sql.fun.SqlCountAggFunction;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.sql.fun.SqlMinMaxAggFunction;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.sql.fun.SqlSumAggFunction;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.tools.RelBuilder;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.tools.RelBuilderFactory;

public class AggregateUnionTransposeRule
extends RelOptRule {
    public static final AggregateUnionTransposeRule INSTANCE = new AggregateUnionTransposeRule(LogicalAggregate.class, LogicalUnion.class, RelFactories.LOGICAL_BUILDER);
    private static final Map<Class<? extends SqlAggFunction>, Boolean> SUPPORTED_AGGREGATES = new IdentityHashMap<Class<? extends SqlAggFunction>, Boolean>();

    public AggregateUnionTransposeRule(Class<? extends Aggregate> aggregateClass, Class<? extends Union> unionClass, RelBuilderFactory relBuilderFactory) {
        super(AggregateUnionTransposeRule.operand(aggregateClass, AggregateUnionTransposeRule.operand(unionClass, AggregateUnionTransposeRule.any()), new RelOptRuleOperand[0]), relBuilderFactory, null);
    }

    @Deprecated
    public AggregateUnionTransposeRule(Class<? extends Aggregate> aggregateClass, RelFactories.AggregateFactory aggregateFactory, Class<? extends Union> unionClass, RelFactories.SetOpFactory setOpFactory) {
        this(aggregateClass, unionClass, RelBuilder.proto(aggregateFactory, setOpFactory));
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        Aggregate aggRel = (Aggregate)call.rel(0);
        Union union = (Union)call.rel(1);
        if (!union.all) {
            return;
        }
        int groupCount = aggRel.getGroupSet().cardinality();
        List<AggregateCall> transformedAggCalls = this.transformAggCalls(aggRel.copy(aggRel.getTraitSet(), aggRel.getInput(), false, aggRel.getGroupSet(), null, aggRel.getAggCallList()), groupCount, aggRel.getAggCallList());
        if (transformedAggCalls == null) {
            return;
        }
        RelBuilder relBuilder = call.builder();
        int transformCount = 0;
        RelMetadataQuery mq = call.getMetadataQuery();
        for (RelNode input : union.getInputs()) {
            boolean alreadyUnique = RelMdUtil.areColumnsDefinitelyUnique(mq, input, aggRel.getGroupSet());
            relBuilder.push(input);
            if (alreadyUnique) continue;
            ++transformCount;
            relBuilder.aggregate(relBuilder.groupKey(aggRel.getGroupSet()), aggRel.getAggCallList());
        }
        if (transformCount == 0) {
            return;
        }
        relBuilder.union(true, union.getInputs().size());
        relBuilder.aggregate(relBuilder.groupKey(aggRel.getGroupSet(), aggRel.getGroupSets()), transformedAggCalls);
        call.transformTo(relBuilder.build());
    }

    private List<AggregateCall> transformAggCalls(RelNode input, int groupCount, List<AggregateCall> origCalls) {
        ArrayList<AggregateCall> newCalls = new ArrayList<AggregateCall>();
        for (Ord<AggregateCall> ord : Ord.zip(origCalls)) {
            RelDataType aggType;
            SqlAggFunction aggFun;
            AggregateCall origCall = (AggregateCall)ord.e;
            if (origCall.isDistinct() || !SUPPORTED_AGGREGATES.containsKey(origCall.getAggregation().getClass())) {
                return null;
            }
            if (origCall.getAggregation() == SqlStdOperatorTable.COUNT) {
                aggFun = SqlStdOperatorTable.SUM0;
                aggType = null;
            } else {
                aggFun = origCall.getAggregation();
                aggType = origCall.getType();
            }
            AggregateCall newCall = AggregateCall.create(aggFun, origCall.isDistinct(), origCall.isApproximate(), ImmutableList.of(Integer.valueOf(groupCount + ord.i)), -1, origCall.collation, groupCount, input, aggType, origCall.getName());
            newCalls.add(newCall);
        }
        return newCalls;
    }

    static {
        SUPPORTED_AGGREGATES.put(SqlMinMaxAggFunction.class, true);
        SUPPORTED_AGGREGATES.put(SqlCountAggFunction.class, true);
        SUPPORTED_AGGREGATES.put(SqlSumAggFunction.class, true);
        SUPPORTED_AGGREGATES.put(SqlSumEmptyIsZeroAggFunction.class, true);
        SUPPORTED_AGGREGATES.put(SqlAnyValueAggFunction.class, true);
    }
}

