/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.recurrent;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterBlock;
import ai.djl.nn.ParameterType;
import ai.djl.nn.recurrent.RNN;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public abstract class RecurrentCell
extends ParameterBlock {
    private static final LayoutType[] EXPECTED_LAYOUT = new LayoutType[]{LayoutType.BATCH, LayoutType.TIME, LayoutType.CHANNEL};
    protected long stateSize;
    protected float dropRate;
    protected int numStackedLayers;
    protected String mode;
    protected boolean useSequenceLength;
    protected boolean useBidirectional;
    protected int gates;
    protected byte currentVersion = 1;
    protected boolean stateOutputs;
    protected Shape stateShape;
    protected List<Parameter> parameters = new ArrayList<Parameter>();

    public RecurrentCell(BaseBuilder<?> builder) {
        this.stateSize = builder.stateSize;
        this.dropRate = builder.dropRate;
        this.numStackedLayers = builder.numStackedLayers;
        this.useSequenceLength = builder.useSequenceLength;
        this.useBidirectional = builder.useBidirectional;
        this.stateOutputs = builder.stateOutputs;
        for (int i = 0; i < this.numStackedLayers; ++i) {
            this.parameters.add(new Parameter(String.format("l%d_i2h_weight", i), this, ParameterType.WEIGHT));
            this.parameters.add(new Parameter(String.format("l%d_h2h_weight", i), this, ParameterType.WEIGHT));
            this.parameters.add(new Parameter(String.format("l%d_i2h_bias", i), this, ParameterType.BIAS));
            this.parameters.add(new Parameter(String.format("l%d_h2h_bias", i), this, ParameterType.BIAS));
            if (!this.useBidirectional) continue;
            this.parameters.add(new Parameter(String.format("r%d_i2h_weight", i), this, ParameterType.WEIGHT));
            this.parameters.add(new Parameter(String.format("r%d_h2h_weight", i), this, ParameterType.WEIGHT));
            this.parameters.add(new Parameter(String.format("r%d_i2h_bias", i), this, ParameterType.BIAS));
            this.parameters.add(new Parameter(String.format("r%d_h2h_bias", i), this, ParameterType.BIAS));
        }
    }

    protected void validateInputSize(NDList inputs) {
        int numberofInputsRequired = 1;
        if (this.useSequenceLength) {
            numberofInputsRequired = 2;
        }
        if (inputs.size() != numberofInputsRequired) {
            throw new IllegalArgumentException("Invalid number of inputs for RNN. Size of input NDList must be " + numberofInputsRequired + " when useSequenceLength is " + this.useSequenceLength);
        }
    }

    @Override
    public NDList forward(ParameterStore parameterStore, NDList inputs, PairList<String, Object> params) {
        inputs = this.opInputs(parameterStore, inputs);
        NDArrayEx ex = inputs.head().getNDArrayInternal();
        NDList output = ex.rnn(inputs, this.mode, this.stateSize, this.dropRate, this.numStackedLayers, this.useSequenceLength, this.useBidirectional, this.stateOutputs, params);
        NDList result = new NDList(output.head().transpose(1, 0, 2));
        if (this.stateOutputs) {
            result.add(output.get(1));
        }
        return result;
    }

    @Override
    public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) {
        Shape inputShape = inputs[0];
        return new Shape[]{new Shape(inputShape.get(1), inputShape.get(0), this.stateSize)};
    }

    @Override
    public List<Parameter> getDirectParameters() {
        return this.parameters;
    }

    @Override
    public void beforeInitialize(Shape[] inputs) {
        this.inputShapes = inputs;
        Shape inputShape = inputs[0];
        Block.validateLayout(EXPECTED_LAYOUT, inputShape.getLayout());
        long batchSize = inputShape.get(0);
        inputs[0] = new Shape(inputShape.get(1), inputShape.get(0), inputShape.get(2));
        this.stateShape = new Shape(this.numStackedLayers, batchSize, this.stateSize);
    }

    @Override
    public Shape getParameterShape(String name, Shape[] inputShapes) {
        Shape shape = inputShapes[0];
        long inputSize = shape.get(2);
        if (name.contains("bias")) {
            return new Shape((long)this.gates * this.stateSize);
        }
        if (name.contains("i2h")) {
            return new Shape((long)this.gates * this.stateSize, inputSize);
        }
        if (name.contains("h2h")) {
            return new Shape((long)this.gates * this.stateSize, this.stateSize);
        }
        throw new IllegalArgumentException("Invalid parameter name");
    }

    @Override
    public void saveParameters(DataOutputStream os) throws IOException {
        os.writeByte(this.currentVersion);
        for (Parameter parameter : this.parameters) {
            parameter.save(os);
        }
    }

    @Override
    public void loadParameters(NDManager manager, DataInputStream is) throws IOException, MalformedModelException {
        byte version = is.readByte();
        if (version != this.currentVersion) {
            throw new MalformedModelException("Unsupported encoding version: " + version);
        }
        for (Parameter parameter : this.parameters) {
            parameter.load(manager, is);
        }
    }

    protected NDList opInputs(ParameterStore parameterStore, NDList inputs) {
        this.validateInputSize(inputs);
        inputs = this.updateInputLayoutToTNC(inputs);
        NDArray head = inputs.head();
        Device device = head.getDevice();
        NDList result = new NDList(head);
        try (NDList parameterList = new NDList();){
            for (Parameter parameter : this.parameters) {
                NDArray array = parameterStore.getValue(parameter, device);
                parameterList.add(array.flatten());
            }
            NDArray array = NDArrays.concat(parameterList);
            result.add(array);
        }
        result.add(inputs.head().getManager().zeros(this.stateShape));
        if (this.useSequenceLength) {
            result.add(inputs.get(1));
        }
        return result;
    }

    protected NDList updateInputLayoutToTNC(NDList inputs) {
        return new NDList(inputs.singletonOrThrow().transpose(1, 0, 2));
    }

    public static abstract class BaseBuilder<T extends BaseBuilder> {
        protected float dropRate;
        protected long stateSize = -1L;
        protected int numStackedLayers = -1;
        protected double lstmStateClipMin;
        protected double lstmStateClipMax;
        protected boolean clipLstmState;
        protected boolean useSequenceLength;
        protected boolean useBidirectional;
        protected boolean stateOutputs;
        protected RNN.Activation activation;

        public T optDropRate(float dropRate) {
            this.dropRate = dropRate;
            return this.self();
        }

        public T setStateSize(int stateSize) {
            this.stateSize = stateSize;
            return this.self();
        }

        public T setNumStackedLayers(int numStackedLayers) {
            this.numStackedLayers = numStackedLayers;
            return this.self();
        }

        public T setSequenceLength(boolean useSequenceLength) {
            this.useSequenceLength = useSequenceLength;
            return this.self();
        }

        public T optBidrectional(boolean useBidirectional) {
            this.useBidirectional = useBidirectional;
            return this.self();
        }

        public T optStateOutput(boolean stateOutputs) {
            this.stateOutputs = stateOutputs;
            return this.self();
        }

        protected abstract T self();
    }
}

