package com.facebook.presto.sql.planner;

import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.expressions.RowExpressionNodeInliner;
import com.facebook.presto.expressions.RowExpressionTreeRewriter;
import com.facebook.presto.metadata.FunctionManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.function.OperatorType;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.spi.type.BooleanType;
import com.facebook.presto.spi.type.TypeManager;
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
import com.facebook.presto.sql.relational.Expressions;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.facebook.presto.util.DisjointSet;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Predicate;
import com.google.common.base.Predicates;
import com.google.common.collect.ComparisonChain;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSetMultimap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Ordering;
import com.google.common.collect.SetMultimap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

/* loaded from: input_file:com/facebook/presto/sql/planner/RowExpressionEqualityInference.class */
public class RowExpressionEqualityInference {
    private static final Ordering<RowExpression> CANONICAL_ORDERING = Ordering.from((rowExpression, rowExpression2) -> {
        return ComparisonChain.start().compare(VariablesExtractor.extractAll(rowExpression).size(), VariablesExtractor.extractAll(rowExpression2).size()).compare(Expressions.uniqueSubExpressions(rowExpression).size(), Expressions.uniqueSubExpressions(rowExpression2).size()).compare(rowExpression.toString(), rowExpression2.toString()).result();
    });
    private final SetMultimap<RowExpression, RowExpression> equalitySets;
    private final Map<RowExpression, RowExpression> canonicalMap;
    private final Set<RowExpression> derivedExpressions;
    private final RowExpressionDeterminismEvaluator determinismEvaluator;
    private final FunctionManager functionManager;

    /* loaded from: input_file:com/facebook/presto/sql/planner/RowExpressionEqualityInference$Builder.class */
    public static class Builder {
        private final DisjointSet<RowExpression> equalities;
        private final Set<RowExpression> derivedExpressions;
        private final FunctionManager functionManager;
        private final NullabilityAnalyzer nullabilityAnalyzer;
        private final RowExpressionDeterminismEvaluator determinismEvaluator;

        public Builder(FunctionManager functionManager, TypeManager typeManager) {
            this.equalities = new DisjointSet<>();
            this.derivedExpressions = new LinkedHashSet();
            this.determinismEvaluator = new RowExpressionDeterminismEvaluator(functionManager);
            this.functionManager = functionManager;
            this.nullabilityAnalyzer = new NullabilityAnalyzer(functionManager, typeManager);
        }

        public Builder(Metadata metadata) {
            this(metadata.getFunctionManager(), metadata.getTypeManager());
        }

        public Predicate<RowExpression> isInferenceCandidate() {
            return rowExpression -> {
                RowExpression normalizeInPredicateToEquality = normalizeInPredicateToEquality(rowExpression);
                return isOperation(normalizeInPredicateToEquality, OperatorType.EQUAL) && this.determinismEvaluator.isDeterministic(normalizeInPredicateToEquality) && !this.nullabilityAnalyzer.mayReturnNullOnNonNullInput(normalizeInPredicateToEquality) && !RowExpressionEqualityInference.getLeft(normalizeInPredicateToEquality).equals(RowExpressionEqualityInference.getRight(normalizeInPredicateToEquality));
            };
        }

        public static Predicate<RowExpression> isInferenceCandidate(Metadata metadata) {
            return new Builder(metadata).isInferenceCandidate();
        }

        private RowExpression normalizeInPredicateToEquality(RowExpression rowExpression) {
            if (RowExpressionEqualityInference.isInPredicate(rowExpression)) {
                int size = ((SpecialFormExpression) rowExpression).getArguments().size() - 1;
                Preconditions.checkArgument(size >= 1, "InList cannot be empty");
                if (size == 1) {
                    return RowExpressionEqualityInference.buildEqualsExpression(this.functionManager, ((SpecialFormExpression) rowExpression).getArguments().get(0), ((SpecialFormExpression) rowExpression).getArguments().get(1));
                }
            }
            return rowExpression;
        }

        public Iterable<RowExpression> nonInferrableConjuncts(RowExpression rowExpression) {
            return Iterables.filter(LogicalRowExpressions.extractConjuncts(rowExpression), Predicates.not(isInferenceCandidate()));
        }

        public static Iterable<RowExpression> nonInferrableConjuncts(Metadata metadata, RowExpression rowExpression) {
            return new Builder(metadata).nonInferrableConjuncts(rowExpression);
        }

