package com.facebook.presto.sql.planner;

import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.TypeManager;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.metadata.FunctionManager;
import com.facebook.presto.metadata.OperatorNotFoundException;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.DistinctLimitNode;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.LimitNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.plan.TopNNode;
import com.facebook.presto.spi.plan.UnionNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
import com.facebook.presto.sql.planner.RowExpressionEqualityInference;
import com.facebook.presto.sql.planner.optimizations.SetOperationNodeUtils;
import com.facebook.presto.sql.planner.plan.AssignUniqueId;
import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.InternalPlanVisitor;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.SemiJoinNode;
import com.facebook.presto.sql.planner.plan.SortNode;
import com.facebook.presto.sql.planner.plan.SpatialJoinNode;
import com.facebook.presto.sql.planner.plan.WindowNode;
import com.facebook.presto.sql.relational.Expressions;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.facebook.presto.sql.relational.RowExpressionDomainTranslator;
import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;

/* loaded from: input_file:com/facebook/presto/sql/planner/RowExpressionPredicateExtractor.class */
public class RowExpressionPredicateExtractor {
    private final RowExpressionDomainTranslator domainTranslator;
    private final FunctionManager functionManager;
    private final TypeManager typeManager;

    /* loaded from: input_file:com/facebook/presto/sql/planner/RowExpressionPredicateExtractor$Visitor.class */
    private static class Visitor extends InternalPlanVisitor<RowExpression, Void> {
        private final RowExpressionDomainTranslator domainTranslator;
        private final LogicalRowExpressions logicalRowExpressions;
        private final RowExpressionDeterminismEvaluator determinismEvaluator;
        private final TypeManager typeManager;
        private final FunctionManager functionManger;

        public Visitor(RowExpressionDomainTranslator rowExpressionDomainTranslator, FunctionManager functionManager, TypeManager typeManager) {
            this.domainTranslator = (RowExpressionDomainTranslator) Objects.requireNonNull(rowExpressionDomainTranslator, "domainTranslator is null");
            this.typeManager = (TypeManager) Objects.requireNonNull(typeManager);
            this.functionManger = (FunctionManager) Objects.requireNonNull(functionManager);
            this.determinismEvaluator = new RowExpressionDeterminismEvaluator(functionManager);
            this.logicalRowExpressions = new LogicalRowExpressions(this.determinismEvaluator, new FunctionResolution(functionManager), functionManager);
        }

        @Override // com.facebook.presto.spi.plan.PlanVisitor
        public RowExpression visitPlan(PlanNode planNode, Void r4) {
            return LogicalRowExpressions.TRUE_CONSTANT;
        }

        @Override // com.facebook.presto.spi.plan.PlanVisitor
        public RowExpression visitAggregation(AggregationNode aggregationNode, Void r6) {
            return aggregationNode.getGroupingKeys().isEmpty() ? LogicalRowExpressions.TRUE_CONSTANT : pullExpressionThroughVariables((RowExpression) aggregationNode.getSource().accept(this, r6), aggregationNode.getGroupingKeys());
        }

        @Override // com.facebook.presto.spi.plan.PlanVisitor
        public RowExpression visitFilter(FilterNode filterNode, Void r8) {
            RowExpression rowExpression = (RowExpression) filterNode.getSource().accept(this, r8);
            return this.logicalRowExpressions.combineConjuncts(this.logicalRowExpressions.filterDeterministicConjuncts(filterNode.getPredicate()), rowExpression);
        }

        @Override // com.facebook.presto.sql.planner.plan.InternalPlanVisitor
        public RowExpression visitExchange(ExchangeNode exchangeNode, Void r6) {
            return deriveCommonPredicates(exchangeNode, num -> {
                HashMap hashMap = new HashMap();
                for (int i = 0; i < exchangeNode.getInputs().get(num.intValue()).size(); i++) {
                    hashMap.put(exchangeNode.getOutputVariables().get(i), exchangeNode.getInputs().get(num.intValue()).get(i));
                }
                return hashMap.entrySet();
            });
        }

        @Override // com.facebook.presto.sql.planner.plan.InternalPlanVisitor
        public RowExpression visitEnforceSingleRow(EnforceSingleRowNode enforceSingleRowNode, Void r6) {
            return enforceSingleRowNode.getSource() instanceof ProjectNode ? (RowExpression) enforceSingleRowNode.getSource().accept(this, r6) : LogicalRowExpressions.TRUE_CONSTANT;
        }

