/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.recurrent;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Parameter;
import ai.djl.nn.recurrent.RecurrentBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import ai.djl.util.Preconditions;

public class LSTM
extends RecurrentBlock {
    private boolean clipLstmState;
    private double lstmStateClipMin;
    private double lstmStateClipMax;
    private NDArray beginStateCell;

    LSTM(Builder builder) {
        super(builder);
        this.mode = "lstm";
        this.gates = 4;
        this.clipLstmState = builder.clipLstmState;
        this.lstmStateClipMin = builder.lstmStateClipMin;
        this.lstmStateClipMax = builder.lstmStateClipMax;
    }

    @Override
    public NDList forward(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        inputs = this.opInputs(parameterStore, inputs);
        NDArrayEx ex = inputs.head().getNDArrayInternal();
        NDList output = this.clipLstmState ? ex.lstm(inputs, this.stateSize, this.dropRate, this.numStackedLayers, this.useSequenceLength, this.isBidirectional(), true, this.lstmStateClipMin, this.lstmStateClipMax, params) : ex.rnn(inputs, this.mode, this.stateSize, this.dropRate, this.numStackedLayers, this.useSequenceLength, this.isBidirectional(), true, params);
        NDList result = new NDList(output.head().transpose(1, 0, 2));
        if (this.stateOutputs) {
            result.add(output.get(1));
            result.add(output.get(2));
        }
        this.resetBeginStates();
        return result;
    }

    @Override
    public void setBeginStates(NDList beginStates) {
        this.beginState = (NDArray)beginStates.get(0);
        this.beginStateCell = (NDArray)beginStates.get(1);
    }

    @Override
    protected void resetBeginStates() {
        this.beginState = null;
        this.beginStateCell = null;
    }

    @Override
    protected NDList opInputs(ParameterStore parameterStore, NDList inputs) {
        this.validateInputSize(inputs);
        long batchSize = inputs.head().getShape().get(0);
        inputs = this.updateInputLayoutToTNC(inputs);
        NDArray head = inputs.singletonOrThrow();
        Device device = head.getDevice();
        NDList result = new NDList(head);
        try (NDList parameterList = new NDList();){
            for (Parameter parameter : this.parameters.values()) {
                NDArray array = parameterStore.getValue(parameter, device);
                parameterList.add(array.flatten());
            }
            NDArray array = NDArrays.concat(parameterList);
            result.add(array);
        }
        Shape stateShape = new Shape(this.numStackedLayers * this.numDirections, batchSize, this.stateSize);
        if (this.beginState != null) {
            result.add(this.beginState);
            result.add(this.beginStateCell);
        } else {
            result.add(head.getManager().zeros(stateShape, DataType.FLOAT32, device));
            result.add(head.getManager().zeros(stateShape, DataType.FLOAT32, device));
        }
        if (this.useSequenceLength) {
            result.add(inputs.get(1));
        }
        return result;
    }

    public static Builder builder() {
        return new Builder();
    }

    public static final class Builder
    extends RecurrentBlock.BaseBuilder<Builder> {
        @Override
        protected Builder self() {
            return this;
        }

        public Builder optLstmStateClipMin(float lstmStateClipMin, float lstmStateClipMax) {
            this.lstmStateClipMin = lstmStateClipMin;
            this.lstmStateClipMax = lstmStateClipMax;
            this.clipLstmState = true;
            return this.self();
        }

        public LSTM build() {
            Preconditions.checkArgument(this.stateSize > 0L && this.numStackedLayers > 0, "Must set stateSize and numStackedLayers");
            return new LSTM(this);
        }
    }
}