        public Builder addEqualityInference(RowExpression... rowExpressionArr) {
            for (RowExpression rowExpression : rowExpressionArr) {
                extractInferenceCandidates(rowExpression);
            }
            return this;
        }

        public Builder extractInferenceCandidates(RowExpression rowExpression) {
            return addAllEqualities(Iterables.filter(LogicalRowExpressions.extractConjuncts(rowExpression), isInferenceCandidate()));
        }

        public Builder addAllEqualities(Iterable<RowExpression> iterable) {
            Iterator<RowExpression> it2 = iterable.iterator();
            while (it2.hasNext()) {
                addEquality(it2.next());
            }
            return this;
        }

        public Builder addEquality(RowExpression rowExpression) {
            RowExpression normalizeInPredicateToEquality = normalizeInPredicateToEquality(rowExpression);
            Preconditions.checkArgument(isInferenceCandidate().apply(normalizeInPredicateToEquality), "RowExpression must be a simple equality: " + normalizeInPredicateToEquality);
            addEquality(RowExpressionEqualityInference.getLeft(normalizeInPredicateToEquality), RowExpressionEqualityInference.getRight(normalizeInPredicateToEquality));
            return this;
        }

        public Builder addEquality(RowExpression rowExpression, RowExpression rowExpression2) {
            Preconditions.checkArgument(!rowExpression.equals(rowExpression2), "Need to provide equality between different expressions");
            Preconditions.checkArgument(this.determinismEvaluator.isDeterministic(rowExpression), "RowExpression must be deterministic: " + rowExpression);
            Preconditions.checkArgument(this.determinismEvaluator.isDeterministic(rowExpression2), "RowExpression must be deterministic: " + rowExpression2);
            this.equalities.findAndUnion(rowExpression, rowExpression2);
            return this;
        }

        private void generateMoreEquivalences() {
            Collection<Set<RowExpression>> equivalentClasses = this.equalities.getEquivalentClasses();
            ImmutableMap.Builder builder = ImmutableMap.builder();
            for (Set<RowExpression> set : equivalentClasses) {
                set.forEach(rowExpression -> {
                    builder.put(rowExpression, set);
                });
            }
            ImmutableMap build = builder.build();
            for (K k : build.keySet()) {
                if (!this.derivedExpressions.contains(k)) {
                    for (RowExpression rowExpression2 : Iterables.filter(Expressions.uniqueSubExpressions(k), Predicates.not(Predicates.equalTo(k)))) {
                        Set set2 = (Set) build.get(rowExpression2);
                        if (set2 != null) {
                            Iterator it2 = Iterables.filter(set2, Predicates.not(Predicates.equalTo(rowExpression2))).iterator();
                            while (it2.hasNext()) {
                                RowExpression rewriteWith = RowExpressionTreeRewriter.rewriteWith(new RowExpressionNodeInliner(ImmutableMap.of(rowExpression2, (RowExpression) it2.next())), k);
                                this.equalities.findAndUnion(k, rewriteWith);
                                this.derivedExpressions.add(rewriteWith);
                            }
                        }
                    }
                }
            }
        }

        public RowExpressionEqualityInference build() {
            generateMoreEquivalences();
            return new RowExpressionEqualityInference(this.equalities.getEquivalentClasses(), this.derivedExpressions, this.determinismEvaluator, this.functionManager);
        }

        private boolean isOperation(RowExpression rowExpression, OperatorType operatorType) {
            if (!(rowExpression instanceof CallExpression)) {
                return false;
            }
            Optional<OperatorType> operatorType2 = this.functionManager.getFunctionMetadata(((CallExpression) rowExpression).getFunctionHandle()).getOperatorType();
            return operatorType2.isPresent() && operatorType2.get() == operatorType;
        }
    }

    /* loaded from: input_file:com/facebook/presto/sql/planner/RowExpressionEqualityInference$EqualityPartition.class */
    public static class EqualityPartition {
        private final List<RowExpression> scopeEqualities;
        private final List<RowExpression> scopeComplementEqualities;
        private final List<RowExpression> scopeStraddlingEqualities;

        public EqualityPartition(Iterable<RowExpression> iterable, Iterable<RowExpression> iterable2, Iterable<RowExpression> iterable3) {
            this.scopeEqualities = ImmutableList.copyOf((Iterable) Objects.requireNonNull(iterable, "scopeEqualities is null"));
            this.scopeComplementEqualities = ImmutableList.copyOf((Iterable) Objects.requireNonNull(iterable2, "scopeComplementEqualities is null"));
            this.scopeStraddlingEqualities = ImmutableList.copyOf((Iterable) Objects.requireNonNull(iterable3, "scopeStraddlingEqualities is null"));
        }