        @Override // com.facebook.presto.spi.plan.PlanVisitor
        public RowExpression visitProject(ProjectNode projectNode, Void r7) {
            return pullExpressionThroughVariables(this.logicalRowExpressions.combineConjuncts(ImmutableList.builder().addAll((Iterable) projectNode.getAssignments().getMap().entrySet().stream().filter(this::notIdentityAssignment).filter(this::canCompareEquity).map(this::toEquality).collect(ImmutableList.toImmutableList())).add((ImmutableList.Builder) projectNode.getSource().accept(this, r7)).build()), projectNode.getOutputVariables());
        }

        @Override // com.facebook.presto.spi.plan.PlanVisitor
        public RowExpression visitTopN(TopNNode topNNode, Void r6) {
            return (RowExpression) topNNode.getSource().accept(this, r6);
        }

        @Override // com.facebook.presto.spi.plan.PlanVisitor
        public RowExpression visitLimit(LimitNode limitNode, Void r6) {
            return (RowExpression) limitNode.getSource().accept(this, r6);
        }

        @Override // com.facebook.presto.sql.planner.plan.InternalPlanVisitor
        public RowExpression visitAssignUniqueId(AssignUniqueId assignUniqueId, Void r6) {
            return (RowExpression) assignUniqueId.getSource().accept(this, r6);
        }

        @Override // com.facebook.presto.spi.plan.PlanVisitor
        public RowExpression visitDistinctLimit(DistinctLimitNode distinctLimitNode, Void r6) {
            return (RowExpression) distinctLimitNode.getSource().accept(this, r6);
        }

        @Override // com.facebook.presto.spi.plan.PlanVisitor
        public RowExpression visitTableScan(TableScanNode tableScanNode, Void r6) {
            ImmutableBiMap inverse = ImmutableBiMap.copyOf((Map) tableScanNode.getAssignments()).inverse();
            return this.domainTranslator.toPredicate(tableScanNode.getCurrentConstraint().simplify().transform(columnHandle -> {
                if (inverse.containsKey(columnHandle)) {
                    return (VariableReferenceExpression) inverse.get(columnHandle);
                }
                return null;
            }));
        }

        @Override // com.facebook.presto.sql.planner.plan.InternalPlanVisitor
        public RowExpression visitSort(SortNode sortNode, Void r6) {
            return (RowExpression) sortNode.getSource().accept(this, r6);
        }

        @Override // com.facebook.presto.sql.planner.plan.InternalPlanVisitor
        public RowExpression visitWindow(WindowNode windowNode, Void r6) {
            return (RowExpression) windowNode.getSource().accept(this, r6);
        }

        @Override // com.facebook.presto.spi.plan.PlanVisitor
        public RowExpression visitUnion(UnionNode unionNode, Void r6) {
            return deriveCommonPredicates(unionNode, num -> {
                return SetOperationNodeUtils.outputMap(unionNode, num.intValue()).entries();
            });
        }

