/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.io.network.partition;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Optional;
import java.util.stream.IntStream;
import org.apache.flink.core.memory.MemorySegmentProvider;
import org.apache.flink.runtime.event.AbstractEvent;
import org.apache.flink.runtime.io.network.ConnectionManager;
import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.BufferBuilderTestUtils;
import org.apache.flink.runtime.io.network.buffer.BufferConsumer;
import org.apache.flink.runtime.io.network.buffer.BufferPool;
import org.apache.flink.runtime.io.network.buffer.NoOpBufferPool;
import org.apache.flink.runtime.io.network.partition.InputChannelTestUtils;
import org.apache.flink.runtime.io.network.partition.PipelinedResultPartition;
import org.apache.flink.runtime.io.network.partition.PipelinedSubpartition;
import org.apache.flink.runtime.io.network.partition.PrioritizedDeque;
import org.apache.flink.runtime.io.network.partition.ResultPartitionBuilder;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
import org.apache.flink.runtime.io.network.partition.consumer.InputChannel;
import org.apache.flink.runtime.io.network.partition.consumer.InputChannelBuilder;
import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateBuilder;
import org.apache.flink.runtime.io.network.util.TestBufferFactory;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.throughput.ThroughputCalculator;
import org.apache.flink.util.clock.Clock;
import org.apache.flink.util.clock.SystemClock;
import org.apache.flink.util.function.SupplierWithException;
import org.assertj.core.api.AbstractCollectionAssert;
import org.assertj.core.api.AbstractIntegerAssert;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

public class InputGateFairnessTest {
    @Test
    void testFairConsumptionLocalChannelsPreFilled() throws Exception {
        PipelinedSubpartition[] sources;
        int numberOfChannels = 37;
        int buffersPerChannel = 27;
        PipelinedResultPartition[] resultPartitions = (PipelinedResultPartition[])IntStream.range(0, 37).mapToObj(i -> (PipelinedResultPartition)new ResultPartitionBuilder().build()).toArray(PipelinedResultPartition[]::new);
        BufferConsumer bufferConsumer = BufferBuilderTestUtils.createFilledFinishedBufferConsumer(42);
        for (PipelinedSubpartition pipelinedSubpartition : sources = (PipelinedSubpartition[])Arrays.stream(resultPartitions).map(resultPartition -> resultPartition.getAllPartitions()[0]).toArray(PipelinedSubpartition[]::new)) {
            for (int p = 0; p < 27; ++p) {
                pipelinedSubpartition.add(bufferConsumer.copy());
            }
            pipelinedSubpartition.finish();
        }
        for (PipelinedSubpartition pipelinedSubpartition : resultPartitions) {
            pipelinedSubpartition.setup();
        }
        SingleInputGate gate = this.createFairnessVerifyingInputGate(37);
        InputChannel[] inputChannels = (InputChannel[])IntStream.range(0, 37).mapToObj(i -> InputChannelBuilder.newBuilder().setChannelIndex(i).setPartitionManager(resultPartitions[i].partitionManager).setPartitionId(resultPartitions[i].getPartitionId()).buildLocalChannel(gate)).toArray(InputChannel[]::new);
        InputGateFairnessTest.setupInputGate(gate, inputChannels);
        for (int i2 = 1036; i2 > 0; --i2) {
            Assertions.assertThat((Optional)gate.getNext()).isNotNull();
            int n = Integer.MAX_VALUE;
            int max = 0;
            for (PipelinedSubpartition source : sources) {
                int size = source.getNumberOfQueuedBuffers();
                n = Math.min(n, size);
                max = Math.max(max, size);
            }
            Assertions.assertThat((max == n || max == n + 1 ? 1 : 0) != 0).isTrue();
        }
        Assertions.assertThat((Optional)gate.getNext()).isNotPresent();
    }

