/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.mxnet.engine;

import ai.djl.mxnet.engine.MxNDManager;
import ai.djl.mxnet.engine.Symbol;
import ai.djl.mxnet.jna.JnaUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.training.GradientCollector;

public class MxGradientCollector
implements GradientCollector {
    MxGradientCollector() {
        boolean prevRecordingState = MxGradientCollector.setRecording(true);
        if (prevRecordingState) {
            throw new IllegalStateException("Autograd Recording is already set to True. Please create autograd using try with resource ");
        }
        boolean prevTrainingState = MxGradientCollector.setTraining(true);
        if (prevTrainingState) {
            throw new IllegalStateException("Autograd Training is already set to True. Please create autograd using try with resource ");
        }
    }

    public static boolean isRecording() {
        return JnaUtils.autogradIsRecording();
    }

    public static boolean isTraining() {
        return JnaUtils.autogradIsTraining();
    }

    public static boolean setRecording(boolean isRecording) {
        return JnaUtils.autogradSetIsRecording(isRecording);
    }

    public static boolean setTraining(boolean isTraining) {
        return JnaUtils.autogradSetTraining(isTraining);
    }

    public static Symbol getSymbol(NDManager manager, NDArray array) {
        return new Symbol((MxNDManager)manager, JnaUtils.autogradGetSymbol(array));
    }

    public void close() {
        MxGradientCollector.setRecording(false);
        MxGradientCollector.setTraining(false);
    }

    public void backward(NDArray array) {
        this.backward(array, false);
    }

    private void backward(NDArray array, boolean retainGraph) {
        JnaUtils.autogradBackward(new NDList(new NDArray[]{array}), retainGraph ? 1 : 0);
    }
}

