/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.scheduler.adaptivebatch.util;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.flink.runtime.executiongraph.IndexRange;
import org.apache.flink.runtime.executiongraph.JobVertexInputInfo;
import org.apache.flink.runtime.executiongraph.VertexInputInfoComputationUtils;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.scheduler.adaptivebatch.BlockingInputInfo;
import org.apache.flink.runtime.scheduler.adaptivebatch.util.AggregatedBlockingInputInfo;
import org.apache.flink.runtime.scheduler.adaptivebatch.util.SubpartitionSlice;
import org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AllToAllVertexInputInfoComputer {
    private static final Logger LOG = LoggerFactory.getLogger(AllToAllVertexInputInfoComputer.class);
    private final double skewedFactor;
    private final long defaultSkewedThreshold;

    public AllToAllVertexInputInfoComputer(double skewedFactor, long defaultSkewedThreshold) {
        this.skewedFactor = skewedFactor;
        this.defaultSkewedThreshold = defaultSkewedThreshold;
    }

    public Map<IntermediateDataSetID, JobVertexInputInfo> compute(JobVertexID jobVertexId, List<BlockingInputInfo> inputInfos, int parallelism, int minParallelism, int maxParallelism, long dataVolumePerTask) {
        ArrayList<BlockingInputInfo> inputInfosWithoutInterKeysCorrelation = new ArrayList<BlockingInputInfo>();
        ArrayList<BlockingInputInfo> inputInfosWithInterKeysCorrelation = new ArrayList<BlockingInputInfo>();
        for (BlockingInputInfo inputInfo : inputInfos) {
            if (inputInfo.areInterInputsKeysCorrelated()) {
                inputInfosWithInterKeysCorrelation.add(inputInfo);
                continue;
            }
            inputInfosWithoutInterKeysCorrelation.add(inputInfo);
        }
        HashMap<IntermediateDataSetID, JobVertexInputInfo> vertexInputInfos = new HashMap<IntermediateDataSetID, JobVertexInputInfo>();
        if (!inputInfosWithInterKeysCorrelation.isEmpty()) {
            vertexInputInfos.putAll(this.computeJobVertexInputInfosForInputsWithInterKeysCorrelation(jobVertexId, inputInfosWithInterKeysCorrelation, parallelism, minParallelism, maxParallelism, VertexParallelismAndInputInfosDeciderUtils.calculateDataVolumePerTaskForInputsGroup(dataVolumePerTask, inputInfosWithInterKeysCorrelation, inputInfos)));
            parallelism = VertexParallelismAndInputInfosDeciderUtils.checkAndGetParallelism(vertexInputInfos.values());
        }
        if (!inputInfosWithoutInterKeysCorrelation.isEmpty()) {
            vertexInputInfos.putAll(this.computeJobVertexInputInfosForInputsWithoutInterKeysCorrelation(inputInfosWithoutInterKeysCorrelation, parallelism, VertexParallelismAndInputInfosDeciderUtils.calculateDataVolumePerTaskForInputsGroup(dataVolumePerTask, inputInfosWithoutInterKeysCorrelation, inputInfos)));
        }
        return vertexInputInfos;
    }

    private Map<IntermediateDataSetID, JobVertexInputInfo> computeJobVertexInputInfosForInputsWithInterKeysCorrelation(JobVertexID jobVertexId, List<BlockingInputInfo> inputInfos, int parallelism, int minParallelism, int maxParallelism, long dataVolumePerTask) {
        List<BlockingInputInfo> nonBroadcastInputInfos = VertexParallelismAndInputInfosDeciderUtils.getNonBroadcastInputInfos(inputInfos);
        if (nonBroadcastInputInfos.isEmpty()) {
            LOG.info("All inputs are broadcast for vertex {}, fallback to compute a parallelism that can evenly distribute num subpartitions.", (Object)jobVertexId);
            return VertexInputInfoComputationUtils.computeVertexInputInfos(parallelism, inputInfos, true);
        }
        Map<Integer, List<SubpartitionSlice>> subpartitionSlicesByTypeNumber = this.createSubpartitionSlicesForInputsWithInterKeysCorrelation(nonBroadcastInputInfos, dataVolumePerTask);
        Optional<List<IndexRange>> optionalSubpartitionSliceRanges = VertexParallelismAndInputInfosDeciderUtils.tryComputeSubpartitionSliceRange(minParallelism, maxParallelism, dataVolumePerTask, subpartitionSlicesByTypeNumber);
        if (optionalSubpartitionSliceRanges.isEmpty()) {
            LOG.info("Cannot find a legal parallelism to evenly distribute data amount for job vertex {}, fallback to compute a parallelism that can evenly distribute num subpartitions.", (Object)jobVertexId);
            return VertexInputInfoComputationUtils.computeVertexInputInfos(parallelism, inputInfos, true);
        }
        List<IndexRange> subpartitionSliceRanges = optionalSubpartitionSliceRanges.get();
        Preconditions.checkState(VertexParallelismAndInputInfosDeciderUtils.isLegalParallelism(subpartitionSliceRanges.size(), minParallelism, maxParallelism));
        return VertexParallelismAndInputInfosDeciderUtils.createJobVertexInputInfos(inputInfos, subpartitionSlicesByTypeNumber, subpartitionSliceRanges, index -> ((BlockingInputInfo)inputInfos.get((int)index)).getInputTypeNumber());
    }

    private Map<Integer, List<SubpartitionSlice>> createSubpartitionSlicesForInputsWithInterKeysCorrelation(List<BlockingInputInfo> nonBroadcastInputInfos, long dataVolumePerTask) {
        Map<Integer, AggregatedBlockingInputInfo> aggregatedInputInfoByTypeNumber = this.createAggregatedBlockingInputInfos(nonBroadcastInputInfos, dataVolumePerTask);
        int subPartitionNum = VertexParallelismAndInputInfosDeciderUtils.checkAndGetSubpartitionNumForAggregatedInputs(aggregatedInputInfoByTypeNumber.values());
        HashMap<Integer, List<SubpartitionSlice>> subpartitionSliceGroupByTypeNumber = new HashMap<Integer, List<SubpartitionSlice>>();
        for (int subpartitionIndex = 0; subpartitionIndex < subPartitionNum; ++subpartitionIndex) {
            Map<Integer, List<SubpartitionSlice>> subpartitionSlices = AllToAllVertexInputInfoComputer.createBalancedSubpartitionSlicesForInputsWithInterKeysCorrelation(subpartitionIndex, aggregatedInputInfoByTypeNumber);
            ArrayList<Integer> typeNumberList = new ArrayList<Integer>(subpartitionSlices.keySet());
            ArrayList originalRangeLists = new ArrayList(subpartitionSlices.values());
            List cartesianProductRangeList = VertexParallelismAndInputInfosDeciderUtils.cartesianProduct(originalRangeLists);
            for (List subpartitionSlice : cartesianProductRangeList) {
                for (int j = 0; j < subpartitionSlice.size(); ++j) {
                    int typeNumber = (Integer)typeNumberList.get(j);
                    subpartitionSliceGroupByTypeNumber.computeIfAbsent(typeNumber, ignored -> new ArrayList()).add((SubpartitionSlice)subpartitionSlice.get(j));
                }
            }
        }
        return subpartitionSliceGroupByTypeNumber;
    }

    private Map<Integer, AggregatedBlockingInputInfo> createAggregatedBlockingInputInfos(List<BlockingInputInfo> nonBroadcastInputInfos, long dataVolumePerTask) {
        Map<Integer, List<BlockingInputInfo>> inputsByTypeNumber = nonBroadcastInputInfos.stream().collect(Collectors.groupingBy(BlockingInputInfo::getInputTypeNumber));
        Preconditions.checkState(AllToAllVertexInputInfoComputer.hasSameIntraInputKeyCorrelation(inputsByTypeNumber));
        HashMap<Integer, AggregatedBlockingInputInfo> blockingInputInfoContexts = new HashMap<Integer, AggregatedBlockingInputInfo>();
        for (Map.Entry<Integer, List<BlockingInputInfo>> entry : inputsByTypeNumber.entrySet()) {
            Integer typeNumber = entry.getKey();
            List<BlockingInputInfo> inputInfos = entry.getValue();
            blockingInputInfoContexts.put(typeNumber, AggregatedBlockingInputInfo.createAggregatedBlockingInputInfo(this.defaultSkewedThreshold, this.skewedFactor, dataVolumePerTask, inputInfos));
        }
        return blockingInputInfoContexts;
    }

    private static Map<Integer, List<SubpartitionSlice>> createBalancedSubpartitionSlicesForInputsWithInterKeysCorrelation(int subpartitionIndex, Map<Integer, AggregatedBlockingInputInfo> aggregatedInputInfoByTypeNumber) {
        HashMap<Integer, List<SubpartitionSlice>> subpartitionSlices = new HashMap<Integer, List<SubpartitionSlice>>();
        IndexRange subpartitionRange = new IndexRange(subpartitionIndex, subpartitionIndex);
        for (Map.Entry<Integer, AggregatedBlockingInputInfo> entry : aggregatedInputInfoByTypeNumber.entrySet()) {
            Integer typeNumber = entry.getKey();
            AggregatedBlockingInputInfo aggregatedBlockingInputInfo = entry.getValue();
            if (aggregatedBlockingInputInfo.isSplittable() && aggregatedBlockingInputInfo.isSkewedSubpartition(subpartitionIndex)) {
                List<IndexRange> partitionRanges = AllToAllVertexInputInfoComputer.computePartitionRangesEvenlyData(subpartitionIndex, aggregatedBlockingInputInfo.getTargetSize(), aggregatedBlockingInputInfo.getSubpartitionBytesByPartition());
                subpartitionSlices.put(typeNumber, SubpartitionSlice.createSubpartitionSlicesByMultiPartitionRanges(partitionRanges, subpartitionRange, aggregatedBlockingInputInfo.getSubpartitionBytesByPartition()));
                continue;
            }
            IndexRange partitionRange = new IndexRange(0, aggregatedBlockingInputInfo.getMaxPartitionNum() - 1);
            subpartitionSlices.put(typeNumber, Collections.singletonList(SubpartitionSlice.createSubpartitionSlice(partitionRange, subpartitionRange, aggregatedBlockingInputInfo.getAggregatedSubpartitionBytes(subpartitionIndex))));
        }
        return subpartitionSlices;
    }

    private static List<IndexRange> computePartitionRangesEvenlyData(int subPartitionIndex, long targetSize, Map<Integer, long[]> subPartitionBytesByPartitionIndex) {
        ArrayList<IndexRange> splitPartitionRange = new ArrayList<IndexRange>();
        int partitionNum = subPartitionBytesByPartitionIndex.size();
        long tmpSum = 0L;
        int startIndex = 0;
        for (int i = 0; i < partitionNum; ++i) {
            long[] subPartitionBytes = subPartitionBytesByPartitionIndex.get(i);
            long num = subPartitionBytes[subPartitionIndex];
            if (i == startIndex || tmpSum + num < targetSize) {
                tmpSum += num;
                continue;
            }
            splitPartitionRange.add(new IndexRange(startIndex, i - 1));
            startIndex = i;
            tmpSum = num;
        }
        splitPartitionRange.add(new IndexRange(startIndex, partitionNum - 1));
        return splitPartitionRange;
    }

    private Map<IntermediateDataSetID, JobVertexInputInfo> computeJobVertexInputInfosForInputsWithoutInterKeysCorrelation(List<BlockingInputInfo> inputInfos, int parallelism, long dataVolumePerTask) {
        long totalDataBytes = inputInfos.stream().mapToLong(BlockingInputInfo::getNumBytesProduced).sum();
        HashMap<IntermediateDataSetID, JobVertexInputInfo> vertexInputInfos = new HashMap<IntermediateDataSetID, JobVertexInputInfo>();
        for (BlockingInputInfo inputInfo : inputInfos) {
            vertexInputInfos.put(inputInfo.getResultId(), this.computeVertexInputInfoForInputWithoutInterKeysCorrelation(inputInfo, parallelism, VertexParallelismAndInputInfosDeciderUtils.calculateDataVolumePerTaskForInput(dataVolumePerTask, inputInfo.getNumBytesProduced(), totalDataBytes)));
        }
        return vertexInputInfos;
    }

    private JobVertexInputInfo computeVertexInputInfoForInputWithoutInterKeysCorrelation(BlockingInputInfo inputInfo, int parallelism, long dataVolumePerTask) {
        if (inputInfo.isBroadcast()) {
            return VertexParallelismAndInputInfosDeciderUtils.createdJobVertexInputInfoForBroadcast(inputInfo, parallelism);
        }
        List<SubpartitionSlice> subpartitionSlices = this.createSubpartitionSlicesForInputWithoutInterKeysCorrelation(inputInfo);
        Optional<List<IndexRange>> optionalSubpartitionSliceRanges = VertexParallelismAndInputInfosDeciderUtils.tryComputeSubpartitionSliceRange(parallelism, parallelism, dataVolumePerTask, Map.of(inputInfo.getInputTypeNumber(), subpartitionSlices));
        if (optionalSubpartitionSliceRanges.isEmpty()) {
            LOG.info("Cannot find a legal parallelism to evenly distribute data amount for input {}, fallback to compute a parallelism that can evenly distribute num subpartitions.", (Object)inputInfo.getResultId());
            return VertexInputInfoComputationUtils.computeVertexInputInfoForPointwise(inputInfo.getNumPartitions(), parallelism, inputInfo::getNumSubpartitions, true);
        }
        List<IndexRange> subpartitionSliceRanges = optionalSubpartitionSliceRanges.get();
        Preconditions.checkState(VertexParallelismAndInputInfosDeciderUtils.isLegalParallelism(subpartitionSliceRanges.size(), parallelism, parallelism));
        return VertexParallelismAndInputInfosDeciderUtils.createdJobVertexInputInfoForNonBroadcast(inputInfo, subpartitionSliceRanges, subpartitionSlices);
    }

    private List<SubpartitionSlice> createSubpartitionSlicesForInputWithoutInterKeysCorrelation(BlockingInputInfo inputInfo) {
        ArrayList<SubpartitionSlice> subpartitionSlices = new ArrayList<SubpartitionSlice>();
        if (inputInfo.isIntraInputKeyCorrelated()) {
            int numSubpartitions = VertexParallelismAndInputInfosDeciderUtils.checkAndGetSubpartitionNum(List.of(inputInfo));
            IndexRange partitionRange = new IndexRange(0, inputInfo.getNumPartitions() - 1);
            for (int i = 0; i < numSubpartitions; ++i) {
                IndexRange subpartitionRange = new IndexRange(i, i);
                subpartitionSlices.add(SubpartitionSlice.createSubpartitionSlice(partitionRange, subpartitionRange, inputInfo.getNumBytesProduced(partitionRange, subpartitionRange)));
            }
        } else {
            for (int i = 0; i < inputInfo.getNumPartitions(); ++i) {
                IndexRange partitionRange = new IndexRange(i, i);
                for (int j = 0; j < inputInfo.getNumSubpartitions(i); ++j) {
                    IndexRange subpartitionRange = new IndexRange(j, j);
                    subpartitionSlices.add(SubpartitionSlice.createSubpartitionSlice(partitionRange, subpartitionRange, inputInfo.getNumBytesProduced(partitionRange, subpartitionRange)));
                }
            }
        }
        return subpartitionSlices;
    }

    private static boolean hasSameIntraInputKeyCorrelation(Map<Integer, List<BlockingInputInfo>> inputGroups) {
        return inputGroups.values().stream().allMatch(inputs -> inputs.stream().map(BlockingInputInfo::isIntraInputKeyCorrelated).distinct().count() == 1L);
    }
}

