package com.facebook.presto.cost;

import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.util.MoreMath;
import java.util.Optional;
import java.util.OptionalDouble;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;

/* loaded from: input_file:com/facebook/presto/cost/ComparisonStatsCalculator.class */
public final class ComparisonStatsCalculator {
    private ComparisonStatsCalculator() {
    }

    public static Optional<PlanNodeStatsEstimate> comparisonExpressionToLiteralStats(PlanNodeStatsEstimate planNodeStatsEstimate, Optional<Symbol> optional, SymbolStatsEstimate symbolStatsEstimate, OptionalDouble optionalDouble, ComparisonExpression.Operator operator) {
        switch (operator) {
            case EQUAL:
                return expressionToLiteralEquality(planNodeStatsEstimate, optional, symbolStatsEstimate, optionalDouble);
            case NOT_EQUAL:
                return expressionToLiteralNonEquality(planNodeStatsEstimate, optional, symbolStatsEstimate, optionalDouble);
            case LESS_THAN:
            case LESS_THAN_OR_EQUAL:
                return expressionToLiteralLessThan(planNodeStatsEstimate, optional, symbolStatsEstimate, optionalDouble);
            case GREATER_THAN:
            case GREATER_THAN_OR_EQUAL:
                return expressionToLiteralGreaterThan(planNodeStatsEstimate, optional, symbolStatsEstimate, optionalDouble);
            case IS_DISTINCT_FROM:
            default:
                return Optional.empty();
        }
    }

    private static Optional<PlanNodeStatsEstimate> expressionToLiteralRangeComparison(PlanNodeStatsEstimate planNodeStatsEstimate, Optional<Symbol> optional, SymbolStatsEstimate symbolStatsEstimate, StatisticRange statisticRange) {
        StatisticRange from = StatisticRange.from(symbolStatsEstimate);
        StatisticRange intersect = from.intersect(statisticRange);
        double overlapPercentWith = from.overlapPercentWith(intersect);
        PlanNodeStatsEstimate mapOutputRowCount = planNodeStatsEstimate.mapOutputRowCount(d -> {
            return Double.valueOf(overlapPercentWith * (1.0d - symbolStatsEstimate.getNullsFraction()) * d.doubleValue());
        });
        if (optional.isPresent()) {
            SymbolStatsEstimate build = SymbolStatsEstimate.builder().setAverageRowSize(symbolStatsEstimate.getAverageRowSize()).setStatisticsRange(intersect).setNullsFraction(CMAESOptimizer.DEFAULT_STOPFITNESS).build();
            mapOutputRowCount = mapOutputRowCount.mapSymbolColumnStatistics(optional.get(), symbolStatsEstimate2 -> {
                return build;
            });
        }
        return Optional.of(mapOutputRowCount);
    }

    private static Optional<PlanNodeStatsEstimate> expressionToLiteralEquality(PlanNodeStatsEstimate planNodeStatsEstimate, Optional<Symbol> optional, SymbolStatsEstimate symbolStatsEstimate, OptionalDouble optionalDouble) {
        return expressionToLiteralRangeComparison(planNodeStatsEstimate, optional, symbolStatsEstimate, optionalDouble.isPresent() ? new StatisticRange(optionalDouble.getAsDouble(), optionalDouble.getAsDouble(), 1.0d) : new StatisticRange(Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 1.0d));
    }

