/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.checkpoint;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.flink.annotation.Internal;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.OperatorIDPair;
import org.apache.flink.runtime.checkpoint.JobManagerTaskRestore;
import org.apache.flink.runtime.checkpoint.MappingBasedRepartitioner;
import org.apache.flink.runtime.checkpoint.OperatorState;
import org.apache.flink.runtime.checkpoint.OperatorStateRepartitioner;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.checkpoint.RescaleMappings;
import org.apache.flink.runtime.checkpoint.RoundRobinOperatorStateRepartitioner;
import org.apache.flink.runtime.checkpoint.StateObjectCollection;
import org.apache.flink.runtime.checkpoint.TaskStateAssignment;
import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.IntermediateResult;
import org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper;
import org.apache.flink.runtime.jobgraph.IntermediateDataSet;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.JobEdge;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.jobgraph.OperatorInstanceID;
import org.apache.flink.runtime.state.AbstractChannelStateHandle;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.StateObject;
import org.apache.flink.util.CollectionUtil;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Internal
public class StateAssignmentOperation {
    private static final Logger LOG = LoggerFactory.getLogger(StateAssignmentOperation.class);
    private final Set<ExecutionJobVertex> tasks;
    private final Map<OperatorID, OperatorState> operatorStates;
    private final long restoreCheckpointId;
    private final boolean allowNonRestoredState;
    private final Map<ExecutionJobVertex, TaskStateAssignment> vertexAssignments;
    private final Map<IntermediateDataSetID, TaskStateAssignment> consumerAssignment = new HashMap<IntermediateDataSetID, TaskStateAssignment>();

    public StateAssignmentOperation(long restoreCheckpointId, Set<ExecutionJobVertex> tasks, Map<OperatorID, OperatorState> operatorStates, boolean allowNonRestoredState) {
        this.restoreCheckpointId = restoreCheckpointId;
        this.tasks = Preconditions.checkNotNull(tasks);
        this.operatorStates = Preconditions.checkNotNull(operatorStates);
        this.allowNonRestoredState = allowNonRestoredState;
        this.vertexAssignments = CollectionUtil.newHashMapWithExpectedSize(tasks.size());
    }

    public void assignStates() {
        StateAssignmentOperation.checkStateMappingCompleteness(this.allowNonRestoredState, this.operatorStates, this.tasks);
        HashMap<OperatorID, OperatorState> localOperators = new HashMap<OperatorID, OperatorState>(this.operatorStates);
        for (ExecutionJobVertex executionJobVertex : this.tasks) {
            List<OperatorIDPair> operatorIDPairs = executionJobVertex.getOperatorIDs();
            HashMap<OperatorID, OperatorState> operatorStates = CollectionUtil.newHashMapWithExpectedSize(operatorIDPairs.size());
            for (OperatorIDPair operatorIDPair : operatorIDPairs) {
                OperatorID operatorID = operatorIDPair.getUserDefinedOperatorID().filter(localOperators::containsKey).orElse(operatorIDPair.getGeneratedOperatorID());
                OperatorState operatorState = (OperatorState)localOperators.remove(operatorID);
                if (operatorState == null) {
                    operatorState = new OperatorState(operatorID, executionJobVertex.getParallelism(), executionJobVertex.getMaxParallelism());
                }
                operatorStates.put(operatorIDPair.getGeneratedOperatorID(), operatorState);
            }
            TaskStateAssignment stateAssignment = new TaskStateAssignment(executionJobVertex, operatorStates, this.consumerAssignment, this.vertexAssignments);
            this.vertexAssignments.put(executionJobVertex, stateAssignment);
            for (IntermediateResult producedDataSet : executionJobVertex.getInputs()) {
                this.consumerAssignment.put(producedDataSet.getId(), stateAssignment);
            }
        }
        for (TaskStateAssignment stateAssignment : this.vertexAssignments.values()) {
            if (!stateAssignment.hasNonFinishedState && !stateAssignment.hasUpstreamOutputStates() && !stateAssignment.hasDownstreamInputStates()) continue;
            this.assignAttemptState(stateAssignment);
        }
        for (TaskStateAssignment stateAssignment : this.vertexAssignments.values()) {
            if (!stateAssignment.hasNonFinishedState && !stateAssignment.isFullyFinished && !stateAssignment.hasUpstreamOutputStates() && !stateAssignment.hasDownstreamInputStates()) continue;
            this.assignTaskStateToExecutionJobVertices(stateAssignment);
        }
    }

