package org.apache.flink.runtime.checkpoint.channel;

import java.io.Closeable;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import javax.annotation.concurrent.NotThreadSafe;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.logger.NetworkActionsLogger;
import org.apache.flink.runtime.state.AbstractChannelStateHandle;
import org.apache.flink.runtime.state.CheckpointStreamFactory;
import org.apache.flink.runtime.state.CheckpointedStateScope;
import org.apache.flink.runtime.state.InputChannelStateHandle;
import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.function.RunnableWithException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* JADX INFO: Access modifiers changed from: package-private */
@NotThreadSafe
/* loaded from: input_file:org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriter.class */
public class ChannelStateCheckpointWriter {
    private static final Logger LOG = LoggerFactory.getLogger(ChannelStateCheckpointWriter.class);
    private final DataOutputStream dataStream;
    private final CheckpointStreamFactory.CheckpointStateOutputStream checkpointStream;
    private final ChannelStateWriter.ChannelStateWriteResult result;
    private final Map<InputChannelInfo, AbstractChannelStateHandle.StateContentMetaInfo> inputChannelOffsets;
    private final Map<ResultSubpartitionInfo, AbstractChannelStateHandle.StateContentMetaInfo> resultSubpartitionOffsets;
    private final ChannelStateSerializer serializer;
    private final long checkpointId;
    private boolean allInputsReceived;
    private boolean allOutputsReceived;
    private final RunnableWithException onComplete;
    private final int subtaskIndex;
    private String taskName;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriter$HandleFactory.class */
    public interface HandleFactory<I, H extends AbstractChannelStateHandle<I>> {
        public static final HandleFactory<InputChannelInfo, InputChannelStateHandle> INPUT_CHANNEL = InputChannelStateHandle::new;
        public static final HandleFactory<ResultSubpartitionInfo, ResultSubpartitionStateHandle> RESULT_SUBPARTITION = ResultSubpartitionStateHandle::new;

