package com.facebook.presto.cost;

import com.facebook.presto.Session;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.iterative.GroupReference;
import com.facebook.presto.sql.planner.plan.InternalPlanVisitor;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.SemiJoinNode;
import com.facebook.presto.sql.planner.plan.SpatialJoinNode;
import com.facebook.presto.sql.planner.plan.UnionNode;
import java.util.Objects;
import java.util.Optional;
import javax.annotation.concurrent.ThreadSafe;
import javax.inject.Inject;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;

@ThreadSafe
/* loaded from: input_file:com/facebook/presto/cost/CostCalculatorWithEstimatedExchanges.class */
public class CostCalculatorWithEstimatedExchanges implements CostCalculator {
    private final CostCalculator costCalculator;
    private final TaskCountEstimator taskCountEstimator;

    /* loaded from: input_file:com/facebook/presto/cost/CostCalculatorWithEstimatedExchanges$ExchangeCostEstimator.class */
    private static class ExchangeCostEstimator extends InternalPlanVisitor<LocalCostEstimate, Void> {
        private final StatsProvider stats;
        private final TypeProvider types;
        private final TaskCountEstimator taskCountEstimator;

        ExchangeCostEstimator(StatsProvider statsProvider, TypeProvider typeProvider, TaskCountEstimator taskCountEstimator) {
            this.stats = (StatsProvider) Objects.requireNonNull(statsProvider, "stats is null");
            this.types = (TypeProvider) Objects.requireNonNull(typeProvider, "types is null");
            this.taskCountEstimator = (TaskCountEstimator) Objects.requireNonNull(taskCountEstimator, "taskCountEstimator is null");
        }

        @Override // com.facebook.presto.spi.plan.PlanVisitor
        public LocalCostEstimate visitPlan(PlanNode planNode, Void r4) {
            return LocalCostEstimate.zero();
        }

        @Override // com.facebook.presto.sql.planner.plan.InternalPlanVisitor
        public LocalCostEstimate visitGroupReference(GroupReference groupReference, Void r5) {
            throw new UnsupportedOperationException();
        }

        @Override // com.facebook.presto.spi.plan.PlanVisitor
        public LocalCostEstimate visitAggregation(AggregationNode aggregationNode, Void r6) {
            PlanNode source = aggregationNode.getSource();
            double outputSizeInBytes = getStats(source).getOutputSizeInBytes(source.getOutputVariables());
            return LocalCostEstimate.addPartialComponents(CostCalculatorWithEstimatedExchanges.calculateRemoteRepartitionCost(outputSizeInBytes), CostCalculatorWithEstimatedExchanges.calculateLocalRepartitionCost(outputSizeInBytes), new LocalCostEstimate[0]);
        }

        @Override // com.facebook.presto.sql.planner.plan.InternalPlanVisitor
        public LocalCostEstimate visitJoin(JoinNode joinNode, Void r9) {
            return CostCalculatorWithEstimatedExchanges.calculateJoinExchangeCost(joinNode.getLeft(), joinNode.getRight(), this.stats, this.types, Objects.equals(joinNode.getDistributionType(), Optional.of(JoinNode.DistributionType.REPLICATED)), this.taskCountEstimator.estimateSourceDistributedTaskCount());
        }

        @Override // com.facebook.presto.sql.planner.plan.InternalPlanVisitor
        public LocalCostEstimate visitSemiJoin(SemiJoinNode semiJoinNode, Void r9) {
            return CostCalculatorWithEstimatedExchanges.calculateJoinExchangeCost(semiJoinNode.getSource(), semiJoinNode.getFilteringSource(), this.stats, this.types, Objects.equals(semiJoinNode.getDistributionType(), Optional.of(SemiJoinNode.DistributionType.REPLICATED)), this.taskCountEstimator.estimateSourceDistributedTaskCount());
        }

        @Override // com.facebook.presto.sql.planner.plan.InternalPlanVisitor
        public LocalCostEstimate visitSpatialJoin(SpatialJoinNode spatialJoinNode, Void r9) {
            return CostCalculatorWithEstimatedExchanges.calculateJoinExchangeCost(spatialJoinNode.getLeft(), spatialJoinNode.getRight(), this.stats, this.types, spatialJoinNode.getDistributionType() == SpatialJoinNode.DistributionType.REPLICATED, this.taskCountEstimator.estimateSourceDistributedTaskCount());
        }

        @Override // com.facebook.presto.sql.planner.plan.InternalPlanVisitor
        public LocalCostEstimate visitUnion(UnionNode unionNode, Void r5) {
            return CostCalculatorWithEstimatedExchanges.calculateRemoteGatherCost(getStats(unionNode).getOutputSizeInBytes(unionNode.getOutputVariables()));
        }

