/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.io.network.partition.consumer;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.configuration.NettyShuffleEnvironmentOptions;
import org.apache.flink.metrics.MetricGroup;
import org.apache.flink.runtime.clusterframework.types.ResourceID;
import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
import org.apache.flink.runtime.executiongraph.IndexRange;
import org.apache.flink.runtime.io.network.ConnectionManager;
import org.apache.flink.runtime.io.network.TaskEventPublisher;
import org.apache.flink.runtime.io.network.buffer.BufferDecompressor;
import org.apache.flink.runtime.io.network.buffer.BufferPool;
import org.apache.flink.runtime.io.network.buffer.BufferPoolFactory;
import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
import org.apache.flink.runtime.io.network.metrics.InputChannelMetrics;
import org.apache.flink.runtime.io.network.partition.PartitionProducerStateProvider;
import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
import org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet;
import org.apache.flink.runtime.io.network.partition.consumer.GateBuffersSpec;
import org.apache.flink.runtime.io.network.partition.consumer.InputChannel;
import org.apache.flink.runtime.io.network.partition.consumer.InputGateSpecUtils;
import org.apache.flink.runtime.io.network.partition.consumer.LocalRecoveredInputChannel;
import org.apache.flink.runtime.io.network.partition.consumer.RemoteRecoveredInputChannel;
import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
import org.apache.flink.runtime.io.network.partition.consumer.UnknownInputChannel;
import org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageConfiguration;
import org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageIdMappingUtils;
import org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageInputChannelId;
import org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStoragePartitionId;
import org.apache.flink.runtime.io.network.partition.hybrid.tiered.netty.TieredStorageNettyServiceImpl;
import org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.TieredStorageConsumerClient;
import org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.TieredStorageConsumerSpec;
import org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierShuffleDescriptor;
import org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.UnknownTierShuffleDescriptor;
import org.apache.flink.runtime.shuffle.NettyShuffleDescriptor;
import org.apache.flink.runtime.shuffle.NettyShuffleUtils;
import org.apache.flink.runtime.shuffle.ShuffleDescriptor;
import org.apache.flink.runtime.shuffle.ShuffleIOOwnerContext;
import org.apache.flink.runtime.shuffle.ShuffleUtils;
import org.apache.flink.runtime.taskmanager.NettyShuffleEnvironmentConfiguration;
import org.apache.flink.runtime.throughput.BufferDebloatConfiguration;
import org.apache.flink.runtime.throughput.BufferDebloater;
import org.apache.flink.runtime.throughput.ThroughputCalculator;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.clock.SystemClock;
import org.apache.flink.util.function.SupplierWithException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SingleInputGateFactory {
    private static final Logger LOG = LoggerFactory.getLogger(SingleInputGateFactory.class);
    @Nonnull
    protected final ResourceID taskExecutorResourceId;
    protected final int partitionRequestInitialBackoff;
    protected final int partitionRequestMaxBackoff;
    protected final int partitionRequestListenerTimeout;
    @Nonnull
    protected final ConnectionManager connectionManager;
    @Nonnull
    protected final ResultPartitionManager partitionManager;
    @Nonnull
    protected final TaskEventPublisher taskEventPublisher;
    @Nonnull
    protected final NetworkBufferPool networkBufferPool;
    private final Optional<Integer> maxRequiredBuffersPerGate;
    protected final int configuredNetworkBuffersPerChannel;
    private final int floatingNetworkBuffersPerGate;
    private final boolean batchShuffleCompressionEnabled;
    private final NettyShuffleEnvironmentOptions.CompressionCodec compressionCodec;
    private final int networkBufferSize;
    private final BufferDebloatConfiguration debloatConfiguration;
    @Nullable
    private final TieredStorageConfiguration tieredStorageConfiguration;
    @Nullable
    private final TieredStorageNettyServiceImpl tieredStorageNettyService;

    public SingleInputGateFactory(@Nonnull ResourceID taskExecutorResourceId, @Nonnull NettyShuffleEnvironmentConfiguration networkConfig, @Nonnull ConnectionManager connectionManager, @Nonnull ResultPartitionManager partitionManager, @Nonnull TaskEventPublisher taskEventPublisher, @Nonnull NetworkBufferPool networkBufferPool, @Nullable TieredStorageConfiguration tieredStorageConfiguration, @Nullable TieredStorageNettyServiceImpl tieredStorageNettyService) {
        this.taskExecutorResourceId = taskExecutorResourceId;
        this.partitionRequestInitialBackoff = networkConfig.partitionRequestInitialBackoff();
        this.partitionRequestMaxBackoff = networkConfig.partitionRequestMaxBackoff();
        this.partitionRequestListenerTimeout = networkConfig.getPartitionRequestListenerTimeout();
        this.maxRequiredBuffersPerGate = networkConfig.maxRequiredBuffersPerGate();
        this.configuredNetworkBuffersPerChannel = NettyShuffleUtils.getNetworkBuffersPerInputChannel(networkConfig.networkBuffersPerChannel());
        this.floatingNetworkBuffersPerGate = networkConfig.floatingNetworkBuffersPerGate();
        this.batchShuffleCompressionEnabled = networkConfig.isBatchShuffleCompressionEnabled();
        this.compressionCodec = networkConfig.getCompressionCodec();
        this.networkBufferSize = networkConfig.networkBufferSize();
        this.connectionManager = connectionManager;
        this.partitionManager = partitionManager;
        this.taskEventPublisher = taskEventPublisher;
        this.networkBufferPool = networkBufferPool;
        this.debloatConfiguration = networkConfig.getDebloatConfiguration();
        this.tieredStorageConfiguration = tieredStorageConfiguration;
        this.tieredStorageNettyService = tieredStorageNettyService;
    }

    public SingleInputGate create(@Nonnull ShuffleIOOwnerContext owner, int gateIndex, @Nonnull InputGateDeploymentDescriptor igdd, @Nonnull PartitionProducerStateProvider partitionProducerStateProvider, @Nonnull InputChannelMetrics metrics) {
        GateBuffersSpec gateBuffersSpec = InputGateSpecUtils.createGateBuffersSpec(this.maxRequiredBuffersPerGate, this.configuredNetworkBuffersPerChannel, this.floatingNetworkBuffersPerGate, igdd.getConsumedPartitionType(), igdd.getNumConsumedShuffleDescriptors(), this.tieredStorageConfiguration != null);
        SupplierWithException<BufferPool, IOException> bufferPoolFactory = SingleInputGateFactory.createBufferPoolFactory(this.networkBufferPool, gateBuffersSpec.getRequiredFloatingBuffers(), gateBuffersSpec.getTotalFloatingBuffers());
        BufferDecompressor bufferDecompressor = null;
        if (igdd.getConsumedPartitionType().supportCompression() && this.batchShuffleCompressionEnabled) {
            bufferDecompressor = new BufferDecompressor(this.networkBufferSize, this.compressionCodec);
        }
        String owningTaskName = owner.getOwnerName();
        MetricGroup networkInputGroup = owner.getInputGroup();
        SingleInputGate inputGate = new SingleInputGate(owningTaskName, gateIndex, igdd.getConsumedResultId(), igdd.getConsumedPartitionType(), igdd.getNumConsumedShuffleDescriptors(), partitionProducerStateProvider, bufferPoolFactory, bufferDecompressor, this.networkBufferPool, this.networkBufferSize, new ThroughputCalculator(SystemClock.getInstance()), this.maybeCreateBufferDebloater(owningTaskName, gateIndex, networkInputGroup.addGroup(gateIndex)));
        this.createInputChannelsAndTieredStorageService(owningTaskName, igdd, inputGate, gateBuffersSpec, metrics);
        return inputGate;
    }

    private BufferDebloater maybeCreateBufferDebloater(String owningTaskName, int gateIndex, MetricGroup inputGroup) {
        if (this.debloatConfiguration.isEnabled()) {
            BufferDebloater bufferDebloater = new BufferDebloater(owningTaskName, gateIndex, this.debloatConfiguration.getTargetTotalTime().toMillis(), this.debloatConfiguration.getStartingBufferSize(), this.debloatConfiguration.getMaxBufferSize(), this.debloatConfiguration.getMinBufferSize(), this.debloatConfiguration.getBufferDebloatThresholdPercentages(), this.debloatConfiguration.getNumberOfSamples());
            inputGroup.gauge("estimatedTimeToConsumeBuffersMs", () -> bufferDebloater.getLastEstimatedTimeToConsumeBuffers().toMillis());
            inputGroup.gauge("debloatedBufferSize", bufferDebloater::getLastBufferSize);
            return bufferDebloater;
        }
        return null;
    }

    private void createInputChannelsAndTieredStorageService(String owningTaskName, InputGateDeploymentDescriptor inputGateDeploymentDescriptor, SingleInputGate inputGate, GateBuffersSpec gateBuffersSpec, InputChannelMetrics metrics) {
        ShuffleDescriptor[] shuffleDescriptors = inputGateDeploymentDescriptor.getShuffleDescriptors();
        int inputChannelSize = inputGateDeploymentDescriptor.getNumConsumedShuffleDescriptors();
        InputChannel[] inputChannels = new InputChannel[inputChannelSize];
        ChannelStatistics channelStatistics = new ChannelStatistics();
        int channelIdx = 0;
        ArrayList<TieredStorageConsumerSpec> tieredStorageConsumerSpecs = new ArrayList<TieredStorageConsumerSpec>();
        ArrayList<List<TierShuffleDescriptor>> tierShuffleDescriptors = new ArrayList<List<TierShuffleDescriptor>>();
        for (IndexRange consumedShuffleDescriptorRange : inputGateDeploymentDescriptor.getConsumedShuffleDescriptorRanges()) {
            for (int i = consumedShuffleDescriptorRange.getStartIndex(); i <= consumedShuffleDescriptorRange.getEndIndex(); ++i) {
                ResultSubpartitionIndexSet subpartitionIndexSet = new ResultSubpartitionIndexSet(inputGateDeploymentDescriptor.getConsumedSubpartitionRange(i));
                ShuffleDescriptor descriptor = shuffleDescriptors[i];
                TieredStoragePartitionId partitionId = TieredStorageIdMappingUtils.convertId(descriptor.getResultPartitionID());
                inputChannels[channelIdx] = this.createInputChannel(inputGate, channelIdx, gateBuffersSpec.getEffectiveExclusiveBuffersPerChannel(), descriptor, subpartitionIndexSet, channelStatistics, metrics);
                if (this.tieredStorageConfiguration != null) {
                    this.addTierShuffleDescriptors(tierShuffleDescriptors, descriptor);
                    tieredStorageConsumerSpecs.add(new TieredStorageConsumerSpec(inputGate.getInputGateIndex(), partitionId, new TieredStorageInputChannelId(channelIdx), subpartitionIndexSet));
                }
                ++channelIdx;
            }
        }
        Preconditions.checkState(channelIdx == inputChannelSize);
        inputGate.setInputChannels(inputChannels);
        if (this.tieredStorageConfiguration != null) {
            TieredStorageConsumerClient tieredStorageConsumerClient = new TieredStorageConsumerClient(this.tieredStorageConfiguration.getTierFactories(), tieredStorageConsumerSpecs, tierShuffleDescriptors, this.tieredStorageNettyService);
            inputGate.setTieredStorageService(tieredStorageConsumerSpecs, tieredStorageConsumerClient, this.tieredStorageNettyService);
        }
        LOG.debug("{}: Created {} input channels ({}).", new Object[]{owningTaskName, inputChannels.length, channelStatistics});
    }

    private InputChannel createInputChannel(SingleInputGate inputGate, int index, int buffersPerChannel, ShuffleDescriptor shuffleDescriptor, ResultSubpartitionIndexSet subpartitionIndexSet, ChannelStatistics channelStatistics, InputChannelMetrics metrics) {
        return ShuffleUtils.applyWithShuffleTypeCheck(NettyShuffleDescriptor.class, shuffleDescriptor, unknownShuffleDescriptor -> {
            ++channelStatistics.numUnknownChannels;
            return new UnknownInputChannel(inputGate, index, unknownShuffleDescriptor.getResultPartitionID(), subpartitionIndexSet, this.partitionManager, this.taskEventPublisher, this.connectionManager, this.partitionRequestInitialBackoff, this.partitionRequestMaxBackoff, this.partitionRequestListenerTimeout, buffersPerChannel, metrics);
        }, nettyShuffleDescriptor -> this.createKnownInputChannel(inputGate, index, buffersPerChannel, (NettyShuffleDescriptor)nettyShuffleDescriptor, subpartitionIndexSet, channelStatistics, metrics));
    }

    @VisibleForTesting
    protected InputChannel createKnownInputChannel(SingleInputGate inputGate, int index, int buffersPerChannel, NettyShuffleDescriptor inputChannelDescriptor, ResultSubpartitionIndexSet subpartitionIndexSet, ChannelStatistics channelStatistics, InputChannelMetrics metrics) {
        ResultPartitionID partitionId = inputChannelDescriptor.getResultPartitionID();
        if (inputChannelDescriptor.isLocalTo(this.taskExecutorResourceId)) {
            ++channelStatistics.numLocalChannels;
            return new LocalRecoveredInputChannel(inputGate, index, partitionId, subpartitionIndexSet, this.partitionManager, this.taskEventPublisher, this.partitionRequestInitialBackoff, this.partitionRequestMaxBackoff, buffersPerChannel, metrics);
        }
        ++channelStatistics.numRemoteChannels;
        return new RemoteRecoveredInputChannel(inputGate, index, partitionId, subpartitionIndexSet, inputChannelDescriptor.getConnectionId(), this.connectionManager, this.partitionRequestInitialBackoff, this.partitionRequestMaxBackoff, this.partitionRequestListenerTimeout, buffersPerChannel, metrics);
    }

    private void addTierShuffleDescriptors(List<List<TierShuffleDescriptor>> tierShuffleDescriptors, ShuffleDescriptor descriptor) {
        if (descriptor instanceof NettyShuffleDescriptor) {
            tierShuffleDescriptors.add(((NettyShuffleDescriptor)descriptor).getTierShuffleDescriptors());
        } else if (descriptor.isUnknown()) {
            ArrayList<UnknownTierShuffleDescriptor> unknownDescriptors = new ArrayList<UnknownTierShuffleDescriptor>();
            int numTiers = Preconditions.checkNotNull(this.tieredStorageConfiguration).getTierFactories().size();
            for (int i = 0; i < numTiers; ++i) {
                unknownDescriptors.add(UnknownTierShuffleDescriptor.INSTANCE);
            }
            tierShuffleDescriptors.add(unknownDescriptors);
        } else {
            throw new IllegalArgumentException("Unsupported shuffle descriptor type " + String.valueOf(descriptor));
        }
    }

    @VisibleForTesting
    static SupplierWithException<BufferPool, IOException> createBufferPoolFactory(BufferPoolFactory bufferPoolFactory, int minFloatingBuffersPerGate, int maxFloatingBuffersPerGate) {
        Pair pair = Pair.of((Object)minFloatingBuffersPerGate, (Object)maxFloatingBuffersPerGate);
        return () -> bufferPoolFactory.createBufferPool((Integer)pair.getLeft(), (Integer)pair.getRight());
    }

    protected static class ChannelStatistics {
        int numLocalChannels;
        int numRemoteChannels;
        int numUnknownChannels;

        protected ChannelStatistics() {
        }

        public String toString() {
            return String.format("local: %s, remote: %s, unknown: %s", this.numLocalChannels, this.numRemoteChannels, this.numUnknownChannels);
        }
    }
}