    private void assignAttemptState(TaskStateAssignment taskStateAssignment) {
        this.checkParallelismPreconditions(taskStateAssignment);
        List<KeyGroupRange> keyGroupPartitions = StateAssignmentOperation.createKeyGroupPartitions(taskStateAssignment.executionJobVertex.getMaxParallelism(), taskStateAssignment.newParallelism);
        StateAssignmentOperation.reDistributePartitionableStates(taskStateAssignment.oldState, taskStateAssignment.newParallelism, OperatorSubtaskState::getManagedOperatorState, RoundRobinOperatorStateRepartitioner.INSTANCE, taskStateAssignment.subManagedOperatorState);
        StateAssignmentOperation.reDistributePartitionableStates(taskStateAssignment.oldState, taskStateAssignment.newParallelism, OperatorSubtaskState::getRawOperatorState, RoundRobinOperatorStateRepartitioner.INSTANCE, taskStateAssignment.subRawOperatorState);
        this.reDistributeInputChannelStates(taskStateAssignment);
        this.reDistributeResultSubpartitionStates(taskStateAssignment);
        this.reDistributeKeyedStates(keyGroupPartitions, taskStateAssignment);
    }

    private void assignTaskStateToExecutionJobVertices(TaskStateAssignment assignment) {
        ExecutionJobVertex executionJobVertex = assignment.executionJobVertex;
        List<OperatorIDPair> operatorIDs = executionJobVertex.getOperatorIDs();
        int newParallelism = executionJobVertex.getParallelism();
        for (int subTaskIndex = 0; subTaskIndex < newParallelism; ++subTaskIndex) {
            Execution currentExecutionAttempt = executionJobVertex.getTaskVertices()[subTaskIndex].getCurrentExecutionAttempt();
            if (assignment.isFullyFinished) {
                this.assignFinishedStateToTask(currentExecutionAttempt);
                continue;
            }
            this.assignNonFinishedStateToTask(assignment, operatorIDs, subTaskIndex, currentExecutionAttempt);
        }
    }

    private void assignFinishedStateToTask(Execution currentExecutionAttempt) {
        JobManagerTaskRestore taskRestore = new JobManagerTaskRestore(this.restoreCheckpointId, TaskStateSnapshot.FINISHED_ON_RESTORE);
        currentExecutionAttempt.setInitialState(taskRestore);
    }

    private void assignNonFinishedStateToTask(TaskStateAssignment assignment, List<OperatorIDPair> operatorIDs, int subTaskIndex, Execution currentExecutionAttempt) {
        TaskStateSnapshot taskState = new TaskStateSnapshot(operatorIDs.size(), false);
        for (OperatorIDPair operatorID : operatorIDs) {
            OperatorInstanceID instanceID = OperatorInstanceID.of(subTaskIndex, operatorID.getGeneratedOperatorID());
            OperatorSubtaskState operatorSubtaskState = assignment.getSubtaskState(instanceID);
            taskState.putSubtaskStateByOperatorID(operatorID.getGeneratedOperatorID(), operatorSubtaskState);
        }
        JobManagerTaskRestore taskRestore = new JobManagerTaskRestore(this.restoreCheckpointId, taskState);
        currentExecutionAttempt.setInitialState(taskRestore);
    }

    public void checkParallelismPreconditions(TaskStateAssignment taskStateAssignment) {
        for (OperatorState operatorState : taskStateAssignment.oldState.values()) {
            StateAssignmentOperation.checkParallelismPreconditions(operatorState, taskStateAssignment.executionJobVertex);
        }
    }

    private void reDistributeKeyedStates(List<KeyGroupRange> keyGroupPartitions, TaskStateAssignment stateAssignment) {
        stateAssignment.oldState.forEach((operatorID, operatorState) -> {
            for (int subTaskIndex = 0; subTaskIndex < stateAssignment.newParallelism; ++subTaskIndex) {
                OperatorInstanceID instanceID = OperatorInstanceID.of(subTaskIndex, operatorID);
                Tuple2<List<KeyedStateHandle>, List<KeyedStateHandle>> subKeyedStates = this.reAssignSubKeyedStates((OperatorState)operatorState, keyGroupPartitions, subTaskIndex, stateAssignment.newParallelism, operatorState.getParallelism());
                stateAssignment.subManagedKeyedState.put(instanceID, (List<KeyedStateHandle>)subKeyedStates.f0);
                stateAssignment.subRawKeyedState.put(instanceID, (List<KeyedStateHandle>)subKeyedStates.f1);
            }
        });
    }

