package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.airlift.log.Logger;
import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.cost.CostComparator;
import com.facebook.presto.cost.CostProvider;
import com.facebook.presto.cost.PlanCostEstimate;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.DeterminismEvaluator;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.analyzer.FeaturesConfig;
import com.facebook.presto.sql.planner.EqualityInference;
import com.facebook.presto.sql.planner.PlanVariableAllocator;
import com.facebook.presto.sql.planner.VariablesExtractor;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.optimizations.JoinNodeUtils;
import com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Predicates;
import com.google.common.base.Verify;
import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Ordering;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

/* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.class */
public class ReorderJoins implements Rule<JoinNode> {
    private static final Logger log = Logger.get((Class<?>) ReorderJoins.class);
    private final Pattern<JoinNode> joinNodePattern = Patterns.join().matching(joinNode -> {
        return !joinNode.getDistributionType().isPresent() && joinNode.getType() == JoinNode.Type.INNER && this.determinismEvaluator.isDeterministic(joinNode.getFilter().orElse(LogicalRowExpressions.TRUE_CONSTANT));
    });
    private final CostComparator costComparator;
    private final Metadata metadata;
    private final FunctionResolution functionResolution;
    private final DeterminismEvaluator determinismEvaluator;

    /* JADX INFO: Access modifiers changed from: package-private */
    @VisibleForTesting
    /* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/ReorderJoins$JoinEnumerationResult.class */
    public static class JoinEnumerationResult {
        public static final JoinEnumerationResult UNKNOWN_COST_RESULT = new JoinEnumerationResult(Optional.empty(), PlanCostEstimate.unknown());
        public static final JoinEnumerationResult INFINITE_COST_RESULT = new JoinEnumerationResult(Optional.empty(), PlanCostEstimate.infinite());
        private final Optional<PlanNode> planNode;
        private final PlanCostEstimate cost;

        private JoinEnumerationResult(Optional<PlanNode> optional, PlanCostEstimate planCostEstimate) {
            this.planNode = (Optional) Objects.requireNonNull(optional, "planNode is null");
            this.cost = (PlanCostEstimate) Objects.requireNonNull(planCostEstimate, "cost is null");
            Preconditions.checkArgument(((planCostEstimate.hasUnknownComponents() || planCostEstimate.equals(PlanCostEstimate.infinite())) && !optional.isPresent()) || (!(planCostEstimate.hasUnknownComponents() && planCostEstimate.equals(PlanCostEstimate.infinite())) && optional.isPresent()), "planNode should be present if and only if cost is known");
        }

        public Optional<PlanNode> getPlanNode() {
            return this.planNode;
        }

        public PlanCostEstimate getCost() {
            return this.cost;
        }

