/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.streaming.runtime.io.recovery;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.LongSerializer;
import org.apache.flink.core.io.IOReadableWritable;
import org.apache.flink.core.memory.DataOutputSerializer;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptorUtil;
import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.io.disk.iomanager.IOManagerAsync;
import org.apache.flink.runtime.io.network.api.SubtaskConnectionDescriptor;
import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer;
import org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer;
import org.apache.flink.runtime.io.network.api.writer.ChannelSelector;
import org.apache.flink.runtime.io.network.api.writer.RecordWriter;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
import org.apache.flink.runtime.io.network.buffer.BufferBuilderTestUtils;
import org.apache.flink.runtime.io.network.buffer.BufferConsumer;
import org.apache.flink.runtime.plugable.DeserializationDelegate;
import org.apache.flink.runtime.plugable.NonReusingDeserializationDelegate;
import org.apache.flink.runtime.plugable.SerializationDelegate;
import org.apache.flink.shaded.guava32.com.google.common.collect.Iterables;
import org.apache.flink.streaming.api.watermark.Watermark;
import org.apache.flink.streaming.runtime.io.recovery.DemultiplexingRecordDeserializer;
import org.apache.flink.streaming.runtime.io.recovery.RecordFilter;
import org.apache.flink.streaming.runtime.streamrecord.StreamElement;
import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

class DemultiplexingRecordDeserializerTest {
    private final ThreadLocalRandom random = ThreadLocalRandom.current();
    private IOManager ioManager;

    DemultiplexingRecordDeserializerTest() {
    }

    @BeforeEach
    void setup() {
        this.ioManager = new IOManagerAsync();
    }

    @AfterEach
    void cleanup() {
        this.ioManager = new IOManagerAsync();
    }

    @Test
    void testUpscale() throws IOException {
        DemultiplexingRecordDeserializer deserializer = DemultiplexingRecordDeserializer.create((InputChannelInfo)new InputChannelInfo(2, 0), (InflightDataRescalingDescriptor)InflightDataRescalingDescriptorUtil.rescalingDescriptor(InflightDataRescalingDescriptorUtil.to(0, 1), InflightDataRescalingDescriptorUtil.array(InflightDataRescalingDescriptorUtil.mappings(new int[0][]), InflightDataRescalingDescriptorUtil.mappings(new int[0][]), InflightDataRescalingDescriptorUtil.mappings(InflightDataRescalingDescriptorUtil.to(2, 3), InflightDataRescalingDescriptorUtil.to(4, 5))), Collections.emptySet()), unused -> new SpillingAdaptiveSpanningRecordDeserializer(this.ioManager.getSpillingDirectoriesPaths()), unused -> RecordFilter.all());
        Assertions.assertThat((Collection)deserializer.getVirtualChannelSelectors()).containsOnly((Object[])new SubtaskConnectionDescriptor[]{new SubtaskConnectionDescriptor(0, 2), new SubtaskConnectionDescriptor(0, 3), new SubtaskConnectionDescriptor(1, 2), new SubtaskConnectionDescriptor(1, 3)});
        for (int i = 0; i < 100; ++i) {
            SubtaskConnectionDescriptor selector = (SubtaskConnectionDescriptor)Iterables.get((Iterable)deserializer.getVirtualChannelSelectors(), (int)this.random.nextInt(4));
            long start = selector.getInputSubtaskIndex() << 4 | selector.getOutputSubtaskIndex();
            MemorySegment memorySegment = MemorySegmentFactory.allocateUnpooledSegment((int)128);
            try (BufferBuilder bufferBuilder = BufferBuilderTestUtils.createBufferBuilder(memorySegment);){
                Buffer buffer = this.writeLongs(bufferBuilder, start + 1L, start + 2L, start + 3L);
                deserializer.select(selector);
                deserializer.setNextBuffer(buffer);
            }
            Assertions.assertThat(this.readLongs((DemultiplexingRecordDeserializer<Long>)deserializer)).containsExactly((Object[])new Long[]{start + 1L, start + 2L, start + 3L});
            Assertions.assertThat((boolean)memorySegment.isFreed()).isTrue();
        }
    }

    @Test
    void testAmbiguousChannels() throws IOException {
        DemultiplexingRecordDeserializer deserializer = DemultiplexingRecordDeserializer.create((InputChannelInfo)new InputChannelInfo(1, 0), (InflightDataRescalingDescriptor)InflightDataRescalingDescriptorUtil.rescalingDescriptor(InflightDataRescalingDescriptorUtil.to(41, 42), InflightDataRescalingDescriptorUtil.array(InflightDataRescalingDescriptorUtil.mappings(new int[0][]), InflightDataRescalingDescriptorUtil.mappings(InflightDataRescalingDescriptorUtil.to(2, 3), InflightDataRescalingDescriptorUtil.to(4, 5))), InflightDataRescalingDescriptorUtil.set(42)), unused -> new SpillingAdaptiveSpanningRecordDeserializer(this.ioManager.getSpillingDirectoriesPaths()), unused -> new RecordFilter((ChannelSelector)new ModSelector(2), (TypeSerializer)LongSerializer.INSTANCE, 1));
        Assertions.assertThat((Collection)deserializer.getVirtualChannelSelectors()).containsOnly((Object[])new SubtaskConnectionDescriptor[]{new SubtaskConnectionDescriptor(41, 2), new SubtaskConnectionDescriptor(41, 3), new SubtaskConnectionDescriptor(42, 2), new SubtaskConnectionDescriptor(42, 3)});
        for (int i = 0; i < 100; ++i) {
            MemorySegment memorySegment = MemorySegmentFactory.allocateUnpooledSegment((int)128);
            try (BufferBuilder bufferBuilder = BufferBuilderTestUtils.createBufferBuilder(memorySegment);){
                Buffer buffer = this.writeLongs(bufferBuilder, i, (long)i + 1L);
                SubtaskConnectionDescriptor selector = (SubtaskConnectionDescriptor)Iterables.get((Iterable)deserializer.getVirtualChannelSelectors(), (int)(i / 10 % 2));
                deserializer.select(selector);
                deserializer.setNextBuffer(buffer);
                if (selector.getInputSubtaskIndex() == 41) {
                    Assertions.assertThat(this.readLongs((DemultiplexingRecordDeserializer<Long>)deserializer)).containsExactly((Object[])new Long[]{i, (long)i + 1L});
                } else {
                    Assertions.assertThat(this.readLongs((DemultiplexingRecordDeserializer<Long>)deserializer)).containsExactly((Object[])new Long[]{(long)(i / 2 * 2) + 1L});
                }
            }
            Assertions.assertThat((boolean)memorySegment.isFreed()).isTrue();
        }
    }