        H create(int i, I i2, StreamStateHandle streamStateHandle, List<Long> list, long j);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public ChannelStateCheckpointWriter(String str, int i, CheckpointStartRequest checkpointStartRequest, CheckpointStreamFactory checkpointStreamFactory, ChannelStateSerializer channelStateSerializer, RunnableWithException runnableWithException) throws Exception {
        this(str, i, checkpointStartRequest.getCheckpointId(), checkpointStartRequest.getTargetResult(), checkpointStreamFactory.createCheckpointStateOutputStream(CheckpointedStateScope.EXCLUSIVE), channelStateSerializer, runnableWithException);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @VisibleForTesting
    ChannelStateCheckpointWriter(String str, int i, long j, ChannelStateWriter.ChannelStateWriteResult channelStateWriteResult, CheckpointStreamFactory.CheckpointStateOutputStream checkpointStateOutputStream, ChannelStateSerializer channelStateSerializer, RunnableWithException runnableWithException) throws Exception {
        this(str, i, j, channelStateWriteResult, channelStateSerializer, runnableWithException, checkpointStateOutputStream, new DataOutputStream(checkpointStateOutputStream));
    }

    @VisibleForTesting
    ChannelStateCheckpointWriter(String str, int i, long j, ChannelStateWriter.ChannelStateWriteResult channelStateWriteResult, ChannelStateSerializer channelStateSerializer, RunnableWithException runnableWithException, CheckpointStreamFactory.CheckpointStateOutputStream checkpointStateOutputStream, DataOutputStream dataOutputStream) throws Exception {
        this.inputChannelOffsets = new HashMap();
        this.resultSubpartitionOffsets = new HashMap();
        this.allInputsReceived = false;
        this.allOutputsReceived = false;
        this.taskName = str;
        this.subtaskIndex = i;
        this.checkpointId = j;
        this.result = (ChannelStateWriter.ChannelStateWriteResult) Preconditions.checkNotNull(channelStateWriteResult);
        this.checkpointStream = (CheckpointStreamFactory.CheckpointStateOutputStream) Preconditions.checkNotNull(checkpointStateOutputStream);
        this.serializer = (ChannelStateSerializer) Preconditions.checkNotNull(channelStateSerializer);
        this.dataStream = (DataOutputStream) Preconditions.checkNotNull(dataOutputStream);
        this.onComplete = (RunnableWithException) Preconditions.checkNotNull(runnableWithException);
        runWithChecks(() -> {
            channelStateSerializer.writeHeader(dataOutputStream);
        });
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void writeInput(InputChannelInfo inputChannelInfo, Buffer buffer) throws Exception {
        write(this.inputChannelOffsets, inputChannelInfo, buffer, !this.allInputsReceived, "ChannelStateCheckpointWriter#writeInput");
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void writeOutput(ResultSubpartitionInfo resultSubpartitionInfo, Buffer buffer) throws Exception {
        write(this.resultSubpartitionOffsets, resultSubpartitionInfo, buffer, !this.allOutputsReceived, "ChannelStateCheckpointWriter#writeOutput");
    }

    private <K> void write(Map<K, AbstractChannelStateHandle.StateContentMetaInfo> map, K k, Buffer buffer, boolean z, String str) throws Exception {
        try {
            if (this.result.isDone()) {
                return;
            }
            runWithChecks(() -> {
                Preconditions.checkState(z);
                long pos = this.checkpointStream.getPos();
                Closeable measureIO = NetworkActionsLogger.measureIO(str, buffer);
                Throwable th = null;
                try {
                    try {
                        this.serializer.writeData(this.dataStream, buffer);
                        if (measureIO != null) {
                            if (0 != 0) {
                                try {
                                    measureIO.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                measureIO.close();
                            }
                        }
                        ((AbstractChannelStateHandle.StateContentMetaInfo) map.computeIfAbsent(k, obj -> {
                            return new AbstractChannelStateHandle.StateContentMetaInfo();
                        })).withDataAdded(pos, this.checkpointStream.getPos() - pos);
                        NetworkActionsLogger.tracePersist(str, buffer, this.taskName, k, this.checkpointId);
                    } finally {
                    }
                } catch (Throwable th3) {
                    if (measureIO != null) {
                        if (th != null) {
                            try {
                                measureIO.close();
                            } catch (Throwable th4) {
                                th.addSuppressed(th4);
                            }
                        } else {
                            measureIO.close();
                        }
                    }
                    throw th3;
                }
            });
            buffer.recycleBuffer();
        } finally {
            buffer.recycleBuffer();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void completeInput() throws Exception {
        LOG.debug("complete input, output completed: {}", Boolean.valueOf(this.allOutputsReceived));
        complete(!this.allInputsReceived, () -> {
            this.allInputsReceived = true;
        });
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void completeOutput() throws Exception {
        LOG.debug("complete output, input completed: {}", Boolean.valueOf(this.allInputsReceived));
        complete(!this.allOutputsReceived, () -> {
            this.allOutputsReceived = true;
        });
    }

    private void complete(boolean z, RunnableWithException runnableWithException) throws Exception {
        if (this.result.isDone()) {
            doComplete(z, runnableWithException, this.onComplete);
        } else {
            runWithChecks(() -> {
                doComplete(z, runnableWithException, this.onComplete, this::finishWriteAndResult);
            });
        }
    }

    private void finishWriteAndResult() throws IOException {
        if (this.inputChannelOffsets.isEmpty() && this.resultSubpartitionOffsets.isEmpty()) {
            this.dataStream.close();
            this.result.inputChannelStateHandles.complete(Collections.emptyList());
            this.result.resultSubpartitionStateHandles.complete(Collections.emptyList());
        } else {
            this.dataStream.flush();
            StreamStateHandle closeAndGetHandle = this.checkpointStream.closeAndGetHandle();
            complete(closeAndGetHandle, this.result.inputChannelStateHandles, this.inputChannelOffsets, HandleFactory.INPUT_CHANNEL);
            complete(closeAndGetHandle, this.result.resultSubpartitionStateHandles, this.resultSubpartitionOffsets, HandleFactory.RESULT_SUBPARTITION);
        }
    }

    private void doComplete(boolean z, RunnableWithException runnableWithException, RunnableWithException... runnableWithExceptionArr) throws Exception {
        Preconditions.checkArgument(z);
        runnableWithException.run();
        if (this.allInputsReceived && this.allOutputsReceived) {
            for (RunnableWithException runnableWithException2 : runnableWithExceptionArr) {
                runnableWithException2.run();
            }
        }
    }

    private <I, H extends AbstractChannelStateHandle<I>> void complete(StreamStateHandle streamStateHandle, CompletableFuture<Collection<H>> completableFuture, Map<I, AbstractChannelStateHandle.StateContentMetaInfo> map, HandleFactory<I, H> handleFactory) throws IOException {
        ArrayList arrayList = new ArrayList();
        for (Map.Entry<I, AbstractChannelStateHandle.StateContentMetaInfo> entry : map.entrySet()) {
            arrayList.add(createHandle(handleFactory, streamStateHandle, entry.getKey(), entry.getValue()));
        }
        completableFuture.complete(arrayList);
        LOG.debug("channel state write completed, checkpointId: {}, handles: {}", Long.valueOf(this.checkpointId), arrayList);
    }

    private <I, H extends AbstractChannelStateHandle<I>> H createHandle(HandleFactory<I, H> handleFactory, StreamStateHandle streamStateHandle, I i, AbstractChannelStateHandle.StateContentMetaInfo stateContentMetaInfo) throws IOException {
        Optional<byte[]> asBytesIfInMemory = streamStateHandle.asBytesIfInMemory();
        if (!asBytesIfInMemory.isPresent()) {
            return handleFactory.create(this.subtaskIndex, i, streamStateHandle, stateContentMetaInfo.getOffsets(), stateContentMetaInfo.getSize());
        }
        ByteStreamStateHandle byteStreamStateHandle = new ByteStreamStateHandle(UUID.randomUUID().toString(), this.serializer.extractAndMerge(asBytesIfInMemory.get(), stateContentMetaInfo.getOffsets()));
        return handleFactory.create(this.subtaskIndex, i, byteStreamStateHandle, Collections.singletonList(Long.valueOf(this.serializer.getHeaderLength())), byteStreamStateHandle.getStateSize());
    }

    private void runWithChecks(RunnableWithException runnableWithException) throws Exception {
        try {
            Preconditions.checkState(!this.result.isDone(), "result is already completed", new Object[]{this.result});
            runnableWithException.run();
        } catch (Exception e) {
            fail(e);
            throw e;
        }
    }

    public void fail(Throwable th) throws Exception {
        this.result.fail(th);
        this.checkpointStream.close();
    }
}
