package org.apache.flink.streaming.runtime.io.recovery;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import java.util.function.Predicate;
import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.TaskInfo;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.runtime.checkpoint.CheckpointException;
import org.apache.flink.runtime.checkpoint.CheckpointFailureReason;
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.io.network.api.SubtaskConnectionDescriptor;
import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer;
import org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer;
import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
import org.apache.flink.runtime.plugable.DeserializationDelegate;
import org.apache.flink.shaded.guava31.com.google.common.collect.Maps;
import org.apache.flink.streaming.runtime.io.AbstractStreamTaskNetworkInput;
import org.apache.flink.streaming.runtime.io.DataInputStatus;
import org.apache.flink.streaming.runtime.io.RecoverableStreamTaskInput;
import org.apache.flink.streaming.runtime.io.StreamTaskInput;
import org.apache.flink.streaming.runtime.io.StreamTaskNetworkInput;
import org.apache.flink.streaming.runtime.io.checkpointing.CheckpointedInputGate;
import org.apache.flink.streaming.runtime.partitioner.ConfigurableStreamPartitioner;
import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner;
import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
import org.apache.flink.streaming.runtime.streamrecord.StreamElement;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.tasks.StreamTask;
import org.apache.flink.streaming.runtime.watermarkstatus.StatusWatermarkValve;
import org.apache.flink.util.CollectionUtil;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Internal
/* loaded from: input_file:org/apache/flink/streaming/runtime/io/recovery/RescalingStreamTaskNetworkInput.class */
public final class RescalingStreamTaskNetworkInput<T> extends AbstractStreamTaskNetworkInput<T, DemultiplexingRecordDeserializer<T>> implements RecoverableStreamTaskInput<T> {
    private static final Logger LOG = LoggerFactory.getLogger(RescalingStreamTaskNetworkInput.class);
    private final IOManager ioManager;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/flink/streaming/runtime/io/recovery/RescalingStreamTaskNetworkInput$DeserializerFactory.class */
    public static class DeserializerFactory implements Function<Integer, RecordDeserializer<DeserializationDelegate<StreamElement>>> {
        private final IOManager ioManager;

        public DeserializerFactory(IOManager iOManager) {
            this.ioManager = iOManager;
        }

