package com.facebook.presto.cost;

import com.facebook.presto.Session;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SemiJoinNode;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.NotExpression;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

/* loaded from: input_file:com/facebook/presto/cost/SimpleFilterProjectSemiJoinStatsRule.class */
public class SimpleFilterProjectSemiJoinStatsRule extends SimpleStatsRule<FilterNode> {
    private static final Pattern<FilterNode> PATTERN = Patterns.filter();
    private final FilterStatsCalculator filterStatsCalculator;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/facebook/presto/cost/SimpleFilterProjectSemiJoinStatsRule$SemiJoinOutputFilter.class */
    public static class SemiJoinOutputFilter {
        private final boolean negated;
        private final Expression remainingPredicate;

        public SemiJoinOutputFilter(boolean z, Expression expression) {
            this.negated = z;
            this.remainingPredicate = (Expression) Objects.requireNonNull(expression, "remainingPredicate can not be null");
        }

        public boolean isNegated() {
            return this.negated;
        }

        public Expression getRemainingPredicate() {
            return this.remainingPredicate;
        }
    }

    public SimpleFilterProjectSemiJoinStatsRule(StatsNormalizer statsNormalizer, FilterStatsCalculator filterStatsCalculator) {
        super(statsNormalizer);
        this.filterStatsCalculator = (FilterStatsCalculator) Objects.requireNonNull(filterStatsCalculator, "filterStatsCalculator can not be null");
    }

    @Override // com.facebook.presto.cost.ComposableStatsCalculator.Rule
    public Pattern<FilterNode> getPattern() {
        return PATTERN;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.facebook.presto.cost.SimpleStatsRule
    public Optional<PlanNodeStatsEstimate> doCalculate(FilterNode filterNode, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider typeProvider) {
        SemiJoinNode semiJoinNode;
        PlanNode resolve = lookup.resolve(filterNode.getSource());
        if (resolve instanceof ProjectNode) {
            ProjectNode projectNode = (ProjectNode) resolve;
            if (!projectNode.isIdentity()) {
                return Optional.empty();
            }
            PlanNode resolve2 = lookup.resolve(projectNode.getSource());
            if (!(resolve2 instanceof SemiJoinNode)) {
                return Optional.empty();
            }
            semiJoinNode = (SemiJoinNode) resolve2;
        } else {
            if (!(resolve instanceof SemiJoinNode)) {
                return Optional.empty();
            }
            semiJoinNode = (SemiJoinNode) resolve;
        }
        return calculate(filterNode, semiJoinNode, statsProvider, session, typeProvider);
    }

    private Optional<PlanNodeStatsEstimate> calculate(FilterNode filterNode, SemiJoinNode semiJoinNode, StatsProvider statsProvider, Session session, TypeProvider typeProvider) {
        PlanNodeStatsEstimate stats = statsProvider.getStats(semiJoinNode.getSource());
        PlanNodeStatsEstimate stats2 = statsProvider.getStats(semiJoinNode.getFilteringSource());
        Symbol filteringSourceJoinSymbol = semiJoinNode.getFilteringSourceJoinSymbol();
        Symbol sourceJoinSymbol = semiJoinNode.getSourceJoinSymbol();
        Optional<SemiJoinOutputFilter> extractSemiJoinOutputFilter = extractSemiJoinOutputFilter(filterNode.getPredicate(), semiJoinNode.getSemiJoinOutput());
        if (!extractSemiJoinOutputFilter.isPresent()) {
            return Optional.empty();
        }
        PlanNodeStatsEstimate computeAntiJoin = extractSemiJoinOutputFilter.get().isNegated() ? SemiJoinStatsCalculator.computeAntiJoin(stats, stats2, sourceJoinSymbol, filteringSourceJoinSymbol) : SemiJoinStatsCalculator.computeSemiJoin(stats, stats2, sourceJoinSymbol, filteringSourceJoinSymbol);
        if (computeAntiJoin.isOutputRowCountUnknown()) {
            return Optional.of(PlanNodeStatsEstimate.unknown());
        }
        PlanNodeStatsEstimate filterStats = this.filterStatsCalculator.filterStats(computeAntiJoin, extractSemiJoinOutputFilter.get().getRemainingPredicate(), session, typeProvider);
        return filterStats.isOutputRowCountUnknown() ? Optional.of(computeAntiJoin.mapOutputRowCount(d -> {
            return Double.valueOf(d.doubleValue() * 0.9d);
        })) : Optional.of(filterStats);
    }

    private static Optional<SemiJoinOutputFilter> extractSemiJoinOutputFilter(Expression expression, Symbol symbol) {
        List<Expression> extractConjuncts = ExpressionUtils.extractConjuncts(expression);
        List list = (List) extractConjuncts.stream().filter(expression2 -> {
            return isSemiJoinOutputReference(expression2, symbol);
        }).collect(ImmutableList.toImmutableList());
        if (list.size() != 1) {
            return Optional.empty();
        }
        Expression expression3 = (Expression) Iterables.getOnlyElement(list);
        return Optional.of(new SemiJoinOutputFilter(expression3 instanceof NotExpression, ExpressionUtils.combineConjuncts((Collection<Expression>) extractConjuncts.stream().filter(expression4 -> {
            return expression4 != expression3;
        }).collect(ImmutableList.toImmutableList()))));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static boolean isSemiJoinOutputReference(Expression expression, Symbol symbol) {
        SymbolReference symbolReference = symbol.toSymbolReference();
        return expression.equals(symbolReference) || ((expression instanceof NotExpression) && ((NotExpression) expression).getValue().equals(symbolReference));
    }
}
