package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.hive.jdbc.$internal.org.apache.hadoop.fs.shell.Count;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionManager;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.BooleanType;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.SymbolsExtractor;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.ApplyNode;
import com.facebook.presto.sql.planner.plan.AssignUniqueId;
import com.facebook.presto.sql.planner.plan.Assignments;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.PlanVisitor;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.InPredicate;
import com.facebook.presto.sql.tree.IsNotNullPredicate;
import com.facebook.presto.sql.tree.IsNullPredicate;
import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.NotExpression;
import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.sql.tree.SearchedCaseExpression;
import com.facebook.presto.sql.tree.SymbolReference;
import com.facebook.presto.sql.tree.WhenClause;
import com.facebook.presto.sql.util.AstUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;
import javax.annotation.Nullable;

/* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.class */
public class TransformCorrelatedInPredicateToJoin implements Rule<ApplyNode> {
    private static final Pattern<ApplyNode> PATTERN = Patterns.applyNode().with(Pattern.nonEmpty(Patterns.Apply.correlation()));
    private final FunctionManager functionManager;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin$Decorrelated.class */
    public static class Decorrelated {
        private final List<Expression> correlatedPredicates;
        private final PlanNode decorrelatedNode;

        public Decorrelated(List<Expression> list, PlanNode planNode) {
            this.correlatedPredicates = ImmutableList.copyOf((Collection) Objects.requireNonNull(list, "correlatedPredicates is null"));
            this.decorrelatedNode = (PlanNode) Objects.requireNonNull(planNode, "decorrelatedNode is null");
        }

        public List<Expression> getCorrelatedPredicates() {
            return this.correlatedPredicates;
        }

