package com.facebook.presto.sql.planner;

import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.planner.optimizations.JoinNodeUtils;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.AssignUniqueId;
import com.facebook.presto.sql.planner.plan.DistinctLimitNode;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.InternalPlanVisitor;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.LimitNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SemiJoinNode;
import com.facebook.presto.sql.planner.plan.SortNode;
import com.facebook.presto.sql.planner.plan.SpatialJoinNode;
import com.facebook.presto.sql.planner.plan.TopNNode;
import com.facebook.presto.sql.planner.plan.UnionNode;
import com.facebook.presto.sql.planner.plan.WindowNode;
import com.facebook.presto.sql.relational.OriginalExpressionUtils;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
import com.google.common.collect.Multimaps;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;

/* loaded from: input_file:com/facebook/presto/sql/planner/EffectivePredicateExtractor.class */
public class EffectivePredicateExtractor {
    private static final Predicate<Map.Entry<VariableReferenceExpression, ? extends Expression>> VARIABLE_MATCHES_EXPRESSION = entry -> {
        return ((Expression) entry.getValue()).equals(new SymbolReference(((VariableReferenceExpression) entry.getKey()).getName()));
    };
    private static final Function<Map.Entry<VariableReferenceExpression, ? extends Expression>, Expression> VARIABLE_ENTRY_TO_EQUALITY = entry -> {
        return new ComparisonExpression(ComparisonExpression.Operator.EQUAL, new SymbolReference(((VariableReferenceExpression) entry.getKey()).getName()), (Expression) entry.getValue());
    };
    private final ExpressionDomainTranslator domainTranslator;

    /* loaded from: input_file:com/facebook/presto/sql/planner/EffectivePredicateExtractor$Visitor.class */
    private static class Visitor extends InternalPlanVisitor<Expression, Void> {
        private final ExpressionDomainTranslator domainTranslator;
        private final TypeProvider types;

        public Visitor(ExpressionDomainTranslator expressionDomainTranslator, TypeProvider typeProvider) {
            this.domainTranslator = (ExpressionDomainTranslator) Objects.requireNonNull(expressionDomainTranslator, "domainTranslator is null");
            this.types = (TypeProvider) Objects.requireNonNull(typeProvider, "types is null");
        }

        @Override // com.facebook.presto.spi.plan.PlanVisitor
        public Expression visitPlan(PlanNode planNode, Void r4) {
            return BooleanLiteral.TRUE_LITERAL;
        }

        @Override // com.facebook.presto.sql.planner.plan.InternalPlanVisitor
        public Expression visitAggregation(AggregationNode aggregationNode, Void r6) {
            return aggregationNode.getGroupingKeys().isEmpty() ? BooleanLiteral.TRUE_LITERAL : pullExpressionThroughVariables((Expression) aggregationNode.getSource().accept(this, r6), aggregationNode.getGroupingKeys());
        }

        @Override // com.facebook.presto.spi.plan.PlanVisitor
        public Expression visitFilter(FilterNode filterNode, Void r7) {
            return ExpressionUtils.combineConjuncts(ExpressionUtils.filterDeterministicConjuncts(OriginalExpressionUtils.castToExpression(filterNode.getPredicate())), (Expression) filterNode.getSource().accept(this, r7));
        }

        @Override // com.facebook.presto.sql.planner.plan.InternalPlanVisitor
        public Expression visitExchange(ExchangeNode exchangeNode, Void r6) {
            return deriveCommonPredicates(exchangeNode, num -> {
                HashMap hashMap = new HashMap();
                for (int i = 0; i < exchangeNode.getInputs().get(num.intValue()).size(); i++) {
                    hashMap.put(exchangeNode.getOutputVariables().get(i), new SymbolReference(exchangeNode.getInputs().get(num.intValue()).get(i).getName()));
                }
                return hashMap.entrySet();
            });
        }

        @Override // com.facebook.presto.sql.planner.plan.InternalPlanVisitor
        public Expression visitProject(ProjectNode projectNode, Void r6) {
            return pullExpressionThroughVariables(ExpressionUtils.combineConjuncts(ImmutableList.builder().addAll((Iterable) Maps.transformValues(projectNode.getAssignments().getMap(), OriginalExpressionUtils::castToExpression).entrySet().stream().filter(EffectivePredicateExtractor.VARIABLE_MATCHES_EXPRESSION.negate()).map(EffectivePredicateExtractor.VARIABLE_ENTRY_TO_EQUALITY).collect(ImmutableList.toImmutableList())).add((ImmutableList.Builder) projectNode.getSource().accept(this, r6)).build()), projectNode.getOutputVariables());
        }