    private static Optional<PlanNodeStatsEstimate> expressionToLiteralNonEquality(PlanNodeStatsEstimate planNodeStatsEstimate, Optional<Symbol> optional, SymbolStatsEstimate symbolStatsEstimate, OptionalDouble optionalDouble) {
        StatisticRange from = StatisticRange.from(symbolStatsEstimate);
        double overlapPercentWith = 1.0d - from.overlapPercentWith(from.intersect(optionalDouble.isPresent() ? new StatisticRange(optionalDouble.getAsDouble(), optionalDouble.getAsDouble(), 1.0d) : new StatisticRange(Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 1.0d)));
        PlanNodeStatsEstimate.Builder buildFrom = PlanNodeStatsEstimate.buildFrom(planNodeStatsEstimate);
        buildFrom.setOutputRowCount(overlapPercentWith * (1.0d - symbolStatsEstimate.getNullsFraction()) * planNodeStatsEstimate.getOutputRowCount());
        if (optional.isPresent()) {
            buildFrom = buildFrom.addSymbolStatistics(optional.get(), SymbolStatsEstimate.buildFrom(symbolStatsEstimate).setNullsFraction(CMAESOptimizer.DEFAULT_STOPFITNESS).setDistinctValuesCount(MoreMath.max(symbolStatsEstimate.getDistinctValuesCount() - 1.0d, CMAESOptimizer.DEFAULT_STOPFITNESS)).build());
        }
        return Optional.of(buildFrom.build());
    }

    private static Optional<PlanNodeStatsEstimate> expressionToLiteralLessThan(PlanNodeStatsEstimate planNodeStatsEstimate, Optional<Symbol> optional, SymbolStatsEstimate symbolStatsEstimate, OptionalDouble optionalDouble) {
        return expressionToLiteralRangeComparison(planNodeStatsEstimate, optional, symbolStatsEstimate, new StatisticRange(Double.NEGATIVE_INFINITY, optionalDouble.orElse(Double.POSITIVE_INFINITY), Double.NaN));
    }

    private static Optional<PlanNodeStatsEstimate> expressionToLiteralGreaterThan(PlanNodeStatsEstimate planNodeStatsEstimate, Optional<Symbol> optional, SymbolStatsEstimate symbolStatsEstimate, OptionalDouble optionalDouble) {
        return expressionToLiteralRangeComparison(planNodeStatsEstimate, optional, symbolStatsEstimate, new StatisticRange(optionalDouble.orElse(Double.NEGATIVE_INFINITY), Double.POSITIVE_INFINITY, Double.NaN));
    }

    public static Optional<PlanNodeStatsEstimate> comparisonExpressionToExpressionStats(PlanNodeStatsEstimate planNodeStatsEstimate, Optional<Symbol> optional, SymbolStatsEstimate symbolStatsEstimate, Optional<Symbol> optional2, SymbolStatsEstimate symbolStatsEstimate2, ComparisonExpression.Operator operator) {
        switch (operator) {
            case EQUAL:
                return expressionToExpressionEquality(planNodeStatsEstimate, optional, symbolStatsEstimate, optional2, symbolStatsEstimate2);
            case NOT_EQUAL:
                return expressionToExpressionNonEquality(planNodeStatsEstimate, optional, symbolStatsEstimate, optional2, symbolStatsEstimate2);
            case LESS_THAN:
            case LESS_THAN_OR_EQUAL:
            case GREATER_THAN:
            case GREATER_THAN_OR_EQUAL:
            case IS_DISTINCT_FROM:
            default:
                return Optional.empty();
        }
    }