    private Tuple2<List<KeyedStateHandle>, List<KeyedStateHandle>> reAssignSubKeyedStates(OperatorState operatorState, List<KeyGroupRange> keyGroupPartitions, int subTaskIndex, int newParallelism, int oldParallelism) {
        List<Object> subRawKeyedState;
        List<Object> subManagedKeyedState;
        if (newParallelism == oldParallelism) {
            if (operatorState.getState(subTaskIndex) != null) {
                subManagedKeyedState = operatorState.getState(subTaskIndex).getManagedKeyedState().asList();
                subRawKeyedState = operatorState.getState(subTaskIndex).getRawKeyedState().asList();
            } else {
                subManagedKeyedState = Collections.emptyList();
                subRawKeyedState = Collections.emptyList();
            }
        } else {
            subManagedKeyedState = StateAssignmentOperation.getManagedKeyedStateHandles(operatorState, keyGroupPartitions.get(subTaskIndex));
            subRawKeyedState = StateAssignmentOperation.getRawKeyedStateHandles(operatorState, keyGroupPartitions.get(subTaskIndex));
        }
        if (subManagedKeyedState.isEmpty() && subRawKeyedState.isEmpty()) {
            return new Tuple2<List<KeyedStateHandle>, List<KeyedStateHandle>>(Collections.emptyList(), Collections.emptyList());
        }
        return new Tuple2<List<KeyedStateHandle>, List<KeyedStateHandle>>(subManagedKeyedState, subRawKeyedState);
    }

    public static <T extends StateObject> void reDistributePartitionableStates(Map<OperatorID, OperatorState> oldOperatorStates, int newParallelism, Function<OperatorSubtaskState, StateObjectCollection<T>> extractHandle, OperatorStateRepartitioner<T> stateRepartitioner, Map<OperatorInstanceID, List<T>> result) {
        Map oldStates = StateAssignmentOperation.splitManagedAndRawOperatorStates(oldOperatorStates, extractHandle);
        oldOperatorStates.forEach((operatorID, oldOperatorState) -> result.putAll(StateAssignmentOperation.applyRepartitioner(operatorID, stateRepartitioner, (List)oldStates.get(operatorID), oldOperatorState.getParallelism(), newParallelism)));
    }

    public void reDistributeResultSubpartitionStates(TaskStateAssignment assignment) {
        if (!assignment.hasOutputState && !assignment.hasDownstreamInputStates()) {
            return;
        }
        this.checkForUnsupportedToplogyChanges(assignment.oldState, OperatorSubtaskState::getResultSubpartitionState, assignment.outputOperatorID);
        OperatorState outputState = assignment.oldState.get(assignment.outputOperatorID);
        List outputOperatorState = StateAssignmentOperation.splitBySubtasks(outputState, OperatorSubtaskState::getResultSubpartitionState);
        ExecutionJobVertex executionJobVertex = assignment.executionJobVertex;
        List<IntermediateDataSet> outputs = executionJobVertex.getJobVertex().getProducedDataSets();
        if (outputState.getParallelism() == executionJobVertex.getParallelism()) {
            assignment.resultSubpartitionStates.putAll(StateAssignmentOperation.toInstanceMap(assignment.outputOperatorID, outputOperatorState));
            return;
        }
        for (int partitionIndex = 0; partitionIndex < outputs.size(); ++partitionIndex) {
            List partitionState = outputs.size() == 1 ? outputOperatorState : StateAssignmentOperation.getPartitionState(outputOperatorState, ResultSubpartitionInfo::getPartitionIdx, partitionIndex);
            MappingBasedRepartitioner repartitioner = new MappingBasedRepartitioner(assignment.getOutputMapping(partitionIndex).getRescaleMappings());
            Map repartitioned = StateAssignmentOperation.applyRepartitioner(assignment.outputOperatorID, repartitioner, partitionState, outputOperatorState.size(), executionJobVertex.getParallelism());
            StateAssignmentOperation.addToSubtasks(assignment.resultSubpartitionStates, repartitioned);
        }
    }