    @Test
    void testFairConsumptionLocalChannels() throws Exception {
        int numberOfChannels = 37;
        int buffersPerChannel = 27;
        PipelinedResultPartition[] resultPartitions = (PipelinedResultPartition[])IntStream.range(0, 37).mapToObj(i -> (PipelinedResultPartition)new ResultPartitionBuilder().build()).toArray(PipelinedResultPartition[]::new);
        try (BufferConsumer bufferConsumer = BufferBuilderTestUtils.createFilledFinishedBufferConsumer(42);){
            PipelinedSubpartition[] sources = (PipelinedSubpartition[])Arrays.stream(resultPartitions).map(resultPartition -> resultPartition.getAllPartitions()[0]).toArray(PipelinedSubpartition[]::new);
            SingleInputGate gate = this.createFairnessVerifyingInputGate(37);
            InputChannel[] inputChannels = (InputChannel[])IntStream.range(0, 37).mapToObj(i -> InputChannelBuilder.newBuilder().setChannelIndex(i).setPartitionManager(resultPartitions[i].partitionManager).setPartitionId(resultPartitions[i].getPartitionId()).buildLocalChannel(gate)).toArray(InputChannel[]::new);
            for (PipelinedResultPartition rp : resultPartitions) {
                rp.setup();
            }
            sources[12].add(bufferConsumer.copy());
            InputGateFairnessTest.setupInputGate(gate, inputChannels);
            for (int i2 = 0; i2 < 999; ++i2) {
                Assertions.assertThat((Optional)gate.getNext()).isNotNull();
                int min = Integer.MAX_VALUE;
                int max = 0;
                for (PipelinedSubpartition source : sources) {
                    int size = source.getNumberOfQueuedBuffers();
                    min = Math.min(min, size);
                    max = Math.max(max, size);
                }
                Assertions.assertThat((max == min || max == min + 1 ? 1 : 0) != 0).isTrue();
                if (i2 % 74 != 0) continue;
                this.fillRandom(sources, 3, bufferConsumer);
            }
        }
    }

    @Test
    void testFairConsumptionRemoteChannelsPreFilled() throws Exception {
        int i;
        int numberOfChannels = 37;
        int buffersPerChannel = 27;
        Buffer mockBuffer = TestBufferFactory.createBuffer(42);
        SingleInputGate gate = this.createFairnessVerifyingInputGate(37);
        ConnectionManager connManager = InputChannelTestUtils.createDummyConnectionManager();
        RemoteInputChannel[] channels = new RemoteInputChannel[37];
        for (i = 0; i < 37; ++i) {
            RemoteInputChannel channel;
            channels[i] = channel = InputGateFairnessTest.createRemoteInputChannel(gate, i, connManager);
            for (int p = 0; p < 27; ++p) {
                channel.onBuffer(mockBuffer, p, -1, 0);
            }
            channel.onBuffer(EventSerializer.toBuffer((AbstractEvent)EndOfPartitionEvent.INSTANCE, (boolean)false), 27, -1, 0);
        }
        gate.setInputChannels((InputChannel[])channels);
        gate.setup();
        gate.requestPartitions();
        for (i = 1036; i > 0; --i) {
            Assertions.assertThat((Optional)gate.getNext()).isNotNull();
            int min = Integer.MAX_VALUE;
            int max = 0;
            for (RemoteInputChannel channel : channels) {
                int size = channel.getNumberOfQueuedBuffers();
                min = Math.min(min, size);
                max = Math.max(max, size);
            }
            Assertions.assertThat((max == min || max == min + 1 ? 1 : 0) != 0).isTrue();
        }
        Assertions.assertThat((Optional)gate.getNext()).isNotPresent();
    }

    @Test
    void testFairConsumptionRemoteChannels() throws Exception {
        int i;
        int numberOfChannels = 37;
        int buffersPerChannel = 27;
        Buffer mockBuffer = TestBufferFactory.createBuffer(42);
        SingleInputGate gate = this.createFairnessVerifyingInputGate(37);
        ConnectionManager connManager = InputChannelTestUtils.createDummyConnectionManager();
        RemoteInputChannel[] channels = new RemoteInputChannel[37];
        int[] channelSequenceNums = new int[37];
        for (i = 0; i < 37; ++i) {
            RemoteInputChannel channel;
            channels[i] = channel = InputGateFairnessTest.createRemoteInputChannel(gate, i, connManager);
        }
        channels[11].onBuffer(mockBuffer, 0, -1, 0);
        channelSequenceNums[11] = channelSequenceNums[11] + 1;
        InputGateFairnessTest.setupInputGate(gate, (InputChannel[])channels);
        for (i = 0; i < 999; ++i) {
            Assertions.assertThat((Optional)gate.getNext()).isPresent();
            int min = Integer.MAX_VALUE;
            int max = 0;
            for (RemoteInputChannel channel : channels) {
                int size = channel.getNumberOfQueuedBuffers();
                min = Math.min(min, size);
                max = Math.max(max, size);
            }
            Assertions.assertThat((max == min || max == min + 1 ? 1 : 0) != 0).isTrue();
            if (i % 74 != 0) continue;
            this.fillRandom(channels, channelSequenceNums, 3, mockBuffer);
        }
    }

