package com.facebook.presto.cost;

import com.facebook.presto.Session;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionManager;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.planner.plan.SemiJoinNode;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.OriginalExpressionUtils;
import com.facebook.presto.sql.relational.ProjectNodeUtils;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.NotExpression;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

/* loaded from: input_file:com/facebook/presto/cost/SimpleFilterProjectSemiJoinStatsRule.class */
public class SimpleFilterProjectSemiJoinStatsRule extends SimpleStatsRule<FilterNode> {
    private static final Pattern<FilterNode> PATTERN = Patterns.filter();
    private final FilterStatsCalculator filterStatsCalculator;
    private final LogicalRowExpressions logicalRowExpressions;
    private final FunctionResolution functionResolution;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/facebook/presto/cost/SimpleFilterProjectSemiJoinStatsRule$SemiJoinOutputFilter.class */
    public static class SemiJoinOutputFilter {
        private final boolean negated;
        private final RowExpression remainingPredicate;

        public SemiJoinOutputFilter(boolean z, RowExpression rowExpression) {
            this.negated = z;
            this.remainingPredicate = (RowExpression) Objects.requireNonNull(rowExpression, "remainingPredicate can not be null");
        }

        public boolean isNegated() {
            return this.negated;
        }

        public RowExpression getRemainingPredicate() {
            return this.remainingPredicate;
        }
    }

    public SimpleFilterProjectSemiJoinStatsRule(StatsNormalizer statsNormalizer, FilterStatsCalculator filterStatsCalculator, FunctionManager functionManager) {
        super(statsNormalizer);
        this.filterStatsCalculator = (FilterStatsCalculator) Objects.requireNonNull(filterStatsCalculator, "filterStatsCalculator can not be null");
        Objects.requireNonNull(functionManager, "functionManager can not be null");
        this.logicalRowExpressions = new LogicalRowExpressions(new RowExpressionDeterminismEvaluator(functionManager), new FunctionResolution(functionManager), functionManager);
        this.functionResolution = new FunctionResolution(functionManager);
    }