    public void reDistributeInputChannelStates(TaskStateAssignment stateAssignment) {
        if (!stateAssignment.hasInputState && !stateAssignment.hasUpstreamOutputStates()) {
            return;
        }
        this.checkForUnsupportedToplogyChanges(stateAssignment.oldState, OperatorSubtaskState::getInputChannelState, stateAssignment.inputOperatorID);
        ExecutionJobVertex executionJobVertex = stateAssignment.executionJobVertex;
        List<IntermediateResult> inputs = executionJobVertex.getInputs();
        OperatorState inputState = stateAssignment.oldState.get(stateAssignment.inputOperatorID);
        List inputOperatorState = StateAssignmentOperation.splitBySubtasks(inputState, OperatorSubtaskState::getInputChannelState);
        boolean hasAnyFullMapper = executionJobVertex.getJobVertex().getInputs().stream().map(JobEdge::getDownstreamSubtaskStateMapper).anyMatch(m -> m.equals((Object)SubtaskStateMapper.FULL));
        boolean hasAnyPreviousOperatorChanged = executionJobVertex.getInputs().stream().map(IntermediateResult::getProducer).map(this.vertexAssignments::get).anyMatch(taskStateAssignment -> {
            int oldParallelism = stateAssignment.oldState.get(stateAssignment.inputOperatorID).getParallelism();
            return oldParallelism != taskStateAssignment.executionJobVertex.getParallelism();
        });
        if (!(inputState.getParallelism() != executionJobVertex.getParallelism() || hasAnyFullMapper && hasAnyPreviousOperatorChanged)) {
            stateAssignment.inputChannelStates.putAll(StateAssignmentOperation.toInstanceMap(stateAssignment.inputOperatorID, inputOperatorState));
            return;
        }
        for (int gateIndex = 0; gateIndex < inputs.size(); ++gateIndex) {
            RescaleMappings mapping = stateAssignment.getInputMapping(gateIndex).getRescaleMappings();
            List gateState = inputs.size() == 1 ? inputOperatorState : StateAssignmentOperation.getPartitionState(inputOperatorState, InputChannelInfo::getGateIdx, gateIndex);
            MappingBasedRepartitioner repartitioner = new MappingBasedRepartitioner(mapping);
            Map repartitioned = StateAssignmentOperation.applyRepartitioner(stateAssignment.inputOperatorID, repartitioner, gateState, inputOperatorState.size(), stateAssignment.newParallelism);
            StateAssignmentOperation.addToSubtasks(stateAssignment.inputChannelStates, repartitioned);
        }
    }

    private static <K, V> void addToSubtasks(Map<K, List<V>> target, Map<K, List<V>> toAdd) {
        toAdd.forEach((key, values) -> target.computeIfAbsent(key, unused -> new ArrayList(values.size())).addAll(values));
    }

    private <T extends AbstractChannelStateHandle<?>> void checkForUnsupportedToplogyChanges(Map<OperatorID, OperatorState> oldOperatorStates, Function<OperatorSubtaskState, StateObjectCollection<T>> extractHandle, OperatorID expectedOperatorID) {
        List unexpectedState = oldOperatorStates.entrySet().stream().filter(idAndState -> !((OperatorID)idAndState.getKey()).equals(expectedOperatorID)).filter(idAndState -> this.hasChannelState((OperatorState)idAndState.getValue(), extractHandle)).map(Map.Entry::getKey).collect(Collectors.toList());
        if (!unexpectedState.isEmpty()) {
            throw new IllegalStateException("Cannot recover from unaligned checkpoint when topology changes, such that data exchanges with persisted data are now chained.\nThe following operators contain channel state: " + unexpectedState);
        }
    }

    private <T extends AbstractChannelStateHandle<?>> boolean hasChannelState(OperatorState operatorState, Function<OperatorSubtaskState, StateObjectCollection<T>> extractHandle) {
        return operatorState.getSubtaskStates().values().stream().anyMatch(subState -> !this.isEmpty((StateObjectCollection)extractHandle.apply((OperatorSubtaskState)subState)));
    }

    private <T extends AbstractChannelStateHandle<?>> boolean isEmpty(StateObjectCollection<T> s) {
        return s.stream().allMatch(state -> state.getOffsets().isEmpty());
    }