        public List<RowExpression> getScopeEqualities() {
            return this.scopeEqualities;
        }

        public List<RowExpression> getScopeComplementEqualities() {
            return this.scopeComplementEqualities;
        }

        public List<RowExpression> getScopeStraddlingEqualities() {
            return this.scopeStraddlingEqualities;
        }
    }

    private RowExpressionEqualityInference(Iterable<Set<RowExpression>> iterable, Set<RowExpression> set, RowExpressionDeterminismEvaluator rowExpressionDeterminismEvaluator, FunctionManager functionManager) {
        this.determinismEvaluator = rowExpressionDeterminismEvaluator;
        this.functionManager = functionManager;
        ImmutableSetMultimap.Builder builder = ImmutableSetMultimap.builder();
        for (Set<RowExpression> set2 : iterable) {
            if (!set2.isEmpty()) {
                builder.putAll((ImmutableSetMultimap.Builder) CANONICAL_ORDERING.min(set2), (Iterable) set2);
            }
        }
        this.equalitySets = builder.build();
        ImmutableMap.Builder builder2 = ImmutableMap.builder();
        for (Map.Entry<RowExpression, RowExpression> entry : this.equalitySets.entries()) {
            builder2.put(entry.getValue(), entry.getKey());
        }
        this.canonicalMap = builder2.build();
        this.derivedExpressions = ImmutableSet.copyOf((Collection) set);
    }

    public static RowExpressionEqualityInference createEqualityInference(Metadata metadata, RowExpression... rowExpressionArr) {
        return new Builder(metadata).addEqualityInference(rowExpressionArr).build();
    }

    public RowExpression rewriteExpression(RowExpression rowExpression, Predicate<VariableReferenceExpression> predicate) {
        Preconditions.checkArgument(this.determinismEvaluator.isDeterministic(rowExpression), "Only deterministic expressions may be considered for rewrite");
        return rewriteExpression(rowExpression, predicate, true);
    }