        public PlanNode getDecorrelatedNode() {
            return this.decorrelatedNode;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin$DecorrelatingVisitor.class */
    public static class DecorrelatingVisitor extends PlanVisitor<Optional<Decorrelated>, PlanNode> {
        private final Lookup lookup;
        private final Set<Symbol> correlation;

        public DecorrelatingVisitor(Lookup lookup, Iterable<Symbol> iterable) {
            this.lookup = (Lookup) Objects.requireNonNull(lookup, "lookup is null");
            this.correlation = ImmutableSet.copyOf((Iterable) Objects.requireNonNull(iterable, "correlation is null"));
        }

        public Optional<Decorrelated> decorrelate(PlanNode planNode) {
            return (Optional) this.lookup.resolve(planNode).accept(this, planNode);
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public Optional<Decorrelated> visitProject(ProjectNode projectNode, PlanNode planNode) {
            return isCorrelatedShallowly(projectNode) ? Optional.empty() : decorrelate(projectNode.getSource()).map(decorrelated -> {
                Assignments.Builder putAll = Assignments.builder().putAll(projectNode.getAssignments());
                Stream<R> flatMap = decorrelated.getCorrelatedPredicates().stream().flatMap((v0) -> {
                    return AstUtils.preOrder(v0);
                });
                Class<SymbolReference> cls = SymbolReference.class;
                SymbolReference.class.getClass();
                Stream filter = flatMap.filter((v1) -> {
                    return r1.isInstance(v1);
                });
                Class<SymbolReference> cls2 = SymbolReference.class;
                SymbolReference.class.getClass();
                filter.map((v1) -> {
                    return r1.cast(v1);
                }).filter(symbolReference -> {
                    return !this.correlation.contains(Symbol.from(symbolReference));
                }).forEach(symbolReference2 -> {
                    putAll.putIdentity(Symbol.from(symbolReference2));
                });
                return new Decorrelated(decorrelated.getCorrelatedPredicates(), new ProjectNode(projectNode.getId(), decorrelated.getDecorrelatedNode(), putAll.build()));
            });
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public Optional<Decorrelated> visitFilter(FilterNode filterNode, PlanNode planNode) {
            return decorrelate(filterNode.getSource()).map(decorrelated -> {
                return new Decorrelated(ImmutableList.builder().addAll((Iterable) decorrelated.getCorrelatedPredicates()).add((ImmutableList.Builder) filterNode.getPredicate()).build(), decorrelated.getDecorrelatedNode());
            });
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public Optional<Decorrelated> visitPlan(PlanNode planNode, PlanNode planNode2) {
            return isCorrelatedRecursively(planNode) ? Optional.empty() : Optional.of(new Decorrelated(ImmutableList.of(), planNode2));
        }

        private boolean isCorrelatedRecursively(PlanNode planNode) {
            if (isCorrelatedShallowly(planNode)) {
                return true;
            }
            Stream<PlanNode> stream = planNode.getSources().stream();
            Lookup lookup = this.lookup;
            lookup.getClass();
            return stream.map(lookup::resolve).anyMatch(this::isCorrelatedRecursively);
        }

        private boolean isCorrelatedShallowly(PlanNode planNode) {
            Stream<Symbol> stream = SymbolsExtractor.extractUniqueNonRecursive(planNode).stream();
            Set<Symbol> set = this.correlation;
            set.getClass();
            return stream.anyMatch((v1) -> {
                return r1.contains(v1);
            });
        }
    }

    public TransformCorrelatedInPredicateToJoin(FunctionManager functionManager) {
        this.functionManager = (FunctionManager) Objects.requireNonNull(functionManager, "functionManager is null");
    }

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public Pattern<ApplyNode> getPattern() {
        return PATTERN;
    }

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public Rule.Result apply(ApplyNode applyNode, Captures captures, Rule.Context context) {
        Assignments subqueryAssignments = applyNode.getSubqueryAssignments();
        if (subqueryAssignments.size() != 1) {
            return Rule.Result.empty();
        }
        Expression expression = (Expression) Iterables.getOnlyElement(subqueryAssignments.getExpressions());
        if (!(expression instanceof InPredicate)) {
            return Rule.Result.empty();
        }
        return apply(context.getSession(), applyNode, (InPredicate) expression, (Symbol) Iterables.getOnlyElement(subqueryAssignments.getSymbols()), context.getLookup(), context.getIdAllocator(), context.getSymbolAllocator());
    }

    private Rule.Result apply(Session session, ApplyNode applyNode, InPredicate inPredicate, Symbol symbol, Lookup lookup, PlanNodeIdAllocator planNodeIdAllocator, SymbolAllocator symbolAllocator) {
        Optional<Decorrelated> decorrelate = new DecorrelatingVisitor(lookup, applyNode.getCorrelation()).decorrelate(applyNode.getSubquery());
        return !decorrelate.isPresent() ? Rule.Result.empty() : Rule.Result.ofPlanNode(buildInPredicateEquivalent(session, applyNode, inPredicate, symbol, decorrelate.get(), planNodeIdAllocator, symbolAllocator));
    }

    private PlanNode buildInPredicateEquivalent(Session session, ApplyNode applyNode, InPredicate inPredicate, Symbol symbol, Decorrelated decorrelated, PlanNodeIdAllocator planNodeIdAllocator, SymbolAllocator symbolAllocator) {
        Expression and = ExpressionUtils.and(decorrelated.getCorrelatedPredicates());
        PlanNode decorrelatedNode = decorrelated.getDecorrelatedNode();
        AssignUniqueId assignUniqueId = new AssignUniqueId(planNodeIdAllocator.getNextId(), applyNode.getInput(), symbolAllocator.newSymbol("unique", BigintType.BIGINT));
        Symbol newSymbol = symbolAllocator.newSymbol("buildSideKnownNonNull", BigintType.BIGINT);
        ProjectNode projectNode = new ProjectNode(planNodeIdAllocator.getNextId(), decorrelatedNode, Assignments.builder().putIdentities(decorrelatedNode.getOutputSymbols()).put(newSymbol, bigint(0L)).build());
        Symbol from = Symbol.from(inPredicate.getValue());
        Symbol from2 = Symbol.from(inPredicate.getValueList());
        JoinNode leftOuterJoin = leftOuterJoin(planNodeIdAllocator, assignUniqueId, projectNode, ExpressionUtils.and(ExpressionUtils.or(new IsNullPredicate(from.toSymbolReference()), new ComparisonExpression(ComparisonExpression.Operator.EQUAL, from.toSymbolReference(), from2.toSymbolReference()), new IsNullPredicate(from2.toSymbolReference())), and));
        Symbol newSymbol2 = symbolAllocator.newSymbol("countMatches", BigintType.BIGINT);
        Symbol newSymbol3 = symbolAllocator.newSymbol("countNullMatches", BigintType.BIGINT);
        Expression and2 = ExpressionUtils.and(isNotNull(from), isNotNull(from2));
        return new ProjectNode(planNodeIdAllocator.getNextId(), new AggregationNode(planNodeIdAllocator.getNextId(), leftOuterJoin, ImmutableMap.builder().put(newSymbol2, countWithFilter(session, and2)).put(newSymbol3, countWithFilter(session, ExpressionUtils.and(isNotNull(newSymbol), not(and2)))).build(), AggregationNode.singleGroupingSet(assignUniqueId.getOutputSymbols()), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty()), Assignments.builder().putIdentities(applyNode.getInput().getOutputSymbols()).put(symbol, new SearchedCaseExpression(ImmutableList.of(new WhenClause(isGreaterThan(newSymbol2, 0L), booleanConstant(true)), new WhenClause(isGreaterThan(newSymbol3, 0L), booleanConstant(null))), Optional.of(booleanConstant(false)))).build());
    }

    private static JoinNode leftOuterJoin(PlanNodeIdAllocator planNodeIdAllocator, AssignUniqueId assignUniqueId, ProjectNode projectNode, Expression expression) {
        return new JoinNode(planNodeIdAllocator.getNextId(), JoinNode.Type.LEFT, assignUniqueId, projectNode, ImmutableList.of(), ImmutableList.builder().addAll((Iterable) assignUniqueId.getOutputSymbols()).addAll((Iterable) projectNode.getOutputSymbols()).build(), Optional.of(expression), Optional.empty(), Optional.empty(), Optional.empty());
    }

    private AggregationNode.Aggregation countWithFilter(Session session, Expression expression) {
        return new AggregationNode.Aggregation(new FunctionCall(QualifiedName.of(Count.NAME), Optional.empty(), Optional.of(expression), Optional.empty(), false, ImmutableList.of()), this.functionManager.resolveFunction(session, QualifiedName.of(Count.NAME), ImmutableList.of()), Optional.empty());
    }

    private static Expression isGreaterThan(Symbol symbol, long j) {
        return new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, symbol.toSymbolReference(), bigint(j));
    }

    private static Expression not(Expression expression) {
        return new NotExpression(expression);
    }

    private static Expression isNotNull(Symbol symbol) {
        return new IsNotNullPredicate(symbol.toSymbolReference());
    }

    private static Expression bigint(long j) {
        return new Cast(new LongLiteral(String.valueOf(j)), BigintType.BIGINT.toString());
    }

    private static Expression booleanConstant(@Nullable Boolean bool) {
        return bool == null ? new Cast(new NullLiteral(), BooleanType.BOOLEAN.toString()) : new BooleanLiteral(bool.toString());
    }
}
