/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.scheduler.adaptivebatch;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import org.apache.flink.configuration.BatchExecutionOptions;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.MemorySize;
import org.apache.flink.runtime.executiongraph.JobVertexInputInfo;
import org.apache.flink.runtime.executiongraph.ParallelismAndInputInfos;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.scheduler.adaptivebatch.BlockingInputInfo;
import org.apache.flink.runtime.scheduler.adaptivebatch.VertexParallelismAndInputInfosDecider;
import org.apache.flink.runtime.scheduler.adaptivebatch.util.AllToAllVertexInputInfoComputer;
import org.apache.flink.runtime.scheduler.adaptivebatch.util.PointwiseVertexInputInfoComputer;
import org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DefaultVertexParallelismAndInputInfosDecider
implements VertexParallelismAndInputInfosDecider {
    private static final Logger LOG = LoggerFactory.getLogger(DefaultVertexParallelismAndInputInfosDecider.class);
    private final int globalMaxParallelism;
    private final int globalMinParallelism;
    private final long dataVolumePerTask;
    private final int globalDefaultSourceParallelism;
    private final AllToAllVertexInputInfoComputer allToAllVertexInputInfoComputer;
    private final PointwiseVertexInputInfoComputer pointwiseVertexInputInfoComputer;

    private DefaultVertexParallelismAndInputInfosDecider(int globalMaxParallelism, int globalMinParallelism, MemorySize dataVolumePerTask, int globalDefaultSourceParallelism, double skewedFactor, long skewedThreshold) {
        Preconditions.checkArgument((globalMinParallelism > 0 ? 1 : 0) != 0, (Object)"The minimum parallelism must be larger than 0.");
        Preconditions.checkArgument((globalMaxParallelism >= globalMinParallelism ? 1 : 0) != 0, (Object)"Maximum parallelism should be greater than or equal to the minimum parallelism.");
        Preconditions.checkArgument((globalDefaultSourceParallelism > 0 ? 1 : 0) != 0, (Object)"The default source parallelism must be larger than 0.");
        Preconditions.checkNotNull((Object)dataVolumePerTask);
        Preconditions.checkArgument((skewedFactor > 0.0 ? 1 : 0) != 0, (Object)"The default skewed partition factor must be larger than 0.");
        Preconditions.checkArgument((skewedThreshold > 0L ? 1 : 0) != 0, (Object)"The default skewed threshold must be larger than 0.");
        this.globalMaxParallelism = globalMaxParallelism;
        this.globalMinParallelism = globalMinParallelism;
        this.dataVolumePerTask = dataVolumePerTask.getBytes();
        this.globalDefaultSourceParallelism = globalDefaultSourceParallelism;
        this.allToAllVertexInputInfoComputer = new AllToAllVertexInputInfoComputer(skewedFactor, skewedThreshold);
        this.pointwiseVertexInputInfoComputer = new PointwiseVertexInputInfoComputer();
    }

    @Override
    public ParallelismAndInputInfos decideParallelismAndInputInfosForVertex(JobVertexID jobVertexId, List<BlockingInputInfo> consumedResults, int vertexInitialParallelism, int vertexMinParallelism, int vertexMaxParallelism) {
        Preconditions.checkArgument((vertexInitialParallelism == -1 || vertexInitialParallelism > 0 ? 1 : 0) != 0);
        Preconditions.checkArgument((vertexMinParallelism == -1 || vertexMinParallelism > 0 ? 1 : 0) != 0);
        Preconditions.checkArgument((vertexMaxParallelism > 0 && vertexMaxParallelism >= vertexInitialParallelism && vertexMaxParallelism >= vertexMinParallelism ? 1 : 0) != 0);
        if (consumedResults.isEmpty()) {
            int parallelism = vertexInitialParallelism > 0 ? vertexInitialParallelism : this.computeSourceParallelismUpperBound(jobVertexId, vertexMaxParallelism);
            return new ParallelismAndInputInfos(parallelism, Collections.emptyMap());
        }
        Preconditions.checkArgument((vertexInitialParallelism == -1 ? 1 : 0) != 0);
        int minParallelism = Math.max(this.globalMinParallelism, vertexMinParallelism);
        int maxParallelism = this.globalMaxParallelism;
        if (vertexMaxParallelism < minParallelism) {
            LOG.info("The vertex maximum parallelism {} is smaller than the minimum parallelism {}. Use {} as the lower bound to decide parallelism of job vertex {}.", new Object[]{vertexMaxParallelism, minParallelism, vertexMaxParallelism, jobVertexId});
            minParallelism = vertexMaxParallelism;
        }
        if (vertexMaxParallelism < maxParallelism) {
            LOG.info("The vertex maximum parallelism {} is smaller than the global maximum parallelism {}. Use {} as the upper bound to decide parallelism of job vertex {}.", new Object[]{vertexMaxParallelism, maxParallelism, vertexMaxParallelism, jobVertexId});
            maxParallelism = vertexMaxParallelism;
        }
        Preconditions.checkState((maxParallelism >= minParallelism ? 1 : 0) != 0);
        return this.decideParallelismAndInputInfosForNonSource(jobVertexId, consumedResults, minParallelism, maxParallelism);
    }

    @Override
    public int computeSourceParallelismUpperBound(JobVertexID jobVertexId, int maxParallelism) {
        if (this.globalDefaultSourceParallelism > maxParallelism) {
            LOG.info("The global default source parallelism {} is larger than the maximum parallelism {}. Use {} as the upper bound parallelism of source job vertex {}.", new Object[]{this.globalDefaultSourceParallelism, maxParallelism, maxParallelism, jobVertexId});
            return maxParallelism;
        }
        return this.globalDefaultSourceParallelism;
    }

    @Override
    public long getDataVolumePerTask() {
        return this.dataVolumePerTask;
    }

    private ParallelismAndInputInfos decideParallelismAndInputInfosForNonSource(JobVertexID jobVertexId, List<BlockingInputInfo> consumedResults, int minParallelism, int maxParallelism) {
        int parallelism = this.decideParallelism(jobVertexId, consumedResults, minParallelism, maxParallelism);
        ArrayList<BlockingInputInfo> pointwiseInputs = new ArrayList<BlockingInputInfo>();
        ArrayList<BlockingInputInfo> allToAllInputs = new ArrayList<BlockingInputInfo>();
        consumedResults.forEach(inputInfo -> {
            if (inputInfo.isPointwise()) {
                pointwiseInputs.add((BlockingInputInfo)inputInfo);
            } else {
                allToAllInputs.add((BlockingInputInfo)inputInfo);
            }
        });
        if (!(pointwiseInputs.isEmpty() || allToAllInputs.isEmpty() || VertexParallelismAndInputInfosDeciderUtils.getNonBroadcastInputInfos(allToAllInputs).isEmpty())) {
            minParallelism = parallelism;
            maxParallelism = parallelism;
        }
        HashMap<IntermediateDataSetID, JobVertexInputInfo> vertexInputInfos = new HashMap<IntermediateDataSetID, JobVertexInputInfo>();
        if (!pointwiseInputs.isEmpty()) {
            vertexInputInfos.putAll(this.pointwiseVertexInputInfoComputer.compute(pointwiseInputs, parallelism, minParallelism, maxParallelism, VertexParallelismAndInputInfosDeciderUtils.calculateDataVolumePerTaskForInputsGroup(this.dataVolumePerTask, pointwiseInputs, consumedResults)));
            if (!allToAllInputs.isEmpty()) {
                minParallelism = parallelism = VertexParallelismAndInputInfosDeciderUtils.checkAndGetParallelism(vertexInputInfos.values());
                maxParallelism = parallelism;
            }
        }
        if (!allToAllInputs.isEmpty()) {
            vertexInputInfos.putAll(this.allToAllVertexInputInfoComputer.compute(jobVertexId, allToAllInputs, parallelism, minParallelism, maxParallelism, VertexParallelismAndInputInfosDeciderUtils.calculateDataVolumePerTaskForInputsGroup(this.dataVolumePerTask, allToAllInputs, consumedResults)));
        }
        for (BlockingInputInfo inputInfo2 : consumedResults) {
            VertexParallelismAndInputInfosDeciderUtils.logBalancedDataDistributionOptimizationResult(LOG, jobVertexId, inputInfo2, (JobVertexInputInfo)vertexInputInfos.get(inputInfo2.getResultId()));
        }
        return new ParallelismAndInputInfos(VertexParallelismAndInputInfosDeciderUtils.checkAndGetParallelism(vertexInputInfos.values()), vertexInputInfos);
    }

    int decideParallelism(JobVertexID jobVertexId, List<BlockingInputInfo> consumedResults, int minParallelism, int maxParallelism) {
        Preconditions.checkArgument((!consumedResults.isEmpty() ? 1 : 0) != 0);
        List<BlockingInputInfo> nonBroadcastResults = VertexParallelismAndInputInfosDeciderUtils.getNonBroadcastInputInfos(consumedResults);
        if (nonBroadcastResults.isEmpty()) {
            return minParallelism;
        }
        long totalBytes = nonBroadcastResults.stream().mapToLong(BlockingInputInfo::getNumBytesProduced).sum();
        int parallelism = (int)Math.ceil((double)totalBytes / (double)this.dataVolumePerTask);
        LOG.debug("The total size of non-broadcast data is {}, the initially decided parallelism of job vertex {} is {}.", new Object[]{new MemorySize(totalBytes), jobVertexId, parallelism});
        if (parallelism < minParallelism) {
            LOG.info("The initially decided parallelism {} is smaller than the minimum parallelism {}. Use {} as the finally decided parallelism of job vertex {}.", new Object[]{parallelism, minParallelism, minParallelism, jobVertexId});
            parallelism = minParallelism;
        } else if (parallelism > maxParallelism) {
            LOG.info("The initially decided parallelism {} is larger than the maximum parallelism {}. Use {} as the finally decided parallelism of job vertex {}.", new Object[]{parallelism, maxParallelism, maxParallelism, jobVertexId});
            parallelism = maxParallelism;
        }
        return parallelism;
    }

    static DefaultVertexParallelismAndInputInfosDecider from(int maxParallelism, double skewedFactor, long skewedThreshold, Configuration configuration) {
        return new DefaultVertexParallelismAndInputInfosDecider(maxParallelism, (Integer)configuration.get(BatchExecutionOptions.ADAPTIVE_AUTO_PARALLELISM_MIN_PARALLELISM), (MemorySize)configuration.get(BatchExecutionOptions.ADAPTIVE_AUTO_PARALLELISM_AVG_DATA_VOLUME_PER_TASK), (Integer)configuration.get(BatchExecutionOptions.ADAPTIVE_AUTO_PARALLELISM_DEFAULT_SOURCE_PARALLELISM, (Object)maxParallelism), skewedFactor, skewedThreshold);
    }
}

