package com.facebook.presto.execution.scheduler;

import com.facebook.presto.execution.SqlStageExecution;
import com.facebook.presto.execution.StageState;
import com.facebook.presto.sql.planner.PlanFragment;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.IndexJoinNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.PlanFragmentId;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.PlanVisitor;
import com.facebook.presto.sql.planner.plan.RemoteSourceNode;
import com.facebook.presto.sql.planner.plan.SemiJoinNode;
import com.facebook.presto.sql.planner.plan.UnionNode;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.concurrent.NotThreadSafe;
import org.jgrapht.DirectedGraph;
import org.jgrapht.alg.StrongConnectivityInspector;
import org.jgrapht.graph.DefaultDirectedGraph;
import org.jgrapht.graph.DefaultEdge;
import org.jgrapht.traverse.TopologicalOrderIterator;

@NotThreadSafe
/* loaded from: input_file:com/facebook/presto/execution/scheduler/PhasedExecutionSchedule.class */
public class PhasedExecutionSchedule implements ExecutionSchedule {
    private final List<Set<SqlStageExecution>> schedulePhases;
    private final Set<SqlStageExecution> activeSources = new HashSet();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/facebook/presto/execution/scheduler/PhasedExecutionSchedule$Visitor.class */
    public static class Visitor extends PlanVisitor<Set<PlanFragmentId>, PlanFragmentId> {
        private final Map<PlanFragmentId, PlanFragment> fragments;
        private final DirectedGraph<PlanFragmentId, DefaultEdge> graph;
        private final Map<PlanFragmentId, Set<PlanFragmentId>> fragmentSources = new HashMap();

        public Visitor(Collection<PlanFragment> collection, DirectedGraph<PlanFragmentId, DefaultEdge> directedGraph) {
            this.fragments = (Map) collection.stream().collect(ImmutableMap.toImmutableMap((v0) -> {
                return v0.getId();
            }, Function.identity()));
            this.graph = directedGraph;
        }

        public Set<PlanFragmentId> processFragment(PlanFragmentId planFragmentId) {
            if (this.fragmentSources.containsKey(planFragmentId)) {
                return this.fragmentSources.get(planFragmentId);
            }
            Set<PlanFragmentId> processFragment = processFragment(this.fragments.get(planFragmentId));
            this.fragmentSources.put(planFragmentId, processFragment);
            return processFragment;
        }

