package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolsExtractor;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.tree.Expression;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import com.google.common.collect.Streams;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughJoin.class */
public class PushPartialAggregationThroughJoin implements Rule<AggregationNode> {
    private static final Capture<JoinNode> JOIN_NODE = Capture.newCapture();
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().matching(PushPartialAggregationThroughJoin::isSupportedAggregationNode).with(Patterns.source().matching(Patterns.join().capturedAs(JOIN_NODE)));

    private static boolean isSupportedAggregationNode(AggregationNode aggregationNode) {
        return !aggregationNode.isStreamable() && !aggregationNode.getHashSymbol().isPresent() && aggregationNode.getStep() == AggregationNode.Step.PARTIAL && aggregationNode.getGroupingSetCount() == 1;
    }

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.isPushAggregationThroughJoin(session);
    }

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public Rule.Result apply(AggregationNode aggregationNode, Captures captures, Rule.Context context) {
        JoinNode joinNode = (JoinNode) captures.get(JOIN_NODE);
        return joinNode.getType() != JoinNode.Type.INNER ? Rule.Result.empty() : allAggregationsOn(aggregationNode.getAggregations(), joinNode.getLeft().getOutputSymbols()) ? Rule.Result.ofPlanNode(pushPartialToLeftChild(aggregationNode, joinNode, context)) : allAggregationsOn(aggregationNode.getAggregations(), joinNode.getRight().getOutputSymbols()) ? Rule.Result.ofPlanNode(pushPartialToRightChild(aggregationNode, joinNode, context)) : Rule.Result.empty();
    }

    private boolean allAggregationsOn(Map<Symbol, AggregationNode.Aggregation> map, List<Symbol> list) {
        return list.containsAll(SymbolsExtractor.extractUnique((Iterable<? extends Expression>) map.values().stream().map((v0) -> {
            return v0.getCall();
        }).collect(ImmutableList.toImmutableList())));
    }

    private PlanNode pushPartialToLeftChild(AggregationNode aggregationNode, JoinNode joinNode, Rule.Context context) {
        ImmutableSet copyOf = ImmutableSet.copyOf((Collection) joinNode.getLeft().getOutputSymbols());
        return pushPartialToJoin(aggregationNode, joinNode, replaceAggregationSource(aggregationNode, joinNode.getLeft(), getPushedDownGroupingSet(aggregationNode, copyOf, Sets.intersection(getJoinRequiredSymbols(joinNode), copyOf))), joinNode.getRight(), context);
    }

    private PlanNode pushPartialToRightChild(AggregationNode aggregationNode, JoinNode joinNode, Rule.Context context) {
        ImmutableSet copyOf = ImmutableSet.copyOf((Collection) joinNode.getRight().getOutputSymbols());
        return pushPartialToJoin(aggregationNode, joinNode, joinNode.getLeft(), replaceAggregationSource(aggregationNode, joinNode.getRight(), getPushedDownGroupingSet(aggregationNode, copyOf, Sets.intersection(getJoinRequiredSymbols(joinNode), copyOf))), context);
    }

    private Set<Symbol> getJoinRequiredSymbols(JoinNode joinNode) {
        return (Set) Streams.concat(joinNode.getCriteria().stream().map((v0) -> {
            return v0.getLeft();
        }), joinNode.getCriteria().stream().map((v0) -> {
            return v0.getRight();
        }), ((Set) joinNode.getFilter().map(SymbolsExtractor::extractUnique).orElse(ImmutableSet.of())).stream(), ((ImmutableSet) joinNode.getLeftHashSymbol().map((v0) -> {
            return ImmutableSet.of(v0);
        }).orElse(ImmutableSet.of())).stream(), ((ImmutableSet) joinNode.getRightHashSymbol().map((v0) -> {
            return ImmutableSet.of(v0);
        }).orElse(ImmutableSet.of())).stream()).collect(ImmutableSet.toImmutableSet());
    }

    private List<Symbol> getPushedDownGroupingSet(AggregationNode aggregationNode, Set<Symbol> set, Set<Symbol> set2) {
        Stream<Symbol> stream = aggregationNode.getGroupingKeys().stream();
        set.getClass();
        List<Symbol> list = (List) stream.filter((v1) -> {
            return r1.contains(v1);
        }).collect(Collectors.toList());
        HashSet hashSet = new HashSet(list);
        Stream<Symbol> stream2 = set2.stream();
        hashSet.getClass();
        Stream<Symbol> filter = stream2.filter((v1) -> {
            return r1.add(v1);
        });
        list.getClass();
        filter.forEach((v1) -> {
            r1.add(v1);
        });
        return list;
    }

    private AggregationNode replaceAggregationSource(AggregationNode aggregationNode, PlanNode planNode, List<Symbol> list) {
        return new AggregationNode(aggregationNode.getId(), planNode, aggregationNode.getAggregations(), AggregationNode.singleGroupingSet(list), ImmutableList.of(), aggregationNode.getStep(), aggregationNode.getHashSymbol(), aggregationNode.getGroupIdSymbol());
    }

    private PlanNode pushPartialToJoin(AggregationNode aggregationNode, JoinNode joinNode, PlanNode planNode, PlanNode planNode2, Rule.Context context) {
        JoinNode joinNode2 = new JoinNode(joinNode.getId(), joinNode.getType(), planNode, planNode2, joinNode.getCriteria(), ImmutableList.builder().addAll((Iterable) planNode.getOutputSymbols()).addAll((Iterable) planNode2.getOutputSymbols()).build(), joinNode.getFilter(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType());
        return Util.restrictOutputs(context.getIdAllocator(), joinNode2, ImmutableSet.copyOf((Collection) aggregationNode.getOutputSymbols())).orElse(joinNode2);
    }
}