        @Override // com.facebook.presto.sql.planner.plan.InternalPlanVisitor
        public Expression visitTopN(TopNNode topNNode, Void r6) {
            return (Expression) topNNode.getSource().accept(this, r6);
        }

        @Override // com.facebook.presto.sql.planner.plan.InternalPlanVisitor
        public Expression visitLimit(LimitNode limitNode, Void r6) {
            return (Expression) limitNode.getSource().accept(this, r6);
        }

        @Override // com.facebook.presto.sql.planner.plan.InternalPlanVisitor
        public Expression visitAssignUniqueId(AssignUniqueId assignUniqueId, Void r6) {
            return (Expression) assignUniqueId.getSource().accept(this, r6);
        }

        @Override // com.facebook.presto.sql.planner.plan.InternalPlanVisitor
        public Expression visitDistinctLimit(DistinctLimitNode distinctLimitNode, Void r6) {
            return (Expression) distinctLimitNode.getSource().accept(this, r6);
        }

        @Override // com.facebook.presto.spi.plan.PlanVisitor
        public Expression visitTableScan(TableScanNode tableScanNode, Void r6) {
            ImmutableBiMap inverse = ImmutableBiMap.copyOf((Map) tableScanNode.getAssignments()).inverse();
            return this.domainTranslator.toPredicate(tableScanNode.getCurrentConstraint().simplify().transform(columnHandle -> {
                if (inverse.containsKey(columnHandle)) {
                    return ((VariableReferenceExpression) inverse.get(columnHandle)).getName();
                }
                return null;
            }));
        }

        @Override // com.facebook.presto.sql.planner.plan.InternalPlanVisitor
        public Expression visitSort(SortNode sortNode, Void r6) {
            return (Expression) sortNode.getSource().accept(this, r6);
        }

        @Override // com.facebook.presto.sql.planner.plan.InternalPlanVisitor
        public Expression visitWindow(WindowNode windowNode, Void r6) {
            return (Expression) windowNode.getSource().accept(this, r6);
        }

        @Override // com.facebook.presto.sql.planner.plan.InternalPlanVisitor
        public Expression visitUnion(UnionNode unionNode, Void r6) {
            return deriveCommonPredicates(unionNode, num -> {
                return Multimaps.transformValues(unionNode.outputMap(num.intValue()), variableReferenceExpression -> {
                    return new SymbolReference(variableReferenceExpression.getName());
                }).entries();
            });
        }

