/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.pytorch.engine;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.pytorch.engine.PtNDManager;
import ai.djl.pytorch.jni.IValueUtils;
import ai.djl.pytorch.jni.JniUtils;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PtSymbolBlock
extends AbstractSymbolBlock
implements AutoCloseable {
    private static final Logger logger = LoggerFactory.getLogger(PtSymbolBlock.class);
    private static final byte VERSION = 1;
    private AtomicReference<Long> handle;
    private String uid;
    private PtNDManager manager;
    private boolean isTrain;
    private PairList<String, Shape> inputDescriptions;
    private PairList<String, Shape> outputDescriptions;
    private boolean first;

    public PtSymbolBlock(PtNDManager manager, long handle) {
        super((byte)1);
        this.handle = new AtomicReference<Long>(handle);
        this.manager = manager;
        this.uid = String.valueOf(handle);
        manager.attach(this.uid, this);
        this.isTrain = true;
        this.first = true;
    }

    @Override
    public void close() {
        Long pointer = this.handle.getAndSet(null);
        if (pointer != null) {
            JniUtils.deleteModule(pointer);
            this.manager.detach(this.uid);
            this.manager = null;
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        if (this.isTrain != training) {
            this.isTrain = training;
            if (this.isTrain) {
                JniUtils.enableTrainingMode(this);
            } else {
                JniUtils.enableInferenceMode(this);
            }
        }
        if (!this.first) return IValueUtils.forward(this, inputs, training);
        Class<PtSymbolBlock> clazz = PtSymbolBlock.class;
        synchronized (PtSymbolBlock.class) {
            if (!this.first) return IValueUtils.forward(this, inputs, training);
            this.inputDescriptions = new PairList();
            this.outputDescriptions = new PairList();
            for (NDArray array : inputs) {
                this.inputDescriptions.add((Object)array.getName(), (Object)array.getShape());
            }
            NDList outputs = IValueUtils.forward(this, inputs, training);
            for (NDArray array : outputs) {
                this.outputDescriptions.add((Object)array.getName(), (Object)array.getShape());
            }
            this.first = false;
            // ** MonitorExit[var5_5] (shouldn't be in output)
            return outputs;
        }
    }

    public PairList<String, Shape> describeInput() {
        if (this.inputDescriptions == null) {
            logger.warn("Input shapes are unknown, please run predict or forward onceand call describeInput again.");
        }
        return this.inputDescriptions;
    }

    public PairList<String, Shape> describeOutput() {
        if (this.outputDescriptions == null) {
            logger.warn("Output shapes are unknown, please run predict or forward onceand call describeOutput again.");
        }
        return this.outputDescriptions;
    }

    public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
        return new Shape[0];
    }

    public Long getHandle() {
        Long reference = this.handle.get();
        if (reference == null) {
            throw new IllegalStateException("PyTorch model handle has been released!");
        }
        return reference;
    }
}

