package com.facebook.presto.cost;

import com.facebook.presto.Session;
import com.facebook.presto.cost.SymbolStatsEstimate;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.type.DecimalType;
import com.facebook.presto.spi.type.TypeSignature;
import com.facebook.presto.sql.analyzer.ExpressionAnalyzer;
import com.facebook.presto.sql.analyzer.Scope;
import com.facebook.presto.sql.planner.LiteralInterpreter;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.tree.ArithmeticBinaryExpression;
import com.facebook.presto.sql.tree.ArithmeticUnaryExpression;
import com.facebook.presto.sql.tree.AstVisitor;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.CoalesceExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.Literal;
import com.facebook.presto.sql.tree.Node;
import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.SymbolReference;
import com.facebook.presto.util.MoreMath;
import com.google.common.collect.ImmutableList;
import java.util.Iterator;
import java.util.Objects;
import java.util.OptionalDouble;
import javax.inject.Inject;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;

/* loaded from: input_file:com/facebook/presto/cost/ScalarStatsCalculator.class */
public class ScalarStatsCalculator {
    private final Metadata metadata;

    /* loaded from: input_file:com/facebook/presto/cost/ScalarStatsCalculator$Visitor.class */
    private class Visitor extends AstVisitor<SymbolStatsEstimate, Void> {
        private final PlanNodeStatsEstimate input;
        private final Session session;

