package org.apache.flink.runtime.checkpoint;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ScheduledExecutorService;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.flink.core.testutils.FlinkAssertions;
import org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils;
import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutorServiceAdapter;
import org.apache.flink.runtime.execution.ExecutionState;
import org.apache.flink.runtime.executiongraph.DefaultExecutionGraph;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionGraph;
import org.apache.flink.runtime.executiongraph.ExecutionGraphCheckpointPlanCalculatorContext;
import org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.testtasks.NoOpInvokable;
import org.apache.flink.testutils.TestingUtils;
import org.apache.flink.testutils.executor.TestExecutorExtension;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

/* loaded from: input_file:org/apache/flink/runtime/checkpoint/DefaultCheckpointPlanCalculatorTest.class */
class DefaultCheckpointPlanCalculatorTest {

    @RegisterExtension
    static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_EXTENSION = TestingUtils.defaultExecutorExtension();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/runtime/checkpoint/DefaultCheckpointPlanCalculatorTest$EdgeDeclaration.class */
    public static class EdgeDeclaration {
        final int source;
        final int target;
        final DistributionPattern distributionPattern;

        public EdgeDeclaration(int i, int i2, DistributionPattern distributionPattern) {
            this.source = i;
            this.target = i2;
            this.distributionPattern = distributionPattern;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/runtime/checkpoint/DefaultCheckpointPlanCalculatorTest$TaskDeclaration.class */
    public static class TaskDeclaration {
        final int vertexIndex;
        final Set<Integer> subtaskIndices;

        public TaskDeclaration(int i, Set<Integer> set) {
            this.vertexIndex = i;
            this.subtaskIndices = set;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/runtime/checkpoint/DefaultCheckpointPlanCalculatorTest$VertexDeclaration.class */
    public static class VertexDeclaration {
        final int parallelism;
        final Set<Integer> finishedSubtaskIndices;

        public VertexDeclaration(int i, Set<Integer> set) {
            this.parallelism = i;
            this.finishedSubtaskIndices = set;
        }
    }

    DefaultCheckpointPlanCalculatorTest() {
    }

    @Test
    void testComputeAllRunningGraph() throws Exception {
        runSingleTest(Arrays.asList(new VertexDeclaration(3, Collections.emptySet()), new VertexDeclaration(4, Collections.emptySet()), new VertexDeclaration(5, Collections.emptySet()), new VertexDeclaration(6, Collections.emptySet())), Arrays.asList(new EdgeDeclaration(0, 2, DistributionPattern.ALL_TO_ALL), new EdgeDeclaration(1, 2, DistributionPattern.POINTWISE), new EdgeDeclaration(2, 3, DistributionPattern.ALL_TO_ALL)), Arrays.asList(new TaskDeclaration(0, range(0, 3)), new TaskDeclaration(1, range(0, 4))));
    }

    @Test
    void testAllToAllEdgeWithSomeSourcesFinished() throws Exception {
        runSingleTest(Arrays.asList(new VertexDeclaration(3, range(0, 2)), new VertexDeclaration(4, Collections.emptySet())), Collections.singletonList(new EdgeDeclaration(0, 1, DistributionPattern.ALL_TO_ALL)), Collections.singletonList(new TaskDeclaration(0, range(2, 3))));
    }

    @Test
    void testOneToOneEdgeWithSomeSourcesFinished() throws Exception {
        runSingleTest(Arrays.asList(new VertexDeclaration(4, range(0, 2)), new VertexDeclaration(4, Collections.emptySet())), Collections.singletonList(new EdgeDeclaration(0, 1, DistributionPattern.POINTWISE)), Arrays.asList(new TaskDeclaration(0, range(2, 4)), new TaskDeclaration(1, range(0, 2))));
    }

    @Test
    void testOneToOnEdgeWithSomeSourcesAndTargetsFinished() throws Exception {
        runSingleTest(Arrays.asList(new VertexDeclaration(4, range(0, 2)), new VertexDeclaration(4, of(0))), Collections.singletonList(new EdgeDeclaration(0, 1, DistributionPattern.POINTWISE)), Arrays.asList(new TaskDeclaration(0, range(2, 4)), new TaskDeclaration(1, range(1, 2))));
    }

    @Test
    void testComputeWithMultipleInputs() throws Exception {
        runSingleTest(Arrays.asList(new VertexDeclaration(3, range(0, 3)), new VertexDeclaration(5, of(0, 2, 3)), new VertexDeclaration(5, of(2, 4)), new VertexDeclaration(5, of(2))), Arrays.asList(new EdgeDeclaration(0, 3, DistributionPattern.ALL_TO_ALL), new EdgeDeclaration(1, 3, DistributionPattern.POINTWISE), new EdgeDeclaration(2, 3, DistributionPattern.POINTWISE)), Arrays.asList(new TaskDeclaration(1, of(1, 4)), new TaskDeclaration(2, of(0, 1, 3))));
    }

    @Test
    void testComputeWithMultipleLevels() throws Exception {
        runSingleTest(Arrays.asList(new VertexDeclaration(16, range(0, 4)), new VertexDeclaration(16, range(0, 16)), new VertexDeclaration(16, range(0, 2)), new VertexDeclaration(16, Collections.emptySet()), new VertexDeclaration(16, Collections.emptySet())), Arrays.asList(new EdgeDeclaration(0, 2, DistributionPattern.POINTWISE), new EdgeDeclaration(0, 3, DistributionPattern.POINTWISE), new EdgeDeclaration(1, 2, DistributionPattern.ALL_TO_ALL), new EdgeDeclaration(1, 3, DistributionPattern.POINTWISE), new EdgeDeclaration(2, 4, DistributionPattern.POINTWISE), new EdgeDeclaration(3, 4, DistributionPattern.ALL_TO_ALL)), Arrays.asList(new TaskDeclaration(0, range(4, 16)), new TaskDeclaration(2, range(2, 4)), new TaskDeclaration(3, range(0, 4))));
    }

    @Test
    void testPlanCalculationWhenOneTaskNotRunning() throws Exception {
        runWithNotRunningTask(true, true);
        runWithNotRunningTask(true, false);
        runWithNotRunningTask(false, false);
        runWithNotRunningTask(false, true);
    }

    private void runWithNotRunningTask(boolean z, boolean z2) throws Exception {
        Iterator it = EnumSet.complementOf(EnumSet.of(ExecutionState.RUNNING)).iterator();
        while (it.hasNext()) {
            ExecutionState executionState = (ExecutionState) it.next();
            JobVertexID jobVertexID = new JobVertexID();
            JobVertexID jobVertexID2 = new JobVertexID();
            ExecutionGraph build = new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder().addJobVertex(jobVertexID, z).addJobVertex(jobVertexID2, z2).setTransitToRunning(false).build((ScheduledExecutorService) EXECUTOR_EXTENSION.getExecutor());
            transitVertexToState(build, jobVertexID, ExecutionState.RUNNING);
            transitVertexToState(build, jobVertexID2, executionState);
            FlinkAssertions.assertThatFuture(createCheckpointPlanCalculator(build).calculateCheckpointPlan()).withFailMessage("The computation should fail since some tasks to trigger are in %s state", new Object[]{executionState}).eventuallyFailsWith(ExecutionException.class).havingCause().isInstanceOfSatisfying(CheckpointException.class, checkpointException -> {
                Assertions.assertThat(checkpointException.getCheckpointFailureReason()).isEqualTo(CheckpointFailureReason.NOT_ALL_REQUIRED_TASKS_RUNNING);
            });
        }
    }

    private void transitVertexToState(ExecutionGraph executionGraph, JobVertexID jobVertexID, ExecutionState executionState) {
        ((ExecutionVertex) Arrays.stream(executionGraph.getJobVertex(jobVertexID).getTaskVertices()).filter(executionVertex -> {
            return executionVertex.getJobvertexId().equals(jobVertexID);
        }).findFirst().get()).getCurrentExecutionAttempt().transitionState(executionState);
    }

    private void runSingleTest(List<VertexDeclaration> list, List<EdgeDeclaration> list2, List<TaskDeclaration> list3) throws Exception {
        runSingleTest(list, list2, list3, (List) IntStream.range(0, list.size()).mapToObj(i -> {
            return new TaskDeclaration(i, ((VertexDeclaration) list.get(i)).finishedSubtaskIndices);
        }).collect(Collectors.toList()));
    }

    private void runSingleTest(List<VertexDeclaration> list, List<EdgeDeclaration> list2, List<TaskDeclaration> list3, List<TaskDeclaration> list4) throws Exception {
        ExecutionGraph createExecutionGraph = createExecutionGraph(list, list2);
        DefaultCheckpointPlanCalculator createCheckpointPlanCalculator = createCheckpointPlanCalculator(createExecutionGraph);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        list4.forEach(taskDeclaration -> {
            ExecutionJobVertex chooseJobVertex = chooseJobVertex(createExecutionGraph, taskDeclaration.vertexIndex);
            arrayList.add(new TaskDeclaration(taskDeclaration.vertexIndex, minus(range(0, chooseJobVertex.getParallelism()), taskDeclaration.subtaskIndices)));
            if (taskDeclaration.subtaskIndices.size() == chooseJobVertex.getParallelism()) {
                arrayList2.add(chooseJobVertex);
            }
        });
        checkCheckpointPlan(chooseTasks(createExecutionGraph, (TaskDeclaration[]) list3.toArray(new TaskDeclaration[0])), chooseTasks(createExecutionGraph, (TaskDeclaration[]) arrayList.toArray(new TaskDeclaration[0])), (List) chooseTasks(createExecutionGraph, (TaskDeclaration[]) list4.toArray(new TaskDeclaration[0])).stream().map((v0) -> {
            return v0.getCurrentExecutionAttempt();
        }).collect(Collectors.toList()), arrayList2, (CheckpointPlan) createCheckpointPlanCalculator.calculateCheckpointPlan().get());
    }

    private ExecutionGraph createExecutionGraph(List<VertexDeclaration> list, List<EdgeDeclaration> list2) throws Exception {
        JobVertex[] jobVertexArr = new JobVertex[list.size()];
        for (int i = 0; i < list.size(); i++) {
            jobVertexArr[i] = ExecutionGraphTestUtils.createJobVertex(vertexName(i), list.get(i).parallelism, NoOpInvokable.class);
        }
        for (EdgeDeclaration edgeDeclaration : list2) {
            jobVertexArr[edgeDeclaration.target].connectNewDataSetAsInput(jobVertexArr[edgeDeclaration.source], edgeDeclaration.distributionPattern, ResultPartitionType.PIPELINED);
        }
        DefaultExecutionGraph createExecutionGraph = ExecutionGraphTestUtils.createExecutionGraph((ScheduledExecutorService) EXECUTOR_EXTENSION.getExecutor(), jobVertexArr);
        createExecutionGraph.start(ComponentMainThreadExecutorServiceAdapter.forMainThread());
        createExecutionGraph.transitionToRunning();
        createExecutionGraph.getAllExecutionVertices().forEach(executionVertex -> {
            executionVertex.getCurrentExecutionAttempt().transitionState(ExecutionState.RUNNING);
        });
        for (int i2 = 0; i2 < list.size(); i2++) {
            JobVertexID id = jobVertexArr[i2].getID();
            list.get(i2).finishedSubtaskIndices.forEach(num -> {
                createExecutionGraph.getJobVertex(id).getTaskVertices()[num.intValue()].getCurrentExecutionAttempt().markFinished();
            });
        }
        return createExecutionGraph;
    }

    private DefaultCheckpointPlanCalculator createCheckpointPlanCalculator(ExecutionGraph executionGraph) {
        return new DefaultCheckpointPlanCalculator(executionGraph.getJobID(), new ExecutionGraphCheckpointPlanCalculatorContext(executionGraph), executionGraph.getVerticesTopologically(), true);
    }

    private void checkCheckpointPlan(List<ExecutionVertex> list, List<ExecutionVertex> list2, List<Execution> list3, List<ExecutionJobVertex> list4, CheckpointPlan checkpointPlan) {
        assertSameInstancesWithoutOrder("The computed tasks to trigger is different from expected", (List) list.stream().map((v0) -> {
            return v0.getCurrentExecutionAttempt();
        }).collect(Collectors.toList()), checkpointPlan.getTasksToTrigger());
        assertSameInstancesWithoutOrder("The computed running tasks is different from expected", list2, checkpointPlan.getTasksToCommitTo());
        assertSameInstancesWithoutOrder("The computed finished tasks is different from expected", list3, checkpointPlan.getFinishedTasks());
        assertSameInstancesWithoutOrder("The computed fully finished JobVertex is different from expected", list4, checkpointPlan.getFullyFinishedJobVertex());
        assertSameInstancesWithoutOrder("The computed tasks to ack is different from expected", (Collection) list2.stream().map((v0) -> {
            return v0.getCurrentExecutionAttempt();
        }).collect(Collectors.toList()), checkpointPlan.getTasksToWaitFor());
    }

    private <T> void assertSameInstancesWithoutOrder(String str, Collection<T> collection, Collection<T> collection2) {
        Assertions.assertThat(collection).as(str, new Object[0]).containsExactlyInAnyOrderElementsOf(collection2);
    }

    private List<ExecutionVertex> chooseTasks(ExecutionGraph executionGraph, TaskDeclaration... taskDeclarationArr) {
        ArrayList arrayList = new ArrayList();
        for (TaskDeclaration taskDeclaration : taskDeclarationArr) {
            ExecutionJobVertex chooseJobVertex = chooseJobVertex(executionGraph, taskDeclaration.vertexIndex);
            taskDeclaration.subtaskIndices.forEach(num -> {
                arrayList.add(chooseJobVertex.getTaskVertices()[num.intValue()]);
            });
        }
        return arrayList;
    }

    private ExecutionJobVertex chooseJobVertex(ExecutionGraph executionGraph, int i) {
        String vertexName = vertexName(i);
        Optional findFirst = executionGraph.getAllVertices().values().stream().filter(executionJobVertex -> {
            return executionJobVertex.getName().equals(vertexName);
        }).findFirst();
        if (findFirst.isPresent()) {
            return (ExecutionJobVertex) findFirst.get();
        }
        throw new RuntimeException("Vertex not found with index " + i);
    }

    private String vertexName(int i) {
        return "vertex_" + i;
    }

    private Set<Integer> range(int i, int i2) {
        return (Set) IntStream.range(i, i2).boxed().collect(Collectors.toSet());
    }

    private Set<Integer> of(Integer... numArr) {
        return new HashSet(Arrays.asList(numArr));
    }

    private Set<Integer> minus(Set<Integer> set, Set<Integer> set2) {
        return (Set) set.stream().filter(num -> {
            return !set2.contains(num);
        }).collect(Collectors.toSet());
    }
}
