package com.facebook.presto.cost;

import com.facebook.presto.Session;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.analyzer.ExpressionAnalyzer;
import com.facebook.presto.sql.analyzer.Scope;
import com.facebook.presto.sql.planner.LiteralInterpreter;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.tree.AstVisitor;
import com.facebook.presto.sql.tree.BetweenPredicate;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.InListExpression;
import com.facebook.presto.sql.tree.InPredicate;
import com.facebook.presto.sql.tree.IsNotNullPredicate;
import com.facebook.presto.sql.tree.IsNullPredicate;
import com.facebook.presto.sql.tree.Literal;
import com.facebook.presto.sql.tree.LogicalBinaryExpression;
import com.facebook.presto.sql.tree.Node;
import com.facebook.presto.sql.tree.NotExpression;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalDouble;
import java.util.function.Function;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;

/* loaded from: input_file:com/facebook/presto/cost/FilterStatsCalculator.class */
public class FilterStatsCalculator {
    static final double UNKNOWN_FILTER_COEFFICIENT = 0.9d;
    private final Metadata metadata;
    private final ScalarStatsCalculator scalarStatsCalculator;
    private final StatsNormalizer normalizer;

    /* loaded from: input_file:com/facebook/presto/cost/FilterStatsCalculator$FilterExpressionStatsCalculatingVisitor.class */
    private class FilterExpressionStatsCalculatingVisitor extends AstVisitor<Optional<PlanNodeStatsEstimate>, Void> {
        private final PlanNodeStatsEstimate input;
        private final Session session;
        private final TypeProvider types;

        FilterExpressionStatsCalculatingVisitor(PlanNodeStatsEstimate planNodeStatsEstimate, Session session, TypeProvider typeProvider) {
            this.input = planNodeStatsEstimate;
            this.session = session;
            this.types = typeProvider;
        }

