/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.scheduler.adapter;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.flink.runtime.executiongraph.DefaultExecutionGraph;
import org.apache.flink.runtime.executiongraph.EdgeManager;
import org.apache.flink.runtime.executiongraph.ExecutionGraph;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.executiongraph.IntermediateResultPartition;
import org.apache.flink.runtime.executiongraph.failover.SchedulingPipelinedRegionComputeUtil;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.IntermediateDataSet;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.jobgraph.JobEdge;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.topology.DefaultLogicalPipelinedRegion;
import org.apache.flink.runtime.jobgraph.topology.DefaultLogicalTopology;
import org.apache.flink.runtime.jobgraph.topology.LogicalEdge;
import org.apache.flink.runtime.jobgraph.topology.LogicalVertex;
import org.apache.flink.runtime.jobmanager.scheduler.CoLocationConstraint;
import org.apache.flink.runtime.jobmanager.scheduler.CoLocationGroup;
import org.apache.flink.runtime.scheduler.SchedulingTopologyListener;
import org.apache.flink.runtime.scheduler.adapter.DefaultExecutionVertex;
import org.apache.flink.runtime.scheduler.adapter.DefaultResultPartition;
import org.apache.flink.runtime.scheduler.adapter.DefaultSchedulingPipelinedRegion;
import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
import org.apache.flink.runtime.scheduler.strategy.ConsumerVertexGroup;
import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
import org.apache.flink.runtime.scheduler.strategy.ResultPartitionState;
import org.apache.flink.runtime.scheduler.strategy.SchedulingExecutionVertex;
import org.apache.flink.runtime.scheduler.strategy.SchedulingTopology;
import org.apache.flink.util.IterableUtils;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DefaultExecutionTopology
implements SchedulingTopology {
    private static final Logger LOG = LoggerFactory.getLogger(DefaultExecutionTopology.class);
    private final Map<ExecutionVertexID, DefaultExecutionVertex> executionVerticesById;
    private final List<DefaultExecutionVertex> executionVerticesList;
    private final Map<IntermediateResultPartitionID, DefaultResultPartition> resultPartitionsById;
    private final Map<ExecutionVertexID, DefaultSchedulingPipelinedRegion> pipelinedRegionsByVertex;
    private final List<DefaultSchedulingPipelinedRegion> pipelinedRegions;
    private final EdgeManager edgeManager;
    private final Supplier<List<ExecutionVertexID>> sortedExecutionVertexIds;
    private Map<JobVertexID, DefaultLogicalPipelinedRegion> logicalPipelinedRegionsByJobVertexId;
    private final List<SchedulingTopologyListener> schedulingTopologyListeners = new ArrayList<SchedulingTopologyListener>();

    private DefaultExecutionTopology(Supplier<List<ExecutionVertexID>> sortedExecutionVertexIds, EdgeManager edgeManager, Map<JobVertexID, DefaultLogicalPipelinedRegion> logicalPipelinedRegionsByJobVertexId) {
        this.sortedExecutionVertexIds = Preconditions.checkNotNull(sortedExecutionVertexIds);
        this.edgeManager = Preconditions.checkNotNull(edgeManager);
        this.logicalPipelinedRegionsByJobVertexId = Preconditions.checkNotNull(logicalPipelinedRegionsByJobVertexId);
        this.executionVerticesById = new HashMap<ExecutionVertexID, DefaultExecutionVertex>();
        this.executionVerticesList = new ArrayList<DefaultExecutionVertex>();
        this.resultPartitionsById = new HashMap<IntermediateResultPartitionID, DefaultResultPartition>();
        this.pipelinedRegionsByVertex = new HashMap<ExecutionVertexID, DefaultSchedulingPipelinedRegion>();
        this.pipelinedRegions = new ArrayList<DefaultSchedulingPipelinedRegion>();
    }

    @Override
    public Iterable<DefaultExecutionVertex> getVertices() {
        return Collections.unmodifiableList(this.executionVerticesList);
    }

    @Override
    public DefaultExecutionVertex getVertex(ExecutionVertexID executionVertexId) {
        DefaultExecutionVertex executionVertex = this.executionVerticesById.get(executionVertexId);
        if (executionVertex == null) {
            throw new IllegalArgumentException("can not find vertex: " + String.valueOf(executionVertexId));
        }
        return executionVertex;
    }

    @Override
    public DefaultResultPartition getResultPartition(IntermediateResultPartitionID intermediateResultPartitionId) {
        DefaultResultPartition resultPartition = this.resultPartitionsById.get(intermediateResultPartitionId);
        if (resultPartition == null) {
            throw new IllegalArgumentException("can not find partition: " + String.valueOf(intermediateResultPartitionId));
        }
        return resultPartition;
    }

    @Override
    public void registerSchedulingTopologyListener(SchedulingTopologyListener listener) {
        Preconditions.checkNotNull(listener);
        this.schedulingTopologyListeners.add(listener);
    }

    @Override
    public Iterable<DefaultSchedulingPipelinedRegion> getAllPipelinedRegions() {
        Preconditions.checkNotNull(this.pipelinedRegions);
        return Collections.unmodifiableCollection(this.pipelinedRegions);
    }

    @Override
    public DefaultSchedulingPipelinedRegion getPipelinedRegionOfVertex(ExecutionVertexID vertexId) {
        Preconditions.checkNotNull(this.pipelinedRegionsByVertex);
        DefaultSchedulingPipelinedRegion pipelinedRegion = this.pipelinedRegionsByVertex.get(vertexId);
        if (pipelinedRegion == null) {
            throw new IllegalArgumentException("Unknown execution vertex " + String.valueOf(vertexId));
        }
        return pipelinedRegion;
    }

    public EdgeManager getEdgeManager() {
        return this.edgeManager;
    }

    public static Map<JobVertexID, DefaultLogicalPipelinedRegion> computeLogicalPipelinedRegionsByJobVertexId(List<JobVertex> topologicallySortedJobVertices) {
        Iterable<DefaultLogicalPipelinedRegion> logicalPipelinedRegions = DefaultLogicalTopology.fromTopologicallySortedJobVertices(topologicallySortedJobVertices).getAllPipelinedRegions();
        HashMap<JobVertexID, DefaultLogicalPipelinedRegion> logicalPipelinedRegionsByJobVertexId = new HashMap<JobVertexID, DefaultLogicalPipelinedRegion>();
        for (DefaultLogicalPipelinedRegion logicalPipelinedRegion : logicalPipelinedRegions) {
            for (LogicalVertex logicalVertex : logicalPipelinedRegion.getVertices()) {
                logicalPipelinedRegionsByJobVertexId.put((JobVertexID)logicalVertex.getId(), logicalPipelinedRegion);
            }
        }
        return logicalPipelinedRegionsByJobVertexId;
    }

    public void notifyExecutionGraphUpdatedWithNewJobVertices(List<JobVertex> topologicallySortedJobVertices) {
        this.logicalPipelinedRegionsByJobVertexId = DefaultExecutionTopology.computeLogicalPipelinedRegionsByJobVertexId(topologicallySortedJobVertices);
    }

    public void notifyExecutionGraphUpdatedWithInitializedJobVertices(DefaultExecutionGraph executionGraph, List<ExecutionJobVertex> newlyInitializedJobVertices) {
        Preconditions.checkNotNull(executionGraph, "execution graph can not be null");
        Set newJobVertexIds = newlyInitializedJobVertices.stream().map(ExecutionJobVertex::getJobVertexId).collect(Collectors.toSet());
        newlyInitializedJobVertices.stream().map(ExecutionJobVertex::getJobVertex).flatMap(v -> v.getInputs().stream()).map(JobEdge::getSource).filter(r -> r.getResultType().mustBePipelinedConsumed()).map(IntermediateDataSet::getProducer).map(JobVertex::getID).forEach(id -> Preconditions.checkState(newJobVertexIds.contains(id)));
        Iterable newExecutionVertices = newlyInitializedJobVertices.stream().flatMap(jobVertex -> Stream.of(jobVertex.getTaskVertices())).collect(Collectors.toList());
        this.generateNewExecutionVerticesAndResultPartitions(newExecutionVertices);
        this.generateNewPipelinedRegions(newExecutionVertices);
        DefaultExecutionTopology.ensureCoLocatedVerticesInSameRegion(this.pipelinedRegions, executionGraph);
        this.notifySchedulingTopologyUpdated(newExecutionVertices);
    }

    private void notifySchedulingTopologyUpdated(Iterable<ExecutionVertex> newExecutionVertices) {
        List<ExecutionVertexID> newVertexIds = IterableUtils.toStream(newExecutionVertices).map(ExecutionVertex::getID).collect(Collectors.toList());
        for (SchedulingTopologyListener listener : this.schedulingTopologyListeners) {
            listener.notifySchedulingTopologyUpdated(this, newVertexIds);
        }
    }

    public static DefaultExecutionTopology fromExecutionGraph(DefaultExecutionGraph executionGraph) {
        Preconditions.checkNotNull(executionGraph, "execution graph can not be null");
        EdgeManager edgeManager = executionGraph.getEdgeManager();
        DefaultExecutionTopology schedulingTopology = new DefaultExecutionTopology(() -> IterableUtils.toStream(executionGraph.getAllExecutionVertices()).map(ExecutionVertex::getID).collect(Collectors.toList()), edgeManager, DefaultExecutionTopology.computeLogicalPipelinedRegionsByJobVertexId(IterableUtils.toStream(executionGraph.getVerticesTopologically()).map(ExecutionJobVertex::getJobVertex).collect(Collectors.toList())));
        schedulingTopology.notifyExecutionGraphUpdatedWithInitializedJobVertices(executionGraph, IterableUtils.toStream(executionGraph.getVerticesTopologically()).filter(ExecutionJobVertex::isInitialized).collect(Collectors.toList()));
        return schedulingTopology;
    }

    private void generateNewExecutionVerticesAndResultPartitions(Iterable<ExecutionVertex> newExecutionVertices) {
        for (ExecutionVertex vertex : newExecutionVertices) {
            List<DefaultResultPartition> producedPartitions = DefaultExecutionTopology.generateProducedSchedulingResultPartition(vertex.getProducedPartitions(), this.edgeManager::getConsumerVertexGroupsForPartition);
            producedPartitions.forEach(partition -> this.resultPartitionsById.put(partition.getId(), (DefaultResultPartition)partition));
            DefaultExecutionVertex schedulingVertex = DefaultExecutionTopology.generateSchedulingExecutionVertex(vertex, producedPartitions, this.edgeManager.getConsumedPartitionGroupsForVertex(vertex.getID()), this.resultPartitionsById::get);
            this.executionVerticesById.put(schedulingVertex.getId(), schedulingVertex);
        }
        this.executionVerticesList.clear();
        for (ExecutionVertexID vertexID : this.sortedExecutionVertexIds.get()) {
            this.executionVerticesList.add(this.executionVerticesById.get(vertexID));
        }
    }

    private static List<DefaultResultPartition> generateProducedSchedulingResultPartition(Map<IntermediateResultPartitionID, IntermediateResultPartition> producedIntermediatePartitions, Function<IntermediateResultPartitionID, List<ConsumerVertexGroup>> partitionConsumerVertexGroupsRetriever) {
        ArrayList<DefaultResultPartition> producedSchedulingPartitions = new ArrayList<DefaultResultPartition>(producedIntermediatePartitions.size());
        producedIntermediatePartitions.values().forEach(irp -> producedSchedulingPartitions.add(new DefaultResultPartition(irp.getPartitionId(), irp.getIntermediateResult().getId(), irp.getResultType(), () -> irp.hasDataAllProduced() ? ResultPartitionState.ALL_DATA_PRODUCED : ResultPartitionState.CREATED, () -> (List)partitionConsumerVertexGroupsRetriever.apply(irp.getPartitionId()), irp::getConsumedPartitionGroups)));
        return producedSchedulingPartitions;
    }

    private static DefaultExecutionVertex generateSchedulingExecutionVertex(ExecutionVertex vertex, List<DefaultResultPartition> producedPartitions, List<ConsumedPartitionGroup> consumedPartitionGroups, Function<IntermediateResultPartitionID, DefaultResultPartition> resultPartitionRetriever) {
        DefaultExecutionVertex schedulingVertex = new DefaultExecutionVertex(vertex.getID(), producedPartitions, vertex::getExecutionState, consumedPartitionGroups, resultPartitionRetriever);
        producedPartitions.forEach(partition -> partition.setProducer(schedulingVertex));
        return schedulingVertex;
    }

    private void generateNewPipelinedRegions(Iterable<ExecutionVertex> newExecutionVertices) {
        Iterable newSchedulingExecutionVertices = IterableUtils.toStream(newExecutionVertices).map(ExecutionVertex::getID).map(this.executionVerticesById::get).collect(Collectors.toList());
        IdentityHashMap<DefaultLogicalPipelinedRegion, List> sortedExecutionVerticesInPipelinedRegion = new IdentityHashMap<DefaultLogicalPipelinedRegion, List>();
        for (DefaultExecutionVertex schedulingVertex : newSchedulingExecutionVertices) {
            sortedExecutionVerticesInPipelinedRegion.computeIfAbsent(this.logicalPipelinedRegionsByJobVertexId.get(schedulingVertex.getId().getJobVertexId()), ignore -> new ArrayList()).add(schedulingVertex);
        }
        long buildRegionsStartTime = System.nanoTime();
        Set<Set> rawPipelinedRegions = Collections.newSetFromMap(new IdentityHashMap());
        for (Map.Entry entry : sortedExecutionVerticesInPipelinedRegion.entrySet()) {
            DefaultLogicalPipelinedRegion logicalPipelinedRegion = (DefaultLogicalPipelinedRegion)entry.getKey();
            List schedulingExecutionVertices = (List)entry.getValue();
            if (DefaultExecutionTopology.containsIntraRegionAllToAllEdge(logicalPipelinedRegion)) {
                rawPipelinedRegions.add(new HashSet(schedulingExecutionVertices));
                continue;
            }
            rawPipelinedRegions.addAll(SchedulingPipelinedRegionComputeUtil.computePipelinedRegions(schedulingExecutionVertices, this.executionVerticesById::get, this.resultPartitionsById::get));
        }
        for (Set rawPipelinedRegion : rawPipelinedRegions) {
            DefaultSchedulingPipelinedRegion pipelinedRegion = new DefaultSchedulingPipelinedRegion(rawPipelinedRegion, this.resultPartitionsById::get);
            this.pipelinedRegions.add(pipelinedRegion);
            for (SchedulingExecutionVertex executionVertex : rawPipelinedRegion) {
                this.pipelinedRegionsByVertex.put((ExecutionVertexID)executionVertex.getId(), pipelinedRegion);
            }
        }
        long buildRegionsDuration = (System.nanoTime() - buildRegionsStartTime) / 1000000L;
        LOG.info("Built {} new pipelined regions in {} ms, total {} pipelined regions currently.", new Object[]{rawPipelinedRegions.size(), buildRegionsDuration, this.pipelinedRegions.size()});
    }

    private static boolean containsIntraRegionAllToAllEdge(DefaultLogicalPipelinedRegion logicalPipelinedRegion) {
        for (LogicalVertex logicalVertex : logicalPipelinedRegion.getVertices()) {
            for (LogicalEdge logicalEdge : logicalVertex.getInputs()) {
                if (logicalEdge.getDistributionPattern() != DistributionPattern.ALL_TO_ALL || !logicalPipelinedRegion.contains(logicalEdge.getProducerVertexId())) continue;
                return true;
            }
        }
        return false;
    }

    private static void ensureCoLocatedVerticesInSameRegion(List<DefaultSchedulingPipelinedRegion> pipelinedRegions, ExecutionGraph executionGraph) {
        HashMap<CoLocationConstraint, DefaultSchedulingPipelinedRegion> constraintToRegion = new HashMap<CoLocationConstraint, DefaultSchedulingPipelinedRegion>();
        for (DefaultSchedulingPipelinedRegion region : pipelinedRegions) {
            for (DefaultExecutionVertex vertex : region.getVertices()) {
                CoLocationConstraint constraint = DefaultExecutionTopology.getCoLocationConstraint(vertex.getId(), executionGraph);
                if (constraint == null) continue;
                DefaultSchedulingPipelinedRegion regionOfConstraint = (DefaultSchedulingPipelinedRegion)constraintToRegion.get(constraint);
                Preconditions.checkState(regionOfConstraint == null || regionOfConstraint == region, "co-located tasks must be in the same pipelined region");
                constraintToRegion.putIfAbsent(constraint, region);
            }
        }
    }

    private static CoLocationConstraint getCoLocationConstraint(ExecutionVertexID executionVertexId, ExecutionGraph executionGraph) {
        CoLocationGroup coLocationGroup = Objects.requireNonNull(executionGraph.getJobVertex(executionVertexId.getJobVertexId())).getCoLocationGroup();
        return coLocationGroup == null ? null : coLocationGroup.getLocationConstraint(executionVertexId.getSubtaskIndex());
    }
}