        @Override // com.facebook.presto.sql.planner.plan.InternalPlanVisitor
        public RowExpression visitJoin(JoinNode joinNode, Void r13) {
            RowExpression rowExpression = (RowExpression) joinNode.getLeft().accept(this, r13);
            RowExpression rowExpression2 = (RowExpression) joinNode.getRight().accept(this, r13);
            List<RowExpression> list = (List) joinNode.getCriteria().stream().map(this::toRowExpression).collect(ImmutableList.toImmutableList());
            switch (joinNode.getType()) {
                case INNER:
                    return pullExpressionThroughVariables(this.logicalRowExpressions.combineConjuncts(ImmutableList.builder().add((ImmutableList.Builder) rowExpression).add((ImmutableList.Builder) rowExpression2).add((ImmutableList.Builder) this.logicalRowExpressions.combineConjuncts(list)).add((ImmutableList.Builder) joinNode.getFilter().orElse(LogicalRowExpressions.TRUE_CONSTANT)).build()), joinNode.getOutputVariables());
                case LEFT:
                    LogicalRowExpressions logicalRowExpressions = this.logicalRowExpressions;
                    ImmutableList.Builder add = ImmutableList.builder().add((ImmutableList.Builder) pullExpressionThroughVariables(rowExpression, joinNode.getOutputVariables()));
                    List<RowExpression> extractConjuncts = LogicalRowExpressions.extractConjuncts(rowExpression2);
                    List<VariableReferenceExpression> outputVariables = joinNode.getOutputVariables();
                    List<VariableReferenceExpression> outputVariables2 = joinNode.getRight().getOutputVariables();
                    outputVariables2.getClass();
                    ImmutableList.Builder addAll = add.addAll((Iterable) pullNullableConjunctsThroughOuterJoin(extractConjuncts, outputVariables, (v1) -> {
                        return r8.contains(v1);
                    }));
                    List<VariableReferenceExpression> outputVariables3 = joinNode.getOutputVariables();
                    List<VariableReferenceExpression> outputVariables4 = joinNode.getRight().getOutputVariables();
                    outputVariables4.getClass();
                    return logicalRowExpressions.combineConjuncts(addAll.addAll((Iterable) pullNullableConjunctsThroughOuterJoin(list, outputVariables3, (v1) -> {
                        return r8.contains(v1);
                    })).build());
                case RIGHT:
                    LogicalRowExpressions logicalRowExpressions2 = this.logicalRowExpressions;
                    ImmutableList.Builder add2 = ImmutableList.builder().add((ImmutableList.Builder) pullExpressionThroughVariables(rowExpression2, joinNode.getOutputVariables()));
                    List<RowExpression> extractConjuncts2 = LogicalRowExpressions.extractConjuncts(rowExpression);
                    List<VariableReferenceExpression> outputVariables5 = joinNode.getOutputVariables();
                    List<VariableReferenceExpression> outputVariables6 = joinNode.getLeft().getOutputVariables();
                    outputVariables6.getClass();
                    ImmutableList.Builder addAll2 = add2.addAll((Iterable) pullNullableConjunctsThroughOuterJoin(extractConjuncts2, outputVariables5, (v1) -> {
                        return r8.contains(v1);
                    }));
                    List<VariableReferenceExpression> outputVariables7 = joinNode.getOutputVariables();
                    List<VariableReferenceExpression> outputVariables8 = joinNode.getLeft().getOutputVariables();
                    outputVariables8.getClass();
                    return logicalRowExpressions2.combineConjuncts(addAll2.addAll((Iterable) pullNullableConjunctsThroughOuterJoin(list, outputVariables7, (v1) -> {
                        return r8.contains(v1);
                    })).build());
                case FULL:
                    LogicalRowExpressions logicalRowExpressions3 = this.logicalRowExpressions;
                    ImmutableList.Builder builder = ImmutableList.builder();
                    List<RowExpression> extractConjuncts3 = LogicalRowExpressions.extractConjuncts(rowExpression);
                    List<VariableReferenceExpression> outputVariables9 = joinNode.getOutputVariables();
                    List<VariableReferenceExpression> outputVariables10 = joinNode.getLeft().getOutputVariables();
                    outputVariables10.getClass();
                    ImmutableList.Builder addAll3 = builder.addAll((Iterable) pullNullableConjunctsThroughOuterJoin(extractConjuncts3, outputVariables9, (v1) -> {
                        return r8.contains(v1);
                    }));
                    List<RowExpression> extractConjuncts4 = LogicalRowExpressions.extractConjuncts(rowExpression2);
                    List<VariableReferenceExpression> outputVariables11 = joinNode.getOutputVariables();
                    List<VariableReferenceExpression> outputVariables12 = joinNode.getRight().getOutputVariables();
                    outputVariables12.getClass();
                    ImmutableList.Builder addAll4 = addAll3.addAll((Iterable) pullNullableConjunctsThroughOuterJoin(extractConjuncts4, outputVariables11, (v1) -> {
                        return r8.contains(v1);
                    }));
                    List<VariableReferenceExpression> outputVariables13 = joinNode.getOutputVariables();
                    List<VariableReferenceExpression> outputVariables14 = joinNode.getLeft().getOutputVariables();
                    outputVariables14.getClass();
                    List<VariableReferenceExpression> outputVariables15 = joinNode.getRight().getOutputVariables();
                    outputVariables15.getClass();
                    return logicalRowExpressions3.combineConjuncts(addAll4.addAll((Iterable) pullNullableConjunctsThroughOuterJoin(list, outputVariables13, (v1) -> {
                        return r8.contains(v1);
                    }, (v1) -> {
                        return r8.contains(v1);
                    })).build());
                default:
                    throw new UnsupportedOperationException("Unknown join type: " + joinNode.getType());
            }
        }