        @Override // com.facebook.presto.sql.tree.AstVisitor
        public Optional<PlanNodeStatsEstimate> process(Node node, @Nullable Void r6) {
            return ((Optional) super.process(node, (Node) r6)).map(planNodeStatsEstimate -> {
                return FilterStatsCalculator.this.normalizer.normalize(planNodeStatsEstimate, this.types);
            });
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.facebook.presto.sql.tree.AstVisitor
        public Optional<PlanNodeStatsEstimate> visitExpression(Expression expression, Void r4) {
            return Optional.empty();
        }

        private Optional<PlanNodeStatsEstimate> filterForFalseExpression() {
            PlanNodeStatsEstimate.Builder builder = PlanNodeStatsEstimate.builder();
            this.input.getSymbolsWithKnownStatistics().forEach(symbol -> {
                builder.addSymbolStatistics(symbol, SymbolStatsEstimate.ZERO_STATS);
            });
            return Optional.of(builder.setOutputRowCount(CMAESOptimizer.DEFAULT_STOPFITNESS).build());
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.facebook.presto.sql.tree.AstVisitor
        public Optional<PlanNodeStatsEstimate> visitNotExpression(NotExpression notExpression, Void r7) {
            return notExpression.getValue() instanceof IsNullPredicate ? process(new IsNotNullPredicate(((IsNullPredicate) notExpression.getValue()).getValue())) : process(notExpression.getValue()).map(planNodeStatsEstimate -> {
                return PlanNodeStatsEstimateMath.differenceInStats(this.input, planNodeStatsEstimate);
            });
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.facebook.presto.sql.tree.AstVisitor
        public Optional<PlanNodeStatsEstimate> visitLogicalBinaryExpression(LogicalBinaryExpression logicalBinaryExpression, Void r7) {
            switch (logicalBinaryExpression.getOperator()) {
                case AND:
                    return visitLogicalBinaryAnd(logicalBinaryExpression.getLeft(), logicalBinaryExpression.getRight());
                case OR:
                    return visitLogicalBinaryOr(logicalBinaryExpression.getLeft(), logicalBinaryExpression.getRight());
                default:
                    throw new IllegalStateException("Unimplemented logical binary operator expression " + logicalBinaryExpression.getOperator());
            }
        }

        private Optional<PlanNodeStatsEstimate> visitLogicalBinaryAnd(Expression expression, Expression expression2) {
            Optional<PlanNodeStatsEstimate> process = process(expression);
            if (!process.isPresent()) {
                return process(expression2).map(planNodeStatsEstimate -> {
                    return FilterStatsCalculator.filterStatsForUnknownExpression(planNodeStatsEstimate);
                });
            }
            Optional<PlanNodeStatsEstimate> process2 = new FilterExpressionStatsCalculatingVisitor(process.get(), this.session, this.types).process(expression2);
            return process2.isPresent() ? process2 : process.map(planNodeStatsEstimate2 -> {
                return FilterStatsCalculator.filterStatsForUnknownExpression(planNodeStatsEstimate2);
            });
        }

        private Optional<PlanNodeStatsEstimate> visitLogicalBinaryOr(Expression expression, Expression expression2) {
            Optional<PlanNodeStatsEstimate> process = process(expression);
            if (!process.isPresent()) {
                return Optional.empty();
            }
            Optional<PlanNodeStatsEstimate> process2 = process(expression2);
            if (!process2.isPresent()) {
                return Optional.empty();
            }
            Optional<PlanNodeStatsEstimate> process3 = new FilterExpressionStatsCalculatingVisitor(process.get(), this.session, this.types).process(expression2);
            return !process3.isPresent() ? Optional.empty() : Optional.of(PlanNodeStatsEstimateMath.differenceInNonRangeStats(PlanNodeStatsEstimateMath.addStatsAndSumDistinctValues(process.get(), process2.get()), process3.get()));
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.facebook.presto.sql.tree.AstVisitor
        public Optional<PlanNodeStatsEstimate> visitBooleanLiteral(BooleanLiteral booleanLiteral, Void r5) {
            return booleanLiteral.equals(BooleanLiteral.TRUE_LITERAL) ? Optional.of(this.input) : filterForFalseExpression();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.facebook.presto.sql.tree.AstVisitor
        public Optional<PlanNodeStatsEstimate> visitIsNotNullPredicate(IsNotNullPredicate isNotNullPredicate, Void r6) {
            if (!(isNotNullPredicate.getValue() instanceof SymbolReference)) {
                return visitExpression((Expression) isNotNullPredicate, r6);
            }
            Symbol from = Symbol.from(isNotNullPredicate.getValue());
            SymbolStatsEstimate symbolStatistics = this.input.getSymbolStatistics(from);
            return Optional.of(this.input.mapOutputRowCount(d -> {
                return Double.valueOf(d.doubleValue() * (1.0d - symbolStatistics.getNullsFraction()));
            }).mapSymbolColumnStatistics(from, symbolStatsEstimate -> {
                return symbolStatsEstimate.mapNullsFraction(d2 -> {
                    return Double.valueOf(CMAESOptimizer.DEFAULT_STOPFITNESS);
                });
            }));
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.facebook.presto.sql.tree.AstVisitor
        public Optional<PlanNodeStatsEstimate> visitIsNullPredicate(IsNullPredicate isNullPredicate, Void r6) {
            if (!(isNullPredicate.getValue() instanceof SymbolReference)) {
                return visitExpression((Expression) isNullPredicate, r6);
            }
            Symbol from = Symbol.from(isNullPredicate.getValue());
            SymbolStatsEstimate symbolStatistics = this.input.getSymbolStatistics(from);
            return Optional.of(this.input.mapOutputRowCount(d -> {
                return Double.valueOf(d.doubleValue() * symbolStatistics.getNullsFraction());
            }).mapSymbolColumnStatistics(from, symbolStatsEstimate -> {
                return SymbolStatsEstimate.builder().setNullsFraction(1.0d).setLowValue(Double.NaN).setHighValue(Double.NaN).setDistinctValuesCount(CMAESOptimizer.DEFAULT_STOPFITNESS).build();
            }));
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.facebook.presto.sql.tree.AstVisitor
        public Optional<PlanNodeStatsEstimate> visitBetweenPredicate(BetweenPredicate betweenPredicate, Void r8) {
            if (!(betweenPredicate.getValue() instanceof SymbolReference)) {
                return visitExpression((Expression) betweenPredicate, r8);
            }
            if (!(betweenPredicate.getMin() instanceof Literal) && !isSingleValue(getExpressionStats(betweenPredicate.getMin()))) {
                return visitExpression((Expression) betweenPredicate, r8);
            }
            if (!(betweenPredicate.getMax() instanceof Literal) && !isSingleValue(getExpressionStats(betweenPredicate.getMax()))) {
                return visitExpression((Expression) betweenPredicate, r8);
            }
            SymbolStatsEstimate symbolStatistics = this.input.getSymbolStatistics(Symbol.from(betweenPredicate.getValue()));
            ComparisonExpression comparisonExpression = new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL, betweenPredicate.getValue(), betweenPredicate.getMin());
            ComparisonExpression comparisonExpression2 = new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN_OR_EQUAL, betweenPredicate.getValue(), betweenPredicate.getMax());
            return process(Double.isInfinite(symbolStatistics.getLowValue()) ? ExpressionUtils.and(comparisonExpression, comparisonExpression2) : ExpressionUtils.and(comparisonExpression2, comparisonExpression));
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.facebook.presto.sql.tree.AstVisitor
        public Optional<PlanNodeStatsEstimate> visitInPredicate(InPredicate inPredicate, Void r9) {
            if (!(inPredicate.getValueList() instanceof InListExpression)) {
                return Optional.empty();
            }
            ImmutableList immutableList = (ImmutableList) ((InListExpression) inPredicate.getValueList()).getValues().stream().map(expression -> {
                return process(new ComparisonExpression(ComparisonExpression.Operator.EQUAL, inPredicate.getValue(), expression));
            }).collect(ImmutableList.toImmutableList());
            if (!immutableList.stream().allMatch((v0) -> {
                return v0.isPresent();
            })) {
                return Optional.empty();
            }
            PlanNodeStatsEstimate planNodeStatsEstimate = (PlanNodeStatsEstimate) immutableList.stream().map((v0) -> {
                return v0.get();
            }).reduce(filterForFalseExpression().get(), PlanNodeStatsEstimateMath::addStatsAndSumDistinctValues);
            if (Double.isNaN(planNodeStatsEstimate.getOutputRowCount())) {
                return Optional.empty();
            }
            Optional<Symbol> asSymbol = asSymbol(inPredicate.getValue());
            SymbolStatsEstimate expressionStats = getExpressionStats(inPredicate.getValue());
            if (Objects.equals(expressionStats, SymbolStatsEstimate.UNKNOWN_STATS)) {
                return Optional.empty();
            }
            double outputRowCount = this.input.getOutputRowCount() * (1.0d - expressionStats.getNullsFraction());
            PlanNodeStatsEstimate mapOutputRowCount = this.input.mapOutputRowCount(d -> {
                return Double.valueOf(Double.min(planNodeStatsEstimate.getOutputRowCount(), outputRowCount));
            });
            if (asSymbol.isPresent()) {
                SymbolStatsEstimate mapDistinctValuesCount = planNodeStatsEstimate.getSymbolStatistics(asSymbol.get()).mapDistinctValuesCount(d2 -> {
                    return Double.valueOf(Double.min(d2.doubleValue(), expressionStats.getDistinctValuesCount()));
                });
                mapOutputRowCount = mapOutputRowCount.mapSymbolColumnStatistics(asSymbol.get(), symbolStatsEstimate -> {
                    return mapDistinctValuesCount;
                });
            }
            return Optional.of(mapOutputRowCount);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.facebook.presto.sql.tree.AstVisitor
        public Optional<PlanNodeStatsEstimate> visitComparisonExpression(ComparisonExpression comparisonExpression, Void r9) {
            ComparisonExpression.Operator operator = comparisonExpression.getOperator();
            Expression left = comparisonExpression.getLeft();
            Expression right = comparisonExpression.getRight();
            Preconditions.checkArgument(((left instanceof Literal) && (right instanceof Literal)) ? false : true, "Literal-to-literal not supported here, should be eliminated earlier");
            if (!(left instanceof SymbolReference) && (right instanceof SymbolReference)) {
                return process(new ComparisonExpression(operator.flip(), right, left));
            }
            if ((left instanceof Literal) && !(right instanceof Literal)) {
                return process(new ComparisonExpression(operator.flip(), right, left));
            }
            Optional<Symbol> asSymbol = asSymbol(left);
            SymbolStatsEstimate expressionStats = getExpressionStats(left);
            if (Objects.equals(expressionStats, SymbolStatsEstimate.UNKNOWN_STATS)) {
                return visitExpression((Expression) comparisonExpression, r9);
            }
            if (right instanceof Literal) {
                return ComparisonStatsCalculator.comparisonExpressionToLiteralStats(this.input, asSymbol, expressionStats, doubleValueFromLiteral(getType(left), (Literal) right), operator);
            }
            Optional<Symbol> asSymbol2 = asSymbol(right);
            SymbolStatsEstimate expressionStats2 = getExpressionStats(right);
            if (Objects.equals(expressionStats2, SymbolStatsEstimate.UNKNOWN_STATS)) {
                return visitExpression((Expression) comparisonExpression, r9);
            }
            if ((left instanceof SymbolReference) && Objects.equals(left, right)) {
                return process(new IsNotNullPredicate(left));
            }
            if (isSingleValue(expressionStats2)) {
                return ComparisonStatsCalculator.comparisonExpressionToLiteralStats(this.input, asSymbol, expressionStats, Double.isNaN(expressionStats2.getLowValue()) ? OptionalDouble.empty() : OptionalDouble.of(expressionStats2.getLowValue()), operator);
            }
            return ComparisonStatsCalculator.comparisonExpressionToExpressionStats(this.input, asSymbol, expressionStats, asSymbol2, expressionStats2, operator);
        }

        private Optional<Symbol> asSymbol(Expression expression) {
            return expression instanceof SymbolReference ? Optional.of(Symbol.from(expression)) : Optional.empty();
        }

        private boolean isSingleValue(SymbolStatsEstimate symbolStatsEstimate) {
            return symbolStatsEstimate.getDistinctValuesCount() == 1.0d && Double.compare(symbolStatsEstimate.getLowValue(), symbolStatsEstimate.getHighValue()) == 0 && !Double.isInfinite(symbolStatsEstimate.getLowValue());
        }

        private Type getType(Expression expression) {
            return (Type) asSymbol(expression).map(symbol -> {
                return (Type) Objects.requireNonNull(this.types.get(symbol), (Supplier<String>) () -> {
                    return String.format("No type for symbol %s", symbol);
                });
            }).orElseGet(() -> {
                return ExpressionAnalyzer.createWithoutSubqueries(FilterStatsCalculator.this.metadata.getFunctionRegistry(), FilterStatsCalculator.this.metadata.getTypeManager(), this.session, this.types, (List<Expression>) ImmutableList.of(), (Function<? super Node, ? extends RuntimeException>) node -> {
                    return new IllegalStateException("Unexpected Subquery");
                }, false).analyze(expression, Scope.create());
            });
        }

        private SymbolStatsEstimate getExpressionStats(Expression expression) {
            return (SymbolStatsEstimate) asSymbol(expression).map(symbol -> {
                return (SymbolStatsEstimate) Objects.requireNonNull(this.input.getSymbolStatistics(symbol), (Supplier<String>) () -> {
                    return String.format("No statistics for symbol %s", symbol);
                });
            }).orElseGet(() -> {
                return FilterStatsCalculator.this.scalarStatsCalculator.calculate(expression, this.input, this.session);
            });
        }

        private OptionalDouble doubleValueFromLiteral(Type type, Literal literal) {
            return StatsUtil.toStatsRepresentation(FilterStatsCalculator.this.metadata, this.session, type, LiteralInterpreter.evaluate(FilterStatsCalculator.this.metadata, this.session.toConnectorSession(), literal));
        }
    }

    public FilterStatsCalculator(Metadata metadata, ScalarStatsCalculator scalarStatsCalculator, StatsNormalizer statsNormalizer) {
        this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
        this.scalarStatsCalculator = (ScalarStatsCalculator) Objects.requireNonNull(scalarStatsCalculator, "scalarStatsCalculator is null");
        this.normalizer = (StatsNormalizer) Objects.requireNonNull(statsNormalizer, "normalizer is null");
    }

    public PlanNodeStatsEstimate filterStats(PlanNodeStatsEstimate planNodeStatsEstimate, Expression expression, Session session, TypeProvider typeProvider) {
        return new FilterExpressionStatsCalculatingVisitor(planNodeStatsEstimate, session, typeProvider).process(expression).orElseGet(() -> {
            return this.normalizer.normalize(filterStatsForUnknownExpression(planNodeStatsEstimate), typeProvider);
        });
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static PlanNodeStatsEstimate filterStatsForUnknownExpression(PlanNodeStatsEstimate planNodeStatsEstimate) {
        return planNodeStatsEstimate.mapOutputRowCount(d -> {
            return Double.valueOf(d.doubleValue() * UNKNOWN_FILTER_COEFFICIENT);
        });
    }
}
