/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.scheduler.adaptivebatch.util;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.LongStream;
import org.apache.flink.runtime.executiongraph.ExecutionVertexInputInfo;
import org.apache.flink.runtime.executiongraph.IndexRange;
import org.apache.flink.runtime.executiongraph.IndexRangeUtil;
import org.apache.flink.runtime.executiongraph.JobVertexInputInfo;
import org.apache.flink.runtime.executiongraph.VertexInputInfoComputationUtils;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.scheduler.adaptivebatch.BisectionSearchUtils;
import org.apache.flink.runtime.scheduler.adaptivebatch.BlockingInputInfo;
import org.apache.flink.runtime.scheduler.adaptivebatch.util.AggregatedBlockingInputInfo;
import org.apache.flink.runtime.scheduler.adaptivebatch.util.SubpartitionSlice;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class VertexParallelismAndInputInfosDeciderUtils {
    private static final Logger LOG = LoggerFactory.getLogger(VertexParallelismAndInputInfosDeciderUtils.class);

    public static Optional<List<IndexRange>> adjustToClosestLegalParallelism(long currentDataVolumeLimit, int currentParallelism, int minParallelism, int maxParallelism, long minLimit, long maxLimit, Function<Long, Integer> parallelismComputer, Function<Long, List<IndexRange>> subpartitionRangesComputer) {
        long adjustedDataVolumeLimit = currentDataVolumeLimit;
        if (currentParallelism < minParallelism) {
            adjustedDataVolumeLimit = BisectionSearchUtils.findMaxLegalValue(value -> (Integer)parallelismComputer.apply((Long)value) >= minParallelism, minLimit, currentDataVolumeLimit);
            long minPossibleLegalParallelism = parallelismComputer.apply(adjustedDataVolumeLimit).intValue();
            adjustedDataVolumeLimit = BisectionSearchUtils.findMinLegalValue(value -> (long)((Integer)parallelismComputer.apply((Long)value)).intValue() == minPossibleLegalParallelism, minLimit, adjustedDataVolumeLimit);
        } else if (currentParallelism > maxParallelism) {
            adjustedDataVolumeLimit = BisectionSearchUtils.findMinLegalValue(value -> (Integer)parallelismComputer.apply((Long)value) <= maxParallelism, currentDataVolumeLimit, maxLimit);
        }
        int adjustedParallelism = parallelismComputer.apply(adjustedDataVolumeLimit);
        if (VertexParallelismAndInputInfosDeciderUtils.isLegalParallelism(adjustedParallelism, minParallelism, maxParallelism)) {
            return Optional.of(subpartitionRangesComputer.apply(adjustedDataVolumeLimit));
        }
        return Optional.empty();
    }

    public static <T> List<List<T>> cartesianProduct(List<List<T>> lists) {
        ArrayList<List<T>> resultLists = new ArrayList<List<T>>();
        if (lists.isEmpty()) {
            resultLists.add(new ArrayList());
            return resultLists;
        }
        List<T> firstList = lists.get(0);
        List<List<T>> remainingLists = VertexParallelismAndInputInfosDeciderUtils.cartesianProduct(lists.subList(1, lists.size()));
        for (T condition : firstList) {
            for (List<T> remainingList : remainingLists) {
                ArrayList<T> resultList = new ArrayList<T>();
                resultList.add(condition);
                resultList.addAll(remainingList);
                resultLists.add(resultList);
            }
        }
        return resultLists;
    }

    public static long median(long[] nums) {
        int len = nums.length;
        long[] sortedNums = LongStream.of(nums).sorted().toArray();
        if (len % 2 == 0) {
            return Math.max((sortedNums[len / 2] + sortedNums[len / 2 - 1]) / 2L, 1L);
        }
        return Math.max(sortedNums[len / 2], 1L);
    }

    public static long computeSkewThreshold(long medianSize, double skewedFactor, long defaultSkewedThreshold) {
        return (long)Math.max((double)medianSize * skewedFactor, (double)defaultSkewedThreshold);
    }

    public static long computeTargetSize(long[] subpartitionBytes, long skewedThreshold, long dataVolumePerTask) {
        long[] nonSkewPartitions = LongStream.of(subpartitionBytes).filter(v -> v <= skewedThreshold).toArray();
        if (nonSkewPartitions.length == 0) {
            return dataVolumePerTask;
        }
        return Math.max(dataVolumePerTask, LongStream.of(nonSkewPartitions).sum() / (long)nonSkewPartitions.length);
    }

    public static List<BlockingInputInfo> getNonBroadcastInputInfos(List<BlockingInputInfo> consumedResults) {
        return consumedResults.stream().filter(resultInfo -> !resultInfo.isBroadcast()).collect(Collectors.toList());
    }

    public static boolean hasSameNumPartitions(List<BlockingInputInfo> inputInfos) {
        Set partitionNums = inputInfos.stream().map(BlockingInputInfo::getNumPartitions).collect(Collectors.toSet());
        return partitionNums.size() == 1;
    }

    public static int getMaxNumPartitions(List<BlockingInputInfo> consumedResults) {
        Preconditions.checkArgument(!consumedResults.isEmpty());
        return consumedResults.stream().mapToInt(BlockingInputInfo::getNumPartitions).max().getAsInt();
    }

    public static int checkAndGetSubpartitionNum(List<BlockingInputInfo> consumedResults) {
        Set subpartitionNumSet = consumedResults.stream().flatMap(resultInfo -> IntStream.range(0, resultInfo.getNumPartitions()).boxed().map(resultInfo::getNumSubpartitions)).collect(Collectors.toSet());
        Preconditions.checkState(subpartitionNumSet.size() == 1);
        return (Integer)subpartitionNumSet.iterator().next();
    }

    public static int checkAndGetSubpartitionNumForAggregatedInputs(Collection<AggregatedBlockingInputInfo> inputInfos) {
        Set subpartitionNumSet = inputInfos.stream().map(AggregatedBlockingInputInfo::getNumSubpartitions).collect(Collectors.toSet());
        Preconditions.checkState(subpartitionNumSet.size() == 1);
        return (Integer)subpartitionNumSet.iterator().next();
    }

    public static boolean isLegalParallelism(int parallelism, int minParallelism, int maxParallelism) {
        return parallelism >= minParallelism && parallelism <= maxParallelism;
    }

    public static boolean checkAndGetIntraCorrelation(List<BlockingInputInfo> inputInfos) {
        Set intraCorrelationSet = inputInfos.stream().map(BlockingInputInfo::isIntraInputKeyCorrelated).collect(Collectors.toSet());
        Preconditions.checkArgument(intraCorrelationSet.size() == 1);
        return (Boolean)intraCorrelationSet.iterator().next();
    }

    public static int checkAndGetParallelism(Collection<JobVertexInputInfo> vertexInputInfos) {
        Set parallelismSet = vertexInputInfos.stream().map(vertexInputInfo -> vertexInputInfo.getExecutionVertexInputInfos().size()).collect(Collectors.toSet());
        Preconditions.checkState(parallelismSet.size() == 1);
        return (Integer)parallelismSet.iterator().next();
    }

    public static Optional<List<IndexRange>> tryComputeSubpartitionSliceRange(int minParallelism, int maxParallelism, long maxDataVolumePerTask, Map<Integer, List<SubpartitionSlice>> subpartitionSlices) {
        Optional<List<IndexRange>> subpartitionSliceRanges = VertexParallelismAndInputInfosDeciderUtils.tryComputeSubpartitionSliceRangeEvenlyDistributedData(minParallelism, maxParallelism, maxDataVolumePerTask, subpartitionSlices);
        if (subpartitionSliceRanges.isEmpty()) {
            LOG.info("Failed to compute a legal subpartition slice range that can evenly distribute data amount, fallback to compute it that can evenly distribute the number of subpartition slices.");
            subpartitionSliceRanges = VertexParallelismAndInputInfosDeciderUtils.tryComputeSubpartitionSliceRangeEvenlyDistributedSubpartitionSlices(minParallelism, maxParallelism, subpartitionSlices);
        }
        return subpartitionSliceRanges;
    }

    public static Map<IntermediateDataSetID, JobVertexInputInfo> createJobVertexInputInfos(List<BlockingInputInfo> inputInfos, Map<Integer, List<SubpartitionSlice>> subpartitionSlices, List<IndexRange> subpartitionSliceRanges, Function<Integer, Integer> subpartitionSliceKeyResolver) {
        HashMap<IntermediateDataSetID, JobVertexInputInfo> vertexInputInfos = new HashMap<IntermediateDataSetID, JobVertexInputInfo>();
        for (int i = 0; i < inputInfos.size(); ++i) {
            BlockingInputInfo inputInfo = inputInfos.get(i);
            if (inputInfo.isBroadcast()) {
                vertexInputInfos.put(inputInfo.getResultId(), VertexParallelismAndInputInfosDeciderUtils.createdJobVertexInputInfoForBroadcast(inputInfo, subpartitionSliceRanges.size()));
                continue;
            }
            vertexInputInfos.put(inputInfo.getResultId(), VertexParallelismAndInputInfosDeciderUtils.createdJobVertexInputInfoForNonBroadcast(inputInfo, subpartitionSliceRanges, subpartitionSlices.get(subpartitionSliceKeyResolver.apply(i))));
        }
        return vertexInputInfos;
    }

    public static JobVertexInputInfo createdJobVertexInputInfoForBroadcast(BlockingInputInfo inputInfo, int parallelism) {
        Preconditions.checkArgument(inputInfo.isBroadcast());
        int numPartitions = inputInfo.getNumPartitions();
        ArrayList<ExecutionVertexInputInfo> executionVertexInputInfos = new ArrayList<ExecutionVertexInputInfo>();
        for (int i = 0; i < parallelism; ++i) {
            ExecutionVertexInputInfo executionVertexInputInfo = inputInfo.isSingleSubpartitionContainsAllData() ? new ExecutionVertexInputInfo(i, new IndexRange(0, numPartitions - 1), new IndexRange(0, 0)) : new ExecutionVertexInputInfo(i, new IndexRange(0, numPartitions - 1), new IndexRange(0, inputInfo.getNumSubpartitions(0) - 1));
            executionVertexInputInfos.add(executionVertexInputInfo);
        }
        return new JobVertexInputInfo(executionVertexInputInfos);
    }

    public static JobVertexInputInfo createdJobVertexInputInfoForNonBroadcast(BlockingInputInfo inputInfo, List<IndexRange> subpartitionSliceRanges, List<SubpartitionSlice> subpartitionSlices) {
        Preconditions.checkArgument(!inputInfo.isBroadcast());
        int numPartitions = inputInfo.getNumPartitions();
        ArrayList<ExecutionVertexInputInfo> executionVertexInputInfos = new ArrayList<ExecutionVertexInputInfo>();
        for (int i = 0; i < subpartitionSliceRanges.size(); ++i) {
            IndexRange subpartitionSliceRange = subpartitionSliceRanges.get(i);
            Map<IndexRange, IndexRange> consumedSubpartitionGroups = VertexParallelismAndInputInfosDeciderUtils.computeConsumedSubpartitionGroups(subpartitionSliceRange, subpartitionSlices, numPartitions, inputInfo.isPointwise());
            executionVertexInputInfos.add(new ExecutionVertexInputInfo(i, consumedSubpartitionGroups));
        }
        return new JobVertexInputInfo(executionVertexInputInfos);
    }

    private static Optional<List<IndexRange>> tryComputeSubpartitionSliceRangeEvenlyDistributedData(int minParallelism, int maxParallelism, long maxDataVolumePerTask, Map<Integer, List<SubpartitionSlice>> subpartitionSlices) {
        int subpartitionSlicesSize = VertexParallelismAndInputInfosDeciderUtils.checkAndGetSubpartitionSlicesSize(subpartitionSlices);
        List<IndexRange> subpartitionSliceRanges = VertexParallelismAndInputInfosDeciderUtils.computeSubpartitionSliceRanges(maxDataVolumePerTask, subpartitionSlicesSize, subpartitionSlices);
        if (!VertexParallelismAndInputInfosDeciderUtils.isLegalParallelism(subpartitionSliceRanges.size(), minParallelism, maxParallelism)) {
            LOG.info("Failed to compute a legal subpartition slice range that can evenly distribute data amount, try to adjust to a legal parallelism.");
            long minBytesSize = maxDataVolumePerTask;
            long sumBytesSize = 0L;
            for (int i = 0; i < subpartitionSlicesSize; ++i) {
                long currentBytesSize = 0L;
                for (List<SubpartitionSlice> subpartitionSlice : subpartitionSlices.values()) {
                    currentBytesSize += subpartitionSlice.get(i).getDataBytes();
                }
                minBytesSize = Math.min(minBytesSize, currentBytesSize);
                sumBytesSize += currentBytesSize;
            }
            return VertexParallelismAndInputInfosDeciderUtils.adjustToClosestLegalParallelism(maxDataVolumePerTask, subpartitionSliceRanges.size(), minParallelism, maxParallelism, minBytesSize, sumBytesSize, limit -> VertexParallelismAndInputInfosDeciderUtils.computeParallelism(limit, subpartitionSlicesSize, subpartitionSlices), limit -> VertexParallelismAndInputInfosDeciderUtils.computeSubpartitionSliceRanges(limit, subpartitionSlicesSize, subpartitionSlices));
        }
        return Optional.of(subpartitionSliceRanges);
    }

    private static Optional<List<IndexRange>> tryComputeSubpartitionSliceRangeEvenlyDistributedSubpartitionSlices(int minParallelism, int maxParallelism, Map<Integer, List<SubpartitionSlice>> subpartitionSlices) {
        int subpartitionSlicesSize = VertexParallelismAndInputInfosDeciderUtils.checkAndGetSubpartitionSlicesSize(subpartitionSlices);
        if (subpartitionSlicesSize < minParallelism) {
            return Optional.empty();
        }
        int parallelism = Math.min(subpartitionSlicesSize, maxParallelism);
        ArrayList<IndexRange> subpartitionSliceRanges = new ArrayList<IndexRange>();
        for (int i = 0; i < parallelism; ++i) {
            int start = i * subpartitionSlicesSize / parallelism;
            int nextStart = (i + 1) * subpartitionSlicesSize / parallelism;
            subpartitionSliceRanges.add(new IndexRange(start, nextStart - 1));
        }
        Preconditions.checkState(subpartitionSliceRanges.size() == parallelism);
        return Optional.of(subpartitionSliceRanges);
    }

    /*
     * WARNING - void declaration
     */
    private static Map<IndexRange, IndexRange> computeConsumedSubpartitionGroups(IndexRange subpartitionSliceRange, List<SubpartitionSlice> subpartitionSlices, int numPartitions, boolean isPointwise) {
        IndexRange valueRange;
        Map<IndexRange, Object> rangeMap = new TreeMap(Comparator.comparingInt(IndexRange::getStartIndex));
        for (int i = subpartitionSliceRange.getStartIndex(); i <= subpartitionSliceRange.getEndIndex(); ++i) {
            void var7_8;
            SubpartitionSlice subpartitionSlice = subpartitionSlices.get(i);
            if (isPointwise) {
                IndexRange indexRange = subpartitionSlice.getPartitionRange(numPartitions);
                valueRange = subpartitionSlice.getSubpartitionRange();
            } else {
                IndexRange indexRange = subpartitionSlice.getSubpartitionRange();
                valueRange = subpartitionSlice.getPartitionRange(numPartitions);
            }
            rangeMap.computeIfAbsent((IndexRange)var7_8, k -> new ArrayList()).add(valueRange);
        }
        rangeMap = rangeMap.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> IndexRangeUtil.mergeIndexRanges((Collection)entry.getValue())));
        HashMap<IndexRange, List> reversedRangeMap = new HashMap<IndexRange, List>();
        for (Map.Entry entry2 : rangeMap.entrySet()) {
            valueRange = (IndexRange)entry2.getKey();
            for (IndexRange keyRange : (List)entry2.getValue()) {
                reversedRangeMap.computeIfAbsent(keyRange, k -> new ArrayList()).add(valueRange);
            }
        }
        Map<IndexRange, IndexRange> mergedReversedRangeMap = reversedRangeMap.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> {
            List<IndexRange> mergedRange = IndexRangeUtil.mergeIndexRanges((Collection)entry.getValue());
            Preconditions.checkState(mergedRange.size() == 1);
            return mergedRange.get(0);
        }));
        if (isPointwise) {
            return VertexParallelismAndInputInfosDeciderUtils.reverseIndexRangeMap(mergedReversedRangeMap);
        }
        return mergedReversedRangeMap;
    }

    private static List<IndexRange> computeSubpartitionSliceRanges(long limit, int subpartitionGroupSize, Map<Integer, List<SubpartitionSlice>> subpartitionSlices) {
        ArrayList<IndexRange> subpartitionSliceRanges = new ArrayList<IndexRange>();
        long accumulatedSize = 0L;
        int startIndex = 0;
        HashMap<Integer, Set> bucketsByTypeNumber = new HashMap<Integer, Set>();
        for (int i = 0; i < subpartitionGroupSize; ++i) {
            SubpartitionSlice subpartitionSlice;
            Integer typeNumber;
            long currentGroupSize = 0L;
            long currentGroupSizeDeduplicated = 0L;
            for (Map.Entry<Integer, List<SubpartitionSlice>> entry : subpartitionSlices.entrySet()) {
                typeNumber = entry.getKey();
                subpartitionSlice = entry.getValue().get(i);
                Set bucket = bucketsByTypeNumber.computeIfAbsent(typeNumber, ignored -> new HashSet());
                if (!bucket.contains(subpartitionSlice)) {
                    currentGroupSizeDeduplicated += subpartitionSlice.getDataBytes();
                }
                currentGroupSize += subpartitionSlice.getDataBytes();
            }
            if (i == startIndex || accumulatedSize + currentGroupSizeDeduplicated <= limit) {
                accumulatedSize += currentGroupSizeDeduplicated;
            } else {
                subpartitionSliceRanges.add(new IndexRange(startIndex, i - 1));
                startIndex = i;
                accumulatedSize = currentGroupSize;
                bucketsByTypeNumber.clear();
            }
            for (Map.Entry<Integer, List<SubpartitionSlice>> entry : subpartitionSlices.entrySet()) {
                typeNumber = entry.getKey();
                subpartitionSlice = entry.getValue().get(i);
                bucketsByTypeNumber.computeIfAbsent(typeNumber, ignored -> new HashSet()).add(subpartitionSlice);
            }
        }
        subpartitionSliceRanges.add(new IndexRange(startIndex, subpartitionGroupSize - 1));
        return subpartitionSliceRanges;
    }

    private static int computeParallelism(long limit, int subpartitionSlicesSize, Map<Integer, List<SubpartitionSlice>> subpartitionSlices) {
        int count = 1;
        long accumulatedSize = 0L;
        int startIndex = 0;
        HashMap<Integer, Set> bucketsByTypeNumber = new HashMap<Integer, Set>();
        for (int i = 0; i < subpartitionSlicesSize; ++i) {
            SubpartitionSlice subpartitionSlice;
            Integer typeNumber;
            long currentGroupSize = 0L;
            long currentGroupSizeDeduplicated = 0L;
            for (Map.Entry<Integer, List<SubpartitionSlice>> entry : subpartitionSlices.entrySet()) {
                typeNumber = entry.getKey();
                subpartitionSlice = entry.getValue().get(i);
                Set bucket = bucketsByTypeNumber.computeIfAbsent(typeNumber, ignored -> new HashSet());
                if (!bucket.contains(subpartitionSlice)) {
                    currentGroupSizeDeduplicated += subpartitionSlice.getDataBytes();
                }
                currentGroupSize += subpartitionSlice.getDataBytes();
            }
            if (i == startIndex || accumulatedSize + currentGroupSizeDeduplicated <= limit) {
                accumulatedSize += currentGroupSizeDeduplicated;
            } else {
                ++count;
                startIndex = i;
                accumulatedSize = currentGroupSize;
                bucketsByTypeNumber.clear();
            }
            for (Map.Entry<Integer, List<SubpartitionSlice>> entry : subpartitionSlices.entrySet()) {
                typeNumber = entry.getKey();
                subpartitionSlice = entry.getValue().get(i);
                bucketsByTypeNumber.computeIfAbsent(typeNumber, ignored -> new HashSet()).add(subpartitionSlice);
            }
        }
        return count;
    }

    private static int checkAndGetSubpartitionSlicesSize(Map<Integer, List<SubpartitionSlice>> subpartitionSlices) {
        Set subpartitionSliceSizes = subpartitionSlices.values().stream().map(List::size).collect(Collectors.toSet());
        Preconditions.checkArgument(subpartitionSliceSizes.size() == 1);
        return (Integer)subpartitionSliceSizes.iterator().next();
    }

    private static Map<IndexRange, IndexRange> reverseIndexRangeMap(Map<IndexRange, IndexRange> indexRangeMap) {
        HashMap<IndexRange, IndexRange> reversedRangeMap = new HashMap<IndexRange, IndexRange>();
        for (Map.Entry<IndexRange, IndexRange> entry : indexRangeMap.entrySet()) {
            Preconditions.checkState(!reversedRangeMap.containsKey(entry.getValue()));
            reversedRangeMap.put(entry.getValue(), entry.getKey());
        }
        return reversedRangeMap;
    }

    public static long calculateDataVolumePerTaskForInputsGroup(long globalDataVolumePerTask, List<BlockingInputInfo> inputsGroup, List<BlockingInputInfo> allInputs) {
        return VertexParallelismAndInputInfosDeciderUtils.calculateDataVolumePerTaskForInput(globalDataVolumePerTask, inputsGroup.stream().mapToLong(BlockingInputInfo::getNumBytesProduced).sum(), allInputs.stream().mapToLong(BlockingInputInfo::getNumBytesProduced).sum());
    }

    public static long calculateDataVolumePerTaskForInput(long globalDataVolumePerTask, long inputsGroupBytes, long totalDataBytes) {
        return (long)((double)inputsGroupBytes / (double)totalDataBytes * (double)globalDataVolumePerTask);
    }

    public static void logBalancedDataDistributionOptimizationResult(Logger logger, JobVertexID jobVertexId, BlockingInputInfo inputInfo, JobVertexInputInfo optimizedJobVertexInputInfo) {
        int parallelism;
        List<ExecutionVertexInputInfo> nonOptimizedExecutionVertexInputInfos;
        List<ExecutionVertexInputInfo> optimizedExecutionVertexInputInfos = optimizedJobVertexInputInfo.getExecutionVertexInputInfos();
        if (!optimizedExecutionVertexInputInfos.equals(nonOptimizedExecutionVertexInputInfos = VertexParallelismAndInputInfosDeciderUtils.computeNumBasedJobVertexInputInfo(parallelism = optimizedExecutionVertexInputInfos.size(), inputInfo).getExecutionVertexInputInfos())) {
            logger.info("Optimized the balanced data distribution for vertex {}, which reads from result {} with type number {}", new Object[]{jobVertexId, inputInfo.getResultId(), inputInfo.getInputTypeNumber()});
        }
    }

    private static JobVertexInputInfo computeNumBasedJobVertexInputInfo(int parallelism, BlockingInputInfo inputInfo) {
        int sourceParallelism = inputInfo.getNumPartitions();
        if (inputInfo.isPointwise()) {
            return VertexInputInfoComputationUtils.computeVertexInputInfoForPointwise(sourceParallelism, parallelism, inputInfo::getNumSubpartitions, true);
        }
        return VertexInputInfoComputationUtils.computeVertexInputInfoForAllToAll(sourceParallelism, parallelism, inputInfo::getNumSubpartitions, true, inputInfo.isBroadcast(), inputInfo.isSingleSubpartitionContainsAllData());
    }

    static int checkAndGetPartitionNum(List<BlockingInputInfo> consumedResults) {
        Set subpartitionNumSet = consumedResults.stream().map(BlockingInputInfo::getNumPartitions).collect(Collectors.toSet());
        Preconditions.checkState(subpartitionNumSet.size() == 1);
        return (Integer)subpartitionNumSet.iterator().next();
    }

    static int getMinSubpartitionCount(List<BlockingInputInfo> consumedResults) {
        Preconditions.checkState(!consumedResults.isEmpty());
        int minSubpartitionCount = Integer.MAX_VALUE;
        for (BlockingInputInfo inputInfo : consumedResults) {
            int numPartitions = inputInfo.getNumPartitions();
            int numSubpartitions = VertexParallelismAndInputInfosDeciderUtils.checkAndGetSubpartitionNum(List.of(inputInfo));
            minSubpartitionCount = Math.min(minSubpartitionCount, numPartitions * numSubpartitions);
        }
        return minSubpartitionCount;
    }

    static List<BlockingInputInfo> getInputsWithIntraCorrelation(List<BlockingInputInfo> inputInfos) {
        return inputInfos.stream().filter(BlockingInputInfo::isIntraInputKeyCorrelated).collect(Collectors.toList());
    }
}

