package com.facebook.presto.cost;

import com.facebook.presto.Session;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.spi.relation.LogicalRowExpressions;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.relational.OriginalExpressionUtils;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.SymbolReference;
import com.facebook.presto.util.MoreMath;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
import com.google.common.collect.UnmodifiableIterator;
import java.util.Collection;
import java.util.Comparator;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;

/* loaded from: input_file:com/facebook/presto/cost/JoinStatsRule.class */
public class JoinStatsRule extends SimpleStatsRule<JoinNode> {
    private static final Pattern<JoinNode> PATTERN = Patterns.join();
    private static final double DEFAULT_UNMATCHED_JOIN_COMPLEMENT_NDVS_COEFFICIENT = 0.5d;
    private final FilterStatsCalculator filterStatsCalculator;
    private final StatsNormalizer normalizer;
    private final double unmatchedJoinComplementNdvsCoefficient;

    public JoinStatsRule(FilterStatsCalculator filterStatsCalculator, StatsNormalizer statsNormalizer) {
        this(filterStatsCalculator, statsNormalizer, DEFAULT_UNMATCHED_JOIN_COMPLEMENT_NDVS_COEFFICIENT);
    }

    @VisibleForTesting
    JoinStatsRule(FilterStatsCalculator filterStatsCalculator, StatsNormalizer statsNormalizer, double d) {
        super(statsNormalizer);
        this.filterStatsCalculator = (FilterStatsCalculator) Objects.requireNonNull(filterStatsCalculator, "filterStatsCalculator is null");
        this.normalizer = statsNormalizer;
        this.unmatchedJoinComplementNdvsCoefficient = d;
    }