        private PlanNodeStatsEstimate getStats(PlanNode planNode) {
            return this.stats.getStats(planNode);
        }
    }

    @Inject
    public CostCalculatorWithEstimatedExchanges(CostCalculator costCalculator, TaskCountEstimator taskCountEstimator) {
        this.costCalculator = (CostCalculator) Objects.requireNonNull(costCalculator, "costCalculator is null");
        this.taskCountEstimator = (TaskCountEstimator) Objects.requireNonNull(taskCountEstimator, "taskCountEstimator is null");
    }

    @Override // com.facebook.presto.cost.CostCalculator
    public PlanCostEstimate calculateCost(PlanNode planNode, StatsProvider statsProvider, CostProvider costProvider, Session session, TypeProvider typeProvider) {
        return addExchangeCost(this.costCalculator.calculateCost(planNode, statsProvider, costProvider, session, typeProvider), (LocalCostEstimate) planNode.accept(new ExchangeCostEstimator(statsProvider, typeProvider, this.taskCountEstimator), null));
    }

    private static PlanCostEstimate addExchangeCost(PlanCostEstimate planCostEstimate, LocalCostEstimate localCostEstimate) {
        return new PlanCostEstimate(planCostEstimate.getCpuCost() + localCostEstimate.getCpuCost(), planCostEstimate.getMaxMemory() + localCostEstimate.getMaxMemory(), planCostEstimate.getMaxMemoryWhenOutputting() + localCostEstimate.getMaxMemory(), planCostEstimate.getNetworkCost() + localCostEstimate.getNetworkCost());
    }

    public static LocalCostEstimate calculateRemoteGatherCost(double d) {
        return LocalCostEstimate.ofNetwork(d);
    }

    public static LocalCostEstimate calculateRemoteRepartitionCost(double d) {
        return LocalCostEstimate.of(d, CMAESOptimizer.DEFAULT_STOPFITNESS, d);
    }

    public static LocalCostEstimate calculateLocalRepartitionCost(double d) {
        return LocalCostEstimate.ofCpu(d);
    }

    public static LocalCostEstimate calculateRemoteReplicateCost(double d, int i) {
        return LocalCostEstimate.ofNetwork(d * i);
    }

    public static LocalCostEstimate calculateJoinCostWithoutOutput(PlanNode planNode, PlanNode planNode2, StatsProvider statsProvider, TypeProvider typeProvider, boolean z, int i) {
        return LocalCostEstimate.addPartialComponents(calculateJoinExchangeCost(planNode, planNode2, statsProvider, typeProvider, z, i), calculateJoinInputCost(planNode, planNode2, statsProvider, typeProvider, z, i), new LocalCostEstimate[0]);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static LocalCostEstimate calculateJoinExchangeCost(PlanNode planNode, PlanNode planNode2, StatsProvider statsProvider, TypeProvider typeProvider, boolean z, int i) {
        double outputSizeInBytes = statsProvider.getStats(planNode).getOutputSizeInBytes(planNode.getOutputVariables());
        double outputSizeInBytes2 = statsProvider.getStats(planNode2).getOutputSizeInBytes(planNode2.getOutputVariables());
        return z ? LocalCostEstimate.addPartialComponents(calculateRemoteReplicateCost(outputSizeInBytes2, i), calculateLocalRepartitionCost(outputSizeInBytes2), new LocalCostEstimate[0]) : LocalCostEstimate.addPartialComponents(calculateRemoteRepartitionCost(outputSizeInBytes), calculateRemoteRepartitionCost(outputSizeInBytes2), calculateLocalRepartitionCost(outputSizeInBytes2));
    }

    public static LocalCostEstimate calculateJoinInputCost(PlanNode planNode, PlanNode planNode2, StatsProvider statsProvider, TypeProvider typeProvider, boolean z, int i) {
        int i2 = z ? i : 1;
        PlanNodeStatsEstimate stats = statsProvider.getStats(planNode);
        double outputSizeInBytes = statsProvider.getStats(planNode2).getOutputSizeInBytes(planNode2.getOutputVariables());
        double outputSizeInBytes2 = stats.getOutputSizeInBytes(planNode.getOutputVariables()) + (outputSizeInBytes * i2);
        if (z) {
            outputSizeInBytes2 += outputSizeInBytes * (i2 - 1);
        }
        return LocalCostEstimate.of(outputSizeInBytes2, outputSizeInBytes * i2, CMAESOptimizer.DEFAULT_STOPFITNESS);
    }
}
