package org.apache.flink.runtime.checkpoint;

import java.util.Collections;
import java.util.Random;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.state.InputChannelStateHandle;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;

/* loaded from: input_file:org/apache/flink/runtime/checkpoint/TaskStateSnapshotTest.class */
class TaskStateSnapshotTest {
    TaskStateSnapshotTest() {
    }

    @Test
    void putGetSubtaskStateByOperatorID() {
        TaskStateSnapshot taskStateSnapshot = new TaskStateSnapshot();
        OperatorID operatorID = new OperatorID();
        OperatorID operatorID2 = new OperatorID();
        OperatorSubtaskState build = OperatorSubtaskState.builder().build();
        OperatorSubtaskState build2 = OperatorSubtaskState.builder().build();
        OperatorSubtaskState build3 = OperatorSubtaskState.builder().build();
        Assertions.assertThat(taskStateSnapshot.getSubtaskStateByOperatorID(operatorID)).isNull();
        Assertions.assertThat(taskStateSnapshot.getSubtaskStateByOperatorID(operatorID2)).isNull();
        taskStateSnapshot.putSubtaskStateByOperatorID(operatorID, build);
        taskStateSnapshot.putSubtaskStateByOperatorID(operatorID2, build2);
        Assertions.assertThat(taskStateSnapshot.getSubtaskStateByOperatorID(operatorID)).isEqualTo(build);
        Assertions.assertThat(taskStateSnapshot.getSubtaskStateByOperatorID(operatorID2)).isEqualTo(build2);
        Assertions.assertThat(taskStateSnapshot.putSubtaskStateByOperatorID(operatorID, build3)).isEqualTo(build);
        Assertions.assertThat(taskStateSnapshot.getSubtaskStateByOperatorID(operatorID)).isEqualTo(build3);
    }

    @Test
    void hasState() {
        Random random = new Random(66L);
        TaskStateSnapshot taskStateSnapshot = new TaskStateSnapshot();
        Assertions.assertThat(taskStateSnapshot.hasState()).isFalse();
        OperatorSubtaskState build = OperatorSubtaskState.builder().build();
        Assertions.assertThat(build.hasState()).isFalse();
        taskStateSnapshot.putSubtaskStateByOperatorID(new OperatorID(), build);
        Assertions.assertThat(taskStateSnapshot.hasState()).isFalse();
        OperatorSubtaskState build2 = OperatorSubtaskState.builder().setManagedOperatorState(StateHandleDummyUtil.createNewOperatorStateHandle(2, random)).build();
        Assertions.assertThat(build2.hasState()).isTrue();
        taskStateSnapshot.putSubtaskStateByOperatorID(new OperatorID(), build2);
        Assertions.assertThat(taskStateSnapshot.hasState()).isTrue();
    }

    @Test
    void discardState() throws Exception {
        TaskStateSnapshot taskStateSnapshot = new TaskStateSnapshot();
        OperatorID operatorID = new OperatorID();
        OperatorID operatorID2 = new OperatorID();
        OperatorSubtaskState operatorSubtaskState = (OperatorSubtaskState) Mockito.mock(OperatorSubtaskState.class);
        OperatorSubtaskState operatorSubtaskState2 = (OperatorSubtaskState) Mockito.mock(OperatorSubtaskState.class);
        taskStateSnapshot.putSubtaskStateByOperatorID(operatorID, operatorSubtaskState);
        taskStateSnapshot.putSubtaskStateByOperatorID(operatorID2, operatorSubtaskState2);
        taskStateSnapshot.discardState();
        ((OperatorSubtaskState) Mockito.verify(operatorSubtaskState)).discardState();
        ((OperatorSubtaskState) Mockito.verify(operatorSubtaskState2)).discardState();
    }

    @Test
    void getStateSize() {
        Random random = new Random(66L);
        TaskStateSnapshot taskStateSnapshot = new TaskStateSnapshot();
        Assertions.assertThat(taskStateSnapshot.getStateSize()).isZero();
        OperatorSubtaskState build = OperatorSubtaskState.builder().build();
        Assertions.assertThat(build.hasState()).isFalse();
        taskStateSnapshot.putSubtaskStateByOperatorID(new OperatorID(), build);
        Assertions.assertThat(taskStateSnapshot.getStateSize()).isZero();
        OperatorStateHandle createNewOperatorStateHandle = StateHandleDummyUtil.createNewOperatorStateHandle(2, random);
        OperatorSubtaskState build2 = OperatorSubtaskState.builder().setManagedOperatorState(createNewOperatorStateHandle).build();
        OperatorStateHandle createNewOperatorStateHandle2 = StateHandleDummyUtil.createNewOperatorStateHandle(2, random);
        OperatorSubtaskState build3 = OperatorSubtaskState.builder().setRawOperatorState(createNewOperatorStateHandle2).build();
        taskStateSnapshot.putSubtaskStateByOperatorID(new OperatorID(), build2);
        taskStateSnapshot.putSubtaskStateByOperatorID(new OperatorID(), build3);
        Assertions.assertThat(taskStateSnapshot.getStateSize()).isEqualTo(createNewOperatorStateHandle.getStateSize() + createNewOperatorStateHandle2.getStateSize());
    }

    @Test
    void testSizeIncludesChannelState() {
        Random random = new Random();
        InputChannelStateHandle createNewInputChannelStateHandle = StateHandleDummyUtil.createNewInputChannelStateHandle(10, random);
        ResultSubpartitionStateHandle createNewResultSubpartitionStateHandle = StateHandleDummyUtil.createNewResultSubpartitionStateHandle(10, random);
        TaskStateSnapshot taskStateSnapshot = new TaskStateSnapshot(Collections.singletonMap(new OperatorID(), OperatorSubtaskState.builder().setInputChannelState(StateObjectCollection.singleton(createNewInputChannelStateHandle)).setResultSubpartitionState(StateObjectCollection.singleton(createNewResultSubpartitionStateHandle)).build()));
        Assertions.assertThat(taskStateSnapshot.getStateSize()).isEqualTo(createNewInputChannelStateHandle.getStateSize() + createNewResultSubpartitionStateHandle.getStateSize());
        Assertions.assertThat(taskStateSnapshot.hasState()).isTrue();
    }
}