        private Iterable<RowExpression> pullNullableConjunctsThroughOuterJoin(List<RowExpression> list, Collection<VariableReferenceExpression> collection, Predicate<VariableReferenceExpression>... predicateArr) {
            return (Iterable) list.stream().map(rowExpression -> {
                return pullExpressionThroughVariables(rowExpression, collection);
            }).map(rowExpression2 -> {
                return VariablesExtractor.extractAll(rowExpression2).isEmpty() ? LogicalRowExpressions.TRUE_CONSTANT : rowExpression2;
            }).map(expressionOrNullVariables(predicateArr)).collect(ImmutableList.toImmutableList());
        }

        public Function<RowExpression, RowExpression> expressionOrNullVariables(Predicate<VariableReferenceExpression>... predicateArr) {
            return rowExpression -> {
                ImmutableList.Builder builder = ImmutableList.builder();
                builder.add((ImmutableList.Builder) rowExpression);
                for (Predicate predicate : predicateArr) {
                    List list = (List) VariablesExtractor.extractUnique(rowExpression).stream().filter(predicate).collect(ImmutableList.toImmutableList());
                    if (!Iterables.isEmpty(list)) {
                        ImmutableList.Builder builder2 = ImmutableList.builder();
                        Iterator it2 = list.iterator();
                        while (it2.hasNext()) {
                            builder2.add((ImmutableList.Builder) Expressions.specialForm(SpecialFormExpression.Form.IS_NULL, BooleanType.BOOLEAN, (VariableReferenceExpression) it2.next()));
                        }
                        LogicalRowExpressions logicalRowExpressions = this.logicalRowExpressions;
                        builder.add((ImmutableList.Builder) LogicalRowExpressions.and(builder2.build()));
                    }
                }
                LogicalRowExpressions logicalRowExpressions2 = this.logicalRowExpressions;
                return LogicalRowExpressions.or(builder.build());
            };
        }

        @Override // com.facebook.presto.sql.planner.plan.InternalPlanVisitor
        public RowExpression visitSemiJoin(SemiJoinNode semiJoinNode, Void r6) {
            return (RowExpression) semiJoinNode.getSource().accept(this, r6);
        }

        @Override // com.facebook.presto.sql.planner.plan.InternalPlanVisitor
        public RowExpression visitSpatialJoin(SpatialJoinNode spatialJoinNode, Void r13) {
            RowExpression rowExpression = (RowExpression) spatialJoinNode.getLeft().accept(this, r13);
            RowExpression rowExpression2 = (RowExpression) spatialJoinNode.getRight().accept(this, r13);
            switch (spatialJoinNode.getType()) {
                case INNER:
                    return this.logicalRowExpressions.combineConjuncts(ImmutableList.builder().add((ImmutableList.Builder) pullExpressionThroughVariables(rowExpression, spatialJoinNode.getOutputVariables())).add((ImmutableList.Builder) pullExpressionThroughVariables(rowExpression2, spatialJoinNode.getOutputVariables())).build());
                case LEFT:
                    LogicalRowExpressions logicalRowExpressions = this.logicalRowExpressions;
                    ImmutableList.Builder add = ImmutableList.builder().add((ImmutableList.Builder) pullExpressionThroughVariables(rowExpression, spatialJoinNode.getOutputVariables()));
                    List<RowExpression> extractConjuncts = LogicalRowExpressions.extractConjuncts(rowExpression2);
                    List<VariableReferenceExpression> outputVariables = spatialJoinNode.getOutputVariables();
                    List<VariableReferenceExpression> outputVariables2 = spatialJoinNode.getRight().getOutputVariables();
                    outputVariables2.getClass();
                    return logicalRowExpressions.combineConjuncts(add.addAll((Iterable) pullNullableConjunctsThroughOuterJoin(extractConjuncts, outputVariables, (v1) -> {
                        return r8.contains(v1);
                    })).build());
                default:
                    throw new IllegalArgumentException("Unsupported spatial join type: " + spatialJoinNode.getType());
            }
        }