    private static <T extends AbstractChannelStateHandle<I>, I> List<List<T>> getPartitionState(List<List<T>> subtaskStates, Function<I, Integer> partitionExtractor, int partitionId) {
        return subtaskStates.stream().map(subtaskState -> subtaskState.stream().filter(state -> (Integer)partitionExtractor.apply(state.getInfo()) == partitionId).collect(Collectors.toList())).collect(Collectors.toList());
    }

    private static <T extends StateObject> Map<OperatorID, List<List<T>>> splitManagedAndRawOperatorStates(Map<OperatorID, OperatorState> operatorStates, Function<OperatorSubtaskState, StateObjectCollection<T>> extractHandle) {
        return operatorStates.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, operatorIdAndState -> StateAssignmentOperation.splitBySubtasks((OperatorState)operatorIdAndState.getValue(), extractHandle)));
    }

    private static <T extends StateObject> List<List<T>> splitBySubtasks(OperatorState operatorState, Function<OperatorSubtaskState, StateObjectCollection<T>> extractHandle) {
        ArrayList<List<T>> statePerSubtask = new ArrayList<List<T>>(operatorState.getParallelism());
        for (int subTaskIndex = 0; subTaskIndex < operatorState.getParallelism(); ++subTaskIndex) {
            OperatorSubtaskState subtaskState = operatorState.getState(subTaskIndex);
            statePerSubtask.add(subtaskState == null ? Collections.emptyList() : extractHandle.apply(subtaskState).asList());
        }
        return statePerSubtask;
    }

    public static List<KeyedStateHandle> getManagedKeyedStateHandles(OperatorState operatorState, KeyGroupRange subtaskKeyGroupRange) {
        int parallelism = operatorState.getParallelism();
        List<KeyedStateHandle> subtaskKeyedStateHandles = null;
        for (int i = 0; i < parallelism; ++i) {
            if (operatorState.getState(i) == null) continue;
            StateObjectCollection<KeyedStateHandle> keyedStateHandles = operatorState.getState(i).getManagedKeyedState();
            if (subtaskKeyedStateHandles == null) {
                subtaskKeyedStateHandles = new ArrayList<KeyedStateHandle>(parallelism * keyedStateHandles.size());
            }
            StateAssignmentOperation.extractIntersectingState(keyedStateHandles, subtaskKeyGroupRange, subtaskKeyedStateHandles);
        }
        return subtaskKeyedStateHandles != null ? subtaskKeyedStateHandles : Collections.emptyList();
    }

    public static List<KeyedStateHandle> getRawKeyedStateHandles(OperatorState operatorState, KeyGroupRange subtaskKeyGroupRange) {
        int parallelism = operatorState.getParallelism();
        List<KeyedStateHandle> extractedKeyedStateHandles = null;
        for (int i = 0; i < parallelism; ++i) {
            if (operatorState.getState(i) == null) continue;
            StateObjectCollection<KeyedStateHandle> rawKeyedState = operatorState.getState(i).getRawKeyedState();
            if (extractedKeyedStateHandles == null) {
                extractedKeyedStateHandles = new ArrayList<KeyedStateHandle>(parallelism * rawKeyedState.size());
            }
            StateAssignmentOperation.extractIntersectingState(rawKeyedState, subtaskKeyGroupRange, extractedKeyedStateHandles);
        }
        return extractedKeyedStateHandles != null ? extractedKeyedStateHandles : Collections.emptyList();
    }

    @VisibleForTesting
    public static void extractIntersectingState(Collection<? extends KeyedStateHandle> originalSubtaskStateHandles, KeyGroupRange rangeToExtract, List<KeyedStateHandle> extractedStateCollector) {
        for (KeyedStateHandle keyedStateHandle : originalSubtaskStateHandles) {
            KeyedStateHandle intersectedKeyedStateHandle;
            if (keyedStateHandle == null || (intersectedKeyedStateHandle = keyedStateHandle.getIntersection(rangeToExtract)) == null) continue;
            extractedStateCollector.add(intersectedKeyedStateHandle);
        }
    }

    public static List<KeyGroupRange> createKeyGroupPartitions(int numberKeyGroups, int parallelism) {
        Preconditions.checkArgument(numberKeyGroups >= parallelism);
        ArrayList<KeyGroupRange> result = new ArrayList<KeyGroupRange>(parallelism);
        for (int i = 0; i < parallelism; ++i) {
            result.add(KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(numberKeyGroups, parallelism, i));
        }
        return result;
    }

    private static void checkParallelismPreconditions(OperatorState operatorState, ExecutionJobVertex executionJobVertex) {
        if (operatorState.getMaxParallelism() < executionJobVertex.getParallelism()) {
            throw new IllegalStateException("The state for task " + executionJobVertex.getJobVertexId() + " can not be restored. The maximum parallelism (" + operatorState.getMaxParallelism() + ") of the restored state is lower than the configured parallelism (" + executionJobVertex.getParallelism() + "). Please reduce the parallelism of the task to be lower or equal to the maximum parallelism.");
        }
        if (operatorState.getMaxParallelism() != executionJobVertex.getMaxParallelism()) {
            if (executionJobVertex.canRescaleMaxParallelism(operatorState.getMaxParallelism())) {
                LOG.debug("Rescaling maximum parallelism for JobVertex {} from {} to {}", new Object[]{executionJobVertex.getJobVertexId(), executionJobVertex.getMaxParallelism(), operatorState.getMaxParallelism()});
                executionJobVertex.setMaxParallelism(operatorState.getMaxParallelism());
            } else {
                throw new IllegalStateException("The maximum parallelism (" + operatorState.getMaxParallelism() + ") with which the latest checkpoint of the execution job vertex " + executionJobVertex + " has been taken and the current maximum parallelism (" + executionJobVertex.getMaxParallelism() + ") changed. This is currently not supported.");
            }
        }
    }

    private static void checkStateMappingCompleteness(boolean allowNonRestoredState, Map<OperatorID, OperatorState> operatorStates, Set<ExecutionJobVertex> tasks) {
        HashSet<OperatorID> allOperatorIDs = new HashSet<OperatorID>();
        for (ExecutionJobVertex executionJobVertex : tasks) {
            for (OperatorIDPair operatorIDPair : executionJobVertex.getOperatorIDs()) {
                allOperatorIDs.add(operatorIDPair.getGeneratedOperatorID());
                operatorIDPair.getUserDefinedOperatorID().ifPresent(allOperatorIDs::add);
            }
        }
        for (Map.Entry entry : operatorStates.entrySet()) {
            if (allOperatorIDs.contains(entry.getKey())) continue;
            OperatorState operatorState = (OperatorState)entry.getValue();
            if (allowNonRestoredState) {
                LOG.info("Skipped checkpoint state for operator {}.", (Object)operatorState.getOperatorID());
                continue;
            }
            throw new IllegalStateException("There is no operator for the state " + operatorState.getOperatorID());
        }
    }

    public static <T> Map<OperatorInstanceID, List<T>> applyRepartitioner(OperatorID operatorID, OperatorStateRepartitioner<T> opStateRepartitioner, List<List<T>> chainOpParallelStates, int oldParallelism, int newParallelism) {
        List<List<T>> states = StateAssignmentOperation.applyRepartitioner(opStateRepartitioner, chainOpParallelStates, oldParallelism, newParallelism);
        return StateAssignmentOperation.toInstanceMap(operatorID, states);
    }

    private static <T> Map<OperatorInstanceID, List<T>> toInstanceMap(OperatorID operatorID, List<List<T>> states) {
        HashMap<OperatorInstanceID, List<T>> result = CollectionUtil.newHashMapWithExpectedSize(states.size());
        for (int subtaskIndex = 0; subtaskIndex < states.size(); ++subtaskIndex) {
            Preconditions.checkNotNull(states.get(subtaskIndex) != null, "states.get(subtaskIndex) is null");
            result.put(OperatorInstanceID.of(subtaskIndex, operatorID), states.get(subtaskIndex));
        }
        return result;
    }

    public static <T> List<List<T>> applyRepartitioner(OperatorStateRepartitioner<T> opStateRepartitioner, List<List<T>> chainOpParallelStates, int oldParallelism, int newParallelism) {
        if (chainOpParallelStates == null) {
            return Collections.emptyList();
        }
        return opStateRepartitioner.repartitionState(chainOpParallelStates, oldParallelism, newParallelism);
    }
}

