package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.FunctionManager;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.PlanVariableAllocator;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.optimizations.PlanNodeDecorrelator;
import com.facebook.presto.sql.planner.plan.AssignUniqueId;
import com.facebook.presto.sql.planner.plan.AssignmentUtils;
import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.LateralJoinNode;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.OriginalExpressionUtils;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.Expression;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Stream;

/* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/ScalarAggregationToJoinRewriter.class */
public class ScalarAggregationToJoinRewriter {
    private final FunctionResolution functionResolution;
    private final PlanVariableAllocator variableAllocator;
    private final PlanNodeIdAllocator idAllocator;
    private final Lookup lookup;
    private final PlanNodeDecorrelator planNodeDecorrelator;

    public ScalarAggregationToJoinRewriter(FunctionManager functionManager, PlanVariableAllocator planVariableAllocator, PlanNodeIdAllocator planNodeIdAllocator, Lookup lookup) {
        Objects.requireNonNull(functionManager, "metadata is null");
        this.functionResolution = new FunctionResolution(functionManager);
        this.variableAllocator = (PlanVariableAllocator) Objects.requireNonNull(planVariableAllocator, "variableAllocator is null");
        this.idAllocator = (PlanNodeIdAllocator) Objects.requireNonNull(planNodeIdAllocator, "idAllocator is null");
        this.lookup = (Lookup) Objects.requireNonNull(lookup, "lookup is null");
        this.planNodeDecorrelator = new PlanNodeDecorrelator(planNodeIdAllocator, planVariableAllocator, lookup);
    }

    public PlanNode rewriteScalarAggregation(LateralJoinNode lateralJoinNode, AggregationNode aggregationNode) {
        Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelateFilters = this.planNodeDecorrelator.decorrelateFilters(this.lookup.resolve(aggregationNode.getSource()), lateralJoinNode.getCorrelation());
        if (!decorrelateFilters.isPresent()) {
            return lateralJoinNode;
        }
        VariableReferenceExpression newVariable = this.variableAllocator.newVariable("non_null", (Type) BooleanType.BOOLEAN);
        return rewriteScalarAggregation(lateralJoinNode, aggregationNode, new ProjectNode(this.idAllocator.getNextId(), decorrelateFilters.get().getNode(), Assignments.builder().putAll(AssignmentUtils.identitiesAsSymbolReferences(decorrelateFilters.get().getNode().getOutputVariables())).put(newVariable, OriginalExpressionUtils.castToRowExpression(BooleanLiteral.TRUE_LITERAL)).build()), decorrelateFilters.get().getCorrelatedPredicates(), newVariable);
    }

    private PlanNode rewriteScalarAggregation(LateralJoinNode lateralJoinNode, AggregationNode aggregationNode, PlanNode planNode, Optional<Expression> optional, VariableReferenceExpression variableReferenceExpression) {
        AssignUniqueId assignUniqueId = new AssignUniqueId(this.idAllocator.getNextId(), lateralJoinNode.getInput(), this.variableAllocator.newVariable("unique", (Type) BigintType.BIGINT));
        Optional<AggregationNode> createAggregationNode = createAggregationNode(aggregationNode, new JoinNode(this.idAllocator.getNextId(), JoinNode.Type.LEFT, assignUniqueId, planNode, ImmutableList.of(), ImmutableList.builder().addAll(assignUniqueId.getOutputVariables()).addAll(planNode.getOutputVariables()).build(), optional.map(OriginalExpressionUtils::castToRowExpression), Optional.empty(), Optional.empty(), Optional.empty()), variableReferenceExpression);
        if (!createAggregationNode.isPresent()) {
            return lateralJoinNode;
        }
        PlanNodeSearcher searchFrom = PlanNodeSearcher.searchFrom(lateralJoinNode.getSubquery(), this.lookup);
        Class<ProjectNode> cls = ProjectNode.class;
        ProjectNode.class.getClass();
        PlanNodeSearcher where = searchFrom.where((v1) -> {
            return r1.isInstance(v1);
        });
        Class<EnforceSingleRowNode> cls2 = EnforceSingleRowNode.class;
        EnforceSingleRowNode.class.getClass();
        Optional findFirst = where.recurseOnlyWhen((v1) -> {
            return r1.isInstance(v1);
        }).findFirst();
        List<VariableReferenceExpression> truncatedAggregationVariables = getTruncatedAggregationVariables(lateralJoinNode, createAggregationNode.get());
        if (!findFirst.isPresent()) {
            return new ProjectNode(this.idAllocator.getNextId(), createAggregationNode.get(), AssignmentUtils.identityAssignmentsAsSymbolReferences(truncatedAggregationVariables));
        }
        return new ProjectNode(this.idAllocator.getNextId(), createAggregationNode.get(), Assignments.builder().putAll(AssignmentUtils.identitiesAsSymbolReferences(truncatedAggregationVariables)).putAll(((ProjectNode) findFirst.get()).getAssignments()).build());
    }

    private List<VariableReferenceExpression> getTruncatedAggregationVariables(LateralJoinNode lateralJoinNode, AggregationNode aggregationNode) {
        HashSet hashSet = new HashSet(lateralJoinNode.getOutputVariables());
        Stream stream = aggregationNode.getOutputVariables().stream();
        hashSet.getClass();
        return (List) stream.filter((v1) -> {
            return r1.contains(v1);
        }).collect(ImmutableList.toImmutableList());
    }

    private Optional<AggregationNode> createAggregationNode(AggregationNode aggregationNode, JoinNode joinNode, VariableReferenceExpression variableReferenceExpression) {
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (Map.Entry entry : aggregationNode.getAggregations().entrySet()) {
            VariableReferenceExpression variableReferenceExpression2 = (VariableReferenceExpression) entry.getKey();
            if (this.functionResolution.isCountFunction(((AggregationNode.Aggregation) entry.getValue()).getFunctionHandle())) {
                builder.put(variableReferenceExpression2, new AggregationNode.Aggregation(new CallExpression("count", this.functionResolution.countFunction(variableReferenceExpression.getType()), BigintType.BIGINT, ImmutableList.of(OriginalExpressionUtils.castToRowExpression(OriginalExpressionUtils.asSymbolReference(variableReferenceExpression)))), Optional.empty(), Optional.empty(), false, ((AggregationNode.Aggregation) entry.getValue()).getMask()));
            } else {
                builder.put(variableReferenceExpression2, entry.getValue());
            }
        }
        return Optional.of(new AggregationNode(this.idAllocator.getNextId(), joinNode, builder.build(), AggregationNode.singleGroupingSet(joinNode.getLeft().getOutputVariables()), ImmutableList.of(), aggregationNode.getStep(), aggregationNode.getHashVariable(), Optional.empty()));
    }
}
