/*
 * Decompiled with CFR 0.152.
 */
package com.hazelcast.jet.sql.impl.opt.physical;

import com.hazelcast.function.BiConsumerEx;
import com.hazelcast.function.FunctionEx;
import com.hazelcast.function.SupplierEx;
import com.hazelcast.jet.aggregate.AggregateOperation;
import com.hazelcast.jet.sql.impl.aggregate.AvgSqlAggregations;
import com.hazelcast.jet.sql.impl.aggregate.CountSqlAggregations;
import com.hazelcast.jet.sql.impl.aggregate.MaxSqlAggregation;
import com.hazelcast.jet.sql.impl.aggregate.MinSqlAggregation;
import com.hazelcast.jet.sql.impl.aggregate.SqlAggregation;
import com.hazelcast.jet.sql.impl.aggregate.SumSqlAggregations;
import com.hazelcast.jet.sql.impl.aggregate.ValueSqlAggregation;
import com.hazelcast.jet.sql.impl.opt.JetConventions;
import com.hazelcast.jet.sql.impl.opt.OptUtils;
import com.hazelcast.jet.sql.impl.opt.logical.AggregateLogicalRel;
import com.hazelcast.jet.sql.impl.opt.physical.AggregateAccumulateByKeyPhysicalRel;
import com.hazelcast.jet.sql.impl.opt.physical.AggregateAccumulatePhysicalRel;
import com.hazelcast.jet.sql.impl.opt.physical.AggregateByKeyPhysicalRel;
import com.hazelcast.jet.sql.impl.opt.physical.AggregateCombineByKeyPhysicalRel;
import com.hazelcast.jet.sql.impl.opt.physical.AggregateCombinePhysicalRel;
import com.hazelcast.jet.sql.impl.opt.physical.AggregatePhysicalRel;
import com.hazelcast.sql.impl.QueryException;
import com.hazelcast.sql.impl.type.QueryDataType;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.plan.RelTrait;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.util.ImmutableBitSet;

