package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionRegistry;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher;
import com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil;
import com.facebook.presto.sql.planner.optimizations.ScalarAggregationToJoinRewriter;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode;
import com.facebook.presto.sql.planner.plan.LateralJoinNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.util.MorePredicates;
import java.util.Objects;
import java.util.Optional;

/* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedScalarAggregationToJoin.class */
public class TransformCorrelatedScalarAggregationToJoin implements Rule<LateralJoinNode> {
    private static final Pattern<LateralJoinNode> PATTERN = Patterns.lateralJoin().with(Pattern.nonEmpty(Patterns.LateralJoin.correlation()));
    private final FunctionRegistry functionRegistry;

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public Pattern<LateralJoinNode> getPattern() {
        return PATTERN;
    }

    public TransformCorrelatedScalarAggregationToJoin(FunctionRegistry functionRegistry) {
        this.functionRegistry = (FunctionRegistry) Objects.requireNonNull(functionRegistry, "functionRegistry is null");
    }

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public Rule.Result apply(LateralJoinNode lateralJoinNode, Captures captures, Rule.Context context) {
        PlanNode subquery = lateralJoinNode.getSubquery();
        if (!QueryCardinalityUtil.isScalar(subquery, context.getLookup())) {
            return Rule.Result.empty();
        }
        Optional<AggregationNode> findAggregation = findAggregation(subquery, context.getLookup());
        if (!findAggregation.isPresent() || !findAggregation.get().getGroupingKeys().isEmpty()) {
            return Rule.Result.empty();
        }
        PlanNode rewriteScalarAggregation = new ScalarAggregationToJoinRewriter(this.functionRegistry, context.getSymbolAllocator(), context.getIdAllocator(), context.getLookup()).rewriteScalarAggregation(lateralJoinNode, findAggregation.get());
        return rewriteScalarAggregation instanceof LateralJoinNode ? Rule.Result.empty() : Rule.Result.ofPlanNode(rewriteScalarAggregation);
    }

    private static Optional<AggregationNode> findAggregation(PlanNode planNode, Lookup lookup) {
        PlanNodeSearcher searchFrom = PlanNodeSearcher.searchFrom(planNode, lookup);
        Class<AggregationNode> cls = AggregationNode.class;
        AggregationNode.class.getClass();
        return searchFrom.where((v1) -> {
            return r1.isInstance(v1);
        }).recurseOnlyWhen(MorePredicates.isInstanceOfAny(ProjectNode.class, EnforceSingleRowNode.class)).findFirst();
    }
}
