/*
 * Decompiled with CFR 0.152.
 */
package org.apache.tez.dag.library.vertexmanager;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Predicate;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import java.io.ByteArrayInputStream;
import java.io.DataInput;
import java.io.DataInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Collections;
import java.util.Comparator;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.annotation.Nullable;
import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.classification.InterfaceStability;
import org.apache.hadoop.conf.Configuration;
import org.apache.tez.common.TezCommonUtils;
import org.apache.tez.common.TezUtils;
import org.apache.tez.dag.api.EdgeManagerPluginContext;
import org.apache.tez.dag.api.EdgeManagerPluginDescriptor;
import org.apache.tez.dag.api.EdgeManagerPluginOnDemand;
import org.apache.tez.dag.api.EdgeProperty;
import org.apache.tez.dag.api.InputDescriptor;
import org.apache.tez.dag.api.OutputDescriptor;
import org.apache.tez.dag.api.TezUncheckedException;
import org.apache.tez.dag.api.UserPayload;
import org.apache.tez.dag.api.VertexManagerPlugin;
import org.apache.tez.dag.api.VertexManagerPluginContext;
import org.apache.tez.dag.api.VertexManagerPluginDescriptor;
import org.apache.tez.dag.api.event.VertexState;
import org.apache.tez.dag.api.event.VertexStateUpdate;
import org.apache.tez.runtime.api.Event;
import org.apache.tez.runtime.api.TaskAttemptIdentifier;
import org.apache.tez.runtime.api.TaskIdentifier;
import org.apache.tez.runtime.api.events.DataMovementEvent;
import org.apache.tez.runtime.api.events.InputReadErrorEvent;
import org.apache.tez.runtime.api.events.VertexManagerEvent;
import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads;
import org.apache.tez.runtime.library.utils.DATA_RANGE_IN_MB;
import org.roaringbitmap.RoaringBitmap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@InterfaceAudience.Public
@InterfaceStability.Evolving
public class ShuffleVertexManager
extends VertexManagerPlugin {
    public static final String TEZ_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION = "tez.shuffle-vertex-manager.min-src-fraction";
    public static final float TEZ_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION_DEFAULT = 0.25f;
    public static final String TEZ_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION = "tez.shuffle-vertex-manager.max-src-fraction";
    public static final float TEZ_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION_DEFAULT = 0.75f;
    public static final String TEZ_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL = "tez.shuffle-vertex-manager.enable.auto-parallel";
    public static final boolean TEZ_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL_DEFAULT = false;
    public static final String TEZ_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE = "tez.shuffle-vertex-manager.desired-task-input-size";
    public static final long TEZ_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE_DEFAULT = 0x6400000L;
    public static final String TEZ_SHUFFLE_VERTEX_MANAGER_MIN_TASK_PARALLELISM = "tez.shuffle-vertex-manager.min-task-parallelism";
    public static final int TEZ_SHUFFLE_VERTEX_MANAGER_MIN_TASK_PARALLELISM_DEFAULT = 1;
    private static final Logger LOG = LoggerFactory.getLogger(ShuffleVertexManager.class);
    float slowStartMinSrcCompletionFraction;
    float slowStartMaxSrcCompletionFraction;
    long desiredTaskInputDataSize = 0x6400000L;
    int minTaskParallelism = 1;
    boolean enableAutoParallelism = false;
    boolean parallelismDetermined = false;
    int totalNumBipartiteSourceTasks = 0;
    int numBipartiteSourceTasksCompleted = 0;
    int numVertexManagerEventsReceived = 0;
    List<PendingTaskInfo> pendingTasks = Lists.newLinkedList();
    List<VertexManagerEvent> pendingVMEvents = Lists.newLinkedList();
    int totalTasksToSchedule = 0;
    private AtomicBoolean onVertexStartedDone = new AtomicBoolean(false);
    private Set<TaskIdentifier> taskWithVmEvents = Sets.newHashSet();
    private final Map<String, SourceVertexInfo> srcVertexInfo = Maps.newConcurrentMap();
    boolean sourceVerticesScheduled = false;
    @VisibleForTesting
    int bipartiteSources = 0;
    long completedSourceTasksOutputSize = 0L;
    List<VertexStateUpdate> pendingStateUpdates = Lists.newArrayList();
    private int[][] targetIndexes;
    private int basePartitionRange;
    private int remainderRangeForLastShuffler;
    @VisibleForTesting
    long[] stats;

    public ShuffleVertexManager(VertexManagerPluginContext context) {
        super(context);
    }

    static int[] createIndices(int partitionRange, int taskIndex, int offSetPerTask) {
        int startIndex = taskIndex * offSetPerTask;
        int[] indices = new int[partitionRange];
        for (int currentIndex = 0; currentIndex < partitionRange; ++currentIndex) {
            indices[currentIndex] = startIndex + currentIndex;
        }
        return indices;
    }

    public synchronized void onVertexStarted(List<TaskAttemptIdentifier> completions) {
        Map inputs = this.getContext().getInputVertexEdgeProperties();
        for (Map.Entry entry : inputs.entrySet()) {
            this.srcVertexInfo.put((String)entry.getKey(), new SourceVertexInfo((EdgeProperty)entry.getValue()));
            this.getContext().registerForVertexStateUpdates((String)entry.getKey(), EnumSet.of(VertexState.CONFIGURED));
            if (((EdgeProperty)entry.getValue()).getDataMovementType() != EdgeProperty.DataMovementType.SCATTER_GATHER) continue;
            ++this.bipartiteSources;
        }
        if (this.bipartiteSources == 0) {
            throw new TezUncheckedException("Atleast 1 bipartite source should exist");
        }
        for (VertexStateUpdate stateUpdate : this.pendingStateUpdates) {
            this.handleVertexStateUpdate(stateUpdate);
        }
        this.pendingStateUpdates.clear();
        for (VertexManagerEvent vmEvent : this.pendingVMEvents) {
            this.handleVertexManagerEvent(vmEvent);
        }
        this.pendingVMEvents.clear();
        this.updatePendingTasks();
        LOG.info("OnVertexStarted vertex: " + this.getContext().getVertexName() + " with " + this.totalNumBipartiteSourceTasks + " source tasks and " + this.totalTasksToSchedule + " pending tasks");
        if (completions != null) {
            for (TaskAttemptIdentifier attempt : completions) {
                this.onSourceTaskCompleted(attempt);
            }
        }
        this.onVertexStartedDone.set(true);
        this.schedulePendingTasks();
    }

    public synchronized void onSourceTaskCompleted(TaskAttemptIdentifier attempt) {
        BitSet completedSourceTasks;
        String srcVertexName = attempt.getTaskIdentifier().getVertexIdentifier().getName();
        int srcTaskId = attempt.getTaskIdentifier().getIdentifier();
        SourceVertexInfo srcInfo = this.srcVertexInfo.get(srcVertexName);
        if (srcInfo.vertexIsConfigured) {
            Preconditions.checkState((srcTaskId < srcInfo.numTasks ? 1 : 0) != 0, (Object)("Received completion for srcTaskId " + srcTaskId + " but Vertex: " + srcVertexName + " has only " + srcInfo.numTasks + " tasks"));
        }
        if (!(completedSourceTasks = srcInfo.finishedTaskSet).get(srcTaskId)) {
            completedSourceTasks.set(srcTaskId);
            if (srcInfo.edgeProperty.getDataMovementType() == EdgeProperty.DataMovementType.SCATTER_GATHER) {
                ++this.numBipartiteSourceTasksCompleted;
            }
        }
        this.schedulePendingTasks();
    }

    @VisibleForTesting
    void parsePartitionStats(RoaringBitmap partitionStats) {
        Preconditions.checkState((this.stats != null ? 1 : 0) != 0, (Object)"Stats should be initialized");
        Iterator it = partitionStats.iterator();
        DATA_RANGE_IN_MB[] RANGES = DATA_RANGE_IN_MB.values();
        int RANGE_LEN = RANGES.length;
        while (it.hasNext()) {
            int pos = (Integer)it.next();
            int index = pos / RANGE_LEN;
            int rangeIndex = pos % RANGE_LEN;
            if (RANGES[rangeIndex].getSizeInMB() <= 0) continue;
            int n = index;
            this.stats[n] = this.stats[n] + (long)RANGES[rangeIndex].getSizeInMB();
        }
    }

    public synchronized void onVertexManagerEventReceived(VertexManagerEvent vmEvent) {
        if (this.onVertexStartedDone.get()) {
            this.handleVertexManagerEvent(vmEvent);
        } else {
            this.pendingVMEvents.add(vmEvent);
        }
    }

    private void handleVertexManagerEvent(VertexManagerEvent vmEvent) {
        TaskIdentifier producerTask = vmEvent.getProducerAttemptIdentifier().getTaskIdentifier();
        if (!this.taskWithVmEvents.add(producerTask)) {
            LOG.info("Ignoring vertex manager event from: " + producerTask);
            return;
        }
        String vName = producerTask.getVertexIdentifier().getName();
        SourceVertexInfo srcInfo = this.srcVertexInfo.get(vName);
        Preconditions.checkState((srcInfo != null ? 1 : 0) != 0, (Object)("Unknown vmEvent from " + producerTask));
        ++this.numVertexManagerEventsReceived;
        long sourceTaskOutputSize = 0L;
        if (vmEvent.getUserPayload() != null) {
            ShuffleUserPayloads.VertexManagerEventPayloadProto proto;
            try {
                proto = ShuffleUserPayloads.VertexManagerEventPayloadProto.parseFrom(ByteString.copyFrom((ByteBuffer)vmEvent.getUserPayload()));
            }
            catch (InvalidProtocolBufferException e) {
                throw new TezUncheckedException((Throwable)e);
            }
            sourceTaskOutputSize = proto.getOutputSize();
            if (proto.hasPartitionStats()) {
                try {
                    RoaringBitmap partitionStats = new RoaringBitmap();
                    ByteString compressedPartitionStats = proto.getPartitionStats();
                    byte[] rawData = TezCommonUtils.decompressByteStringToByteArray((ByteString)compressedPartitionStats);
                    ByteArrayInputStream bin = new ByteArrayInputStream(rawData);
                    partitionStats.deserialize((DataInput)new DataInputStream(bin));
                    this.parsePartitionStats(partitionStats);
                }
                catch (IOException e) {
                    throw new TezUncheckedException((Throwable)e);
                }
            }
            ++srcInfo.numVMEventsReceived;
            srcInfo.outputSize += sourceTaskOutputSize;
            this.completedSourceTasksOutputSize += sourceTaskOutputSize;
        }
        if (LOG.isDebugEnabled()) {
            LOG.debug("For attempt: " + vmEvent.getProducerAttemptIdentifier() + " received info of output size: " + sourceTaskOutputSize + " vertex numEventsReceived: " + srcInfo.numVMEventsReceived + " vertex output size: " + srcInfo.outputSize + " total numEventsReceived: " + this.numVertexManagerEventsReceived + " total output size: " + this.completedSourceTasksOutputSize);
        }
    }

    void updatePendingTasks() {
        int tasks = this.getContext().getVertexNumTasks(this.getContext().getVertexName());
        if (tasks == this.pendingTasks.size() || tasks <= 0) {
            return;
        }
        this.pendingTasks.clear();
        for (int i = 0; i < tasks; ++i) {
            this.pendingTasks.add(new PendingTaskInfo(i));
        }
        this.totalTasksToSchedule = this.pendingTasks.size();
        if (this.stats == null) {
            this.stats = new long[this.totalTasksToSchedule];
        }
    }

    Iterable<Map.Entry<String, SourceVertexInfo>> getBipartiteInfo() {
        return Iterables.filter(this.srcVertexInfo.entrySet(), (Predicate)new Predicate<Map.Entry<String, SourceVertexInfo>>(){

            public boolean apply(Map.Entry<String, SourceVertexInfo> input) {
                return input.getValue().edgeProperty.getDataMovementType() == EdgeProperty.DataMovementType.SCATTER_GATHER;
            }
        });
    }

    @VisibleForTesting
    boolean determineParallelismAndApply(float minSourceVertexCompletedTaskFraction) {
        boolean canDetermineParallelismLater;
        if (this.numVertexManagerEventsReceived == 0 && this.totalNumBipartiteSourceTasks > 0) {
            return true;
        }
        int currentParallelism = this.pendingTasks.size();
        boolean bl = canDetermineParallelismLater = this.completedSourceTasksOutputSize < this.desiredTaskInputDataSize && minSourceVertexCompletedTaskFraction < this.slowStartMaxSrcCompletionFraction;
        if (canDetermineParallelismLater) {
            LOG.info("Defer scheduling tasks; vertex=" + this.getContext().getVertexName() + ", totalNumBipartiteSourceTasks=" + this.totalNumBipartiteSourceTasks + ", completedSourceTasksOutputSize=" + this.completedSourceTasksOutputSize + ", numVertexManagerEventsReceived=" + this.numVertexManagerEventsReceived + ", numBipartiteSourceTasksCompleted=" + this.numBipartiteSourceTasksCompleted + ", minSourceVertexCompletedTaskFraction=" + minSourceVertexCompletedTaskFraction);
            return false;
        }
        long expectedTotalSourceTasksOutputSize = 0L;
        for (Map.Entry<String, SourceVertexInfo> vInfo : this.getBipartiteInfo()) {
            SourceVertexInfo srcInfo = vInfo.getValue();
            if (srcInfo.numTasks <= 0 || srcInfo.numVMEventsReceived <= 0) continue;
            expectedTotalSourceTasksOutputSize += (long)srcInfo.numTasks * srcInfo.outputSize / (long)srcInfo.numVMEventsReceived;
        }
        LOG.info("Expected output: " + expectedTotalSourceTasksOutputSize + " based on actual output: " + this.completedSourceTasksOutputSize + " from " + this.numVertexManagerEventsReceived + " vertex manager events. " + " desiredTaskInputSize: " + this.desiredTaskInputDataSize + " max slow start tasks:" + (float)this.totalNumBipartiteSourceTasks * this.slowStartMaxSrcCompletionFraction + " num sources completed:" + this.numBipartiteSourceTasksCompleted);
        int desiredTaskParallelism = (int)((expectedTotalSourceTasksOutputSize + this.desiredTaskInputDataSize - 1L) / this.desiredTaskInputDataSize);
        if (desiredTaskParallelism < this.minTaskParallelism) {
            desiredTaskParallelism = this.minTaskParallelism;
        }
        if (desiredTaskParallelism >= currentParallelism) {
            LOG.info("Not reducing auto parallelism for vertex: " + this.getContext().getVertexName() + " since the desired parallelism of " + desiredTaskParallelism + " is greater than or equal to the current parallelism of " + this.pendingTasks.size());
            return true;
        }
        this.basePartitionRange = currentParallelism / desiredTaskParallelism;
        if (this.basePartitionRange <= 1) {
            LOG.info("Not reducing auto parallelism for vertex: " + this.getContext().getVertexName() + " by less than half since combining two inputs will potentially break the desired task input size of " + this.desiredTaskInputDataSize);
            return true;
        }
        int numShufflersWithBaseRange = currentParallelism / this.basePartitionRange;
        this.remainderRangeForLastShuffler = currentParallelism % this.basePartitionRange;
        int finalTaskParallelism = this.remainderRangeForLastShuffler > 0 ? numShufflersWithBaseRange + 1 : numShufflersWithBaseRange;
        LOG.info("Reducing auto parallelism for vertex: " + this.getContext().getVertexName() + " from " + this.pendingTasks.size() + " to " + finalTaskParallelism);
        if (finalTaskParallelism < currentParallelism) {
            HashMap<String, EdgeProperty> edgeProperties = new HashMap<String, EdgeProperty>(this.bipartiteSources);
            Iterable<Map.Entry<String, SourceVertexInfo>> bipartiteItr = this.getBipartiteInfo();
            for (Map.Entry<String, SourceVertexInfo> entry : bipartiteItr) {
                String vertex = entry.getKey();
                EdgeProperty oldEdgeProp = entry.getValue().edgeProperty;
                CustomShuffleEdgeManagerConfig edgeManagerConfig = new CustomShuffleEdgeManagerConfig(currentParallelism, finalTaskParallelism, this.basePartitionRange, this.remainderRangeForLastShuffler > 0 ? this.remainderRangeForLastShuffler : this.basePartitionRange);
                EdgeManagerPluginDescriptor edgeManagerDescriptor = EdgeManagerPluginDescriptor.create((String)CustomShuffleEdgeManager.class.getName());
                edgeManagerDescriptor.setUserPayload(edgeManagerConfig.toUserPayload());
                EdgeProperty newEdgeProp = EdgeProperty.create((EdgeManagerPluginDescriptor)edgeManagerDescriptor, (EdgeProperty.DataSourceType)oldEdgeProp.getDataSourceType(), (EdgeProperty.SchedulingType)oldEdgeProp.getSchedulingType(), (OutputDescriptor)oldEdgeProp.getEdgeSource(), (InputDescriptor)oldEdgeProp.getEdgeDestination());
                edgeProperties.put(vertex, newEdgeProp);
            }
            this.getContext().reconfigureVertex(finalTaskParallelism, null, edgeProperties);
            this.updatePendingTasks();
            this.configureTargetMapping(finalTaskParallelism);
        }
        return true;
    }

    void configureTargetMapping(int tasks) {
        this.targetIndexes = new int[tasks][];
        for (int idx = 0; idx < tasks; ++idx) {
            int partitionRange = this.basePartitionRange;
            if (idx == tasks - 1) {
                partitionRange = this.remainderRangeForLastShuffler > 0 ? this.remainderRangeForLastShuffler : this.basePartitionRange;
            }
            this.targetIndexes[idx] = ShuffleVertexManager.createIndices(partitionRange, idx, this.basePartitionRange);
            if (!LOG.isDebugEnabled()) continue;
            LOG.debug("targetIdx[" + idx + "] to " + Arrays.toString(this.targetIndexes[idx]));
        }
    }

    void schedulePendingTasks(int numTasksToSchedule, float minSourceVertexCompletedTaskFraction) {
        if (this.enableAutoParallelism && !this.parallelismDetermined) {
            this.parallelismDetermined = this.determineParallelismAndApply(minSourceVertexCompletedTaskFraction);
            if (!this.parallelismDetermined) {
                return;
            }
            this.getContext().doneReconfiguringVertex();
        }
        if (this.totalNumBipartiteSourceTasks > 0) {
            this.sortPendingTasksBasedOnDataSize();
        }
        ArrayList scheduledTasks = Lists.newArrayListWithCapacity((int)numTasksToSchedule);
        while (!this.pendingTasks.isEmpty() && numTasksToSchedule > 0) {
            --numTasksToSchedule;
            Integer taskIndex = this.pendingTasks.get(0).index;
            scheduledTasks.add(VertexManagerPluginContext.ScheduleTaskRequest.create((int)taskIndex, null));
            this.pendingTasks.remove(0);
        }
        this.getContext().scheduleTasks((List)scheduledTasks);
        if (this.pendingTasks.size() == 0) {
            // empty if block
        }
    }

    private void sortPendingTasksBasedOnDataSize() {
        boolean statsUpdated = this.computePartitionSizes();
        if (statsUpdated) {
            Collections.sort(this.pendingTasks, new Comparator<PendingTaskInfo>(){

                @Override
                public int compare(PendingTaskInfo left, PendingTaskInfo right) {
                    return left.outputStats > right.outputStats ? -1 : (left.outputStats == right.outputStats ? 0 : 1);
                }
            });
            if (LOG.isDebugEnabled()) {
                for (PendingTaskInfo pendingTask : this.pendingTasks) {
                    LOG.debug("Pending task:" + pendingTask.toString());
                }
            }
        }
    }

    private synchronized boolean computePartitionSizes() {
        boolean computedPartitionSizes = false;
        for (PendingTaskInfo taskInfo : this.pendingTasks) {
            int index = taskInfo.index;
            if (this.targetIndexes != null) {
                Preconditions.checkState((index < this.targetIndexes.length ? 1 : 0) != 0, (Object)("index=" + index + ", targetIndexes length=" + this.targetIndexes.length));
                int[] mapping = this.targetIndexes[index];
                long totalStats = 0L;
                for (int i : mapping) {
                    totalStats += this.stats[i];
                }
                if (totalStats <= 0L || taskInfo.outputStats == totalStats) continue;
                computedPartitionSizes = true;
                taskInfo.outputStats = totalStats;
                continue;
            }
            if (this.stats[index] <= 0L || this.stats[index] == taskInfo.outputStats) continue;
            computedPartitionSizes = true;
            taskInfo.outputStats = this.stats[index];
        }
        return computedPartitionSizes;
    }

    boolean canScheduleTasks() {
        for (Map.Entry<String, SourceVertexInfo> entry : this.srcVertexInfo.entrySet()) {
            if (entry.getValue().vertexIsConfigured) continue;
            if (LOG.isDebugEnabled()) {
                LOG.debug("Waiting for vertex: " + entry.getKey() + " in vertex: " + this.getContext().getVertexName());
            }
            return false;
        }
        this.sourceVerticesScheduled = true;
        return this.sourceVerticesScheduled;
    }

    void schedulePendingTasks() {
        if (!this.onVertexStartedDone.get()) {
            return;
        }
        int numPendingTasks = this.pendingTasks.size();
        if (numPendingTasks == 0) {
            return;
        }
        if (!this.sourceVerticesScheduled && !this.canScheduleTasks()) {
            if (LOG.isDebugEnabled()) {
                LOG.debug("Defer scheduling tasks for vertex:" + this.getContext().getVertexName() + " as one task needs to be completed per source vertex");
            }
            return;
        }
        if (this.numBipartiteSourceTasksCompleted == this.totalNumBipartiteSourceTasks && numPendingTasks > 0) {
            LOG.info("All source tasks assigned. Ramping up " + numPendingTasks + " remaining tasks for vertex: " + this.getContext().getVertexName());
            this.schedulePendingTasks(numPendingTasks, 1.0f);
            return;
        }
        float minSourceVertexCompletedTaskFraction = 1.0f;
        String minCompletedVertexName = "";
        for (Map.Entry<String, SourceVertexInfo> vInfo : this.getBipartiteInfo()) {
            int numCompletedTasks;
            float completedFraction;
            SourceVertexInfo srcInfo = vInfo.getValue();
            Preconditions.checkState((boolean)srcInfo.vertexIsConfigured, (Object)("Vertex: " + vInfo.getKey()));
            if (srcInfo.numTasks <= 0 || !(minSourceVertexCompletedTaskFraction > (completedFraction = (float)(numCompletedTasks = srcInfo.getNumCompletedTasks()) / (float)srcInfo.numTasks))) continue;
            minSourceVertexCompletedTaskFraction = completedFraction;
            minCompletedVertexName = vInfo.getKey();
        }
        float tasksFractionToSchedule = 1.0f;
        float percentRange = this.slowStartMaxSrcCompletionFraction - this.slowStartMinSrcCompletionFraction;
        if (percentRange > 0.0f) {
            tasksFractionToSchedule = (minSourceVertexCompletedTaskFraction - this.slowStartMinSrcCompletionFraction) / percentRange;
        } else if (minSourceVertexCompletedTaskFraction < this.slowStartMinSrcCompletionFraction) {
            tasksFractionToSchedule = 0.0f;
        }
        tasksFractionToSchedule = Math.max(0.0f, Math.min(1.0f, tasksFractionToSchedule));
        int numTasksToSchedule = (int)(tasksFractionToSchedule * (float)this.totalTasksToSchedule) - (this.totalTasksToSchedule - numPendingTasks);
        if (numTasksToSchedule > 0) {
            LOG.info("Scheduling " + numTasksToSchedule + " tasks for vertex: " + this.getContext().getVertexName() + " with totalTasks: " + this.totalTasksToSchedule + ". " + this.numBipartiteSourceTasksCompleted + " source tasks completed out of " + this.totalNumBipartiteSourceTasks + ". MinSourceTaskCompletedFraction: " + minSourceVertexCompletedTaskFraction + " in Vertex: " + minCompletedVertexName + " min: " + this.slowStartMinSrcCompletionFraction + " max: " + this.slowStartMaxSrcCompletionFraction);
            this.schedulePendingTasks(numTasksToSchedule, minSourceVertexCompletedTaskFraction);
        }
    }

    public void initialize() {
        Configuration conf;
        try {
            conf = TezUtils.createConfFromUserPayload((UserPayload)this.getContext().getUserPayload());
        }
        catch (IOException e) {
            throw new TezUncheckedException((Throwable)e);
        }
        this.slowStartMinSrcCompletionFraction = conf.getFloat(TEZ_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION, 0.25f);
        float defaultSlowStartMaxSrcFraction = 0.75f;
        if (this.slowStartMinSrcCompletionFraction > defaultSlowStartMaxSrcFraction) {
            defaultSlowStartMaxSrcFraction = this.slowStartMinSrcCompletionFraction;
        }
        this.slowStartMaxSrcCompletionFraction = conf.getFloat(TEZ_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION, defaultSlowStartMaxSrcFraction);
        if (this.slowStartMinSrcCompletionFraction < 0.0f || this.slowStartMaxSrcCompletionFraction > 1.0f || this.slowStartMaxSrcCompletionFraction < this.slowStartMinSrcCompletionFraction) {
            throw new IllegalArgumentException("Invalid values for slowStartMinSrcCompletionFraction/slowStartMaxSrcCompletionFraction. Min cannot be < 0, max cannot be > 1, and max cannot be < min., configuredMin=" + this.slowStartMinSrcCompletionFraction + ", configuredMax=" + this.slowStartMaxSrcCompletionFraction);
        }
        this.enableAutoParallelism = conf.getBoolean(TEZ_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL, false);
        this.desiredTaskInputDataSize = conf.getLong(TEZ_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE, 0x6400000L);
        this.minTaskParallelism = Math.max(1, conf.getInt(TEZ_SHUFFLE_VERTEX_MANAGER_MIN_TASK_PARALLELISM, 1));
        LOG.info("Shuffle Vertex Manager: settings minFrac:" + this.slowStartMinSrcCompletionFraction + " maxFrac:" + this.slowStartMaxSrcCompletionFraction + " auto:" + this.enableAutoParallelism + " desiredTaskIput:" + this.desiredTaskInputDataSize + " minTasks:" + this.minTaskParallelism);
        this.updatePendingTasks();
        if (this.enableAutoParallelism) {
            this.getContext().vertexReconfigurationPlanned();
        }
    }

    private void handleVertexStateUpdate(VertexStateUpdate stateUpdate) {
        Preconditions.checkArgument((stateUpdate.getVertexState() == VertexState.CONFIGURED ? 1 : 0) != 0, (Object)("Received incorrect state notification : " + stateUpdate.getVertexState() + " for vertex: " + stateUpdate.getVertexName() + " in vertex: " + this.getContext().getVertexName()));
        Preconditions.checkArgument((boolean)this.srcVertexInfo.containsKey(stateUpdate.getVertexName()), (Object)("Received incorrect vertex notification : " + stateUpdate.getVertexState() + " for vertex: " + stateUpdate.getVertexName() + " in vertex: " + this.getContext().getVertexName()));
        SourceVertexInfo vInfo = this.srcVertexInfo.get(stateUpdate.getVertexName());
        Preconditions.checkState((!vInfo.vertexIsConfigured ? 1 : 0) != 0);
        vInfo.vertexIsConfigured = true;
        vInfo.numTasks = this.getContext().getVertexNumTasks(stateUpdate.getVertexName());
        if (vInfo.edgeProperty.getDataMovementType() == EdgeProperty.DataMovementType.SCATTER_GATHER) {
            this.totalNumBipartiteSourceTasks += vInfo.numTasks;
        }
        LOG.info("Received configured notification : " + stateUpdate.getVertexState() + " for vertex: " + stateUpdate.getVertexName() + " in vertex: " + this.getContext().getVertexName() + " numBipartiteSourceTasks: " + this.totalNumBipartiteSourceTasks);
        this.schedulePendingTasks();
    }

    public synchronized void onVertexStateUpdated(VertexStateUpdate stateUpdate) {
        if (stateUpdate.getVertexState() == VertexState.CONFIGURED) {
            if (this.onVertexStartedDone.get()) {
                this.handleVertexStateUpdate(stateUpdate);
            } else {
                this.pendingStateUpdates.add(stateUpdate);
            }
        }
    }

    public synchronized void onRootVertexInitialized(String inputName, InputDescriptor inputDescriptor, List<Event> events) {
    }

    public static ShuffleVertexManagerConfigBuilder createConfigBuilder(@Nullable Configuration conf) {
        return new ShuffleVertexManagerConfigBuilder(conf);
    }

    public static final class ShuffleVertexManagerConfigBuilder {
        private final Configuration conf;

        private ShuffleVertexManagerConfigBuilder(@Nullable Configuration conf) {
            this.conf = conf == null ? new Configuration(false) : conf;
        }

        public ShuffleVertexManagerConfigBuilder setAutoReduceParallelism(boolean enabled) {
            this.conf.setBoolean(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL, enabled);
            return this;
        }

        public ShuffleVertexManagerConfigBuilder setSlowStartMinSrcCompletionFraction(float minFraction) {
            this.conf.setFloat(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION, minFraction);
            return this;
        }

        public ShuffleVertexManagerConfigBuilder setSlowStartMaxSrcCompletionFraction(float maxFraction) {
            this.conf.setFloat(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION, maxFraction);
            return this;
        }

        public ShuffleVertexManagerConfigBuilder setDesiredTaskInputSize(long desiredTaskInputSize) {
            this.conf.setLong(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE, desiredTaskInputSize);
            return this;
        }

        public ShuffleVertexManagerConfigBuilder setMinTaskParallelism(int minTaskParallelism) {
            this.conf.setInt(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_MIN_TASK_PARALLELISM, minTaskParallelism);
            return this;
        }

        public VertexManagerPluginDescriptor build() {
            VertexManagerPluginDescriptor desc = VertexManagerPluginDescriptor.create((String)ShuffleVertexManager.class.getName());
            try {
                return (VertexManagerPluginDescriptor)desc.setUserPayload(TezUtils.createUserPayloadFromConf((Configuration)this.conf));
            }
            catch (IOException e) {
                throw new TezUncheckedException((Throwable)e);
            }
        }
    }

    private static class CustomShuffleEdgeManagerConfig {
        int numSourceTaskOutputs;
        int numDestinationTasks;
        int basePartitionRange;
        int remainderRangeForLastShuffler;

        private CustomShuffleEdgeManagerConfig(int numSourceTaskOutputs, int numDestinationTasks, int basePartitionRange, int remainderRangeForLastShuffler) {
            this.numSourceTaskOutputs = numSourceTaskOutputs;
            this.numDestinationTasks = numDestinationTasks;
            this.basePartitionRange = basePartitionRange;
            this.remainderRangeForLastShuffler = remainderRangeForLastShuffler;
        }

        public UserPayload toUserPayload() {
            return UserPayload.create((ByteBuffer)ByteBuffer.wrap(ShuffleUserPayloads.ShuffleEdgeManagerConfigPayloadProto.newBuilder().setNumSourceTaskOutputs(this.numSourceTaskOutputs).setNumDestinationTasks(this.numDestinationTasks).setBasePartitionRange(this.basePartitionRange).setRemainderRangeForLastShuffler(this.remainderRangeForLastShuffler).build().toByteArray()));
        }

        public static CustomShuffleEdgeManagerConfig fromUserPayload(UserPayload payload) throws InvalidProtocolBufferException {
            ShuffleUserPayloads.ShuffleEdgeManagerConfigPayloadProto proto = ShuffleUserPayloads.ShuffleEdgeManagerConfigPayloadProto.parseFrom(ByteString.copyFrom((ByteBuffer)payload.getPayload()));
            return new CustomShuffleEdgeManagerConfig(proto.getNumSourceTaskOutputs(), proto.getNumDestinationTasks(), proto.getBasePartitionRange(), proto.getRemainderRangeForLastShuffler());
        }
    }

    public static class CustomShuffleEdgeManager
    extends EdgeManagerPluginOnDemand {
        int numSourceTaskOutputs;
        int numDestinationTasks;
        int basePartitionRange;
        int remainderRangeForLastShuffler;
        int numSourceTasks;
        int[][] sourceIndices;
        int[][] targetIndices;

        public CustomShuffleEdgeManager(EdgeManagerPluginContext context) {
            super(context);
        }

        public void initialize() {
            CustomShuffleEdgeManagerConfig config;
            UserPayload userPayload = this.getContext().getUserPayload();
            if (userPayload == null || userPayload.getPayload() == null || userPayload.getPayload().limit() == 0) {
                throw new RuntimeException("Could not initialize CustomShuffleEdgeManager from provided user payload");
            }
            try {
                config = CustomShuffleEdgeManagerConfig.fromUserPayload(userPayload);
            }
            catch (InvalidProtocolBufferException e) {
                throw new RuntimeException("Could not initialize CustomShuffleEdgeManager from provided user payload", e);
            }
            this.numSourceTaskOutputs = config.numSourceTaskOutputs;
            this.numDestinationTasks = config.numDestinationTasks;
            this.basePartitionRange = config.basePartitionRange;
            this.remainderRangeForLastShuffler = config.remainderRangeForLastShuffler;
            this.numSourceTasks = this.getContext().getSourceVertexNumTasks();
            Preconditions.checkState((this.numDestinationTasks == this.getContext().getDestinationVertexNumTasks() ? 1 : 0) != 0);
        }

        public int getNumDestinationTaskPhysicalInputs(int destinationTaskIndex) {
            int partitionRange = 1;
            partitionRange = destinationTaskIndex < this.numDestinationTasks - 1 ? this.basePartitionRange : this.remainderRangeForLastShuffler;
            return this.numSourceTasks * partitionRange;
        }

        public int getNumSourceTaskPhysicalOutputs(int sourceTaskIndex) {
            return this.numSourceTaskOutputs;
        }

        public void routeDataMovementEventToDestination(DataMovementEvent event, int sourceTaskIndex, int sourceOutputIndex, Map<Integer, List<Integer>> destinationTaskAndInputIndices) {
            int sourceIndex = event.getSourceIndex();
            int destinationTaskIndex = sourceIndex / this.basePartitionRange;
            int partitionRange = 1;
            partitionRange = destinationTaskIndex < this.numDestinationTasks - 1 ? this.basePartitionRange : this.remainderRangeForLastShuffler;
            int targetIndex = sourceTaskIndex * partitionRange + sourceIndex % partitionRange;
            destinationTaskAndInputIndices.put(destinationTaskIndex, Collections.singletonList(targetIndex));
        }

        public EdgeManagerPluginOnDemand.EventRouteMetadata routeDataMovementEventToDestination(int sourceTaskIndex, int sourceOutputIndex, int destTaskIndex) throws Exception {
            int sourceIndex = sourceOutputIndex;
            int destinationTaskIndex = sourceIndex / this.basePartitionRange;
            if (destinationTaskIndex != destTaskIndex) {
                return null;
            }
            int partitionRange = 1;
            partitionRange = destinationTaskIndex < this.numDestinationTasks - 1 ? this.basePartitionRange : this.remainderRangeForLastShuffler;
            int targetIndex = sourceTaskIndex * partitionRange + sourceIndex % partitionRange;
            return EdgeManagerPluginOnDemand.EventRouteMetadata.create((int)1, (int[])new int[]{targetIndex});
        }

        public void prepareForRouting() throws Exception {
            int numSourceTasks = this.getContext().getSourceVertexNumTasks();
            this.targetIndices = new int[numSourceTasks][];
            for (int srcTaskIndex = 0; srcTaskIndex < numSourceTasks; ++srcTaskIndex) {
                this.targetIndices[srcTaskIndex] = ShuffleVertexManager.createIndices(this.basePartitionRange, srcTaskIndex, this.basePartitionRange);
            }
            int numTargetTasks = this.getContext().getDestinationVertexNumTasks();
            this.sourceIndices = new int[numTargetTasks][];
            for (int destTaskIndex = 0; destTaskIndex < numTargetTasks; ++destTaskIndex) {
                int partitionRange = this.basePartitionRange;
                if (destTaskIndex == numTargetTasks - 1) {
                    partitionRange = this.remainderRangeForLastShuffler;
                }
                this.sourceIndices[destTaskIndex] = ShuffleVertexManager.createIndices(partitionRange, destTaskIndex, this.basePartitionRange);
            }
        }

        private int[] createTargetIndicesForRemainder(int srcTaskIndex) {
            return ShuffleVertexManager.createIndices(this.remainderRangeForLastShuffler, srcTaskIndex, this.remainderRangeForLastShuffler);
        }

        @Nullable
        public EdgeManagerPluginOnDemand.EventRouteMetadata routeCompositeDataMovementEventToDestination(int sourceTaskIndex, int destinationTaskIndex) throws Exception {
            int partitionRange;
            int[] targetIndicesToSend;
            if (destinationTaskIndex == this.numDestinationTasks - 1) {
                targetIndicesToSend = this.remainderRangeForLastShuffler != this.basePartitionRange ? this.createTargetIndicesForRemainder(sourceTaskIndex) : this.targetIndices[sourceTaskIndex];
                partitionRange = this.remainderRangeForLastShuffler;
            } else {
                targetIndicesToSend = this.targetIndices[sourceTaskIndex];
                partitionRange = this.basePartitionRange;
            }
            return EdgeManagerPluginOnDemand.EventRouteMetadata.create((int)partitionRange, (int[])targetIndicesToSend, (int[])this.sourceIndices[destinationTaskIndex]);
        }

        public EdgeManagerPluginOnDemand.EventRouteMetadata routeInputSourceTaskFailedEventToDestination(int sourceTaskIndex, int destinationTaskIndex) throws Exception {
            int partitionRange = this.basePartitionRange;
            if (destinationTaskIndex == this.numDestinationTasks - 1) {
                partitionRange = this.remainderRangeForLastShuffler;
            }
            int startOffset = sourceTaskIndex * partitionRange;
            int[] targetIndices = new int[partitionRange];
            for (int i = 0; i < partitionRange; ++i) {
                targetIndices[i] = startOffset + i;
            }
            return EdgeManagerPluginOnDemand.EventRouteMetadata.create((int)partitionRange, (int[])targetIndices);
        }

        public void routeInputSourceTaskFailedEventToDestination(int sourceTaskIndex, Map<Integer, List<Integer>> destinationTaskAndInputIndices) {
            if (this.remainderRangeForLastShuffler < this.basePartitionRange) {
                int i;
                int startOffset = sourceTaskIndex * this.basePartitionRange;
                ArrayList allIndices = Lists.newArrayListWithCapacity((int)this.basePartitionRange);
                for (int i2 = 0; i2 < this.basePartitionRange; ++i2) {
                    allIndices.add(startOffset + i2);
                }
                List inputIndices = Collections.unmodifiableList(allIndices);
                for (i = 0; i < this.numDestinationTasks - 1; ++i) {
                    destinationTaskAndInputIndices.put(i, inputIndices);
                }
                startOffset = sourceTaskIndex * this.remainderRangeForLastShuffler;
                allIndices = Lists.newArrayListWithCapacity((int)this.remainderRangeForLastShuffler);
                for (i = 0; i < this.remainderRangeForLastShuffler; ++i) {
                    allIndices.add(startOffset + i);
                }
                inputIndices = Collections.unmodifiableList(allIndices);
                destinationTaskAndInputIndices.put(this.numDestinationTasks - 1, inputIndices);
            } else {
                int startOffset = sourceTaskIndex * this.basePartitionRange;
                ArrayList allIndices = Lists.newArrayListWithCapacity((int)this.basePartitionRange);
                for (int i = 0; i < this.basePartitionRange; ++i) {
                    allIndices.add(startOffset + i);
                }
                List inputIndices = Collections.unmodifiableList(allIndices);
                for (int i = 0; i < this.numDestinationTasks; ++i) {
                    destinationTaskAndInputIndices.put(i, inputIndices);
                }
            }
        }

        public int routeInputErrorEventToSource(InputReadErrorEvent event, int destinationTaskIndex, int destinationFailedInputIndex) {
            int partitionRange = 1;
            partitionRange = destinationTaskIndex < this.numDestinationTasks - 1 ? this.basePartitionRange : this.remainderRangeForLastShuffler;
            return destinationFailedInputIndex / partitionRange;
        }

        public int routeInputErrorEventToSource(int destinationTaskIndex, int destinationFailedInputIndex) {
            int partitionRange = 1;
            partitionRange = destinationTaskIndex < this.numDestinationTasks - 1 ? this.basePartitionRange : this.remainderRangeForLastShuffler;
            return destinationFailedInputIndex / partitionRange;
        }

        public int getNumDestinationConsumerTasks(int sourceTaskIndex) {
            return this.numDestinationTasks;
        }
    }

    static class PendingTaskInfo {
        private int index;
        private long outputStats;

        public PendingTaskInfo(int index) {
            this.index = index;
        }

        public String toString() {
            return "[index=" + this.index + ", outputStats=" + this.outputStats + "]";
        }
    }

    static class SourceVertexInfo {
        EdgeProperty edgeProperty;
        boolean vertexIsConfigured;
        BitSet finishedTaskSet;
        int numTasks;
        int numVMEventsReceived;
        long outputSize;

        SourceVertexInfo(EdgeProperty edgeProperty) {
            this.edgeProperty = edgeProperty;
            this.finishedTaskSet = new BitSet();
        }

        int getNumTasks() {
            return this.numTasks;
        }

        int getNumCompletedTasks() {
            return this.finishedTaskSet.cardinality();
        }
    }
}

