/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.factory.ops;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRU;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights;
import org.nd4j.linalg.factory.NDValidation;
import org.nd4j.linalg.factory.Nd4j;

public class NDRNN {
    public INDArray gru(INDArray x, INDArray hLast, INDArray Wx, INDArray Wh, INDArray biases) {
        NDValidation.validateNumerical("gru", "x", x);
        NDValidation.validateNumerical("gru", "hLast", hLast);
        NDValidation.validateNumerical("gru", "Wx", Wx);
        NDValidation.validateNumerical("gru", "Wh", Wh);
        NDValidation.validateNumerical("gru", "biases", biases);
        return Nd4j.exec(new GRU(x, hLast, Wx, Wh, biases))[0];
    }

    public INDArray[] gruCell(INDArray x, INDArray hLast, GRUWeights GRUWeights2) {
        NDValidation.validateNumerical("gruCell", "x", x);
        NDValidation.validateNumerical("gruCell", "hLast", hLast);
        return Nd4j.exec(new GRUCell(x, hLast, GRUWeights2));
    }

    public INDArray[] lstmCell(INDArray x, INDArray cLast, INDArray yLast, LSTMWeights LSTMWeights2, LSTMConfiguration LSTMConfiguration2) {
        NDValidation.validateNumerical("lstmCell", "x", x);
        NDValidation.validateNumerical("lstmCell", "cLast", cLast);
        NDValidation.validateNumerical("lstmCell", "yLast", yLast);
        return Nd4j.exec(new LSTMBlockCell(x, cLast, yLast, LSTMWeights2, LSTMConfiguration2));
    }

    public INDArray[] lstmLayer(INDArray x, INDArray cLast, INDArray yLast, INDArray maxTSLength, LSTMLayerWeights LSTMLayerWeights2, LSTMLayerConfig LSTMLayerConfig2) {
        NDValidation.validateNumerical("lstmLayer", "x", x);
        NDValidation.validateNumerical("lstmLayer", "cLast", cLast);
        NDValidation.validateNumerical("lstmLayer", "yLast", yLast);
        NDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength);
        return Nd4j.exec(new LSTMLayer(x, cLast, yLast, maxTSLength, LSTMLayerWeights2, LSTMLayerConfig2));
    }

    public INDArray[] lstmLayer(INDArray x, LSTMLayerWeights LSTMLayerWeights2, LSTMLayerConfig LSTMLayerConfig2) {
        NDValidation.validateNumerical("lstmLayer", "x", x);
        return Nd4j.exec(new LSTMLayer(x, null, null, null, LSTMLayerWeights2, LSTMLayerConfig2));
    }

    public INDArray lstmblock(INDArray maxTSLength, INDArray x, INDArray cLast, INDArray yLast, LSTMWeights LSTMWeights2, LSTMConfiguration LSTMConfiguration2) {
        NDValidation.validateNumerical("lstmblock", "maxTSLength", maxTSLength);
        NDValidation.validateNumerical("lstmblock", "x", x);
        NDValidation.validateNumerical("lstmblock", "cLast", cLast);
        NDValidation.validateNumerical("lstmblock", "yLast", yLast);
        return Nd4j.exec(new LSTMBlock(maxTSLength, x, cLast, yLast, LSTMWeights2, LSTMConfiguration2))[0];
    }

    public INDArray lstmblock(INDArray x, LSTMWeights LSTMWeights2, LSTMConfiguration LSTMConfiguration2) {
        NDValidation.validateNumerical("lstmblock", "x", x);
        return Nd4j.exec(new LSTMBlock(null, x, null, null, LSTMWeights2, LSTMConfiguration2))[0];
    }

    public INDArray sru(INDArray x, INDArray initialC, INDArray mask, SRUWeights SRUWeights2) {
        NDValidation.validateNumerical("sru", "x", x);
        NDValidation.validateNumerical("sru", "initialC", initialC);
        NDValidation.validateNumerical("sru", "mask", mask);
        return Nd4j.exec(new SRU(x, initialC, mask, SRUWeights2))[0];
    }

    public INDArray sru(INDArray x, INDArray initialC, SRUWeights SRUWeights2) {
        NDValidation.validateNumerical("sru", "x", x);
        NDValidation.validateNumerical("sru", "initialC", initialC);
        return Nd4j.exec(new SRU(x, initialC, null, SRUWeights2))[0];
    }

    public INDArray sruCell(INDArray x, INDArray cLast, SRUWeights SRUWeights2) {
        NDValidation.validateNumerical("sruCell", "x", x);
        NDValidation.validateNumerical("sruCell", "cLast", cLast);
        return Nd4j.exec(new SRUCell(x, cLast, SRUWeights2))[0];
    }
}