        static JoinEnumerationResult createJoinEnumerationResult(Optional<PlanNode> optional, PlanCostEstimate planCostEstimate) {
            return planCostEstimate.hasUnknownComponents() ? UNKNOWN_COST_RESULT : planCostEstimate.equals(PlanCostEstimate.infinite()) ? INFINITE_COST_RESULT : new JoinEnumerationResult(optional, planCostEstimate);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @VisibleForTesting
    /* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/ReorderJoins$JoinEnumerator.class */
    public static class JoinEnumerator {
        private final Session session;
        private final CostProvider costProvider;
        private final Ordering<JoinEnumerationResult> resultComparator;
        private final PlanNodeIdAllocator idAllocator;
        private final Metadata metadata;
        private final RowExpression allFilter;
        private final EqualityInference allFilterInference;
        private final LogicalRowExpressions logicalRowExpressions;
        private final Lookup lookup;
        private final Rule.Context context;
        private final Map<Set<PlanNode>, JoinEnumerationResult> memo = new HashMap();

        @VisibleForTesting
        JoinEnumerator(CostComparator costComparator, RowExpression rowExpression, Rule.Context context, DeterminismEvaluator determinismEvaluator, FunctionResolution functionResolution, Metadata metadata) {
            this.context = (Rule.Context) Objects.requireNonNull(context);
            this.session = (Session) Objects.requireNonNull(context.getSession(), "session is null");
            this.costProvider = (CostProvider) Objects.requireNonNull(context.getCostProvider(), "costProvider is null");
            this.resultComparator = costComparator.forSession(this.session).onResultOf(joinEnumerationResult -> {
                return joinEnumerationResult.cost;
            });
            this.idAllocator = (PlanNodeIdAllocator) Objects.requireNonNull(context.getIdAllocator(), "idAllocator is null");
            this.allFilter = (RowExpression) Objects.requireNonNull(rowExpression, "filter is null");
            this.lookup = (Lookup) Objects.requireNonNull(context.getLookup(), "lookup is null");
            this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
            this.allFilterInference = EqualityInference.createEqualityInference(metadata, rowExpression);
            this.logicalRowExpressions = new LogicalRowExpressions(determinismEvaluator, functionResolution, metadata.getFunctionAndTypeManager());
        }

        /* JADX INFO: Access modifiers changed from: private */
        public JoinEnumerationResult chooseJoinOrder(LinkedHashSet<PlanNode> linkedHashSet, List<VariableReferenceExpression> list) {
            this.context.checkTimeoutNotExhausted();
            ImmutableSet copyOf = ImmutableSet.copyOf((Collection) linkedHashSet);
            JoinEnumerationResult joinEnumerationResult = this.memo.get(copyOf);
            if (joinEnumerationResult == null) {
                Preconditions.checkState(linkedHashSet.size() > 1, "sources size is less than or equal to one");
                ImmutableList.Builder builder = ImmutableList.builder();
                Iterator<Set<Integer>> it2 = generatePartitions(linkedHashSet.size()).iterator();
                while (it2.hasNext()) {
                    JoinEnumerationResult createJoinAccordingToPartitioning = createJoinAccordingToPartitioning(linkedHashSet, list, it2.next());
                    if (createJoinAccordingToPartitioning.equals(JoinEnumerationResult.UNKNOWN_COST_RESULT)) {
                        this.memo.put(copyOf, createJoinAccordingToPartitioning);
                        return createJoinAccordingToPartitioning;
                    }
                    if (!createJoinAccordingToPartitioning.equals(JoinEnumerationResult.INFINITE_COST_RESULT)) {
                        builder.add((ImmutableList.Builder) createJoinAccordingToPartitioning);
                    }
                }
                ImmutableList build = builder.build();
                if (build.isEmpty()) {
                    this.memo.put(copyOf, JoinEnumerationResult.INFINITE_COST_RESULT);
                    return JoinEnumerationResult.INFINITE_COST_RESULT;
                }
                joinEnumerationResult = (JoinEnumerationResult) this.resultComparator.min(build);
                this.memo.put(copyOf, joinEnumerationResult);
            }
            joinEnumerationResult.planNode.ifPresent(planNode -> {
                ReorderJoins.log.debug("Least cost join was: %s", planNode);
            });
            return joinEnumerationResult;
        }

        @VisibleForTesting
        static Set<Set<Integer>> generatePartitions(int i) {
            Preconditions.checkArgument(i > 1, "totalNodes must be greater than 1");
            Set set = (Set) IntStream.range(0, i).boxed().collect(ImmutableSet.toImmutableSet());
            return (Set) Sets.powerSet(set).stream().filter(set2 -> {
                return set2.contains(0);
            }).filter(set3 -> {
                return set3.size() < set.size();
            }).collect(ImmutableSet.toImmutableSet());
        }

        @VisibleForTesting
        JoinEnumerationResult createJoinAccordingToPartitioning(LinkedHashSet<PlanNode> linkedHashSet, List<VariableReferenceExpression> list, Set<Integer> set) {
            ImmutableList copyOf = ImmutableList.copyOf((Collection) linkedHashSet);
            Stream<Integer> stream = set.stream();
            copyOf.getClass();
            LinkedHashSet<PlanNode> linkedHashSet2 = (LinkedHashSet) stream.map((v1) -> {
                return r1.get(v1);
            }).collect(Collectors.toCollection(LinkedHashSet::new));
            return createJoin(linkedHashSet2, (LinkedHashSet) linkedHashSet.stream().filter(planNode -> {
                return !linkedHashSet2.contains(planNode);
            }).collect(Collectors.toCollection(LinkedHashSet::new)), list);
        }

        private JoinEnumerationResult createJoin(LinkedHashSet<PlanNode> linkedHashSet, LinkedHashSet<PlanNode> linkedHashSet2, List<VariableReferenceExpression> list) {
            Set<VariableReferenceExpression> set = (Set) linkedHashSet.stream().flatMap(planNode -> {
                return planNode.getOutputVariables().stream();
            }).collect(ImmutableSet.toImmutableSet());
            Set<VariableReferenceExpression> set2 = (Set) linkedHashSet2.stream().flatMap(planNode2 -> {
                return planNode2.getOutputVariables().stream();
            }).collect(ImmutableSet.toImmutableSet());
            List<RowExpression> joinPredicates = getJoinPredicates(set, set2);
            List list2 = (List) joinPredicates.stream().filter(JoinEnumerator::isJoinEqualityCondition).map(rowExpression -> {
                return toEquiJoinClause((CallExpression) rowExpression, set, this.context.getVariableAllocator());
            }).collect(ImmutableList.toImmutableList());
            if (list2.isEmpty()) {
                return JoinEnumerationResult.INFINITE_COST_RESULT;
            }
            List list3 = (List) joinPredicates.stream().filter(rowExpression2 -> {
                return !isJoinEqualityCondition(rowExpression2);
            }).collect(ImmutableList.toImmutableList());
            ImmutableSet build = ImmutableSet.builder().addAll((Iterable) list).addAll((Iterable) VariablesExtractor.extractUnique(joinPredicates)).build();
            Stream<E> stream = build.stream();
            set.getClass();
            JoinEnumerationResult joinSource = getJoinSource(linkedHashSet, (List) stream.filter((v1) -> {
                return r3.contains(v1);
            }).collect(ImmutableList.toImmutableList()));
            if (joinSource.equals(JoinEnumerationResult.UNKNOWN_COST_RESULT)) {
                return JoinEnumerationResult.UNKNOWN_COST_RESULT;
            }
            if (joinSource.equals(JoinEnumerationResult.INFINITE_COST_RESULT)) {
                return JoinEnumerationResult.INFINITE_COST_RESULT;
            }
            PlanNode planNode3 = (PlanNode) joinSource.planNode.orElseThrow(() -> {
                return new VerifyException("Plan node is not present");
            });
            Stream<E> stream2 = build.stream();
            set2.getClass();
            JoinEnumerationResult joinSource2 = getJoinSource(linkedHashSet2, (List) stream2.filter((v1) -> {
                return r3.contains(v1);
            }).collect(ImmutableList.toImmutableList()));
            if (joinSource2.equals(JoinEnumerationResult.UNKNOWN_COST_RESULT)) {
                return JoinEnumerationResult.UNKNOWN_COST_RESULT;
            }
            if (joinSource2.equals(JoinEnumerationResult.INFINITE_COST_RESULT)) {
                return JoinEnumerationResult.INFINITE_COST_RESULT;
            }
            PlanNode planNode4 = (PlanNode) joinSource2.planNode.orElseThrow(() -> {
                return new VerifyException("Plan node is not present");
            });
            Stream concat = Stream.concat(planNode3.getOutputVariables().stream(), planNode4.getOutputVariables().stream());
            list.getClass();
            return setJoinNodeProperties(new JoinNode(this.idAllocator.getNextId(), JoinNode.Type.INNER, planNode3, planNode4, list2, (List) concat.filter((v1) -> {
                return r1.contains(v1);
            }).collect(ImmutableList.toImmutableList()), list3.isEmpty() ? Optional.empty() : Optional.of(LogicalRowExpressions.and(list3)), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of()));
        }

        private List<RowExpression> getJoinPredicates(Set<VariableReferenceExpression> set, Set<VariableReferenceExpression> set2) {
            ImmutableList.Builder builder = ImmutableList.builder();
            Stream filter = StreamSupport.stream(new EqualityInference.Builder(this.metadata).nonInferrableConjuncts(this.allFilter).spliterator(), false).map(rowExpression -> {
                return this.allFilterInference.rewriteExpression(rowExpression, variableReferenceExpression -> {
                    return set.contains(variableReferenceExpression) || set2.contains(variableReferenceExpression);
                });
            }).filter((v0) -> {
                return Objects.nonNull(v0);
            }).filter(rowExpression2 -> {
                EqualityInference equalityInference = this.allFilterInference;
                set.getClass();
                return equalityInference.rewriteExpression(rowExpression2, (v1) -> {
                    return r2.contains(v1);
                }) == null;
            }).filter(rowExpression3 -> {
                EqualityInference equalityInference = this.allFilterInference;
                set2.getClass();
                return equalityInference.rewriteExpression(rowExpression3, (v1) -> {
                    return r2.contains(v1);
                }) == null;
            });
            builder.getClass();
            filter.forEach((v1) -> {
                r1.add(v1);
            });
            builder.addAll((Iterable) EqualityInference.createEqualityInference(this.metadata, (RowExpression[]) this.allFilterInference.generateEqualitiesPartitionedBy(variableReferenceExpression -> {
                return set.contains(variableReferenceExpression) || set2.contains(variableReferenceExpression);
            }).getScopeEqualities().toArray(new RowExpression[0])).generateEqualitiesPartitionedBy(Predicates.in(set)).getScopeStraddlingEqualities());
            return builder.build();
        }

        private JoinEnumerationResult getJoinSource(LinkedHashSet<PlanNode> linkedHashSet, List<VariableReferenceExpression> list) {
            if (linkedHashSet.size() != 1) {
                return chooseJoinOrder(linkedHashSet, list);
            }
            PlanNode planNode = (PlanNode) Iterables.getOnlyElement(linkedHashSet);
            ImmutableList.Builder builder = ImmutableList.builder();
            EqualityInference equalityInference = this.allFilterInference;
            list.getClass();
            builder.addAll((Iterable) equalityInference.generateEqualitiesPartitionedBy((v1) -> {
                return r2.contains(v1);
            }).getScopeEqualities());
            Stream filter = StreamSupport.stream(new EqualityInference.Builder(this.metadata).nonInferrableConjuncts(this.allFilter).spliterator(), false).map(rowExpression -> {
                EqualityInference equalityInference2 = this.allFilterInference;
                list.getClass();
                return equalityInference2.rewriteExpression(rowExpression, (v1) -> {
                    return r2.contains(v1);
                });
            }).filter((v0) -> {
                return Objects.nonNull(v0);
            });
            builder.getClass();
            filter.forEach((v1) -> {
                r1.add(v1);
            });
            RowExpression combineConjuncts = this.logicalRowExpressions.combineConjuncts(builder.build());
            if (!LogicalRowExpressions.TRUE_CONSTANT.equals(combineConjuncts)) {
                planNode = new FilterNode(this.idAllocator.getNextId(), planNode, combineConjuncts);
            }
            return createJoinEnumerationResult(planNode);
        }

        private static boolean isJoinEqualityCondition(RowExpression rowExpression) {
            return (rowExpression instanceof CallExpression) && ((CallExpression) rowExpression).getDisplayName().equals(OperatorType.EQUAL.getFunctionName().getFunctionName()) && ((CallExpression) rowExpression).getArguments().size() == 2 && (((CallExpression) rowExpression).getArguments().get(0) instanceof VariableReferenceExpression) && (((CallExpression) rowExpression).getArguments().get(1) instanceof VariableReferenceExpression);
        }

        private static JoinNode.EquiJoinClause toEquiJoinClause(CallExpression callExpression, Set<VariableReferenceExpression> set, PlanVariableAllocator planVariableAllocator) {
            Preconditions.checkArgument(callExpression.getArguments().size() == 2, "Unexpected number of arguments in binary operator equals");
            VariableReferenceExpression variableReferenceExpression = (VariableReferenceExpression) callExpression.getArguments().get(0);
            JoinNode.EquiJoinClause equiJoinClause = new JoinNode.EquiJoinClause(variableReferenceExpression, (VariableReferenceExpression) callExpression.getArguments().get(1));
            return set.contains(variableReferenceExpression) ? equiJoinClause : equiJoinClause.flip();
        }

        private JoinEnumerationResult setJoinNodeProperties(JoinNode joinNode) {
            if (QueryCardinalityUtil.isAtMostScalar(joinNode.getRight(), this.lookup)) {
                return createJoinEnumerationResult(joinNode.withDistributionType(JoinNode.DistributionType.REPLICATED));
            }
            if (QueryCardinalityUtil.isAtMostScalar(joinNode.getLeft(), this.lookup)) {
                return createJoinEnumerationResult(joinNode.flipChildren().withDistributionType(JoinNode.DistributionType.REPLICATED));
            }
            List<JoinEnumerationResult> possibleJoinNodes = getPossibleJoinNodes(joinNode, SystemSessionProperties.getJoinDistributionType(this.session));
            Verify.verify(!possibleJoinNodes.isEmpty(), "possibleJoinNodes is empty", new Object[0]);
            Stream<JoinEnumerationResult> stream = possibleJoinNodes.stream();
            JoinEnumerationResult joinEnumerationResult = JoinEnumerationResult.UNKNOWN_COST_RESULT;
            joinEnumerationResult.getClass();
            return stream.anyMatch((v1) -> {
                return r1.equals(v1);
            }) ? JoinEnumerationResult.UNKNOWN_COST_RESULT : (JoinEnumerationResult) this.resultComparator.min(possibleJoinNodes);
        }

        private List<JoinEnumerationResult> getPossibleJoinNodes(JoinNode joinNode, FeaturesConfig.JoinDistributionType joinDistributionType) {
            Preconditions.checkArgument(joinNode.getType() == JoinNode.Type.INNER, "unexpected join node type: %s", joinNode.getType());
            if (joinNode.isCrossJoin()) {
                return getPossibleJoinNodes(joinNode, JoinNode.DistributionType.REPLICATED);
            }
            switch (joinDistributionType) {
                case PARTITIONED:
                    return getPossibleJoinNodes(joinNode, JoinNode.DistributionType.PARTITIONED);
                case BROADCAST:
                    return getPossibleJoinNodes(joinNode, JoinNode.DistributionType.REPLICATED);
                case AUTOMATIC:
                    ImmutableList.Builder builder = ImmutableList.builder();
                    builder.addAll((Iterable) getPossibleJoinNodes(joinNode, JoinNode.DistributionType.PARTITIONED));
                    if (DetermineJoinDistributionType.isBelowMaxBroadcastSize(joinNode, this.context)) {
                        builder.addAll((Iterable) getPossibleJoinNodes(joinNode, JoinNode.DistributionType.REPLICATED));
                    }
                    return builder.build();
                default:
                    throw new IllegalArgumentException("unexpected join distribution type: " + joinDistributionType);
            }
        }

        private List<JoinEnumerationResult> getPossibleJoinNodes(JoinNode joinNode, JoinNode.DistributionType distributionType) {
            return ImmutableList.of(createJoinEnumerationResult(joinNode.withDistributionType(distributionType)), createJoinEnumerationResult(joinNode.flipChildren().withDistributionType(distributionType)));
        }

        private JoinEnumerationResult createJoinEnumerationResult(PlanNode planNode) {
            return JoinEnumerationResult.createJoinEnumerationResult(Optional.of(planNode), this.costProvider.getCost(planNode));
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @VisibleForTesting
    /* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/ReorderJoins$MultiJoinNode.class */
    public static class MultiJoinNode {
        private final LinkedHashSet<PlanNode> sources;
        private final RowExpression filter;
        private final List<VariableReferenceExpression> outputVariables;

        /* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/ReorderJoins$MultiJoinNode$Builder.class */
        static class Builder {
            private List<PlanNode> sources;
            private RowExpression filter;
            private List<VariableReferenceExpression> outputVariables;

            Builder() {
            }

            public Builder setSources(PlanNode... planNodeArr) {
                this.sources = ImmutableList.copyOf(planNodeArr);
                return this;
            }

            public Builder setFilter(RowExpression rowExpression) {
                this.filter = rowExpression;
                return this;
            }

            public Builder setOutputVariables(VariableReferenceExpression... variableReferenceExpressionArr) {
                this.outputVariables = ImmutableList.copyOf(variableReferenceExpressionArr);
                return this;
            }

            public MultiJoinNode build() {
                return new MultiJoinNode(new LinkedHashSet(this.sources), this.filter, this.outputVariables);
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/ReorderJoins$MultiJoinNode$JoinNodeFlattener.class */
        public static class JoinNodeFlattener {
            private final LinkedHashSet<PlanNode> sources = new LinkedHashSet<>();
            private final List<RowExpression> filters = new ArrayList();
            private final List<VariableReferenceExpression> outputVariables;
            private final FunctionResolution functionResolution;
            private final DeterminismEvaluator determinismEvaluator;
            private final Lookup lookup;

            JoinNodeFlattener(JoinNode joinNode, Lookup lookup, int i, FunctionResolution functionResolution, DeterminismEvaluator determinismEvaluator) {
                Objects.requireNonNull(joinNode, "node is null");
                Preconditions.checkState(joinNode.getType() == JoinNode.Type.INNER, "join type must be INNER");
                this.outputVariables = joinNode.getOutputVariables();
                this.lookup = (Lookup) Objects.requireNonNull(lookup, "lookup is null");
                this.functionResolution = (FunctionResolution) Objects.requireNonNull(functionResolution, "functionResolution is null");
                this.determinismEvaluator = (DeterminismEvaluator) Objects.requireNonNull(determinismEvaluator, "determinismEvaluator is null");
                flattenNode(joinNode, i);
            }

            private void flattenNode(PlanNode planNode, int i) {
                PlanNode resolve = this.lookup.resolve(planNode);
                if (!(resolve instanceof JoinNode) || this.sources.size() > i - 2) {
                    this.sources.add(planNode);
                    return;
                }
                JoinNode joinNode = (JoinNode) resolve;
                if (joinNode.getType() != JoinNode.Type.INNER || !this.determinismEvaluator.isDeterministic(joinNode.getFilter().orElse(LogicalRowExpressions.TRUE_CONSTANT)) || joinNode.getDistributionType().isPresent()) {
                    this.sources.add(planNode);
                    return;
                }
                flattenNode(joinNode.getLeft(), i - 1);
                flattenNode(joinNode.getRight(), i);
                Stream<R> map = joinNode.getCriteria().stream().map(equiJoinClause -> {
                    return JoinNodeUtils.toRowExpression(equiJoinClause, this.functionResolution);
                });
                List<RowExpression> list = this.filters;
                list.getClass();
                map.forEach((v1) -> {
                    r1.add(v1);
                });
                Optional<RowExpression> filter = joinNode.getFilter();
                List<RowExpression> list2 = this.filters;
                list2.getClass();
                filter.ifPresent((v1) -> {
                    r1.add(v1);
                });
            }

            MultiJoinNode toMultiJoinNode() {
                return new MultiJoinNode(this.sources, LogicalRowExpressions.and(this.filters), this.outputVariables);
            }
        }

        public MultiJoinNode(LinkedHashSet<PlanNode> linkedHashSet, RowExpression rowExpression, List<VariableReferenceExpression> list) {
            Preconditions.checkArgument(linkedHashSet.size() > 1, "sources size is <= 1");
            this.sources = (LinkedHashSet) Objects.requireNonNull(linkedHashSet, "sources is null");
            this.filter = (RowExpression) Objects.requireNonNull(rowExpression, "filter is null");
            this.outputVariables = ImmutableList.copyOf((Collection) Objects.requireNonNull(list, "outputVariables is null"));
            Preconditions.checkArgument(((List) linkedHashSet.stream().flatMap(planNode -> {
                return planNode.getOutputVariables().stream();
            }).collect(ImmutableList.toImmutableList())).containsAll(list), "inputs do not contain all output variables");
        }

        public RowExpression getFilter() {
            return this.filter;
        }

        public LinkedHashSet<PlanNode> getSources() {
            return this.sources;
        }

        public List<VariableReferenceExpression> getOutputVariables() {
            return this.outputVariables;
        }

        public static Builder builder() {
            return new Builder();
        }

        public int hashCode() {
            return Objects.hash(this.sources, ImmutableSet.copyOf((Collection) LogicalRowExpressions.extractConjuncts(this.filter)), this.outputVariables);
        }

        public boolean equals(Object obj) {
            if (!(obj instanceof MultiJoinNode)) {
                return false;
            }
            MultiJoinNode multiJoinNode = (MultiJoinNode) obj;
            return this.sources.equals(multiJoinNode.sources) && ImmutableSet.copyOf((Collection) LogicalRowExpressions.extractConjuncts(this.filter)).equals(ImmutableSet.copyOf((Collection) LogicalRowExpressions.extractConjuncts(multiJoinNode.filter))) && this.outputVariables.equals(multiJoinNode.outputVariables);
        }

        static MultiJoinNode toMultiJoinNode(JoinNode joinNode, Lookup lookup, int i, FunctionResolution functionResolution, DeterminismEvaluator determinismEvaluator) {
            return new JoinNodeFlattener(joinNode, lookup, i + 1, functionResolution, determinismEvaluator).toMultiJoinNode();
        }
    }

    public ReorderJoins(CostComparator costComparator, Metadata metadata) {
        this.costComparator = (CostComparator) Objects.requireNonNull(costComparator, "costComparator is null");
        this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
        this.functionResolution = new FunctionResolution(metadata.getFunctionAndTypeManager());
        this.determinismEvaluator = new RowExpressionDeterminismEvaluator(metadata.getFunctionAndTypeManager());
    }

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public Pattern<JoinNode> getPattern() {
        return this.joinNodePattern;
    }

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.getJoinReorderingStrategy(session) == FeaturesConfig.JoinReorderingStrategy.AUTOMATIC;
    }

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public Rule.Result apply(JoinNode joinNode, Captures captures, Rule.Context context) {
        MultiJoinNode multiJoinNode = MultiJoinNode.toMultiJoinNode(joinNode, context.getLookup(), SystemSessionProperties.getMaxReorderedJoins(context.getSession()), this.functionResolution, this.determinismEvaluator);
        JoinEnumerationResult chooseJoinOrder = new JoinEnumerator(this.costComparator, multiJoinNode.getFilter(), context, this.determinismEvaluator, this.functionResolution, this.metadata).chooseJoinOrder(multiJoinNode.getSources(), multiJoinNode.getOutputVariables());
        return !chooseJoinOrder.getPlanNode().isPresent() ? Rule.Result.empty() : Rule.Result.ofPlanNode(chooseJoinOrder.getPlanNode().get());
    }
}