    @Override // com.facebook.presto.cost.ComposableStatsCalculator.Rule
    public Pattern<FilterNode> getPattern() {
        return PATTERN;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.facebook.presto.cost.SimpleStatsRule
    public Optional<PlanNodeStatsEstimate> doCalculate(FilterNode filterNode, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider typeProvider) {
        SemiJoinNode semiJoinNode;
        PlanNode resolve = lookup.resolve(filterNode.getSource());
        if (resolve instanceof ProjectNode) {
            ProjectNode projectNode = (ProjectNode) resolve;
            if (!ProjectNodeUtils.isIdentity(projectNode)) {
                return Optional.empty();
            }
            PlanNode resolve2 = lookup.resolve(projectNode.getSource());
            if (!(resolve2 instanceof SemiJoinNode)) {
                return Optional.empty();
            }
            semiJoinNode = (SemiJoinNode) resolve2;
        } else {
            if (!(resolve instanceof SemiJoinNode)) {
                return Optional.empty();
            }
            semiJoinNode = (SemiJoinNode) resolve;
        }
        return calculate(filterNode, semiJoinNode, statsProvider, session, typeProvider);
    }

    private Optional<PlanNodeStatsEstimate> calculate(FilterNode filterNode, SemiJoinNode semiJoinNode, StatsProvider statsProvider, Session session, TypeProvider typeProvider) {
        PlanNodeStatsEstimate stats = statsProvider.getStats(semiJoinNode.getSource());
        PlanNodeStatsEstimate stats2 = statsProvider.getStats(semiJoinNode.getFilteringSource());
        VariableReferenceExpression filteringSourceJoinVariable = semiJoinNode.getFilteringSourceJoinVariable();
        VariableReferenceExpression sourceJoinVariable = semiJoinNode.getSourceJoinVariable();
        VariableReferenceExpression semiJoinOutput = semiJoinNode.getSemiJoinOutput();
        Optional<SemiJoinOutputFilter> extractSemiJoinOutputFilter = OriginalExpressionUtils.isExpression(filterNode.getPredicate()) ? extractSemiJoinOutputFilter(OriginalExpressionUtils.castToExpression(filterNode.getPredicate()), semiJoinOutput) : extractSemiJoinOutputFilter(filterNode.getPredicate(), semiJoinOutput);
        if (!extractSemiJoinOutputFilter.isPresent()) {
            return Optional.empty();
        }
        PlanNodeStatsEstimate computeAntiJoin = extractSemiJoinOutputFilter.get().isNegated() ? SemiJoinStatsCalculator.computeAntiJoin(stats, stats2, sourceJoinVariable, filteringSourceJoinVariable) : SemiJoinStatsCalculator.computeSemiJoin(stats, stats2, sourceJoinVariable, filteringSourceJoinVariable);
        if (computeAntiJoin.isOutputRowCountUnknown()) {
            return Optional.of(PlanNodeStatsEstimate.unknown());
        }
        PlanNodeStatsEstimate filterStats = OriginalExpressionUtils.isExpression(filterNode.getPredicate()) ? this.filterStatsCalculator.filterStats(computeAntiJoin, OriginalExpressionUtils.castToExpression(extractSemiJoinOutputFilter.get().getRemainingPredicate()), session, typeProvider) : this.filterStatsCalculator.filterStats(computeAntiJoin, extractSemiJoinOutputFilter.get().getRemainingPredicate(), session);
        return filterStats.isOutputRowCountUnknown() ? Optional.of(computeAntiJoin.mapOutputRowCount(d -> {
            return Double.valueOf(d.doubleValue() * 0.9d);
        })) : Optional.of(filterStats);
    }

    private Optional<SemiJoinOutputFilter> extractSemiJoinOutputFilter(Expression expression, VariableReferenceExpression variableReferenceExpression) {
        List<Expression> extractConjuncts = ExpressionUtils.extractConjuncts(expression);
        List list = (List) extractConjuncts.stream().filter(expression2 -> {
            return isSemiJoinOutputReference(expression2, variableReferenceExpression);
        }).collect(ImmutableList.toImmutableList());
        if (list.size() != 1) {
            return Optional.empty();
        }
        Expression expression3 = (Expression) Iterables.getOnlyElement(list);
        return Optional.of(new SemiJoinOutputFilter(expression3 instanceof NotExpression, OriginalExpressionUtils.castToRowExpression(ExpressionUtils.combineConjuncts((Collection<Expression>) extractConjuncts.stream().filter(expression4 -> {
            return expression4 != expression3;
        }).collect(ImmutableList.toImmutableList())))));
    }

    private Optional<SemiJoinOutputFilter> extractSemiJoinOutputFilter(RowExpression rowExpression, RowExpression rowExpression2) {
        Preconditions.checkState(!OriginalExpressionUtils.isExpression(rowExpression));
        List<RowExpression> extractConjuncts = LogicalRowExpressions.extractConjuncts(rowExpression);
        List list = (List) extractConjuncts.stream().filter(rowExpression3 -> {
            return isSemiJoinOutputReference(rowExpression3, rowExpression2);
        }).collect(ImmutableList.toImmutableList());
        if (list.size() != 1) {
            return Optional.empty();
        }
        RowExpression rowExpression4 = (RowExpression) Iterables.getOnlyElement(list);
        return Optional.of(new SemiJoinOutputFilter(isNotFunction(rowExpression4), this.logicalRowExpressions.combineConjuncts((Collection<RowExpression>) extractConjuncts.stream().filter(rowExpression5 -> {
            return rowExpression5 != rowExpression4;
        }).collect(ImmutableList.toImmutableList()))));
    }

    private boolean isSemiJoinOutputReference(RowExpression rowExpression, RowExpression rowExpression2) {
        return rowExpression.equals(rowExpression2) || (isNotFunction(rowExpression) && ((CallExpression) rowExpression).getArguments().get(0).equals(rowExpression2));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static boolean isSemiJoinOutputReference(Expression expression, VariableReferenceExpression variableReferenceExpression) {
        SymbolReference symbolReference = new SymbolReference(variableReferenceExpression.getName());
        return expression.equals(symbolReference) || ((expression instanceof NotExpression) && ((NotExpression) expression).getValue().equals(symbolReference));
    }

    private boolean isNotFunction(RowExpression rowExpression) {
        return (rowExpression instanceof CallExpression) && this.functionResolution.isNotFunction(((CallExpression) rowExpression).getFunctionHandle());
    }
}