    private static Optional<PlanNodeStatsEstimate> expressionToExpressionEquality(PlanNodeStatsEstimate planNodeStatsEstimate, Optional<Symbol> optional, SymbolStatsEstimate symbolStatsEstimate, Optional<Symbol> optional2, SymbolStatsEstimate symbolStatsEstimate2) {
        if (Double.isNaN(symbolStatsEstimate.getDistinctValuesCount()) || Double.isNaN(symbolStatsEstimate2.getDistinctValuesCount())) {
            return Optional.empty();
        }
        StatisticRange from = StatisticRange.from(symbolStatsEstimate);
        StatisticRange from2 = StatisticRange.from(symbolStatsEstimate2);
        StatisticRange intersect = from.intersect(from2);
        double nullsFraction = (1.0d - symbolStatsEstimate.getNullsFraction()) * (1.0d - symbolStatsEstimate2.getNullsFraction());
        double firstNonNaN = MoreMath.firstNonNaN(from.overlapPercentWith(intersect), 1.0d);
        double firstNonNaN2 = MoreMath.firstNonNaN(from2.overlapPercentWith(intersect), 1.0d);
        double distinctValuesCount = firstNonNaN * from.getDistinctValuesCount();
        double distinctValuesCount2 = firstNonNaN2 * from2.getDistinctValuesCount();
        double max = ((1.0d * firstNonNaN) * firstNonNaN2) / MoreMath.max(distinctValuesCount, distinctValuesCount2, 1.0d);
        double min = MoreMath.min(distinctValuesCount, distinctValuesCount2);
        PlanNodeStatsEstimate.Builder outputRowCount = PlanNodeStatsEstimate.buildFrom(planNodeStatsEstimate).setOutputRowCount(planNodeStatsEstimate.getOutputRowCount() * nullsFraction * max);
        SymbolStatsEstimate build = SymbolStatsEstimate.builder().setAverageRowSize(averageExcludingNaNs(symbolStatsEstimate.getAverageRowSize(), symbolStatsEstimate2.getAverageRowSize())).setNullsFraction(CMAESOptimizer.DEFAULT_STOPFITNESS).setStatisticsRange(intersect).setDistinctValuesCount(min).build();
        optional.ifPresent(symbol -> {
            outputRowCount.addSymbolStatistics(symbol, build);
        });
        optional2.ifPresent(symbol2 -> {
            outputRowCount.addSymbolStatistics(symbol2, build);
        });
        return Optional.of(outputRowCount.build());
    }

    private static double averageExcludingNaNs(double d, double d2) {
        if (Double.isNaN(d) && Double.isNaN(d2)) {
            return Double.NaN;
        }
        return (Double.isNaN(d) || Double.isNaN(d2)) ? MoreMath.firstNonNaN(d, d2) : (d + d2) / 2.0d;
    }

    private static Optional<PlanNodeStatsEstimate> expressionToExpressionNonEquality(PlanNodeStatsEstimate planNodeStatsEstimate, Optional<Symbol> optional, SymbolStatsEstimate symbolStatsEstimate, Optional<Symbol> optional2, SymbolStatsEstimate symbolStatsEstimate2) {
        double nullsFraction = (1.0d - symbolStatsEstimate.getNullsFraction()) * (1.0d - symbolStatsEstimate2.getNullsFraction());
        PlanNodeStatsEstimate mapOutputRowCount = planNodeStatsEstimate.mapOutputRowCount(d -> {
            return Double.valueOf(d.doubleValue() * nullsFraction);
        });
        SymbolStatsEstimate mapNullsFraction = symbolStatsEstimate.mapNullsFraction(d2 -> {
            return Double.valueOf(CMAESOptimizer.DEFAULT_STOPFITNESS);
        });
        SymbolStatsEstimate mapNullsFraction2 = symbolStatsEstimate2.mapNullsFraction(d3 -> {
            return Double.valueOf(CMAESOptimizer.DEFAULT_STOPFITNESS);
        });
        Optional<PlanNodeStatsEstimate> expressionToExpressionEquality = expressionToExpressionEquality(mapOutputRowCount, optional, mapNullsFraction, optional2, mapNullsFraction2);
        if (!expressionToExpressionEquality.isPresent()) {
            return Optional.empty();
        }
        PlanNodeStatsEstimate mapOutputRowCount2 = mapOutputRowCount.mapOutputRowCount(d4 -> {
            double outputRowCount = ((PlanNodeStatsEstimate) expressionToExpressionEquality.get()).getOutputRowCount() / mapOutputRowCount.getOutputRowCount();
            if (!Double.isFinite(outputRowCount)) {
                outputRowCount = 0.0d;
            }
            return Double.valueOf(d4.doubleValue() * (1.0d - outputRowCount));
        });
        if (optional.isPresent()) {
            mapOutputRowCount2 = mapOutputRowCount2.mapSymbolColumnStatistics(optional.get(), symbolStatsEstimate3 -> {
                return mapNullsFraction;
            });
        }
        if (optional2.isPresent()) {
            mapOutputRowCount2 = mapOutputRowCount2.mapSymbolColumnStatistics(optional2.get(), symbolStatsEstimate4 -> {
                return mapNullsFraction2;
            });
        }
        return Optional.of(mapOutputRowCount2);
    }
}