        @Override // java.util.function.Function
        public RecordDeserializer<DeserializationDelegate<StreamElement>> apply(Integer num) {
            return new SpillingAdaptiveSpanningRecordDeserializer(this.ioManager.getSpillingDirectoriesPaths(), 5242880 / num.intValue(), 2097152 / num.intValue());
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/flink/streaming/runtime/io/recovery/RescalingStreamTaskNetworkInput$RecordFilterFactory.class */
    public static class RecordFilterFactory<T> implements Function<InputChannelInfo, Predicate<StreamRecord<T>>> {
        private final Map<Integer, StreamPartitioner<T>> partitionerCache = CollectionUtil.newHashMapWithExpectedSize(1);
        private final Function<Integer, StreamPartitioner<?>> gatePartitioners;
        private final TypeSerializer<T> inputSerializer;
        private final int numberOfChannels;
        private final int subtaskIndex;
        private final int maxParallelism;

        public RecordFilterFactory(int i, TypeSerializer<T> typeSerializer, int i2, Function<Integer, StreamPartitioner<?>> function, int i3) {
            this.gatePartitioners = function;
            this.inputSerializer = typeSerializer;
            this.numberOfChannels = i2;
            this.subtaskIndex = i;
            this.maxParallelism = i3;
        }

        @Override // java.util.function.Function
        public Predicate<StreamRecord<T>> apply(InputChannelInfo inputChannelInfo) {
            return new RecordFilter(this.partitionerCache.computeIfAbsent(Integer.valueOf(inputChannelInfo.getGateIdx()), this::createPartitioner).copy(), this.inputSerializer, this.subtaskIndex);
        }

        /* JADX WARN: Multi-variable type inference failed */
        private StreamPartitioner<T> createPartitioner(Integer num) {
            RescalePartitioner rescalePartitioner = (StreamPartitioner<T>) this.gatePartitioners.apply(num);
            rescalePartitioner.setup(this.numberOfChannels);
            if (rescalePartitioner instanceof ConfigurableStreamPartitioner) {
                ((ConfigurableStreamPartitioner) rescalePartitioner).configure(this.maxParallelism);
            }
            return rescalePartitioner;
        }
    }

    public RescalingStreamTaskNetworkInput(CheckpointedInputGate checkpointedInputGate, TypeSerializer<T> typeSerializer, IOManager iOManager, StatusWatermarkValve statusWatermarkValve, int i, InflightDataRescalingDescriptor inflightDataRescalingDescriptor, Function<Integer, StreamPartitioner<?>> function, TaskInfo taskInfo, StreamTask.CanEmitBatchOfRecordsChecker canEmitBatchOfRecordsChecker) {
        super(checkpointedInputGate, typeSerializer, statusWatermarkValve, i, getRecordDeserializers(checkpointedInputGate, typeSerializer, iOManager, inflightDataRescalingDescriptor, function, taskInfo), canEmitBatchOfRecordsChecker);
        this.ioManager = iOManager;
        LOG.info("Created demultiplexer for input {} from {}", Integer.valueOf(i), inflightDataRescalingDescriptor);
    }

    private static <T> Map<InputChannelInfo, DemultiplexingRecordDeserializer<T>> getRecordDeserializers(CheckpointedInputGate checkpointedInputGate, TypeSerializer<T> typeSerializer, IOManager iOManager, InflightDataRescalingDescriptor inflightDataRescalingDescriptor, Function<Integer, StreamPartitioner<?>> function, TaskInfo taskInfo) {
        RecordFilterFactory recordFilterFactory = new RecordFilterFactory(taskInfo.getIndexOfThisSubtask(), typeSerializer, taskInfo.getNumberOfParallelSubtasks(), function, taskInfo.getMaxNumberOfParallelSubtasks());
        DeserializerFactory deserializerFactory = new DeserializerFactory(iOManager);
        HashMap newHashMapWithExpectedSize = Maps.newHashMapWithExpectedSize(checkpointedInputGate.getChannelInfos().size());
        for (InputChannelInfo inputChannelInfo : checkpointedInputGate.getChannelInfos()) {
            newHashMapWithExpectedSize.put(inputChannelInfo, DemultiplexingRecordDeserializer.create(inputChannelInfo, inflightDataRescalingDescriptor, deserializerFactory, recordFilterFactory));
        }
        return newHashMapWithExpectedSize;
    }

    @Override // org.apache.flink.streaming.runtime.io.RecoverableStreamTaskInput
    public StreamTaskInput<T> finishRecovery() throws IOException {
        Preconditions.checkState(!this.recordDeserializers.values().stream().anyMatch((v0) -> {
            return v0.hasPartialData();
        }), "Not all data has been fully consumed");
        close();
        return new StreamTaskNetworkInput(this.checkpointedInputGate, this.inputSerializer, this.ioManager, this.statusWatermarkValve, this.inputIndex, this.canEmitBatchOfRecords);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.flink.streaming.runtime.io.AbstractStreamTaskNetworkInput
    public DemultiplexingRecordDeserializer<T> getActiveSerializer(InputChannelInfo inputChannelInfo) {
        DemultiplexingRecordDeserializer<T> demultiplexingRecordDeserializer = (DemultiplexingRecordDeserializer) super.getActiveSerializer(inputChannelInfo);
        if (demultiplexingRecordDeserializer.hasMappings()) {
            return demultiplexingRecordDeserializer;
        }
        throw new IllegalStateException("Channel " + inputChannelInfo + " should not receive data during recovery.");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.flink.streaming.runtime.io.AbstractStreamTaskNetworkInput
    public DataInputStatus processEvent(BufferOrEvent bufferOrEvent) {
        SubtaskConnectionDescriptor event = bufferOrEvent.getEvent();
        if (!(event instanceof SubtaskConnectionDescriptor)) {
            return super.processEvent(bufferOrEvent);
        }
        getActiveSerializer(bufferOrEvent.getChannelInfo()).select(event);
        return DataInputStatus.MORE_AVAILABLE;
    }

    @Override // org.apache.flink.streaming.runtime.io.StreamTaskInput
    public CompletableFuture<Void> prepareSnapshot(ChannelStateWriter channelStateWriter, long j) throws CheckpointException {
        throw new CheckpointException(CheckpointFailureReason.CHECKPOINT_DECLINED_TASK_NOT_READY);
    }
}
