package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.hive.jdbc.$internal.org.apache.hadoop.fs.shell.Count;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionManager;
import com.facebook.presto.spi.function.StandardFunctionResolution;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.planner.PlanVariableAllocator;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.VariablesExtractor;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.ApplyNode;
import com.facebook.presto.sql.planner.plan.AssignUniqueId;
import com.facebook.presto.sql.planner.plan.AssignmentUtils;
import com.facebook.presto.sql.planner.plan.InternalPlanVisitor;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.OriginalExpressionUtils;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.InPredicate;
import com.facebook.presto.sql.tree.IsNotNullPredicate;
import com.facebook.presto.sql.tree.IsNullPredicate;
import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.NotExpression;
import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.SearchedCaseExpression;
import com.facebook.presto.sql.tree.SymbolReference;
import com.facebook.presto.sql.tree.WhenClause;
import com.facebook.presto.sql.util.AstUtils;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;
import javax.annotation.Nullable;

/* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.class */
public class TransformCorrelatedInPredicateToJoin implements Rule<ApplyNode> {
    private static final Pattern<ApplyNode> PATTERN = Patterns.applyNode().with(Pattern.nonEmpty(Patterns.Apply.correlation()));
    private final StandardFunctionResolution functionResolution;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin$Decorrelated.class */
    public static class Decorrelated {
        private final List<Expression> correlatedPredicates;
        private final PlanNode decorrelatedNode;

        public Decorrelated(List<Expression> list, PlanNode planNode) {
            this.correlatedPredicates = ImmutableList.copyOf((Collection) Objects.requireNonNull(list, "correlatedPredicates is null"));
            this.decorrelatedNode = (PlanNode) Objects.requireNonNull(planNode, "decorrelatedNode is null");
        }

        public List<Expression> getCorrelatedPredicates() {
            return this.correlatedPredicates;
        }

