package org.apache.flink.runtime.executiongraph.failover;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.runtime.execution.ExecutionState;
import org.apache.flink.runtime.executiongraph.failover.FailoverStrategy;
import org.apache.flink.runtime.io.network.partition.PartitionException;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
import org.apache.flink.runtime.scheduler.strategy.ConsumerVertexGroup;
import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
import org.apache.flink.runtime.scheduler.strategy.SchedulingExecutionVertex;
import org.apache.flink.runtime.scheduler.strategy.SchedulingPipelinedRegion;
import org.apache.flink.runtime.scheduler.strategy.SchedulingResultPartition;
import org.apache.flink.runtime.scheduler.strategy.SchedulingTopology;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.IterableUtils;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/runtime/executiongraph/failover/RestartPipelinedRegionFailoverStrategy.class */
public class RestartPipelinedRegionFailoverStrategy implements FailoverStrategy {
    private final SchedulingTopology topology;
    private final RegionFailoverResultPartitionAvailabilityChecker resultPartitionAvailabilityChecker;

    /* loaded from: input_file:org/apache/flink/runtime/executiongraph/failover/RestartPipelinedRegionFailoverStrategy$Factory.class */
    public static class Factory implements FailoverStrategy.Factory {
        @Override // org.apache.flink.runtime.executiongraph.failover.FailoverStrategy.Factory
        public FailoverStrategy create(SchedulingTopology schedulingTopology, ResultPartitionAvailabilityChecker resultPartitionAvailabilityChecker) {
            return new RestartPipelinedRegionFailoverStrategy(schedulingTopology, resultPartitionAvailabilityChecker);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/runtime/executiongraph/failover/RestartPipelinedRegionFailoverStrategy$RegionFailoverResultPartitionAvailabilityChecker.class */
    public static class RegionFailoverResultPartitionAvailabilityChecker implements ResultPartitionAvailabilityChecker {
        private final ResultPartitionAvailabilityChecker resultPartitionAvailabilityChecker;
        private final HashSet<IntermediateResultPartitionID> failedPartitions = new HashSet<>();
        private final Function<IntermediateResultPartitionID, ResultPartitionType> resultPartitionTypeRetriever;

        RegionFailoverResultPartitionAvailabilityChecker(ResultPartitionAvailabilityChecker resultPartitionAvailabilityChecker, Function<IntermediateResultPartitionID, ResultPartitionType> function) {
            this.resultPartitionAvailabilityChecker = (ResultPartitionAvailabilityChecker) Preconditions.checkNotNull(resultPartitionAvailabilityChecker);
            this.resultPartitionTypeRetriever = (Function) Preconditions.checkNotNull(function);
        }

        @Override // org.apache.flink.runtime.executiongraph.failover.ResultPartitionAvailabilityChecker
        public boolean isAvailable(IntermediateResultPartitionID intermediateResultPartitionID) {
            return !this.failedPartitions.contains(intermediateResultPartitionID) && this.resultPartitionAvailabilityChecker.isAvailable(intermediateResultPartitionID) && isResultPartitionIsReConsumableOrPipelinedApproximate(intermediateResultPartitionID);
        }

        public void markResultPartitionFailed(IntermediateResultPartitionID intermediateResultPartitionID) {
            this.failedPartitions.add(intermediateResultPartitionID);
        }

        public void removeResultPartitionFromFailedState(IntermediateResultPartitionID intermediateResultPartitionID) {
            this.failedPartitions.remove(intermediateResultPartitionID);
        }

        private boolean isResultPartitionIsReConsumableOrPipelinedApproximate(IntermediateResultPartitionID intermediateResultPartitionID) {
            ResultPartitionType apply = this.resultPartitionTypeRetriever.apply(intermediateResultPartitionID);
            return apply.isReconsumable() || apply == ResultPartitionType.PIPELINED_APPROXIMATE;
        }
    }

    @VisibleForTesting
    public RestartPipelinedRegionFailoverStrategy(SchedulingTopology schedulingTopology) {
        this(schedulingTopology, intermediateResultPartitionID -> {
            return true;
        });
    }

    public RestartPipelinedRegionFailoverStrategy(SchedulingTopology schedulingTopology, ResultPartitionAvailabilityChecker resultPartitionAvailabilityChecker) {
        this.topology = (SchedulingTopology) Preconditions.checkNotNull(schedulingTopology);
        this.resultPartitionAvailabilityChecker = new RegionFailoverResultPartitionAvailabilityChecker(resultPartitionAvailabilityChecker, intermediateResultPartitionID -> {
            return schedulingTopology.getResultPartition(intermediateResultPartitionID).getResultType();
        });
    }

    @Override // org.apache.flink.runtime.executiongraph.failover.FailoverStrategy
    public Set<ExecutionVertexID> getTasksNeedingRestart(ExecutionVertexID executionVertexID, Throwable th) {
        SchedulingPipelinedRegion pipelinedRegionOfVertex = this.topology.getPipelinedRegionOfVertex(executionVertexID);
        if (pipelinedRegionOfVertex == null) {
            throw new IllegalStateException("Can not find the failover region for task " + executionVertexID, th);
        }
        Optional findThrowable = ExceptionUtils.findThrowable(th, PartitionException.class);
        if (findThrowable.isPresent()) {
            this.resultPartitionAvailabilityChecker.markResultPartitionFailed(((PartitionException) findThrowable.get()).getPartitionId().getPartitionId());
        }
        HashSet hashSet = new HashSet();
        Iterator<SchedulingPipelinedRegion> it = getRegionsToRestart(pipelinedRegionOfVertex).iterator();
        while (it.hasNext()) {
            for (SchedulingExecutionVertex schedulingExecutionVertex : it.next().getVertices()) {
                if (schedulingExecutionVertex.getState() != ExecutionState.CREATED) {
                    hashSet.add(schedulingExecutionVertex.getId());
                }
            }
        }
        if (findThrowable.isPresent()) {
            this.resultPartitionAvailabilityChecker.removeResultPartitionFromFailedState(((PartitionException) findThrowable.get()).getPartitionId().getPartitionId());
        }
        return hashSet;
    }

    private Set<SchedulingPipelinedRegion> getRegionsToRestart(SchedulingPipelinedRegion schedulingPipelinedRegion) {
        Set<SchedulingPipelinedRegion> newSetFromMap = Collections.newSetFromMap(new IdentityHashMap());
        Set newSetFromMap2 = Collections.newSetFromMap(new IdentityHashMap());
        Set<ConsumedPartitionGroup> newSetFromMap3 = Collections.newSetFromMap(new IdentityHashMap());
        Set<ConsumerVertexGroup> newSetFromMap4 = Collections.newSetFromMap(new IdentityHashMap());
        ArrayDeque arrayDeque = new ArrayDeque();
        newSetFromMap2.add(schedulingPipelinedRegion);
        arrayDeque.add(schedulingPipelinedRegion);
        while (!arrayDeque.isEmpty()) {
            SchedulingPipelinedRegion schedulingPipelinedRegion2 = (SchedulingPipelinedRegion) arrayDeque.poll();
            newSetFromMap.add(schedulingPipelinedRegion2);
            for (IntermediateResultPartitionID intermediateResultPartitionID : getConsumedPartitionsToVisit(schedulingPipelinedRegion2, newSetFromMap3)) {
                if (!this.resultPartitionAvailabilityChecker.isAvailable(intermediateResultPartitionID)) {
                    SchedulingPipelinedRegion pipelinedRegionOfVertex = this.topology.getPipelinedRegionOfVertex(this.topology.getResultPartition(intermediateResultPartitionID).getProducer().getId());
                    if (!newSetFromMap2.contains(pipelinedRegionOfVertex)) {
                        newSetFromMap2.add(pipelinedRegionOfVertex);
                        arrayDeque.add(pipelinedRegionOfVertex);
                    }
                }
            }
            Iterator<ExecutionVertexID> it = getConsumerVerticesToVisit(schedulingPipelinedRegion2, newSetFromMap4).iterator();
            while (it.hasNext()) {
                SchedulingPipelinedRegion pipelinedRegionOfVertex2 = this.topology.getPipelinedRegionOfVertex(it.next());
                if (!newSetFromMap2.contains(pipelinedRegionOfVertex2)) {
                    newSetFromMap2.add(pipelinedRegionOfVertex2);
                    arrayDeque.add(pipelinedRegionOfVertex2);
                }
            }
        }
        return newSetFromMap;
    }

    private Iterable<IntermediateResultPartitionID> getConsumedPartitionsToVisit(SchedulingPipelinedRegion schedulingPipelinedRegion, Set<ConsumedPartitionGroup> set) {
        ArrayList arrayList = new ArrayList();
        Iterator<? extends SchedulingExecutionVertex> it = schedulingPipelinedRegion.getVertices().iterator();
        while (it.hasNext()) {
            for (ConsumedPartitionGroup consumedPartitionGroup : it.next().getConsumedPartitionGroups()) {
                if (!set.contains(consumedPartitionGroup)) {
                    set.add(consumedPartitionGroup);
                    arrayList.add(consumedPartitionGroup);
                }
            }
        }
        return IterableUtils.flatMap(arrayList, Function.identity());
    }

    private Iterable<ExecutionVertexID> getConsumerVerticesToVisit(SchedulingPipelinedRegion schedulingPipelinedRegion, Set<ConsumerVertexGroup> set) {
        ArrayList arrayList = new ArrayList();
        Iterator<? extends SchedulingExecutionVertex> it = schedulingPipelinedRegion.getVertices().iterator();
        while (it.hasNext()) {
            Iterator<? extends SchedulingResultPartition> it2 = it.next().getProducedResults().iterator();
            while (it2.hasNext()) {
                for (ConsumerVertexGroup consumerVertexGroup : it2.next().getConsumerVertexGroups()) {
                    if (!set.contains(consumerVertexGroup)) {
                        set.add(consumerVertexGroup);
                        arrayList.add(consumerVertexGroup);
                    }
                }
            }
        }
        return IterableUtils.flatMap(arrayList, Function.identity());
    }

    @VisibleForTesting
    public SchedulingPipelinedRegion getFailoverRegion(ExecutionVertexID executionVertexID) {
        return this.topology.getPipelinedRegionOfVertex(executionVertexID);
    }
}
