package com.facebook.presto.cost;

import com.facebook.presto.Session;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.planner.LiteralInterpreter;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.tree.AstVisitor;
import com.facebook.presto.sql.tree.BetweenPredicate;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.ComparisonExpressionType;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.InListExpression;
import com.facebook.presto.sql.tree.InPredicate;
import com.facebook.presto.sql.tree.IsNotNullPredicate;
import com.facebook.presto.sql.tree.IsNullPredicate;
import com.facebook.presto.sql.tree.Literal;
import com.facebook.presto.sql.tree.LogicalBinaryExpression;
import com.facebook.presto.sql.tree.NotExpression;
import com.facebook.presto.sql.tree.SymbolReference;
import java.util.Map;
import java.util.Objects;
import java.util.OptionalDouble;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;

/* loaded from: input_file:com/facebook/presto/cost/FilterStatsCalculator.class */
public class FilterStatsCalculator {
    private static final double UNKNOWN_FILTER_COEFFICIENT = 0.9d;
    private final Metadata metadata;

    /* loaded from: input_file:com/facebook/presto/cost/FilterStatsCalculator$FilterExpressionStatsCalculatingVisitor.class */
    private class FilterExpressionStatsCalculatingVisitor extends AstVisitor<PlanNodeStatsEstimate, Void> {
        private final PlanNodeStatsEstimate input;
        private final Session session;
        private final Map<Symbol, Type> types;