        private RowExpression toRowExpression(JoinNode.EquiJoinClause equiJoinClause) {
            return buildEqualsExpression(this.functionManger, equiJoinClause.getLeft(), equiJoinClause.getRight());
        }

        private RowExpression deriveCommonPredicates(PlanNode planNode, Function<Integer, Collection<Map.Entry<VariableReferenceExpression, VariableReferenceExpression>>> function) {
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < planNode.getSources().size(); i++) {
                arrayList.add(ImmutableSet.copyOf((Collection) LogicalRowExpressions.extractConjuncts(pullExpressionThroughVariables(this.logicalRowExpressions.combineConjuncts(ImmutableList.builder().addAll((Iterable) function.apply(Integer.valueOf(i)).stream().filter(this::notIdentityAssignment).filter(this::canCompareEquity).map(this::toEquality).collect(ImmutableList.toImmutableList())).add((ImmutableList.Builder) planNode.getSources().get(i).accept(this, null)).build()), planNode.getOutputVariables()))));
            }
            Iterator it2 = arrayList.iterator();
            Set set = (Set) it2.next();
            while (true) {
                Set set2 = set;
                if (!it2.hasNext()) {
                    return this.logicalRowExpressions.combineConjuncts(set2);
                }
                set = Sets.intersection(set2, (Set) it2.next());
            }
        }

        private boolean notIdentityAssignment(Map.Entry<VariableReferenceExpression, ? extends RowExpression> entry) {
            return !entry.getKey().equals(entry.getValue());
        }

        private boolean canCompareEquity(Map.Entry<VariableReferenceExpression, ? extends RowExpression> entry) {
            try {
                this.functionManger.resolveOperator(OperatorType.EQUAL, TypeSignatureProvider.fromTypes(entry.getKey().getType(), entry.getValue().getType()));
                return true;
            } catch (OperatorNotFoundException e) {
                return false;
            }
        }

        private RowExpression toEquality(Map.Entry<VariableReferenceExpression, ? extends RowExpression> entry) {
            return buildEqualsExpression(this.functionManger, entry.getKey(), entry.getValue());
        }

        private static CallExpression buildEqualsExpression(FunctionManager functionManager, RowExpression rowExpression, RowExpression rowExpression2) {
            return Expressions.call(OperatorType.EQUAL.getFunctionName().getFunctionName(), functionManager.resolveOperator(OperatorType.EQUAL, TypeSignatureProvider.fromTypes(rowExpression.getType(), rowExpression2.getType())), BooleanType.BOOLEAN, rowExpression, rowExpression2);
        }

        private RowExpression pullExpressionThroughVariables(RowExpression rowExpression, Collection<VariableReferenceExpression> collection) {
            RowExpression rewriteExpression;
            RowExpressionEqualityInference build = new RowExpressionEqualityInference.Builder(this.functionManger, this.typeManager).addEqualityInference(rowExpression).build();
            ImmutableList.Builder builder = ImmutableList.builder();
            for (RowExpression rowExpression2 : new RowExpressionEqualityInference.Builder(this.functionManger, this.typeManager).nonInferrableConjuncts(rowExpression)) {
                if (this.determinismEvaluator.isDeterministic(rowExpression2) && (rewriteExpression = build.rewriteExpression(rowExpression2, Predicates.in(collection))) != null) {
                    builder.add((ImmutableList.Builder) rewriteExpression);
                }
            }
            builder.addAll((Iterable) build.generateEqualitiesPartitionedBy(Predicates.in(collection)).getScopeEqualities());
            return this.logicalRowExpressions.combineConjuncts(builder.build());
        }
    }

    public RowExpressionPredicateExtractor(RowExpressionDomainTranslator rowExpressionDomainTranslator, FunctionManager functionManager, TypeManager typeManager) {
        this.domainTranslator = (RowExpressionDomainTranslator) Objects.requireNonNull(rowExpressionDomainTranslator, "domainTranslator is null");
        this.functionManager = functionManager;
        this.typeManager = typeManager;
    }

    public RowExpression extract(PlanNode planNode) {
        return (RowExpression) planNode.accept(new Visitor(this.domainTranslator, this.functionManager, this.typeManager), null);
    }
}
