/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.scheduler.adaptivebatch.util;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.flink.runtime.executiongraph.IndexRange;
import org.apache.flink.runtime.executiongraph.JobVertexInputInfo;
import org.apache.flink.runtime.executiongraph.VertexInputInfoComputationUtils;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.scheduler.adaptivebatch.BlockingInputInfo;
import org.apache.flink.runtime.scheduler.adaptivebatch.util.SubpartitionSlice;
import org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PointwiseVertexInputInfoComputer {
    private static final Logger LOG = LoggerFactory.getLogger(PointwiseVertexInputInfoComputer.class);
    private static final int MAX_NUM_SUBPARTITION_SLICES_FACTOR = 32;

    public Map<IntermediateDataSetID, JobVertexInputInfo> compute(List<BlockingInputInfo> inputInfos, int parallelism, int minParallelism, int maxParallelism, long dataVolumePerTask) {
        Map<Integer, List<SubpartitionSlice>> subpartitionSlicesByInputIndex = PointwiseVertexInputInfoComputer.createSubpartitionSlicesByInputIndex(inputInfos, maxParallelism);
        Optional<List<IndexRange>> optionalSubpartitionSliceRanges = VertexParallelismAndInputInfosDeciderUtils.tryComputeSubpartitionSliceRange(minParallelism, maxParallelism, dataVolumePerTask, subpartitionSlicesByInputIndex);
        if (optionalSubpartitionSliceRanges.isEmpty()) {
            LOG.info("Cannot find a legal parallelism to evenly distribute data amount for inputs {}, fallback to compute a parallelism that can evenly distribute num subpartitions.", inputInfos.stream().map(BlockingInputInfo::getResultId).collect(Collectors.toList()));
            return VertexInputInfoComputationUtils.computeVertexInputInfos(parallelism, inputInfos, true);
        }
        List<IndexRange> subpartitionSliceRanges = optionalSubpartitionSliceRanges.get();
        Preconditions.checkState(VertexParallelismAndInputInfosDeciderUtils.isLegalParallelism(subpartitionSliceRanges.size(), minParallelism, maxParallelism));
        return VertexParallelismAndInputInfosDeciderUtils.createJobVertexInputInfos(inputInfos, subpartitionSlicesByInputIndex, subpartitionSliceRanges, index -> index);
    }

    private static Map<Integer, List<SubpartitionSlice>> createSubpartitionSlicesByInputIndex(List<BlockingInputInfo> inputInfos, int maxParallelism) {
        List<BlockingInputInfo> inputsWithIntraCorrelation = VertexParallelismAndInputInfosDeciderUtils.getInputsWithIntraCorrelation(inputInfos);
        int numSubpartitionSlices = !inputsWithIntraCorrelation.isEmpty() ? VertexParallelismAndInputInfosDeciderUtils.checkAndGetPartitionNum(inputsWithIntraCorrelation) : Math.min(VertexParallelismAndInputInfosDeciderUtils.getMinSubpartitionCount(inputInfos), 32 * maxParallelism);
        HashMap<Integer, List<SubpartitionSlice>> subpartitionSlices = new HashMap<Integer, List<SubpartitionSlice>>();
        for (int i = 0; i < inputInfos.size(); ++i) {
            BlockingInputInfo inputInfo = inputInfos.get(i);
            subpartitionSlices.put(i, PointwiseVertexInputInfoComputer.createSubpartitionSlices(inputInfo, numSubpartitionSlices));
        }
        return subpartitionSlices;
    }

    private static List<SubpartitionSlice> createSubpartitionSlices(BlockingInputInfo inputInfo, int total) {
        ArrayList<SubpartitionSlice> subpartitionSlices = new ArrayList<SubpartitionSlice>();
        int numPartitions = inputInfo.getNumPartitions();
        int numSubpartitions = VertexParallelismAndInputInfosDeciderUtils.checkAndGetSubpartitionNum(List.of(inputInfo));
        if (numPartitions >= total) {
            for (int i = 0; i < total; ++i) {
                int start = i * numPartitions / total;
                int nextStart = (i + 1) * numPartitions / total;
                IndexRange partitionRange = new IndexRange(start, nextStart - 1);
                IndexRange subpartitionRange = new IndexRange(0, numSubpartitions - 1);
                subpartitionSlices.add(SubpartitionSlice.createSubpartitionSlice(partitionRange, subpartitionRange, inputInfo.getNumBytesProduced(partitionRange, subpartitionRange)));
            }
        } else {
            for (int i = 0; i < numPartitions; ++i) {
                int count = (i + 1) * total / numPartitions - i * total / numPartitions;
                Preconditions.checkState(count > 0 && count <= numSubpartitions);
                IndexRange partitionRange = new IndexRange(i, i);
                for (int j = 0; j < count; ++j) {
                    int start = j * numSubpartitions / count;
                    int nextStart = (j + 1) * numSubpartitions / count;
                    IndexRange subpartitionRange = new IndexRange(start, nextStart - 1);
                    subpartitionSlices.add(SubpartitionSlice.createSubpartitionSlice(partitionRange, subpartitionRange, inputInfo.getNumBytesProduced(partitionRange, subpartitionRange)));
                }
            }
        }
        return subpartitionSlices;
    }
}