        public PlanNode getDecorrelatedNode() {
            return this.decorrelatedNode;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin$DecorrelatingVisitor.class */
    public static class DecorrelatingVisitor extends InternalPlanVisitor<Optional<Decorrelated>, PlanNode> {
        private final Lookup lookup;
        private final Set<VariableReferenceExpression> correlation;
        private final TypeProvider types;

        public DecorrelatingVisitor(Lookup lookup, Iterable<VariableReferenceExpression> iterable, TypeProvider typeProvider) {
            this.lookup = (Lookup) Objects.requireNonNull(lookup, "lookup is null");
            this.correlation = ImmutableSet.copyOf((Iterable) Objects.requireNonNull(iterable, "correlation is null"));
            this.types = (TypeProvider) Objects.requireNonNull(typeProvider, "types is null");
        }

        public Optional<Decorrelated> decorrelate(PlanNode planNode) {
            return (Optional) this.lookup.resolve(planNode).accept(this, planNode);
        }

        @Override // com.facebook.presto.spi.plan.PlanVisitor
        public Optional<Decorrelated> visitProject(ProjectNode projectNode, PlanNode planNode) {
            return isCorrelatedShallowly(projectNode) ? Optional.empty() : decorrelate(projectNode.getSource()).map(decorrelated -> {
                Assignments.Builder putAll = Assignments.builder().putAll(projectNode.getAssignments());
                Stream<R> flatMap = decorrelated.getCorrelatedPredicates().stream().flatMap((v0) -> {
                    return AstUtils.preOrder(v0);
                });
                Class<SymbolReference> cls = SymbolReference.class;
                SymbolReference.class.getClass();
                Stream filter = flatMap.filter((v1) -> {
                    return r1.isInstance(v1);
                });
                Class<SymbolReference> cls2 = SymbolReference.class;
                SymbolReference.class.getClass();
                Stream map = filter.map((v1) -> {
                    return r1.cast(v1);
                }).map(symbolReference -> {
                    return new VariableReferenceExpression(symbolReference.getName(), this.types.get(symbolReference));
                }).filter(variableReferenceExpression -> {
                    return !this.correlation.contains(variableReferenceExpression);
                }).map(AssignmentUtils::identityAsSymbolReference);
                putAll.getClass();
                map.forEach(putAll::put);
                return new Decorrelated(decorrelated.getCorrelatedPredicates(), new ProjectNode(projectNode.getId(), decorrelated.getDecorrelatedNode(), putAll.build()));
            });
        }

        @Override // com.facebook.presto.spi.plan.PlanVisitor
        public Optional<Decorrelated> visitFilter(FilterNode filterNode, PlanNode planNode) {
            return decorrelate(filterNode.getSource()).map(decorrelated -> {
                return new Decorrelated(ImmutableList.builder().addAll((Iterable) decorrelated.getCorrelatedPredicates()).add((ImmutableList.Builder) OriginalExpressionUtils.castToExpression(filterNode.getPredicate())).build(), decorrelated.getDecorrelatedNode());
            });
        }

        @Override // com.facebook.presto.spi.plan.PlanVisitor
        public Optional<Decorrelated> visitPlan(PlanNode planNode, PlanNode planNode2) {
            return isCorrelatedRecursively(planNode) ? Optional.empty() : Optional.of(new Decorrelated(ImmutableList.of(), planNode2));
        }

        private boolean isCorrelatedRecursively(PlanNode planNode) {
            if (isCorrelatedShallowly(planNode)) {
                return true;
            }
            Stream<PlanNode> stream = planNode.getSources().stream();
            Lookup lookup = this.lookup;
            lookup.getClass();
            return stream.map(lookup::resolve).anyMatch(this::isCorrelatedRecursively);
        }

        private boolean isCorrelatedShallowly(PlanNode planNode) {
            Stream<VariableReferenceExpression> stream = VariablesExtractor.extractUniqueNonRecursive(planNode, this.types).stream();
            Set<VariableReferenceExpression> set = this.correlation;
            set.getClass();
            return stream.anyMatch((v1) -> {
                return r1.contains(v1);
            });
        }
    }

    public TransformCorrelatedInPredicateToJoin(FunctionManager functionManager) {
        Objects.requireNonNull(functionManager, "functionManager is null");
        this.functionResolution = new FunctionResolution(functionManager);
    }

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public Pattern<ApplyNode> getPattern() {
        return PATTERN;
    }

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public Rule.Result apply(ApplyNode applyNode, Captures captures, Rule.Context context) {
        Assignments subqueryAssignments = applyNode.getSubqueryAssignments();
        if (subqueryAssignments.size() != 1) {
            return Rule.Result.empty();
        }
        Expression castToExpression = OriginalExpressionUtils.castToExpression((RowExpression) Iterables.getOnlyElement(subqueryAssignments.getExpressions()));
        return !(castToExpression instanceof InPredicate) ? Rule.Result.empty() : apply(applyNode, (InPredicate) castToExpression, (VariableReferenceExpression) Iterables.getOnlyElement(subqueryAssignments.getVariables()), context.getLookup(), context.getIdAllocator(), context.getVariableAllocator());
    }

    private Rule.Result apply(ApplyNode applyNode, InPredicate inPredicate, VariableReferenceExpression variableReferenceExpression, Lookup lookup, PlanNodeIdAllocator planNodeIdAllocator, PlanVariableAllocator planVariableAllocator) {
        Optional<Decorrelated> decorrelate = new DecorrelatingVisitor(lookup, applyNode.getCorrelation(), planVariableAllocator.getTypes()).decorrelate(applyNode.getSubquery());
        return !decorrelate.isPresent() ? Rule.Result.empty() : Rule.Result.ofPlanNode(buildInPredicateEquivalent(applyNode, inPredicate, variableReferenceExpression, decorrelate.get(), planNodeIdAllocator, planVariableAllocator));
    }

    private PlanNode buildInPredicateEquivalent(ApplyNode applyNode, InPredicate inPredicate, VariableReferenceExpression variableReferenceExpression, Decorrelated decorrelated, PlanNodeIdAllocator planNodeIdAllocator, PlanVariableAllocator planVariableAllocator) {
        Expression and = ExpressionUtils.and(decorrelated.getCorrelatedPredicates());
        PlanNode decorrelatedNode = decorrelated.getDecorrelatedNode();
        AssignUniqueId assignUniqueId = new AssignUniqueId(planNodeIdAllocator.getNextId(), applyNode.getInput(), planVariableAllocator.newVariable("unique", BigintType.BIGINT));
        VariableReferenceExpression newVariable = planVariableAllocator.newVariable("buildSideKnownNonNull", BigintType.BIGINT);
        ProjectNode projectNode = new ProjectNode(planNodeIdAllocator.getNextId(), decorrelatedNode, Assignments.builder().putAll(AssignmentUtils.identitiesAsSymbolReferences(decorrelatedNode.getOutputVariables())).put(newVariable, OriginalExpressionUtils.castToRowExpression(bigint(0L))).build());
        Preconditions.checkArgument(inPredicate.getValue() instanceof SymbolReference, "Unexpected expression: %s", inPredicate.getValue());
        SymbolReference symbolReference = (SymbolReference) inPredicate.getValue();
        Preconditions.checkArgument(inPredicate.getValueList() instanceof SymbolReference, "Unexpected expression: %s", inPredicate.getValueList());
        SymbolReference symbolReference2 = (SymbolReference) inPredicate.getValueList();
        JoinNode leftOuterJoin = leftOuterJoin(planNodeIdAllocator, assignUniqueId, projectNode, ExpressionUtils.and(ExpressionUtils.or(new IsNullPredicate(symbolReference), new ComparisonExpression(ComparisonExpression.Operator.EQUAL, symbolReference, symbolReference2), new IsNullPredicate(symbolReference2)), and));
        VariableReferenceExpression newVariable2 = planVariableAllocator.newVariable("countMatches", BigintType.BIGINT);
        VariableReferenceExpression newVariable3 = planVariableAllocator.newVariable("countNullMatches", BigintType.BIGINT);
        Expression and2 = ExpressionUtils.and(new IsNotNullPredicate(symbolReference), new IsNotNullPredicate(symbolReference2));
        return new ProjectNode(planNodeIdAllocator.getNextId(), new AggregationNode(planNodeIdAllocator.getNextId(), leftOuterJoin, ImmutableMap.builder().put(newVariable2, countWithFilter(and2)).put(newVariable3, countWithFilter(ExpressionUtils.and(new IsNotNullPredicate(new SymbolReference(newVariable.getName())), new NotExpression(and2)))).build(), AggregationNode.singleGroupingSet(assignUniqueId.getOutputVariables()), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty()), Assignments.builder().putAll(AssignmentUtils.identitiesAsSymbolReferences(applyNode.getInput().getOutputVariables())).put(variableReferenceExpression, OriginalExpressionUtils.castToRowExpression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(isGreaterThan(newVariable2, 0L), booleanConstant(true)), new WhenClause(isGreaterThan(newVariable3, 0L), booleanConstant(null))), Optional.of(booleanConstant(false))))).build());
    }

    private static JoinNode leftOuterJoin(PlanNodeIdAllocator planNodeIdAllocator, AssignUniqueId assignUniqueId, ProjectNode projectNode, Expression expression) {
        return new JoinNode(planNodeIdAllocator.getNextId(), JoinNode.Type.LEFT, assignUniqueId, projectNode, ImmutableList.of(), ImmutableList.builder().addAll((Iterable) assignUniqueId.getOutputVariables()).addAll((Iterable) projectNode.getOutputVariables()).build(), Optional.of(OriginalExpressionUtils.castToRowExpression(expression)), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of());
    }

    private AggregationNode.Aggregation countWithFilter(Expression expression) {
        return new AggregationNode.Aggregation(new CallExpression(Count.NAME, this.functionResolution.countFunction(), BigintType.BIGINT, ImmutableList.of()), Optional.of(OriginalExpressionUtils.castToRowExpression(expression)), Optional.empty(), false, Optional.empty());
    }

    private static Expression isGreaterThan(VariableReferenceExpression variableReferenceExpression, long j) {
        return new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, new SymbolReference(variableReferenceExpression.getName()), bigint(j));
    }

    private static Expression bigint(long j) {
        return new Cast(new LongLiteral(String.valueOf(j)), BigintType.BIGINT.toString());
    }

    private static Expression booleanConstant(@Nullable Boolean bool) {
        return bool == null ? new Cast(new NullLiteral(), BooleanType.BOOLEAN.toString()) : new BooleanLiteral(bool.toString());
    }
}