    @Override // com.facebook.presto.cost.ComposableStatsCalculator.Rule
    public Pattern<JoinNode> getPattern() {
        return PATTERN;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.facebook.presto.cost.SimpleStatsRule
    public Optional<PlanNodeStatsEstimate> doCalculate(JoinNode joinNode, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider typeProvider) {
        PlanNodeStatsEstimate stats = statsProvider.getStats(joinNode.getLeft());
        PlanNodeStatsEstimate stats2 = statsProvider.getStats(joinNode.getRight());
        PlanNodeStatsEstimate crossJoinStats = crossJoinStats(joinNode, stats, stats2);
        switch (joinNode.getType()) {
            case INNER:
                return Optional.of(computeInnerJoinStats(joinNode, crossJoinStats, session, typeProvider));
            case LEFT:
                return Optional.of(computeLeftJoinStats(joinNode, stats, stats2, crossJoinStats, session, typeProvider));
            case RIGHT:
                return Optional.of(computeRightJoinStats(joinNode, stats, stats2, crossJoinStats, session, typeProvider));
            case FULL:
                return Optional.of(computeFullJoinStats(joinNode, stats, stats2, crossJoinStats, session, typeProvider));
            default:
                throw new IllegalStateException("Unknown join type: " + joinNode.getType());
        }
    }

    private PlanNodeStatsEstimate computeFullJoinStats(JoinNode joinNode, PlanNodeStatsEstimate planNodeStatsEstimate, PlanNodeStatsEstimate planNodeStatsEstimate2, PlanNodeStatsEstimate planNodeStatsEstimate3, Session session, TypeProvider typeProvider) {
        return addJoinComplementStats(planNodeStatsEstimate2, computeLeftJoinStats(joinNode, planNodeStatsEstimate, planNodeStatsEstimate2, planNodeStatsEstimate3, session, typeProvider), calculateJoinComplementStats(joinNode.getFilter(), flippedCriteria(joinNode), planNodeStatsEstimate2, planNodeStatsEstimate));
    }

    private PlanNodeStatsEstimate computeLeftJoinStats(JoinNode joinNode, PlanNodeStatsEstimate planNodeStatsEstimate, PlanNodeStatsEstimate planNodeStatsEstimate2, PlanNodeStatsEstimate planNodeStatsEstimate3, Session session, TypeProvider typeProvider) {
        return addJoinComplementStats(planNodeStatsEstimate, computeInnerJoinStats(joinNode, planNodeStatsEstimate3, session, typeProvider), calculateJoinComplementStats(joinNode.getFilter(), joinNode.getCriteria(), planNodeStatsEstimate, planNodeStatsEstimate2));
    }

    private PlanNodeStatsEstimate computeRightJoinStats(JoinNode joinNode, PlanNodeStatsEstimate planNodeStatsEstimate, PlanNodeStatsEstimate planNodeStatsEstimate2, PlanNodeStatsEstimate planNodeStatsEstimate3, Session session, TypeProvider typeProvider) {
        return addJoinComplementStats(planNodeStatsEstimate2, computeInnerJoinStats(joinNode, planNodeStatsEstimate3, session, typeProvider), calculateJoinComplementStats(joinNode.getFilter(), flippedCriteria(joinNode), planNodeStatsEstimate2, planNodeStatsEstimate));
    }

    private PlanNodeStatsEstimate computeInnerJoinStats(JoinNode joinNode, PlanNodeStatsEstimate planNodeStatsEstimate, Session session, TypeProvider typeProvider) {
        if (joinNode.getCriteria().isEmpty()) {
            return !joinNode.getFilter().isPresent() ? planNodeStatsEstimate : OriginalExpressionUtils.isExpression(joinNode.getFilter().get()) ? this.filterStatsCalculator.filterStats(planNodeStatsEstimate, OriginalExpressionUtils.castToExpression(joinNode.getFilter().get()), session, typeProvider) : this.filterStatsCalculator.filterStats(planNodeStatsEstimate, joinNode.getFilter().get(), session, typeProvider);
        }
        PlanNodeStatsEstimate filterByEquiJoinClauses = filterByEquiJoinClauses(planNodeStatsEstimate, joinNode.getCriteria(), session, typeProvider);
        if (filterByEquiJoinClauses.isOutputRowCountUnknown()) {
            return PlanNodeStatsEstimate.unknown();
        }
        if (!joinNode.getFilter().isPresent()) {
            return filterByEquiJoinClauses;
        }
        PlanNodeStatsEstimate filterStats = OriginalExpressionUtils.isExpression(joinNode.getFilter().get()) ? this.filterStatsCalculator.filterStats(filterByEquiJoinClauses, OriginalExpressionUtils.castToExpression(joinNode.getFilter().get()), session, typeProvider) : this.filterStatsCalculator.filterStats(filterByEquiJoinClauses, joinNode.getFilter().get(), session, typeProvider);
        return filterStats.isOutputRowCountUnknown() ? this.normalizer.normalize(filterByEquiJoinClauses.mapOutputRowCount(d -> {
            return Double.valueOf(d.doubleValue() * 0.9d);
        })) : filterStats;
    }

    private PlanNodeStatsEstimate filterByEquiJoinClauses(PlanNodeStatsEstimate planNodeStatsEstimate, Collection<JoinNode.EquiJoinClause> collection, Session session, TypeProvider typeProvider) {
        Preconditions.checkArgument(!collection.isEmpty(), "clauses is empty");
        PlanNodeStatsEstimate unknown = PlanNodeStatsEstimate.unknown();
        LinkedList linkedList = new LinkedList(collection);
        JoinNode.EquiJoinClause equiJoinClause = (JoinNode.EquiJoinClause) linkedList.poll();
        for (int i = 0; i < collection.size(); i++) {
            PlanNodeStatsEstimate filterByEquiJoinClauses = filterByEquiJoinClauses(planNodeStatsEstimate, equiJoinClause, linkedList, session, typeProvider);
            if (unknown.isOutputRowCountUnknown() || (!filterByEquiJoinClauses.isOutputRowCountUnknown() && filterByEquiJoinClauses.getOutputRowCount() < unknown.getOutputRowCount())) {
                unknown = filterByEquiJoinClauses;
            }
            linkedList.add(equiJoinClause);
            equiJoinClause = (JoinNode.EquiJoinClause) linkedList.poll();
        }
        return unknown;
    }

    private PlanNodeStatsEstimate filterByEquiJoinClauses(PlanNodeStatsEstimate planNodeStatsEstimate, JoinNode.EquiJoinClause equiJoinClause, Collection<JoinNode.EquiJoinClause> collection, Session session, TypeProvider typeProvider) {
        PlanNodeStatsEstimate filterStats = this.filterStatsCalculator.filterStats(planNodeStatsEstimate, new ComparisonExpression(ComparisonExpression.Operator.EQUAL, new SymbolReference(equiJoinClause.getLeft().getName()), new SymbolReference(equiJoinClause.getRight().getName())), session, typeProvider);
        Iterator<JoinNode.EquiJoinClause> it2 = collection.iterator();
        while (it2.hasNext()) {
            filterStats = filterByAuxiliaryClause(filterStats, it2.next());
        }
        return filterStats;
    }

    private PlanNodeStatsEstimate filterByAuxiliaryClause(PlanNodeStatsEstimate planNodeStatsEstimate, JoinNode.EquiJoinClause equiJoinClause) {
        VariableStatsEstimate variableStatistics = planNodeStatsEstimate.getVariableStatistics(equiJoinClause.getLeft());
        VariableStatsEstimate variableStatistics2 = planNodeStatsEstimate.getVariableStatistics(equiJoinClause.getRight());
        StatisticRange from = StatisticRange.from(variableStatistics);
        StatisticRange from2 = StatisticRange.from(variableStatistics2);
        StatisticRange intersect = from.intersect(from2);
        double min = MoreMath.min(firstNonNaN(from.overlapPercentWith(intersect), 1.0d) * from.getDistinctValuesCount(), firstNonNaN(from2.overlapPercentWith(intersect), 1.0d) * from2.getDistinctValuesCount());
        return this.normalizer.normalize(PlanNodeStatsEstimate.buildFrom(planNodeStatsEstimate).setOutputRowCount(planNodeStatsEstimate.getOutputRowCount() * 0.9d).addVariableStatistics(equiJoinClause.getLeft(), VariableStatsEstimate.buildFrom(variableStatistics).setNullsFraction(CMAESOptimizer.DEFAULT_STOPFITNESS).setStatisticsRange(intersect).setDistinctValuesCount(min).build()).addVariableStatistics(equiJoinClause.getRight(), VariableStatsEstimate.buildFrom(variableStatistics2).setNullsFraction(CMAESOptimizer.DEFAULT_STOPFITNESS).setStatisticsRange(intersect).setDistinctValuesCount(min).build()).build());
    }

    private static double firstNonNaN(double... dArr) {
        for (double d : dArr) {
            if (!Double.isNaN(d)) {
                return d;
            }
        }
        throw new IllegalArgumentException("All values are NaN");
    }

    @VisibleForTesting
    PlanNodeStatsEstimate calculateJoinComplementStats(Optional<RowExpression> optional, List<JoinNode.EquiJoinClause> list, PlanNodeStatsEstimate planNodeStatsEstimate, PlanNodeStatsEstimate planNodeStatsEstimate2) {
        if (planNodeStatsEstimate2.getOutputRowCount() == CMAESOptimizer.DEFAULT_STOPFITNESS) {
            return planNodeStatsEstimate;
        }
        if (list.isEmpty()) {
            return optional.isPresent() ? PlanNodeStatsEstimate.unknown() : this.normalizer.normalize(planNodeStatsEstimate.mapOutputRowCount(d -> {
                return Double.valueOf(CMAESOptimizer.DEFAULT_STOPFITNESS);
            }));
        }
        int size = optional.isPresent() ? OriginalExpressionUtils.isExpression(optional.get()) ? ExpressionUtils.extractConjuncts(OriginalExpressionUtils.castToExpression(optional.get())).size() : LogicalRowExpressions.extractConjuncts(optional.get()).size() : 0;
        return (PlanNodeStatsEstimate) list.stream().map(equiJoinClause -> {
            return calculateJoinComplementStats(planNodeStatsEstimate, planNodeStatsEstimate2, equiJoinClause, (list.size() - 1) + size);
        }).filter(planNodeStatsEstimate3 -> {
            return !planNodeStatsEstimate3.isOutputRowCountUnknown();
        }).max(Comparator.comparingDouble((v0) -> {
            return v0.getOutputRowCount();
        })).map(planNodeStatsEstimate4 -> {
            return this.normalizer.normalize(planNodeStatsEstimate4);
        }).orElse(PlanNodeStatsEstimate.unknown());
    }

    private PlanNodeStatsEstimate calculateJoinComplementStats(PlanNodeStatsEstimate planNodeStatsEstimate, PlanNodeStatsEstimate planNodeStatsEstimate2, JoinNode.EquiJoinClause equiJoinClause, int i) {
        PlanNodeStatsEstimate mapOutputRowCount;
        VariableStatsEstimate variableStatistics = planNodeStatsEstimate.getVariableStatistics(equiJoinClause.getLeft());
        VariableStatsEstimate variableStatistics2 = planNodeStatsEstimate2.getVariableStatistics(equiJoinClause.getRight());
        double distinctValuesCount = variableStatistics.getDistinctValuesCount();
        double distinctValuesCount2 = variableStatistics2.getDistinctValuesCount() * this.unmatchedJoinComplementNdvsCoefficient;
        if (distinctValuesCount > distinctValuesCount2) {
            double valuesFraction = ((variableStatistics.getValuesFraction() * (distinctValuesCount - distinctValuesCount2)) / distinctValuesCount) + variableStatistics.getNullsFraction();
            double nullsFraction = variableStatistics.getNullsFraction() / valuesFraction;
            mapOutputRowCount = planNodeStatsEstimate.mapVariableColumnStatistics(equiJoinClause.getLeft(), variableStatsEstimate -> {
                return VariableStatsEstimate.buildFrom(variableStatsEstimate).setLowValue(variableStatistics.getLowValue()).setHighValue(variableStatistics.getHighValue()).setNullsFraction(nullsFraction).setDistinctValuesCount(distinctValuesCount - distinctValuesCount2).build();
            }).mapOutputRowCount(d -> {
                return Double.valueOf(d.doubleValue() * valuesFraction);
            });
        } else {
            if (distinctValuesCount > distinctValuesCount2) {
                return PlanNodeStatsEstimate.unknown();
            }
            mapOutputRowCount = planNodeStatsEstimate.mapVariableColumnStatistics(equiJoinClause.getLeft(), variableStatsEstimate2 -> {
                return VariableStatsEstimate.buildFrom(variableStatsEstimate2).setLowValue(Double.NaN).setHighValue(Double.NaN).setNullsFraction(1.0d).setDistinctValuesCount(CMAESOptimizer.DEFAULT_STOPFITNESS).build();
            }).mapOutputRowCount(d2 -> {
                return Double.valueOf(d2.doubleValue() * variableStatistics.getNullsFraction());
            });
        }
        return mapOutputRowCount.mapOutputRowCount(d3 -> {
            return Double.valueOf(Math.min(planNodeStatsEstimate.getOutputRowCount(), d3.doubleValue() / Math.pow(0.9d, i)));
        });
    }

    /* JADX WARN: Multi-variable type inference failed */
    @VisibleForTesting
    PlanNodeStatsEstimate addJoinComplementStats(PlanNodeStatsEstimate planNodeStatsEstimate, PlanNodeStatsEstimate planNodeStatsEstimate2, PlanNodeStatsEstimate planNodeStatsEstimate3) {
        double outputRowCount = planNodeStatsEstimate2.getOutputRowCount();
        double outputRowCount2 = planNodeStatsEstimate3.getOutputRowCount();
        if (outputRowCount2 == CMAESOptimizer.DEFAULT_STOPFITNESS) {
            return planNodeStatsEstimate2;
        }
        double d = outputRowCount + outputRowCount2;
        PlanNodeStatsEstimate.Builder buildFrom = PlanNodeStatsEstimate.buildFrom(planNodeStatsEstimate2);
        buildFrom.setOutputRowCount(d);
        for (VariableReferenceExpression variableReferenceExpression : planNodeStatsEstimate3.getVariablesWithKnownStatistics()) {
            VariableStatsEstimate variableStatistics = planNodeStatsEstimate.getVariableStatistics(variableReferenceExpression);
            VariableStatsEstimate variableStatistics2 = planNodeStatsEstimate2.getVariableStatistics(variableReferenceExpression);
            buildFrom.addVariableStatistics(variableReferenceExpression, VariableStatsEstimate.buildFrom(variableStatistics2).setLowValue(variableStatistics.getLowValue()).setHighValue(variableStatistics.getHighValue()).setDistinctValuesCount(variableStatistics.getDistinctValuesCount()).setNullsFraction(((variableStatistics2.getNullsFraction() * outputRowCount) + (planNodeStatsEstimate3.getVariableStatistics(variableReferenceExpression).getNullsFraction() * outputRowCount2)) / d).build());
        }
        UnmodifiableIterator it2 = Sets.difference(planNodeStatsEstimate2.getVariablesWithKnownStatistics(), planNodeStatsEstimate3.getVariablesWithKnownStatistics()).iterator();
        while (it2.hasNext()) {
            VariableReferenceExpression variableReferenceExpression2 = (VariableReferenceExpression) it2.next();
            VariableStatsEstimate variableStatistics3 = planNodeStatsEstimate2.getVariableStatistics(variableReferenceExpression2);
            double nullsFraction = ((variableStatistics3.getNullsFraction() * outputRowCount) + outputRowCount2) / d;
            buildFrom.addVariableStatistics(variableReferenceExpression2, variableStatistics3.mapNullsFraction(d2 -> {
                return Double.valueOf(nullsFraction);
            }));
        }
        return buildFrom.build();
    }

    private PlanNodeStatsEstimate crossJoinStats(JoinNode joinNode, PlanNodeStatsEstimate planNodeStatsEstimate, PlanNodeStatsEstimate planNodeStatsEstimate2) {
        PlanNodeStatsEstimate.Builder outputRowCount = PlanNodeStatsEstimate.builder().setOutputRowCount(planNodeStatsEstimate.getOutputRowCount() * planNodeStatsEstimate2.getOutputRowCount());
        joinNode.getLeft().getOutputVariables().forEach(variableReferenceExpression -> {
            outputRowCount.addVariableStatistics(variableReferenceExpression, planNodeStatsEstimate.getVariableStatistics(variableReferenceExpression));
        });
        joinNode.getRight().getOutputVariables().forEach(variableReferenceExpression2 -> {
            outputRowCount.addVariableStatistics(variableReferenceExpression2, planNodeStatsEstimate2.getVariableStatistics(variableReferenceExpression2));
        });
        return this.normalizer.normalize(outputRowCount.build());
    }

    private List<JoinNode.EquiJoinClause> flippedCriteria(JoinNode joinNode) {
        return (List) joinNode.getCriteria().stream().map((v0) -> {
            return v0.flip();
        }).collect(ImmutableList.toImmutableList());
    }
}