    @Test
    void testWatermarks() throws IOException {
        DemultiplexingRecordDeserializer deserializer = DemultiplexingRecordDeserializer.create((InputChannelInfo)new InputChannelInfo(0, 0), (InflightDataRescalingDescriptor)InflightDataRescalingDescriptorUtil.rescalingDescriptor(InflightDataRescalingDescriptorUtil.to(0, 1), InflightDataRescalingDescriptorUtil.array(InflightDataRescalingDescriptorUtil.mappings(InflightDataRescalingDescriptorUtil.to(0, 1), InflightDataRescalingDescriptorUtil.to(4, 5))), Collections.emptySet()), unused -> new SpillingAdaptiveSpanningRecordDeserializer(this.ioManager.getSpillingDirectoriesPaths()), unused -> RecordFilter.all());
        Assertions.assertThat((Collection)deserializer.getVirtualChannelSelectors()).hasSize(4);
        Iterator iterator = deserializer.getVirtualChannelSelectors().iterator();
        while (iterator.hasNext()) {
            SubtaskConnectionDescriptor selector = (SubtaskConnectionDescriptor)iterator.next();
            MemorySegment memorySegment = MemorySegmentFactory.allocateUnpooledSegment((int)128);
            try (BufferBuilder bufferBuilder = BufferBuilderTestUtils.createBufferBuilder(memorySegment);){
                long ts = 42L + (long)selector.getInputSubtaskIndex() + (long)selector.getOutputSubtaskIndex();
                Buffer buffer = this.write(bufferBuilder, new StreamElement[]{new Watermark(ts)});
                deserializer.select(selector);
                deserializer.setNextBuffer(buffer);
            }
            if (iterator.hasNext()) {
                Assertions.assertThat(this.read((DemultiplexingRecordDeserializer<Long>)deserializer)).isEmpty();
            } else {
                Assertions.assertThat(this.read((DemultiplexingRecordDeserializer<Long>)deserializer)).containsExactly((Object[])new StreamElement[]{new Watermark(42L)});
            }
            Assertions.assertThat((boolean)memorySegment.isFreed()).isTrue();
        }
    }

    private Buffer writeLongs(BufferBuilder bufferBuilder, long ... elements) throws IOException {
        return this.write(bufferBuilder, (StreamElement[])Arrays.stream(elements).mapToObj(StreamRecord::new).toArray(StreamElement[]::new));
    }

    private Buffer write(BufferBuilder bufferBuilder, StreamElement ... elements) throws IOException {
        try (BufferConsumer bufferConsumer = bufferBuilder.createBufferConsumer();){
            DataOutputSerializer output = new DataOutputSerializer(128);
            SerializationDelegate delegate = new SerializationDelegate((TypeSerializer)new StreamElementSerializer((TypeSerializer)LongSerializer.INSTANCE));
            for (StreamElement element : elements) {
                delegate.setInstance((Object)element);
                bufferBuilder.appendAndCommit(RecordWriter.serializeRecord((DataOutputSerializer)output, (IOReadableWritable)delegate));
            }
            Buffer buffer = bufferConsumer.build();
            return buffer;
        }
    }

    private List<StreamElement> read(DemultiplexingRecordDeserializer<Long> deserializer) throws IOException {
        RecordDeserializer.DeserializationResult result;
        NonReusingDeserializationDelegate delegate = new NonReusingDeserializationDelegate((TypeSerializer)new StreamElementSerializer((TypeSerializer)LongSerializer.INSTANCE));
        ArrayList<StreamElement> results = new ArrayList<StreamElement>();
        do {
            if (!(result = deserializer.getNextRecord((DeserializationDelegate)delegate)).isFullRecord()) continue;
            results.add((StreamElement)delegate.getInstance());
        } while (!result.isBufferConsumed());
        return results;
    }

    private List<Long> readLongs(DemultiplexingRecordDeserializer<Long> deserializer) throws IOException {
        return this.read(deserializer).stream().map(element -> (Long)element.asRecord().getValue()).collect(Collectors.toList());
    }

    private static class ModSelector
    implements ChannelSelector<SerializationDelegate<StreamRecord<Long>>> {
        private final int numberOfChannels;

        private ModSelector(int numberOfChannels) {
            this.numberOfChannels = numberOfChannels;
        }

        public void setup(int numberOfChannels) {
        }

        public int selectChannel(SerializationDelegate<StreamRecord<Long>> record) {
            return (int)((Long)((StreamRecord)record.getInstance()).getValue() % (long)this.numberOfChannels);
        }

        public boolean isBroadcast() {
            return false;
        }
    }
}

