package com.facebook.presto.cost;

import com.facebook.presto.Session;
import com.facebook.presto.execution.scheduler.NodeSchedulerConfig;
import com.facebook.presto.metadata.InternalNodeManager;
import com.facebook.presto.spi.Node;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.iterative.GroupReference;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.AssignUniqueId;
import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.LimitNode;
import com.facebook.presto.sql.planner.plan.OutputNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.PlanVisitor;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SemiJoinNode;
import com.facebook.presto.sql.planner.plan.TableScanNode;
import com.facebook.presto.sql.planner.plan.ValuesNode;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.IntSupplier;
import javax.annotation.concurrent.ThreadSafe;
import javax.inject.Inject;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;

@ThreadSafe
/* loaded from: input_file:com/facebook/presto/cost/CostCalculatorUsingExchanges.class */
public class CostCalculatorUsingExchanges implements CostCalculator {
    private final IntSupplier numberOfNodes;

    /* loaded from: input_file:com/facebook/presto/cost/CostCalculatorUsingExchanges$CostEstimator.class */
    private static class CostEstimator extends PlanVisitor<PlanNodeCostEstimate, Void> {
        private final int numberOfNodes;
        private final StatsProvider stats;
        private final TypeProvider types;

        CostEstimator(int i, StatsProvider statsProvider, TypeProvider typeProvider) {
            this.numberOfNodes = i;
            this.stats = (StatsProvider) Objects.requireNonNull(statsProvider, "stats is null");
            this.types = (TypeProvider) Objects.requireNonNull(typeProvider, "types is null");
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNodeCostEstimate visitPlan(PlanNode planNode, Void r4) {
            return PlanNodeCostEstimate.UNKNOWN_COST;
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNodeCostEstimate visitGroupReference(GroupReference groupReference, Void r5) {
            throw new UnsupportedOperationException();
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNodeCostEstimate visitAssignUniqueId(AssignUniqueId assignUniqueId, Void r6) {
            return PlanNodeCostEstimate.cpuCost(getStats(assignUniqueId).getOutputSizeInBytes(ImmutableList.of(assignUniqueId.getIdColumn()), this.types));
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNodeCostEstimate visitOutput(OutputNode outputNode, Void r4) {
            return PlanNodeCostEstimate.ZERO_COST;
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNodeCostEstimate visitTableScan(TableScanNode tableScanNode, Void r6) {
            return PlanNodeCostEstimate.cpuCost(getStats(tableScanNode).getOutputSizeInBytes(tableScanNode.getOutputSymbols(), this.types));
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNodeCostEstimate visitFilter(FilterNode filterNode, Void r6) {
            return PlanNodeCostEstimate.cpuCost(getStats(filterNode.getSource()).getOutputSizeInBytes(filterNode.getOutputSymbols(), this.types));
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNodeCostEstimate visitProject(ProjectNode projectNode, Void r6) {
            return PlanNodeCostEstimate.cpuCost(getStats(projectNode).getOutputSizeInBytes(projectNode.getOutputSymbols(), this.types));
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNodeCostEstimate visitAggregation(AggregationNode aggregationNode, Void r11) {
            return new PlanNodeCostEstimate(getStats(aggregationNode.getSource()).getOutputSizeInBytes(aggregationNode.getSource().getOutputSymbols(), this.types), getStats(aggregationNode).getOutputSizeInBytes(aggregationNode.getOutputSymbols(), this.types), CMAESOptimizer.DEFAULT_STOPFITNESS);
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNodeCostEstimate visitJoin(JoinNode joinNode, Void r9) {
            return calculateJoinCost(joinNode, joinNode.getLeft(), joinNode.getRight(), Objects.equals(joinNode.getDistributionType(), Optional.of(JoinNode.DistributionType.REPLICATED)));
        }

        private PlanNodeCostEstimate calculateJoinCost(PlanNode planNode, PlanNode planNode2, PlanNode planNode3, boolean z) {
            int i = z ? this.numberOfNodes : 1;
            PlanNodeStatsEstimate stats = getStats(planNode2);
            PlanNodeStatsEstimate stats2 = getStats(planNode3);
            PlanNodeStatsEstimate stats3 = getStats(planNode);
            double outputSizeInBytes = stats2.getOutputSizeInBytes(planNode3.getOutputSymbols(), this.types);
            double outputSizeInBytes2 = stats.getOutputSizeInBytes(planNode2.getOutputSymbols(), this.types) + (outputSizeInBytes * i) + stats3.getOutputSizeInBytes(planNode.getOutputSymbols(), this.types);
            if (z) {
                outputSizeInBytes2 += outputSizeInBytes * (i - 1);
            }
            return new PlanNodeCostEstimate(outputSizeInBytes2, outputSizeInBytes * i, CMAESOptimizer.DEFAULT_STOPFITNESS);
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNodeCostEstimate visitExchange(ExchangeNode exchangeNode, Void r9) {
            return CostCalculatorUsingExchanges.calculateExchangeCost(this.numberOfNodes, getStats(exchangeNode), exchangeNode.getOutputSymbols(), exchangeNode.getType(), exchangeNode.getScope(), this.types);
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNodeCostEstimate visitSemiJoin(SemiJoinNode semiJoinNode, Void r9) {
            return calculateJoinCost(semiJoinNode, semiJoinNode.getSource(), semiJoinNode.getFilteringSource(), semiJoinNode.getDistributionType().orElse(SemiJoinNode.DistributionType.PARTITIONED).equals(SemiJoinNode.DistributionType.REPLICATED));
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNodeCostEstimate visitValues(ValuesNode valuesNode, Void r4) {
            return PlanNodeCostEstimate.ZERO_COST;
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNodeCostEstimate visitEnforceSingleRow(EnforceSingleRowNode enforceSingleRowNode, Void r4) {
            return PlanNodeCostEstimate.ZERO_COST;
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNodeCostEstimate visitLimit(LimitNode limitNode, Void r6) {
            return PlanNodeCostEstimate.cpuCost(getStats(limitNode).getOutputSizeInBytes(limitNode.getOutputSymbols(), this.types));
        }

        private PlanNodeStatsEstimate getStats(PlanNode planNode) {
            return this.stats.getStats(planNode);
        }
    }

    @Inject
    public CostCalculatorUsingExchanges(NodeSchedulerConfig nodeSchedulerConfig, InternalNodeManager internalNodeManager) {
        this(currentNumberOfWorkerNodes(nodeSchedulerConfig.isIncludeCoordinator(), internalNodeManager));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static IntSupplier currentNumberOfWorkerNodes(boolean z, InternalNodeManager internalNodeManager) {
        Objects.requireNonNull(internalNodeManager, "nodeManager is null");
        return () -> {
            Set<Node> activeNodes = internalNodeManager.getAllNodes().getActiveNodes();
            return z ? activeNodes.size() : Math.toIntExact(activeNodes.stream().filter(node -> {
                return !node.isCoordinator();
            }).count());
        };
    }

    public CostCalculatorUsingExchanges(IntSupplier intSupplier) {
        this.numberOfNodes = (IntSupplier) Objects.requireNonNull(intSupplier, "numberOfNodes is null");
    }

    @Override // com.facebook.presto.cost.CostCalculator
    public PlanNodeCostEstimate calculateCost(PlanNode planNode, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider typeProvider) {
        return (PlanNodeCostEstimate) planNode.accept(new CostEstimator(this.numberOfNodes.getAsInt(), statsProvider, typeProvider), null);
    }

    public static PlanNodeCostEstimate calculateExchangeCost(int i, PlanNodeStatsEstimate planNodeStatsEstimate, List<Symbol> list, ExchangeNode.Type type, ExchangeNode.Scope scope, TypeProvider typeProvider) {
        double d;
        double outputSizeInBytes = planNodeStatsEstimate.getOutputSizeInBytes(list, typeProvider);
        double d2 = 0.0d;
        switch (type) {
            case GATHER:
                d = outputSizeInBytes;
                break;
            case REPARTITION:
                d = outputSizeInBytes;
                d2 = outputSizeInBytes;
                break;
            case REPLICATE:
                d = outputSizeInBytes * i;
                break;
            default:
                throw new UnsupportedOperationException(String.format("Unsupported type [%s] of the exchange", type));
        }
        if (scope == ExchangeNode.Scope.LOCAL) {
            d = 0.0d;
        }
        return new PlanNodeCostEstimate(d2, CMAESOptimizer.DEFAULT_STOPFITNESS, d);
    }
}
