package com.facebook.presto.cost;

import com.facebook.presto.Session;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;

/* loaded from: input_file:com/facebook/presto/cost/AggregationStatsRule.class */
public class AggregationStatsRule extends SimpleStatsRule<AggregationNode> {
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation();

    public AggregationStatsRule(StatsNormalizer statsNormalizer) {
        super(statsNormalizer);
    }

    @Override // com.facebook.presto.cost.ComposableStatsCalculator.Rule
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.facebook.presto.cost.SimpleStatsRule
    public Optional<PlanNodeStatsEstimate> doCalculate(AggregationNode aggregationNode, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider typeProvider) {
        if (aggregationNode.getGroupingSetCount() == 1 && aggregationNode.getStep() == AggregationNode.Step.SINGLE) {
            return Optional.of(groupBy(statsProvider.getStats(aggregationNode.getSource()), aggregationNode.getGroupingKeys(), aggregationNode.getAggregations()));
        }
        return Optional.empty();
    }

    public static PlanNodeStatsEstimate groupBy(PlanNodeStatsEstimate planNodeStatsEstimate, Collection<VariableReferenceExpression> collection, Map<VariableReferenceExpression, AggregationNode.Aggregation> map) {
        PlanNodeStatsEstimate.Builder builder = PlanNodeStatsEstimate.builder();
        for (VariableReferenceExpression variableReferenceExpression : collection) {
            VariableStatsEstimate variableStatistics = planNodeStatsEstimate.getVariableStatistics(variableReferenceExpression);
            builder.addVariableStatistics(variableReferenceExpression, variableStatistics.mapNullsFraction(d -> {
                return d.doubleValue() == CMAESOptimizer.DEFAULT_STOPFITNESS ? Double.valueOf(CMAESOptimizer.DEFAULT_STOPFITNESS) : Double.valueOf(1.0d / (variableStatistics.getDistinctValuesCount() + 1.0d));
            }));
        }
        double d2 = 1.0d;
        Iterator<VariableReferenceExpression> it2 = collection.iterator();
        while (it2.hasNext()) {
            VariableStatsEstimate variableStatistics2 = planNodeStatsEstimate.getVariableStatistics(it2.next());
            d2 *= variableStatistics2.getDistinctValuesCount() + (variableStatistics2.getNullsFraction() == CMAESOptimizer.DEFAULT_STOPFITNESS ? 0 : 1);
        }
        builder.setOutputRowCount(Math.min(d2, planNodeStatsEstimate.getOutputRowCount()));
        for (Map.Entry<VariableReferenceExpression, AggregationNode.Aggregation> entry : map.entrySet()) {
            builder.addVariableStatistics(entry.getKey(), estimateAggregationStats(entry.getValue(), planNodeStatsEstimate));
        }
        return builder.build();
    }

    private static VariableStatsEstimate estimateAggregationStats(AggregationNode.Aggregation aggregation, PlanNodeStatsEstimate planNodeStatsEstimate) {
        Objects.requireNonNull(aggregation, "aggregation is null");
        Objects.requireNonNull(planNodeStatsEstimate, "sourceStats is null");
        return VariableStatsEstimate.unknown();
    }
}