        FilterExpressionStatsCalculatingVisitor(PlanNodeStatsEstimate planNodeStatsEstimate, Session session, Map<Symbol, Type> map) {
            this.input = planNodeStatsEstimate;
            this.session = session;
            this.types = map;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.facebook.presto.sql.tree.AstVisitor
        public PlanNodeStatsEstimate visitExpression(Expression expression, Void r4) {
            return filterForUnknownExpression();
        }

        private PlanNodeStatsEstimate filterForUnknownExpression() {
            return FilterStatsCalculator.filterStatsForUnknownExpression(this.input);
        }

        private PlanNodeStatsEstimate filterForFalseExpression() {
            PlanNodeStatsEstimate.Builder builder = PlanNodeStatsEstimate.builder();
            this.input.getSymbolsWithKnownStatistics().forEach(symbol -> {
                builder.addSymbolStatistics(symbol, SymbolStatsEstimate.ZERO_STATS);
            });
            return builder.setOutputRowCount(CMAESOptimizer.DEFAULT_STOPFITNESS).build();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.facebook.presto.sql.tree.AstVisitor
        public PlanNodeStatsEstimate visitNotExpression(NotExpression notExpression, Void r6) {
            return PlanNodeStatsEstimateMath.differenceInStats(this.input, process(notExpression.getValue()));
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.facebook.presto.sql.tree.AstVisitor
        public PlanNodeStatsEstimate visitLogicalBinaryExpression(LogicalBinaryExpression logicalBinaryExpression, Void r9) {
            PlanNodeStatsEstimate process = process(logicalBinaryExpression.getLeft());
            PlanNodeStatsEstimate process2 = new FilterExpressionStatsCalculatingVisitor(process, this.session, this.types).process(logicalBinaryExpression.getRight());
            switch (logicalBinaryExpression.getType()) {
                case AND:
                    return process2;
                case OR:
                    return PlanNodeStatsEstimateMath.differenceInNonRangeStats(PlanNodeStatsEstimateMath.addStatsAndSumDistinctValues(process, process(logicalBinaryExpression.getRight())), process2);
                default:
                    throw new IllegalStateException("Unimplemented logical binary operator expression " + logicalBinaryExpression.getType());
            }
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.facebook.presto.sql.tree.AstVisitor
        public PlanNodeStatsEstimate visitBooleanLiteral(BooleanLiteral booleanLiteral, Void r5) {
            return booleanLiteral.equals(BooleanLiteral.TRUE_LITERAL) ? this.input : filterForFalseExpression();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.facebook.presto.sql.tree.AstVisitor
        public PlanNodeStatsEstimate visitIsNotNullPredicate(IsNotNullPredicate isNotNullPredicate, Void r6) {
            if (!(isNotNullPredicate.getValue() instanceof SymbolReference)) {
                return visitExpression((Expression) isNotNullPredicate, r6);
            }
            Symbol from = Symbol.from(isNotNullPredicate.getValue());
            SymbolStatsEstimate symbolStatistics = this.input.getSymbolStatistics(from);
            return this.input.mapOutputRowCount(d -> {
                return Double.valueOf(d.doubleValue() * (1.0d - symbolStatistics.getNullsFraction()));
            }).mapSymbolColumnStatistics(from, symbolStatsEstimate -> {
                return symbolStatsEstimate.mapNullsFraction(d2 -> {
                    return Double.valueOf(CMAESOptimizer.DEFAULT_STOPFITNESS);
                });
            });
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.facebook.presto.sql.tree.AstVisitor
        public PlanNodeStatsEstimate visitIsNullPredicate(IsNullPredicate isNullPredicate, Void r6) {
            if (!(isNullPredicate.getValue() instanceof SymbolReference)) {
                return visitExpression((Expression) isNullPredicate, r6);
            }
            Symbol from = Symbol.from(isNullPredicate.getValue());
            SymbolStatsEstimate symbolStatistics = this.input.getSymbolStatistics(from);
            return this.input.mapOutputRowCount(d -> {
                return Double.valueOf(d.doubleValue() * symbolStatistics.getNullsFraction());
            }).mapSymbolColumnStatistics(from, symbolStatsEstimate -> {
                return SymbolStatsEstimate.builder().setNullsFraction(1.0d).setLowValue(Double.NaN).setHighValue(Double.NaN).setDistinctValuesCount(CMAESOptimizer.DEFAULT_STOPFITNESS).build();
            });
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.facebook.presto.sql.tree.AstVisitor
        public PlanNodeStatsEstimate visitBetweenPredicate(BetweenPredicate betweenPredicate, Void r8) {
            if (!(betweenPredicate.getValue() instanceof SymbolReference) || !(betweenPredicate.getMin() instanceof Literal) || !(betweenPredicate.getMax() instanceof Literal)) {
                return visitExpression((Expression) betweenPredicate, r8);
            }
            SymbolStatsEstimate symbolStatistics = this.input.getSymbolStatistics(Symbol.from(betweenPredicate.getValue()));
            ComparisonExpression comparisonExpression = new ComparisonExpression(ComparisonExpressionType.GREATER_THAN_OR_EQUAL, betweenPredicate.getValue(), betweenPredicate.getMin());
            ComparisonExpression comparisonExpression2 = new ComparisonExpression(ComparisonExpressionType.LESS_THAN_OR_EQUAL, betweenPredicate.getValue(), betweenPredicate.getMax());
            return process(Double.isInfinite(symbolStatistics.getLowValue()) ? ExpressionUtils.and(comparisonExpression, comparisonExpression2) : ExpressionUtils.and(comparisonExpression2, comparisonExpression));
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.facebook.presto.sql.tree.AstVisitor
        public PlanNodeStatsEstimate visitInPredicate(InPredicate inPredicate, Void r9) {
            if (!(inPredicate.getValue() instanceof SymbolReference) || !(inPredicate.getValueList() instanceof InListExpression)) {
                return visitExpression((Expression) inPredicate, r9);
            }
            PlanNodeStatsEstimate planNodeStatsEstimate = (PlanNodeStatsEstimate) ((InListExpression) inPredicate.getValueList()).getValues().stream().map(expression -> {
                return process(new ComparisonExpression(ComparisonExpressionType.EQUAL, inPredicate.getValue(), expression));
            }).reduce(filterForFalseExpression(), PlanNodeStatsEstimateMath::addStatsAndSumDistinctValues);
            if (Double.isNaN(planNodeStatsEstimate.getOutputRowCount())) {
                return visitExpression((Expression) inPredicate, r9);
            }
            Symbol from = Symbol.from(inPredicate.getValue());
            SymbolStatsEstimate symbolStatistics = this.input.getSymbolStatistics(from);
            double outputRowCount = this.input.getOutputRowCount() * (1.0d - symbolStatistics.getNullsFraction());
            SymbolStatsEstimate mapDistinctValuesCount = planNodeStatsEstimate.getSymbolStatistics(from).mapDistinctValuesCount(d -> {
                return Double.valueOf(Double.min(d.doubleValue(), symbolStatistics.getDistinctValuesCount()));
            });
            return this.input.mapOutputRowCount(d2 -> {
                return Double.valueOf(Double.min(planNodeStatsEstimate.getOutputRowCount(), outputRowCount));
            }).mapSymbolColumnStatistics(from, symbolStatsEstimate -> {
                return mapDistinctValuesCount;
            });
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.facebook.presto.sql.tree.AstVisitor
        public PlanNodeStatsEstimate visitComparisonExpression(ComparisonExpression comparisonExpression, Void r9) {
            ComparisonExpressionType type = comparisonExpression.getType();
            Expression left = comparisonExpression.getLeft();
            Expression right = comparisonExpression.getRight();
            if (!(left instanceof SymbolReference) && (right instanceof SymbolReference)) {
                return process(new ComparisonExpression(type.flip(), right, left));
            }
            if (!(left instanceof SymbolReference) || !(right instanceof Literal)) {
                return right instanceof SymbolReference ? ComparisonStatsCalculator.comparisonSymbolToSymbolStats(this.input, Symbol.from(left), Symbol.from(right), type) : FilterStatsCalculator.filterStatsForUnknownExpression(this.input);
            }
            Symbol from = Symbol.from(left);
            return ComparisonStatsCalculator.comparisonSymbolToLiteralStats(this.input, from, doubleValueFromLiteral(this.types.get(from), (Literal) right), type);
        }

        private OptionalDouble doubleValueFromLiteral(Type type, Literal literal) {
            return StatsUtil.toStatsRepresentation(FilterStatsCalculator.this.metadata, this.session, type, LiteralInterpreter.evaluate(FilterStatsCalculator.this.metadata, this.session.toConnectorSession(), literal));
        }
    }

    public FilterStatsCalculator(Metadata metadata) {
        this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
    }

    public PlanNodeStatsEstimate filterStats(PlanNodeStatsEstimate planNodeStatsEstimate, Expression expression, Session session, Map<Symbol, Type> map) {
        return new FilterExpressionStatsCalculatingVisitor(planNodeStatsEstimate, session, map).process(expression);
    }

    public static PlanNodeStatsEstimate filterStatsForUnknownExpression(PlanNodeStatsEstimate planNodeStatsEstimate) {
        return planNodeStatsEstimate.mapOutputRowCount(d -> {
            return Double.valueOf(d.doubleValue() * UNKNOWN_FILTER_COEFFICIENT);
        });
    }
}