        Visitor(PlanNodeStatsEstimate planNodeStatsEstimate, Session session) {
            this.input = planNodeStatsEstimate;
            this.session = session;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.facebook.presto.sql.tree.AstVisitor
        public SymbolStatsEstimate visitNode(Node node, Void r4) {
            return SymbolStatsEstimate.UNKNOWN_STATS;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.facebook.presto.sql.tree.AstVisitor
        public SymbolStatsEstimate visitSymbolReference(SymbolReference symbolReference, Void r5) {
            return this.input.getSymbolStatistics(Symbol.from(symbolReference));
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.facebook.presto.sql.tree.AstVisitor
        public SymbolStatsEstimate visitNullLiteral(NullLiteral nullLiteral, Void r6) {
            return SymbolStatsEstimate.builder().setDistinctValuesCount(CMAESOptimizer.DEFAULT_STOPFITNESS).setNullsFraction(1.0d).build();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.facebook.presto.sql.tree.AstVisitor
        public SymbolStatsEstimate visitLiteral(Literal literal, Void r7) {
            Object evaluate = LiteralInterpreter.evaluate(ScalarStatsCalculator.this.metadata, this.session.toConnectorSession(), literal);
            OptionalDouble statsRepresentation = StatsUtil.toStatsRepresentation(ScalarStatsCalculator.this.metadata, this.session, ExpressionAnalyzer.createConstantAnalyzer(ScalarStatsCalculator.this.metadata, this.session, ImmutableList.of()).analyze(literal, Scope.create()), evaluate);
            SymbolStatsEstimate.Builder distinctValuesCount = SymbolStatsEstimate.builder().setNullsFraction(CMAESOptimizer.DEFAULT_STOPFITNESS).setDistinctValuesCount(1.0d);
            if (statsRepresentation.isPresent()) {
                distinctValuesCount.setLowValue(statsRepresentation.getAsDouble());
                distinctValuesCount.setHighValue(statsRepresentation.getAsDouble());
            }
            return distinctValuesCount.build();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.facebook.presto.sql.tree.AstVisitor
        public SymbolStatsEstimate visitCast(Cast cast, Void r7) {
            SymbolStatsEstimate process = process(cast.getExpression());
            TypeSignature parseTypeSignature = TypeSignature.parseTypeSignature(cast.getType());
            double distinctValuesCount = process.getDistinctValuesCount();
            double lowValue = process.getLowValue();
            double highValue = process.getHighValue();
            if (isIntegralType(parseTypeSignature)) {
                if (Double.isFinite(lowValue)) {
                    lowValue = Math.round(lowValue);
                }
                if (Double.isFinite(highValue)) {
                    highValue = Math.round(highValue);
                }
                if (Double.isFinite(lowValue) && Double.isFinite(highValue)) {
                    double d = (highValue - lowValue) + 1.0d;
                    if (!Double.isNaN(distinctValuesCount) && distinctValuesCount > d) {
                        distinctValuesCount = d;
                    }
                }
            }
            return SymbolStatsEstimate.builder().setNullsFraction(process.getNullsFraction()).setLowValue(lowValue).setHighValue(highValue).setDistinctValuesCount(distinctValuesCount).build();
        }

        private boolean isIntegralType(TypeSignature typeSignature) {
            String base = typeSignature.getBase();
            boolean z = -1;
            switch (base.hashCode()) {
                case -1389167889:
                    if (base.equals("bigint")) {
                        z = false;
                        break;
                    }
                    break;
                case -1312398097:
                    if (base.equals("tinyint")) {
                        z = 3;
                        break;
                    }
                    break;
                case -606531192:
                    if (base.equals("smallint")) {
                        z = 2;
                        break;
                    }
                    break;
                case 1542263633:
                    if (base.equals("decimal")) {
                        z = 4;
                        break;
                    }
                    break;
                case 1958052158:
                    if (base.equals("integer")) {
                        z = true;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                case true:
                case true:
                case true:
                    return true;
                case true:
                    return ((DecimalType) ScalarStatsCalculator.this.metadata.getType(typeSignature)).getScale() == 0;
                default:
                    return false;
            }
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.facebook.presto.sql.tree.AstVisitor
        public SymbolStatsEstimate visitArithmeticUnary(ArithmeticUnaryExpression arithmeticUnaryExpression, Void r7) {
            SymbolStatsEstimate process = process(arithmeticUnaryExpression.getValue());
            switch (arithmeticUnaryExpression.getSign()) {
                case PLUS:
                    return process;
                case MINUS:
                    return SymbolStatsEstimate.buildFrom(process).setLowValue(-process.getHighValue()).setHighValue(-process.getLowValue()).build();
                default:
                    throw new IllegalStateException("Unexpected sign: " + arithmeticUnaryExpression.getSign());
            }
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.facebook.presto.sql.tree.AstVisitor
        public SymbolStatsEstimate visitArithmeticBinary(ArithmeticBinaryExpression arithmeticBinaryExpression, Void r11) {
            Objects.requireNonNull(arithmeticBinaryExpression, "node is null");
            SymbolStatsEstimate process = process(arithmeticBinaryExpression.getLeft());
            SymbolStatsEstimate process2 = process(arithmeticBinaryExpression.getRight());
            SymbolStatsEstimate.Builder distinctValuesCount = SymbolStatsEstimate.builder().setAverageRowSize(Math.max(process.getAverageRowSize(), process2.getAverageRowSize())).setNullsFraction((process.getNullsFraction() + process2.getNullsFraction()) - (process.getNullsFraction() * process2.getNullsFraction())).setDistinctValuesCount(MoreMath.min(process.getDistinctValuesCount() * process2.getDistinctValuesCount(), this.input.getOutputRowCount()));
            double lowValue = process.getLowValue();
            double highValue = process.getHighValue();
            double lowValue2 = process2.getLowValue();
            double highValue2 = process2.getHighValue();
            if (arithmeticBinaryExpression.getType() == ArithmeticBinaryExpression.Type.DIVIDE && lowValue2 < CMAESOptimizer.DEFAULT_STOPFITNESS && highValue2 > CMAESOptimizer.DEFAULT_STOPFITNESS) {
                distinctValuesCount.setLowValue(Double.NEGATIVE_INFINITY).setHighValue(Double.POSITIVE_INFINITY);
            } else if (arithmeticBinaryExpression.getType() == ArithmeticBinaryExpression.Type.MODULUS) {
                double max = MoreMath.max(Math.abs(lowValue2), Math.abs(highValue2));
                if (highValue <= CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    distinctValuesCount.setLowValue(MoreMath.max(-max, lowValue)).setHighValue(CMAESOptimizer.DEFAULT_STOPFITNESS);
                } else if (lowValue >= CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    distinctValuesCount.setLowValue(CMAESOptimizer.DEFAULT_STOPFITNESS).setHighValue(MoreMath.min(max, highValue));
                } else {
                    distinctValuesCount.setLowValue(MoreMath.max(-max, lowValue)).setHighValue(MoreMath.min(max, highValue));
                }
            } else {
                double operate = operate(arithmeticBinaryExpression.getType(), lowValue, lowValue2);
                double operate2 = operate(arithmeticBinaryExpression.getType(), lowValue, highValue2);
                double operate3 = operate(arithmeticBinaryExpression.getType(), highValue, lowValue2);
                double operate4 = operate(arithmeticBinaryExpression.getType(), highValue, highValue2);
                distinctValuesCount.setLowValue(MoreMath.min(operate, operate2, operate3, operate4)).setHighValue(MoreMath.max(operate, operate2, operate3, operate4));
            }
            return distinctValuesCount.build();
        }

        private double operate(ArithmeticBinaryExpression.Type type, double d, double d2) {
            switch (type) {
                case ADD:
                    return d + d2;
                case SUBTRACT:
                    return d - d2;
                case MULTIPLY:
                    return d * d2;
                case DIVIDE:
                    return d / d2;
                case MODULUS:
                    return d % d2;
                default:
                    throw new IllegalStateException("Unsupported ArithmeticBinaryExpression.Type: " + type);
            }
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.facebook.presto.sql.tree.AstVisitor
        public SymbolStatsEstimate visitCoalesceExpression(CoalesceExpression coalesceExpression, Void r6) {
            Objects.requireNonNull(coalesceExpression, "node is null");
            SymbolStatsEstimate symbolStatsEstimate = null;
            Iterator<Expression> it2 = coalesceExpression.getOperands().iterator();
            while (it2.hasNext()) {
                SymbolStatsEstimate process = process(it2.next());
                symbolStatsEstimate = symbolStatsEstimate != null ? estimateCoalesce(symbolStatsEstimate, process) : process;
            }
            return (SymbolStatsEstimate) Objects.requireNonNull(symbolStatsEstimate, "result is null");
        }

        private SymbolStatsEstimate estimateCoalesce(SymbolStatsEstimate symbolStatsEstimate, SymbolStatsEstimate symbolStatsEstimate2) {
            return symbolStatsEstimate.getNullsFraction() == CMAESOptimizer.DEFAULT_STOPFITNESS ? symbolStatsEstimate : symbolStatsEstimate.getNullsFraction() == 1.0d ? symbolStatsEstimate2 : SymbolStatsEstimate.builder().setLowValue(MoreMath.min(symbolStatsEstimate.getLowValue(), symbolStatsEstimate2.getLowValue())).setHighValue(MoreMath.max(symbolStatsEstimate.getHighValue(), symbolStatsEstimate2.getHighValue())).setDistinctValuesCount(symbolStatsEstimate.getDistinctValuesCount() + MoreMath.min(symbolStatsEstimate2.getDistinctValuesCount(), this.input.getOutputRowCount() * symbolStatsEstimate.getNullsFraction())).setNullsFraction(symbolStatsEstimate.getNullsFraction() * symbolStatsEstimate2.getNullsFraction()).setAverageRowSize(MoreMath.max(symbolStatsEstimate.getAverageRowSize(), symbolStatsEstimate2.getAverageRowSize())).build();
        }
    }

    @Inject
    public ScalarStatsCalculator(Metadata metadata) {
        this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata can not be null");
    }

    public SymbolStatsEstimate calculate(Expression expression, PlanNodeStatsEstimate planNodeStatsEstimate, Session session) {
        return new Visitor(planNodeStatsEstimate, session).process(expression);
    }
}