final class AggregatePhysicalRule
extends RelOptRule {
    static final RelOptRule INSTANCE = new AggregatePhysicalRule();

    private AggregatePhysicalRule() {
        super(AggregatePhysicalRule.operand(AggregateLogicalRel.class, (RelTrait)JetConventions.LOGICAL, (RelOptRuleOperandChildren)AggregatePhysicalRule.some((RelOptRuleOperand)AggregatePhysicalRule.operand(RelNode.class, (RelOptRuleOperandChildren)AggregatePhysicalRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[0])), AggregatePhysicalRule.class.getSimpleName());
    }

    public void onMatch(RelOptRuleCall call) {
        AggregateLogicalRel logicalAggregate = (AggregateLogicalRel)call.rel(0);
        RelNode input = logicalAggregate.getInput();
        assert (logicalAggregate.getGroupType() == Aggregate.Group.SIMPLE);
        RelNode convertedInput = OptUtils.toPhysicalInput(input);
        Collection<RelNode> transformedInputs = OptUtils.extractPhysicalRelsFromSubset(convertedInput);
        for (RelNode transformedInput : transformedInputs) {
            call.transformTo(AggregatePhysicalRule.optimize(logicalAggregate, transformedInput));
        }
    }

    private static RelNode optimize(AggregateLogicalRel logicalAggregate, RelNode physicalInput) {
        return logicalAggregate.getGroupSet().cardinality() == 0 ? AggregatePhysicalRule.toAggregate(logicalAggregate, physicalInput) : AggregatePhysicalRule.toAggregateByKey(logicalAggregate, physicalInput);
    }

    private static RelNode toAggregate(AggregateLogicalRel logicalAggregate, RelNode physicalInput) {
        AggregateOperation<?, Object[]> aggrOp = AggregatePhysicalRule.aggregateOperation(physicalInput.getRowType(), logicalAggregate.getGroupSet(), logicalAggregate.getAggCallList());
        if (logicalAggregate.containsDistinctCall()) {
            return new AggregatePhysicalRel(physicalInput.getCluster(), physicalInput.getTraitSet(), physicalInput, logicalAggregate.getGroupSet(), (List<ImmutableBitSet>)logicalAggregate.getGroupSets(), logicalAggregate.getAggCallList(), aggrOp);
        }
        AggregateAccumulatePhysicalRel rel = new AggregateAccumulatePhysicalRel(physicalInput.getCluster(), physicalInput.getTraitSet(), physicalInput, aggrOp);
        return new AggregateCombinePhysicalRel(rel.getCluster(), rel.getTraitSet(), rel, logicalAggregate.getGroupSet(), (List<ImmutableBitSet>)logicalAggregate.getGroupSets(), logicalAggregate.getAggCallList(), aggrOp);
    }

    private static RelNode toAggregateByKey(AggregateLogicalRel logicalAggregate, RelNode physicalInput) {
        AggregateOperation<?, Object[]> aggrOp = AggregatePhysicalRule.aggregateOperation(physicalInput.getRowType(), logicalAggregate.getGroupSet(), logicalAggregate.getAggCallList());
        if (logicalAggregate.containsDistinctCall()) {
            return new AggregateByKeyPhysicalRel(physicalInput.getCluster(), physicalInput.getTraitSet(), physicalInput, logicalAggregate.getGroupSet(), (List<ImmutableBitSet>)logicalAggregate.getGroupSets(), logicalAggregate.getAggCallList(), aggrOp);
        }
        AggregateAccumulateByKeyPhysicalRel rel = new AggregateAccumulateByKeyPhysicalRel(physicalInput.getCluster(), physicalInput.getTraitSet(), physicalInput, logicalAggregate.getGroupSet(), aggrOp);
        return new AggregateCombineByKeyPhysicalRel(rel.getCluster(), rel.getTraitSet(), rel, logicalAggregate.getGroupSet(), (List<ImmutableBitSet>)logicalAggregate.getGroupSets(), logicalAggregate.getAggCallList(), aggrOp);
    }

    private static AggregateOperation<?, Object[]> aggregateOperation(RelDataType inputType, ImmutableBitSet groupSet, List<AggregateCall> aggregateCalls) {
        List operandTypes = OptUtils.schema(inputType).getTypes();
        ArrayList<Object> aggregationProviders = new ArrayList<Object>();
        ArrayList<Object> valueProviders = new ArrayList<Object>();
        for (Integer groupIndex : groupSet.toList()) {
            aggregationProviders.add(ValueSqlAggregation::new);
            valueProviders.add((FunctionEx & Serializable)row -> row[groupIndex]);
        }
        block8: for (AggregateCall aggregateCall : aggregateCalls) {
            boolean distinct = aggregateCall.isDistinct();
            List aggregateCallArguments = aggregateCall.getArgList();
            SqlKind kind = aggregateCall.getAggregation().getKind();
            switch (kind) {
                case COUNT: {
                    int countIndex;
                    if (distinct) {
                        countIndex = (Integer)aggregateCallArguments.get(0);
                        aggregationProviders.add((SupplierEx & Serializable)() -> CountSqlAggregations.from(true, true));
                        valueProviders.add((FunctionEx & Serializable)row -> row[countIndex]);
                        continue block8;
                    }
                    if (aggregateCallArguments.size() == 1) {
                        countIndex = (Integer)aggregateCallArguments.get(0);
                        aggregationProviders.add((SupplierEx & Serializable)() -> CountSqlAggregations.from(true, false));
                        valueProviders.add((FunctionEx & Serializable)row -> row[countIndex]);
                        continue block8;
                    }
                    aggregationProviders.add((SupplierEx & Serializable)() -> CountSqlAggregations.from(false, false));
                    valueProviders.add((FunctionEx & Serializable)row -> null);
                    continue block8;
                }
                case MIN: {
                    int minIndex = (Integer)aggregateCallArguments.get(0);
                    aggregationProviders.add(MinSqlAggregation::new);
                    valueProviders.add((FunctionEx & Serializable)row -> row[minIndex]);
                    continue block8;
                }
                case MAX: {
                    int maxIndex = (Integer)aggregateCallArguments.get(0);
                    aggregationProviders.add(MaxSqlAggregation::new);
                    valueProviders.add((FunctionEx & Serializable)row -> row[maxIndex]);
                    continue block8;
                }
                case SUM: {
                    int sumIndex = (Integer)aggregateCallArguments.get(0);
                    QueryDataType sumOperandType = (QueryDataType)operandTypes.get(sumIndex);
                    aggregationProviders.add((SupplierEx & Serializable)() -> SumSqlAggregations.from(sumOperandType, distinct));
                    valueProviders.add((FunctionEx & Serializable)row -> row[sumIndex]);
                    continue block8;
                }
                case AVG: {
                    int avgIndex = (Integer)aggregateCallArguments.get(0);
                    QueryDataType avgOperandType = (QueryDataType)operandTypes.get(avgIndex);
                    aggregationProviders.add((SupplierEx & Serializable)() -> AvgSqlAggregations.from(avgOperandType, distinct));
                    valueProviders.add((FunctionEx & Serializable)row -> row[avgIndex]);
                    continue block8;
                }
            }
            throw QueryException.error((String)("Unsupported aggregation function: " + kind));
        }
        return AggregateOperation.withCreate((SupplierEx & Serializable)() -> {
            ArrayList<Object> aggregations = new ArrayList<Object>(aggregationProviders.size());
            for (SupplierEx aggregationProvider : aggregationProviders) {
                aggregations.add(aggregationProvider.get());
            }
            return aggregations;
        }).andAccumulate((BiConsumerEx & Serializable)(aggregations, row) -> {
            for (int i = 0; i < aggregations.size(); ++i) {
                ((SqlAggregation)aggregations.get(i)).accumulate(((FunctionEx)valueProviders.get(i)).apply(row));
            }
        }).andCombine((BiConsumerEx & Serializable)(lefts, rights) -> {
            assert (lefts.size() == rights.size());
            for (int i = 0; i < lefts.size(); ++i) {
                ((SqlAggregation)lefts.get(i)).combine((SqlAggregation)rights.get(i));
            }
        }).andExportFinish((FunctionEx & Serializable)aggregations -> {
            Object[] values = new Object[aggregations.size()];
            for (int i = 0; i < aggregations.size(); ++i) {
                values[i] = ((SqlAggregation)aggregations.get(i)).collect();
            }
            return values;
        });
    }
}

