/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.scheduler.adaptive.allocator;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import org.apache.flink.annotation.Internal;
import org.apache.flink.runtime.clusterframework.types.AllocationID;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup;
import org.apache.flink.runtime.jobmaster.SlotInfo;
import org.apache.flink.runtime.scheduler.adaptive.JobSchedulingPlan;
import org.apache.flink.runtime.scheduler.adaptive.allocator.AllocatorUtil;
import org.apache.flink.runtime.scheduler.adaptive.allocator.JobAllocationsInformation;
import org.apache.flink.runtime.scheduler.adaptive.allocator.JobInformation;
import org.apache.flink.runtime.scheduler.adaptive.allocator.SlotAssigner;
import org.apache.flink.runtime.scheduler.adaptive.allocator.SlotSharingSlotAllocator;
import org.apache.flink.runtime.scheduler.adaptive.allocator.VertexParallelism;
import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.util.Preconditions;

@Internal
public class StateLocalitySlotAssigner
implements SlotAssigner {
    @Override
    public Collection<JobSchedulingPlan.SlotAssignment> assignSlots(JobInformation jobInformation, Collection<? extends SlotInfo> freeSlots, VertexParallelism vertexParallelism, JobAllocationsInformation previousAllocations) {
        AllocationScore score;
        AllocatorUtil.checkMinimumRequiredSlots(jobInformation, freeSlots);
        ArrayList<SlotSharingSlotAllocator.ExecutionSlotSharingGroup> allGroups = new ArrayList<SlotSharingSlotAllocator.ExecutionSlotSharingGroup>();
        for (SlotSharingGroup slotSharingGroup : jobInformation.getSlotSharingGroups()) {
            allGroups.addAll(AllocatorUtil.createExecutionSlotSharingGroups(vertexParallelism, slotSharingGroup));
        }
        Map<JobVertexID, Integer> parallelism = StateLocalitySlotAssigner.getParallelism(allGroups);
        PriorityQueue<AllocationScore> scores = this.calculateScores(jobInformation, previousAllocations, allGroups, parallelism);
        Map groupsById = allGroups.stream().collect(Collectors.toMap(SlotSharingSlotAllocator.ExecutionSlotSharingGroup::getId, Function.identity()));
        Map slotsById = freeSlots.stream().collect(Collectors.toMap(SlotInfo::getAllocationId, Function.identity()));
        ArrayList<JobSchedulingPlan.SlotAssignment> assignments = new ArrayList<JobSchedulingPlan.SlotAssignment>();
        while ((score = scores.poll()) != null) {
            if (!slotsById.containsKey(score.getAllocationId()) || !groupsById.containsKey(score.getGroupId())) continue;
            assignments.add(new JobSchedulingPlan.SlotAssignment((SlotInfo)slotsById.remove(score.getAllocationId()), groupsById.remove(score.getGroupId())));
        }
        Iterator remainingSlots = slotsById.values().iterator();
        for (SlotSharingSlotAllocator.ExecutionSlotSharingGroup group : groupsById.values()) {
            Preconditions.checkState(remainingSlots.hasNext(), "No slots available for group %s (%s more in total). This is likely a bug.", group, groupsById.size());
            assignments.add(new JobSchedulingPlan.SlotAssignment((SlotInfo)remainingSlots.next(), group));
            remainingSlots.remove();
        }
        return assignments;
    }

    @Nonnull
    private PriorityQueue<AllocationScore> calculateScores(JobInformation jobInformation, JobAllocationsInformation previousAllocations, List<SlotSharingSlotAllocator.ExecutionSlotSharingGroup> allGroups, Map<JobVertexID, Integer> parallelism) {
        PriorityQueue<AllocationScore> scores = new PriorityQueue<AllocationScore>(Comparator.reverseOrder());
        for (SlotSharingSlotAllocator.ExecutionSlotSharingGroup group : allGroups) {
            scores.addAll(this.calculateScore(group, parallelism, jobInformation, previousAllocations));
        }
        return scores;
    }

    private static Map<JobVertexID, Integer> getParallelism(List<SlotSharingSlotAllocator.ExecutionSlotSharingGroup> groups) {
        HashMap<JobVertexID, Integer> parallelism = new HashMap<JobVertexID, Integer>();
        for (SlotSharingSlotAllocator.ExecutionSlotSharingGroup group : groups) {
            for (ExecutionVertexID evi : group.getContainedExecutionVertices()) {
                parallelism.merge(evi.getJobVertexId(), 1, Integer::sum);
            }
        }
        return parallelism;
    }

    public Collection<AllocationScore> calculateScore(SlotSharingSlotAllocator.ExecutionSlotSharingGroup group, Map<JobVertexID, Integer> parallelism, JobInformation jobInformation, JobAllocationsInformation previousAllocations) {
        HashMap score = new HashMap();
        for (ExecutionVertexID evi : group.getContainedExecutionVertices()) {
            KeyGroupRange kgr = KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(jobInformation.getVertexInformation(evi.getJobVertexId()).getMaxParallelism(), parallelism.get(evi.getJobVertexId()), evi.getSubtaskIndex());
            previousAllocations.getAllocations(evi.getJobVertexId()).forEach(allocation -> {
                long value = allocation.getKeyGroupRange().getIntersection(kgr).getNumberOfKeyGroups();
                if (value > 0L) {
                    score.merge(allocation.getAllocationID(), value, Long::sum);
                }
            });
        }
        return score.entrySet().stream().map(e -> new AllocationScore(group.getId(), (AllocationID)e.getKey(), (Long)e.getValue())).collect(Collectors.toList());
    }

    private static class AllocationScore
    implements Comparable<AllocationScore> {
        private final String groupId;
        private final AllocationID allocationId;
        private final long score;

        public AllocationScore(String groupId, AllocationID allocationId, long score) {
            this.groupId = groupId;
            this.allocationId = allocationId;
            this.score = score;
        }

        public String getGroupId() {
            return this.groupId;
        }

        public AllocationID getAllocationId() {
            return this.allocationId;
        }

        public long getScore() {
            return this.score;
        }

        @Override
        public int compareTo(AllocationScore other) {
            int result = Long.compare(this.score, other.score);
            if (result != 0) {
                return result;
            }
            result = other.allocationId.compareTo(this.allocationId);
            if (result != 0) {
                return result;
            }
            return other.groupId.compareTo(this.groupId);
        }
    }
}