    private SingleInputGate createFairnessVerifyingInputGate(int numberOfChannels) {
        return new FairnessVerifyingInputGate("Test Task Name", new IntermediateDataSetID(), numberOfChannels);
    }

    private void fillRandom(PipelinedSubpartition[] partitions, int numPerPartition, BufferConsumer buffer) throws Exception {
        ArrayList<Integer> poss = new ArrayList<Integer>(partitions.length * numPerPartition);
        for (int i = 0; i < partitions.length; ++i) {
            for (int k = 0; k < numPerPartition; ++k) {
                poss.add(i);
            }
        }
        Collections.shuffle(poss);
        for (Integer i : poss) {
            partitions[i].add(buffer.copy());
        }
    }

    private void fillRandom(RemoteInputChannel[] partitions, int[] sequenceNumbers, int numPerPartition, Buffer buffer) throws Exception {
        ArrayList<Integer> poss = new ArrayList<Integer>(partitions.length * numPerPartition);
        for (int i = 0; i < partitions.length; ++i) {
            for (int k = 0; k < numPerPartition; ++k) {
                poss.add(i);
            }
        }
        Collections.shuffle(poss);
        Iterator iterator = poss.iterator();
        while (iterator.hasNext()) {
            int i;
            int n = i = ((Integer)iterator.next()).intValue();
            int n2 = sequenceNumbers[n];
            sequenceNumbers[n] = n2 + 1;
            partitions[i].onBuffer(buffer, n2, -1, 0);
        }
    }

    public static RemoteInputChannel createRemoteInputChannel(SingleInputGate inputGate, int channelIndex, ConnectionManager connectionManager) {
        return InputChannelBuilder.newBuilder().setChannelIndex(channelIndex).setConnectionManager(connectionManager).buildRemoteChannel(inputGate);
    }

    public static void setupInputGate(SingleInputGate gate, InputChannel ... channels) throws IOException {
        gate.setInputChannels(channels);
        gate.setup();
        gate.requestPartitions();
    }

    private static class FairnessVerifyingInputGate
    extends SingleInputGate {
        private static final int BUFFER_SIZE = 32768;
        private static final SupplierWithException<BufferPool, IOException> STUB_BUFFER_POOL_FACTORY = NoOpBufferPool::new;
        private final PrioritizedDeque<InputChannel> channelsWithData = this.getInputChannelsWithData();
        private final HashSet<InputChannel> uniquenessChecker = new HashSet();

        public FairnessVerifyingInputGate(String owningTaskName, IntermediateDataSetID consumedResultId, int numberOfInputChannels) {
            super(owningTaskName, 0, consumedResultId, ResultPartitionType.PIPELINED, numberOfInputChannels, SingleInputGateBuilder.NO_OP_PRODUCER_CHECKER, STUB_BUFFER_POOL_FACTORY, null, (MemorySegmentProvider)new InputChannelTestUtils.UnpooledMemorySegmentProvider(32768), 32768, new ThroughputCalculator((Clock)SystemClock.getInstance()), null);
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public Optional<BufferOrEvent> getNext() throws IOException, InterruptedException {
            PrioritizedDeque<InputChannel> prioritizedDeque = this.channelsWithData;
            synchronized (prioritizedDeque) {
                ((AbstractIntegerAssert)Assertions.assertThat((int)this.channelsWithData.size()).withFailMessage("too many input channels", new Object[0])).isLessThanOrEqualTo(this.getNumberOfInputChannels());
                this.ensureUnique(this.channelsWithData.asUnmodifiableCollection());
            }
            return super.getNext();
        }

        private void ensureUnique(Collection<InputChannel> channels) {
            HashSet<InputChannel> uniquenessChecker = this.uniquenessChecker;
            for (InputChannel channel : channels) {
                if (uniquenessChecker.add(channel)) continue;
                Assertions.fail((String)("Duplicate channel in input gate: " + String.valueOf(channel)));
            }
            ((AbstractCollectionAssert)Assertions.assertThat(uniquenessChecker).withFailMessage("found duplicate input channels", new Object[0])).hasSameSizeAs(channels);
            uniquenessChecker.clear();
        }
    }
}