        @Override // com.facebook.presto.sql.planner.plan.InternalPlanVisitor
        public Expression visitJoin(JoinNode joinNode, Void r12) {
            Expression expression = (Expression) joinNode.getLeft().accept(this, r12);
            Expression expression2 = (Expression) joinNode.getRight().accept(this, r12);
            List<Expression> list = (List) joinNode.getCriteria().stream().map(JoinNodeUtils::toExpression).collect(ImmutableList.toImmutableList());
            switch (joinNode.getType()) {
                case INNER:
                    return pullExpressionThroughVariables(ExpressionUtils.combineConjuncts(ImmutableList.builder().add((ImmutableList.Builder) expression).add((ImmutableList.Builder) expression2).add((ImmutableList.Builder) ExpressionUtils.combineConjuncts(list)).add((ImmutableList.Builder) joinNode.getFilter().map(OriginalExpressionUtils::castToExpression).orElse(BooleanLiteral.TRUE_LITERAL)).build()), joinNode.getOutputVariables());
                case LEFT:
                    ImmutableList.Builder add = ImmutableList.builder().add((ImmutableList.Builder) pullExpressionThroughVariables(expression, joinNode.getOutputVariables()));
                    List<Expression> extractConjuncts = ExpressionUtils.extractConjuncts(expression2);
                    List<VariableReferenceExpression> outputVariables = joinNode.getOutputVariables();
                    List<VariableReferenceExpression> outputVariables2 = joinNode.getRight().getOutputVariables();
                    outputVariables2.getClass();
                    ImmutableList.Builder addAll = add.addAll((Iterable) pullNullableConjunctsThroughOuterJoin(extractConjuncts, outputVariables, (v1) -> {
                        return r7.contains(v1);
                    }));
                    List<VariableReferenceExpression> outputVariables3 = joinNode.getOutputVariables();
                    List<VariableReferenceExpression> outputVariables4 = joinNode.getRight().getOutputVariables();
                    outputVariables4.getClass();
                    return ExpressionUtils.combineConjuncts(addAll.addAll((Iterable) pullNullableConjunctsThroughOuterJoin(list, outputVariables3, (v1) -> {
                        return r7.contains(v1);
                    })).build());
                case RIGHT:
                    ImmutableList.Builder add2 = ImmutableList.builder().add((ImmutableList.Builder) pullExpressionThroughVariables(expression2, joinNode.getOutputVariables()));
                    List<Expression> extractConjuncts2 = ExpressionUtils.extractConjuncts(expression);
                    List<VariableReferenceExpression> outputVariables5 = joinNode.getOutputVariables();
                    List<VariableReferenceExpression> outputVariables6 = joinNode.getLeft().getOutputVariables();
                    outputVariables6.getClass();
                    ImmutableList.Builder addAll2 = add2.addAll((Iterable) pullNullableConjunctsThroughOuterJoin(extractConjuncts2, outputVariables5, (v1) -> {
                        return r7.contains(v1);
                    }));
                    List<VariableReferenceExpression> outputVariables7 = joinNode.getOutputVariables();
                    List<VariableReferenceExpression> outputVariables8 = joinNode.getLeft().getOutputVariables();
                    outputVariables8.getClass();
                    return ExpressionUtils.combineConjuncts(addAll2.addAll((Iterable) pullNullableConjunctsThroughOuterJoin(list, outputVariables7, (v1) -> {
                        return r7.contains(v1);
                    })).build());
                case FULL:
                    ImmutableList.Builder builder = ImmutableList.builder();
                    List<Expression> extractConjuncts3 = ExpressionUtils.extractConjuncts(expression);
                    List<VariableReferenceExpression> outputVariables9 = joinNode.getOutputVariables();
                    List<VariableReferenceExpression> outputVariables10 = joinNode.getLeft().getOutputVariables();
                    outputVariables10.getClass();
                    ImmutableList.Builder addAll3 = builder.addAll((Iterable) pullNullableConjunctsThroughOuterJoin(extractConjuncts3, outputVariables9, (v1) -> {
                        return r7.contains(v1);
                    }));
                    List<Expression> extractConjuncts4 = ExpressionUtils.extractConjuncts(expression2);
                    List<VariableReferenceExpression> outputVariables11 = joinNode.getOutputVariables();
                    List<VariableReferenceExpression> outputVariables12 = joinNode.getRight().getOutputVariables();
                    outputVariables12.getClass();
                    ImmutableList.Builder addAll4 = addAll3.addAll((Iterable) pullNullableConjunctsThroughOuterJoin(extractConjuncts4, outputVariables11, (v1) -> {
                        return r7.contains(v1);
                    }));
                    List<VariableReferenceExpression> outputVariables13 = joinNode.getOutputVariables();
                    List<VariableReferenceExpression> outputVariables14 = joinNode.getLeft().getOutputVariables();
                    outputVariables14.getClass();
                    List<VariableReferenceExpression> outputVariables15 = joinNode.getRight().getOutputVariables();
                    outputVariables15.getClass();
                    return ExpressionUtils.combineConjuncts(addAll4.addAll((Iterable) pullNullableConjunctsThroughOuterJoin(list, outputVariables13, (v1) -> {
                        return r7.contains(v1);
                    }, (v1) -> {
                        return r7.contains(v1);
                    })).build());
                default:
                    throw new UnsupportedOperationException("Unknown join type: " + joinNode.getType());
            }
        }

        private Iterable<Expression> pullNullableConjunctsThroughOuterJoin(List<Expression> list, Collection<VariableReferenceExpression> collection, Predicate<VariableReferenceExpression>... predicateArr) {
            return (Iterable) list.stream().map(expression -> {
                return pullExpressionThroughVariables(expression, collection);
            }).map(expression2 -> {
                return VariablesExtractor.extractAll(expression2, this.types).isEmpty() ? BooleanLiteral.TRUE_LITERAL : expression2;
            }).map(ExpressionUtils.expressionOrNullVariables(this.types, predicateArr)).collect(ImmutableList.toImmutableList());
        }