        private Set<PlanFragmentId> processFragment(PlanFragment planFragment) {
            return ImmutableSet.builder().add((ImmutableSet.Builder) planFragment.getId()).addAll((Iterable) planFragment.getRoot().accept(this, planFragment.getId())).build();
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public Set<PlanFragmentId> visitJoin(JoinNode joinNode, PlanFragmentId planFragmentId) {
            return processJoin(joinNode.getRight(), joinNode.getLeft(), planFragmentId);
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public Set<PlanFragmentId> visitSemiJoin(SemiJoinNode semiJoinNode, PlanFragmentId planFragmentId) {
            return processJoin(semiJoinNode.getFilteringSource(), semiJoinNode.getSource(), planFragmentId);
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public Set<PlanFragmentId> visitIndexJoin(IndexJoinNode indexJoinNode, PlanFragmentId planFragmentId) {
            return processJoin(indexJoinNode.getIndexSource(), indexJoinNode.getProbeSource(), planFragmentId);
        }

        private Set<PlanFragmentId> processJoin(PlanNode planNode, PlanNode planNode2, PlanFragmentId planFragmentId) {
            Set<PlanFragmentId> set = (Set) planNode.accept(this, planFragmentId);
            Set set2 = (Set) planNode2.accept(this, planFragmentId);
            for (PlanFragmentId planFragmentId2 : set) {
                Iterator it2 = set2.iterator();
                while (it2.hasNext()) {
                    this.graph.addEdge(planFragmentId2, (PlanFragmentId) it2.next());
                }
            }
            return ImmutableSet.builder().addAll((Iterable) set).addAll((Iterable) set2).build();
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public Set<PlanFragmentId> visitRemoteSource(RemoteSourceNode remoteSourceNode, PlanFragmentId planFragmentId) {
            ImmutableSet.Builder builder = ImmutableSet.builder();
            Set<PlanFragmentId> of = ImmutableSet.of();
            for (PlanFragmentId planFragmentId2 : remoteSourceNode.getSourceFragmentIds()) {
                this.graph.addEdge(planFragmentId, planFragmentId2);
                Set<PlanFragmentId> processFragment = processFragment(planFragmentId2);
                builder.addAll((Iterable) processFragment);
                addEdges(of, processFragment);
                of = processFragment;
            }
            return builder.build();
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public Set<PlanFragmentId> visitExchange(ExchangeNode exchangeNode, PlanFragmentId planFragmentId) {
            Preconditions.checkArgument(exchangeNode.getScope() == ExchangeNode.Scope.LOCAL, "Only local exchanges are supported in the phased execution scheduler");
            ImmutableSet.Builder builder = ImmutableSet.builder();
            Set<PlanFragmentId> of = ImmutableSet.of();
            Iterator<PlanNode> it2 = exchangeNode.getSources().iterator();
            while (it2.hasNext()) {
                Set<PlanFragmentId> set = (Set) it2.next().accept(this, planFragmentId);
                builder.addAll((Iterable) set);
                addEdges(of, set);
                of = set;
            }
            return builder.build();
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public Set<PlanFragmentId> visitUnion(UnionNode unionNode, PlanFragmentId planFragmentId) {
            ImmutableSet.Builder builder = ImmutableSet.builder();
            Set<PlanFragmentId> of = ImmutableSet.of();
            Iterator<PlanNode> it2 = unionNode.getSources().iterator();
            while (it2.hasNext()) {
                Set<PlanFragmentId> set = (Set) it2.next().accept(this, planFragmentId);
                builder.addAll((Iterable) set);
                addEdges(of, set);
                of = set;
            }
            return builder.build();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public Set<PlanFragmentId> visitPlan(PlanNode planNode, PlanFragmentId planFragmentId) {
            List<PlanNode> sources = planNode.getSources();
            if (sources.isEmpty()) {
                return ImmutableSet.of(planFragmentId);
            }
            if (sources.size() == 1) {
                return (Set) sources.get(0).accept(this, planFragmentId);
            }
            throw new UnsupportedOperationException("not yet implemented: " + planNode.getClass().getName());
        }

        private void addEdges(Set<PlanFragmentId> set, Set<PlanFragmentId> set2) {
            for (PlanFragmentId planFragmentId : set2) {
                Iterator<PlanFragmentId> it2 = set.iterator();
                while (it2.hasNext()) {
                    this.graph.addEdge(it2.next(), planFragmentId);
                }
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public PhasedExecutionSchedule(Collection<SqlStageExecution> collection) {
        List<Set<PlanFragmentId>> extractPhases = extractPhases((Collection) collection.stream().map((v0) -> {
            return v0.getFragment();
        }).collect(ImmutableList.toImmutableList()));
        Map map = (Map) collection.stream().collect(ImmutableMap.toImmutableMap(sqlStageExecution -> {
            return sqlStageExecution.getFragment().getId();
        }, Function.identity()));
        this.schedulePhases = new ArrayList();
        for (Set<PlanFragmentId> set : extractPhases) {
            List<Set<SqlStageExecution>> list = this.schedulePhases;
            Stream<PlanFragmentId> stream = set.stream();
            map.getClass();
            list.add(stream.map((v1) -> {
                return r2.get(v1);
            }).collect(Collectors.toCollection(HashSet::new)));
        }
    }

    @Override // com.facebook.presto.execution.scheduler.ExecutionSchedule
    public Set<SqlStageExecution> getStagesToSchedule() {
        removeCompletedStages();
        addPhasesIfNecessary();
        return isFinished() ? ImmutableSet.of() : this.activeSources;
    }

    private void removeCompletedStages() {
        Iterator<SqlStageExecution> it2 = this.activeSources.iterator();
        while (it2.hasNext()) {
            StageState state = it2.next().getState();
            if (state == StageState.SCHEDULED || state == StageState.RUNNING || state.isDone()) {
                it2.remove();
            }
        }
    }

    private void addPhasesIfNecessary() {
        if (hasSourceDistributedStage(this.activeSources)) {
            return;
        }
        while (!this.schedulePhases.isEmpty()) {
            Set<SqlStageExecution> remove = this.schedulePhases.remove(0);
            this.activeSources.addAll(remove);
            if (hasSourceDistributedStage(remove)) {
                return;
            }
        }
    }

    private static boolean hasSourceDistributedStage(Set<SqlStageExecution> set) {
        return set.stream().anyMatch(sqlStageExecution -> {
            return !sqlStageExecution.getFragment().getPartitionedSources().isEmpty();
        });
    }

    @Override // com.facebook.presto.execution.scheduler.ExecutionSchedule
    public boolean isFinished() {
        return this.activeSources.isEmpty() && this.schedulePhases.isEmpty();
    }

    @VisibleForTesting
    static List<Set<PlanFragmentId>> extractPhases(Collection<PlanFragment> collection) {
        DefaultDirectedGraph defaultDirectedGraph = new DefaultDirectedGraph(DefaultEdge.class);
        collection.forEach(planFragment -> {
            defaultDirectedGraph.addVertex(planFragment.getId());
        });
        Visitor visitor = new Visitor(collection, defaultDirectedGraph);
        Iterator<PlanFragment> it2 = collection.iterator();
        while (it2.hasNext()) {
            visitor.processFragment(it2.next().getId());
        }
        List<Set> stronglyConnectedSets = new StrongConnectivityInspector(defaultDirectedGraph).stronglyConnectedSets();
        HashMap hashMap = new HashMap();
        for (Set set : stronglyConnectedSets) {
            Iterator it3 = set.iterator();
            while (it3.hasNext()) {
                hashMap.put((PlanFragmentId) it3.next(), set);
            }
        }
        DefaultDirectedGraph defaultDirectedGraph2 = new DefaultDirectedGraph(DefaultEdge.class);
        defaultDirectedGraph2.getClass();
        stronglyConnectedSets.forEach((v1) -> {
            r1.addVertex(v1);
        });
        for (E e : defaultDirectedGraph.edgeSet()) {
            PlanFragmentId planFragmentId = (PlanFragmentId) defaultDirectedGraph.getEdgeSource(e);
            PlanFragmentId planFragmentId2 = (PlanFragmentId) defaultDirectedGraph.getEdgeTarget(e);
            Set set2 = (Set) hashMap.get(planFragmentId);
            Set set3 = (Set) hashMap.get(planFragmentId2);
            if (!set2.equals(set3)) {
                defaultDirectedGraph2.addEdge(set2, set3);
            }
        }
        return ImmutableList.copyOf(new TopologicalOrderIterator(defaultDirectedGraph2));
    }
}