    public RowExpression rewriteExpressionAllowNonDeterministic(RowExpression rowExpression, Predicate<VariableReferenceExpression> predicate) {
        return rewriteExpression(rowExpression, predicate, true);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v24, types: [java.lang.Iterable] */
    private RowExpression rewriteExpression(RowExpression rowExpression, Predicate<VariableReferenceExpression> predicate, boolean z) {
        Set<RowExpression> uniqueSubExpressions = Expressions.uniqueSubExpressions(rowExpression);
        if (!z) {
            uniqueSubExpressions = Iterables.filter(uniqueSubExpressions, Predicates.not(Predicates.equalTo(rowExpression)));
        }
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (RowExpression rowExpression2 : uniqueSubExpressions) {
            RowExpression scopedCanonical = getScopedCanonical(rowExpression2, predicate);
            if (scopedCanonical != null) {
                builder.put(rowExpression2, scopedCanonical);
            }
        }
        RowExpression rewriteWith = RowExpressionTreeRewriter.rewriteWith(new RowExpressionNodeInliner(builder.build()), rowExpression);
        if (variableToExpressionPredicate(predicate).apply(rewriteWith)) {
            return rewriteWith;
        }
        return null;
    }

    public EqualityPartition generateEqualitiesPartitionedBy(Predicate<VariableReferenceExpression> predicate) {
        ImmutableSet.Builder builder = ImmutableSet.builder();
        ImmutableSet.Builder builder2 = ImmutableSet.builder();
        ImmutableSet.Builder builder3 = ImmutableSet.builder();
        for (Collection<RowExpression> collection : this.equalitySets.asMap().values()) {
            LinkedHashSet linkedHashSet = new LinkedHashSet();
            LinkedHashSet linkedHashSet2 = new LinkedHashSet();
            LinkedHashSet linkedHashSet3 = new LinkedHashSet();
            Set<RowExpression> set = this.derivedExpressions;
            set.getClass();
            for (RowExpression rowExpression : Iterables.filter(collection, Predicates.not((v1) -> {
                return r1.contains(v1);
            }))) {
                RowExpression rewriteExpression = rewriteExpression(rowExpression, predicate, false);
                if (rewriteExpression != null) {
                    linkedHashSet.add(rewriteExpression);
                }
                RowExpression rewriteExpression2 = rewriteExpression(rowExpression, Predicates.not(predicate), false);
                if (rewriteExpression2 != null) {
                    linkedHashSet2.add(rewriteExpression2);
                }
                if (rewriteExpression == null && rewriteExpression2 == null) {
                    linkedHashSet3.add(rowExpression);
                }
            }
            RowExpression canonical = getCanonical(linkedHashSet);
            if (linkedHashSet.size() >= 2) {
                Iterator it2 = Iterables.filter(linkedHashSet, Predicates.not(Predicates.equalTo(canonical))).iterator();
                while (it2.hasNext()) {
                    builder.add((ImmutableSet.Builder) buildEqualsExpression(this.functionManager, canonical, (RowExpression) it2.next()));
                }
            }
            RowExpression canonical2 = getCanonical(linkedHashSet2);
            if (linkedHashSet2.size() >= 2) {
                Iterator it3 = Iterables.filter(linkedHashSet2, Predicates.not(Predicates.equalTo(canonical2))).iterator();
                while (it3.hasNext()) {
                    builder2.add((ImmutableSet.Builder) buildEqualsExpression(this.functionManager, canonical2, (RowExpression) it3.next()));
                }
            }
            ArrayList arrayList = new ArrayList();
            arrayList.add(canonical);
            arrayList.add(canonical2);
            arrayList.addAll(linkedHashSet3);
            ImmutableList copyOf = ImmutableList.copyOf(Iterables.filter(arrayList, Predicates.notNull()));
            RowExpression canonical3 = getCanonical(copyOf);
            if (canonical3 != null) {
                Iterator it4 = Iterables.filter(copyOf, Predicates.not(Predicates.equalTo(canonical3))).iterator();
                while (it4.hasNext()) {
                    builder3.add((ImmutableSet.Builder) buildEqualsExpression(this.functionManager, canonical3, (RowExpression) it4.next()));
                }
            }
        }
        return new EqualityPartition(builder.build(), builder2.build(), builder3.build());
    }

    private static RowExpression getCanonical(Iterable<RowExpression> iterable) {
        if (Iterables.isEmpty(iterable)) {
            return null;
        }
        return (RowExpression) CANONICAL_ORDERING.min(iterable);
    }

    @VisibleForTesting
    RowExpression getScopedCanonical(RowExpression rowExpression, Predicate<VariableReferenceExpression> predicate) {
        RowExpression rowExpression2 = this.canonicalMap.get(rowExpression);
        if (rowExpression2 == null) {
            return null;
        }
        return getCanonical(Iterables.filter(this.equalitySets.get((SetMultimap<RowExpression, RowExpression>) rowExpression2), variableToExpressionPredicate(predicate)));
    }

    private static Predicate<RowExpression> variableToExpressionPredicate(Predicate<VariableReferenceExpression> predicate) {
        return rowExpression -> {
            return Iterables.all(VariablesExtractor.extractUnique(rowExpression), predicate);
        };
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static RowExpression getLeft(RowExpression rowExpression) {
        Preconditions.checkArgument((rowExpression instanceof CallExpression) && ((CallExpression) rowExpression).getArguments().size() == 2, "must be binary call expression");
        return ((CallExpression) rowExpression).getArguments().get(0);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static RowExpression getRight(RowExpression rowExpression) {
        Preconditions.checkArgument((rowExpression instanceof CallExpression) && ((CallExpression) rowExpression).getArguments().size() == 2, "must be binary call expression");
        return ((CallExpression) rowExpression).getArguments().get(1);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static boolean isInPredicate(RowExpression rowExpression) {
        return (rowExpression instanceof SpecialFormExpression) && ((SpecialFormExpression) rowExpression).getForm() == SpecialFormExpression.Form.IN;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static CallExpression buildEqualsExpression(FunctionManager functionManager, RowExpression rowExpression, RowExpression rowExpression2) {
        return binaryOperation(functionManager, OperatorType.EQUAL, rowExpression, rowExpression2);
    }

    private static CallExpression binaryOperation(FunctionManager functionManager, OperatorType operatorType, RowExpression rowExpression, RowExpression rowExpression2) {
        return Expressions.call(operatorType.getFunctionName().getFunctionName(), functionManager.resolveOperator(operatorType, TypeSignatureProvider.fromTypes(rowExpression.getType(), rowExpression2.getType())), BooleanType.BOOLEAN, rowExpression, rowExpression2);
    }
}