        @Override // com.facebook.presto.sql.planner.plan.InternalPlanVisitor
        public Expression visitSemiJoin(SemiJoinNode semiJoinNode, Void r6) {
            return (Expression) semiJoinNode.getSource().accept(this, r6);
        }

        @Override // com.facebook.presto.sql.planner.plan.InternalPlanVisitor
        public Expression visitSpatialJoin(SpatialJoinNode spatialJoinNode, Void r12) {
            Expression expression = (Expression) spatialJoinNode.getLeft().accept(this, r12);
            Expression expression2 = (Expression) spatialJoinNode.getRight().accept(this, r12);
            switch (spatialJoinNode.getType()) {
                case INNER:
                    return ExpressionUtils.combineConjuncts(ImmutableList.builder().add((ImmutableList.Builder) pullExpressionThroughVariables(expression, spatialJoinNode.getOutputVariables())).add((ImmutableList.Builder) pullExpressionThroughVariables(expression2, spatialJoinNode.getOutputVariables())).build());
                case LEFT:
                    ImmutableList.Builder add = ImmutableList.builder().add((ImmutableList.Builder) pullExpressionThroughVariables(expression, spatialJoinNode.getOutputVariables()));
                    List<Expression> extractConjuncts = ExpressionUtils.extractConjuncts(expression2);
                    List<VariableReferenceExpression> outputVariables = spatialJoinNode.getOutputVariables();
                    List<VariableReferenceExpression> outputVariables2 = spatialJoinNode.getRight().getOutputVariables();
                    outputVariables2.getClass();
                    return ExpressionUtils.combineConjuncts(add.addAll((Iterable) pullNullableConjunctsThroughOuterJoin(extractConjuncts, outputVariables, (v1) -> {
                        return r7.contains(v1);
                    })).build());
                default:
                    throw new IllegalArgumentException("Unsupported spatial join type: " + spatialJoinNode.getType());
            }
        }

        private Expression deriveCommonPredicates(PlanNode planNode, Function<Integer, Collection<Map.Entry<VariableReferenceExpression, SymbolReference>>> function) {
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < planNode.getSources().size(); i++) {
                arrayList.add(ImmutableSet.copyOf((Collection) ExpressionUtils.extractConjuncts(pullExpressionThroughVariables(ExpressionUtils.combineConjuncts(ImmutableList.builder().addAll((Iterable) function.apply(Integer.valueOf(i)).stream().filter(EffectivePredicateExtractor.VARIABLE_MATCHES_EXPRESSION.negate()).map(EffectivePredicateExtractor.VARIABLE_ENTRY_TO_EQUALITY).collect(ImmutableList.toImmutableList())).add((ImmutableList.Builder) planNode.getSources().get(i).accept(this, null)).build()), planNode.getOutputVariables()))));
            }
            Iterator it2 = arrayList.iterator();
            Set set = (Set) it2.next();
            while (true) {
                Set set2 = set;
                if (!it2.hasNext()) {
                    return ExpressionUtils.combineConjuncts(set2);
                }
                set = Sets.intersection(set2, (Set) it2.next());
            }
        }

        private Expression pullExpressionThroughVariables(Expression expression, Collection<VariableReferenceExpression> collection) {
            Expression rewriteExpression;
            EqualityInference createEqualityInference = EqualityInference.createEqualityInference(expression);
            ImmutableList.Builder builder = ImmutableList.builder();
            for (Expression expression2 : EqualityInference.nonInferrableConjuncts(expression)) {
                if (ExpressionDeterminismEvaluator.isDeterministic(expression2) && (rewriteExpression = createEqualityInference.rewriteExpression(expression2, Predicates.in(collection), this.types)) != null) {
                    builder.add((ImmutableList.Builder) rewriteExpression);
                }
            }
            builder.addAll((Iterable) createEqualityInference.generateEqualitiesPartitionedBy(Predicates.in(collection), this.types).getScopeEqualities());
            return ExpressionUtils.combineConjuncts(builder.build());
        }
    }

    public EffectivePredicateExtractor(ExpressionDomainTranslator expressionDomainTranslator) {
        this.domainTranslator = (ExpressionDomainTranslator) Objects.requireNonNull(expressionDomainTranslator, "domainTranslator is null");
    }

    public Expression extract(PlanNode planNode, TypeProvider typeProvider) {
        return (Expression) planNode.accept(new Visitor(this.domainTranslator, typeProvider), null);
    }
}
