package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.optimizations.AggregationNodeUtils;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import java.util.Iterator;
import java.util.Map;
import java.util.Optional;

/* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/AddIntermediateAggregations.class */
public class AddIntermediateAggregations implements Rule<AggregationNode> {
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().with(Patterns.Aggregation.step().equalTo(AggregationNode.Step.FINAL)).with(Pattern.empty(Patterns.Aggregation.groupingColumns())).matching(aggregationNode -> {
        return !aggregationNode.hasOrderings();
    });

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.isEnableIntermediateAggregations(session);
    }

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public Rule.Result apply(AggregationNode aggregationNode, Captures captures, Rule.Context context) {
        Lookup lookup = context.getLookup();
        PlanNodeIdAllocator idAllocator = context.getIdAllocator();
        Session session = context.getSession();
        TypeProvider types = context.getVariableAllocator().getTypes();
        Optional<PlanNode> recurseToPartial = recurseToPartial(lookup.resolve(aggregationNode.getSource()), lookup, idAllocator, types);
        if (!recurseToPartial.isPresent()) {
            return Rule.Result.empty();
        }
        PlanNode planNode = recurseToPartial.get();
        if (SystemSessionProperties.getTaskConcurrency(session) > 1) {
            planNode = ExchangeNode.gatheringExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, new AggregationNode(idAllocator.getNextId(), ExchangeNode.roundRobinExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, planNode), inputsAsOutputs(aggregationNode.getAggregations(), types), aggregationNode.getGroupingSets(), aggregationNode.getPreGroupedVariables(), AggregationNode.Step.INTERMEDIATE, aggregationNode.getHashVariable(), aggregationNode.getGroupIdVariable()));
        }
        return Rule.Result.ofPlanNode(aggregationNode.replaceChildren(ImmutableList.of(planNode)));
    }

    private Optional<PlanNode> recurseToPartial(PlanNode planNode, Lookup lookup, PlanNodeIdAllocator planNodeIdAllocator, TypeProvider typeProvider) {
        if ((planNode instanceof AggregationNode) && ((AggregationNode) planNode).getStep() == AggregationNode.Step.PARTIAL) {
            return Optional.of(addGatheringIntermediate((AggregationNode) planNode, planNodeIdAllocator, typeProvider));
        }
        if (!(planNode instanceof ExchangeNode) && !(planNode instanceof ProjectNode)) {
            return Optional.empty();
        }
        ImmutableList.Builder builder = ImmutableList.builder();
        Iterator it = planNode.getSources().iterator();
        while (it.hasNext()) {
            Optional<PlanNode> recurseToPartial = recurseToPartial(lookup.resolve((PlanNode) it.next()), lookup, planNodeIdAllocator, typeProvider);
            if (!recurseToPartial.isPresent()) {
                return Optional.empty();
            }
            builder.add(recurseToPartial.get());
        }
        return Optional.of(planNode.replaceChildren(builder.build()));
    }

    private PlanNode addGatheringIntermediate(AggregationNode aggregationNode, PlanNodeIdAllocator planNodeIdAllocator, TypeProvider typeProvider) {
        Verify.verify(aggregationNode.getGroupingKeys().isEmpty(), "Should be an un-grouped aggregation", new Object[0]);
        return new AggregationNode(planNodeIdAllocator.getNextId(), ExchangeNode.gatheringExchange(planNodeIdAllocator.getNextId(), ExchangeNode.Scope.LOCAL, aggregationNode), outputsAsInputs(aggregationNode.getAggregations()), aggregationNode.getGroupingSets(), aggregationNode.getPreGroupedVariables(), AggregationNode.Step.INTERMEDIATE, aggregationNode.getHashVariable(), aggregationNode.getGroupIdVariable());
    }

    private static Map<VariableReferenceExpression, AggregationNode.Aggregation> outputsAsInputs(Map<VariableReferenceExpression, AggregationNode.Aggregation> map) {
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (Map.Entry<VariableReferenceExpression, AggregationNode.Aggregation> entry : map.entrySet()) {
            VariableReferenceExpression key = entry.getKey();
            AggregationNode.Aggregation value = entry.getValue();
            Preconditions.checkState(!value.getOrderBy().isPresent(), "Intermediate aggregation does not support ORDER BY");
            builder.put(key, new AggregationNode.Aggregation(new CallExpression(value.getCall().getDisplayName(), value.getCall().getFunctionHandle(), value.getCall().getType(), ImmutableList.of(key)), Optional.empty(), Optional.empty(), false, Optional.empty()));
        }
        return builder.build();
    }

    private static Map<VariableReferenceExpression, AggregationNode.Aggregation> inputsAsOutputs(Map<VariableReferenceExpression, AggregationNode.Aggregation> map, TypeProvider typeProvider) {
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (Map.Entry<VariableReferenceExpression, AggregationNode.Aggregation> entry : map.entrySet()) {
            AggregationNode.Aggregation value = entry.getValue();
            Preconditions.checkArgument((value.getArguments().size() != 1 || value.getOrderBy().isPresent() || value.getFilter().isPresent()) ? false : true, "Aggregation should only have one argument and should have no order by  or filter to be able to rewritten to intermediate form");
            builder.put((VariableReferenceExpression) Iterables.getOnlyElement(AggregationNodeUtils.extractAggregationUniqueVariables(entry.getValue(), typeProvider)), entry.getValue());
        }
        return builder.build();
    }
}
