/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.recurrent;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.recurrent.BaseRecurrentLayer;
import org.deeplearning4j.util.Dropout;
import org.nd4j.linalg.api.blas.Level1;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;

public class GravesLSTM
extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.GravesLSTM> {
    public static final String STATE_KEY_PREV_ACTIVATION = "prevAct";
    public static final String STATE_KEY_PREV_MEMCELL = "prevMem";

    public GravesLSTM(NeuralNetConfiguration conf) {
        super(conf);
    }

    public GravesLSTM(NeuralNetConfiguration conf, INDArray input) {
        super(conf, input);
    }

    @Override
    public Gradient gradient() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public Gradient calcGradient(Gradient layerError, INDArray activation) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon) {
        return this.backpropGradientHelper(epsilon, false, -1);
    }

    @Override
    public Pair<Gradient, INDArray> tbpttBackpropGradient(INDArray epsilon, int tbpttBackwardLength) {
        return this.backpropGradientHelper(epsilon, true, tbpttBackwardLength);
    }

    private Pair<Gradient, INDArray> backpropGradientHelper(INDArray epsilon, boolean truncatedBPTT, int tbpttBackwardLength) {
        int i;
        FwdPassReturn fwdPass;
        if (truncatedBPTT) {
            fwdPass = this.activateHelper(true, (INDArray)this.stateMap.get(STATE_KEY_PREV_ACTIVATION), (INDArray)this.stateMap.get(STATE_KEY_PREV_MEMCELL), true);
            this.tBpttStateMap.put(STATE_KEY_PREV_ACTIVATION, fwdPass.lastAct);
            this.tBpttStateMap.put(STATE_KEY_PREV_MEMCELL, fwdPass.lastMemCell);
        } else {
            fwdPass = this.activateHelper(true, null, null, true);
        }
        INDArray inputWeights = this.getParam("W");
        INDArray recurrentWeights = this.getParam("RW");
        int hiddenLayerSize = recurrentWeights.size(0);
        int prevLayerSize = inputWeights.size(0);
        int miniBatchSize = epsilon.size(0);
        boolean is2dInput = epsilon.rank() < 3;
        int timeSeriesLength = is2dInput ? 1 : epsilon.size(2);
        INDArray wi = fwdPass.paramsMmulCompatible[0];
        INDArray wI = fwdPass.paramsMmulCompatible[1];
        INDArray wf = fwdPass.paramsMmulCompatible[2];
        INDArray wF = fwdPass.paramsMmulCompatible[3];
        INDArray wo = fwdPass.paramsMmulCompatible[5];
        INDArray wO = fwdPass.paramsMmulCompatible[6];
        INDArray wg = fwdPass.paramsMmulCompatible[8];
        INDArray wG = fwdPass.paramsMmulCompatible[9];
        INDArray wFFTranspose = fwdPass.paramsMmulCompatible[4];
        INDArray wOOTranspose = fwdPass.paramsMmulCompatible[7];
        INDArray wGGTranspose = fwdPass.paramsMmulCompatible[10];
        INDArray[] bGradients = new INDArray[4];
        INDArray[] iwGradients = new INDArray[4];
        INDArray[] rwGradients = new INDArray[7];
        for (i = 0; i < 4; ++i) {
            bGradients[i] = Nd4j.create((int[])new int[]{1, hiddenLayerSize});
            iwGradients[i] = Nd4j.create((int[])new int[]{prevLayerSize, hiddenLayerSize}, (char)'f');
            rwGradients[i] = Nd4j.create((int[])new int[]{hiddenLayerSize, hiddenLayerSize}, (char)'f');
        }
        for (i = 0; i < 3; ++i) {
            rwGradients[i + 4] = Nd4j.zeros((int)1, (int)hiddenLayerSize);
        }
        INDArray epsilonNext = Nd4j.zeros((int[])new int[]{miniBatchSize, prevLayerSize, timeSeriesLength});
        INDArray nablaCellStateNext = null;
        INDArray deltaiNext = null;
        INDArray deltafNext = null;
        INDArray deltaoNext = null;
        INDArray deltagNext = null;
        Level1 l1BLAS = Nd4j.getBlasWrapper().level1();
        int endIdx = 0;
        if (truncatedBPTT) {
            endIdx = Math.max(0, timeSeriesLength - tbpttBackwardLength);
        }
        for (int t = timeSeriesLength - 1; t >= endIdx; --t) {
            INDArray prevMemCellState = t == 0 ? null : fwdPass.memCellState[t - 1];
            INDArray prevHiddenUnitActivation = t == 0 ? null : fwdPass.fwdPassOutputAsArrays[t - 1];
            INDArray currMemCellState = fwdPass.memCellState[t];
            INDArray epsilonSlice = is2dInput ? epsilon : epsilon.tensorAlongDimension(t, new int[]{1, 0});
            INDArray nablaOut = Shape.toOffsetZeroCopy((INDArray)epsilonSlice, (char)'f');
            if (t != timeSeriesLength - 1) {
                Nd4j.gemm(deltaiNext, (INDArray)wI, (INDArray)nablaOut, (boolean)false, (boolean)true, (double)1.0, (double)1.0);
                Nd4j.gemm(deltafNext, (INDArray)wF, (INDArray)nablaOut, (boolean)false, (boolean)true, (double)1.0, (double)1.0);
                Nd4j.gemm(deltaoNext, (INDArray)wO, (INDArray)nablaOut, (boolean)false, (boolean)true, (double)1.0, (double)1.0);
                Nd4j.gemm(deltagNext, (INDArray)wG, (INDArray)nablaOut, (boolean)false, (boolean)true, (double)1.0, (double)1.0);
            }
            INDArray sigmahOfS = fwdPass.memCellActivations[t];
            INDArray ao = fwdPass.oa[t];
            INDArray sigmaoPrimeOfZo = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("timesoneminus", ao.dup('f')));
            INDArray deltao = nablaOut.dup('f').muli(sigmahOfS).muli(sigmaoPrimeOfZo);
            INDArray sigmahPrimeOfS = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), currMemCellState.dup('f')).derivative());
            INDArray nablaCellState = ao.muli(nablaOut).muli(sigmahPrimeOfS);
            INDArray deltaMulRowWOO = deltao.dup('f').muliRowVector(wOOTranspose);
            l1BLAS.axpy(nablaCellState.length(), 1.0, deltaMulRowWOO, nablaCellState);
            if (t != timeSeriesLength - 1) {
                INDArray nextForgetGateAs = fwdPass.fa[t + 1];
                int length = nablaCellState.length();
                l1BLAS.axpy(length, 1.0, nextForgetGateAs.muli(nablaCellStateNext), nablaCellState);
                l1BLAS.axpy(length, 1.0, deltafNext.dup('f').muliRowVector(wFFTranspose), nablaCellState);
                l1BLAS.axpy(length, 1.0, deltagNext.dup('f').muliRowVector(wGGTranspose), nablaCellState);
            }
            nablaCellStateNext = nablaCellState;
            INDArray af = fwdPass.fa[t];
            INDArray deltaf = null;
            if (t > 0) {
                deltaf = nablaCellState.dup('f').muli(prevMemCellState).muli(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("timesoneminus", af.dup('f'))));
            }
            INDArray ag = fwdPass.ga[t];
            INDArray ai = fwdPass.ia[t];
            INDArray deltag = ai.muli(nablaCellState).muli(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("timesoneminus", ag.dup('f'))));
            INDArray zi = fwdPass.iz[t];
            INDArray deltai = ag.muli(nablaCellState).muli(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), zi).derivative()));
            INDArray prevLayerActivationSlice = Shape.toMmulCompatible((INDArray)(is2dInput ? this.input : this.input.tensorAlongDimension(t, new int[]{1, 0})));
            Nd4j.gemm((INDArray)prevLayerActivationSlice, (INDArray)deltai, (INDArray)iwGradients[0], (boolean)true, (boolean)false, (double)1.0, (double)1.0);
            if (t > 0) {
                Nd4j.gemm((INDArray)prevLayerActivationSlice, (INDArray)deltaf, (INDArray)iwGradients[1], (boolean)true, (boolean)false, (double)1.0, (double)1.0);
            }
            Nd4j.gemm((INDArray)prevLayerActivationSlice, (INDArray)deltao, (INDArray)iwGradients[2], (boolean)true, (boolean)false, (double)1.0, (double)1.0);
            Nd4j.gemm((INDArray)prevLayerActivationSlice, (INDArray)deltag, (INDArray)iwGradients[3], (boolean)true, (boolean)false, (double)1.0, (double)1.0);
            if (t > 0) {
                Nd4j.gemm((INDArray)prevHiddenUnitActivation, (INDArray)deltai, (INDArray)rwGradients[0], (boolean)true, (boolean)false, (double)1.0, (double)1.0);
                Nd4j.gemm((INDArray)prevHiddenUnitActivation, (INDArray)deltaf, (INDArray)rwGradients[1], (boolean)true, (boolean)false, (double)1.0, (double)1.0);
                Nd4j.gemm((INDArray)prevHiddenUnitActivation, (INDArray)deltao, (INDArray)rwGradients[2], (boolean)true, (boolean)false, (double)1.0, (double)1.0);
                Nd4j.gemm((INDArray)prevHiddenUnitActivation, (INDArray)deltag, (INDArray)rwGradients[3], (boolean)true, (boolean)false, (double)1.0, (double)1.0);
                INDArray dLdwFF = deltaf.dup('f').muli(prevMemCellState).sum(new int[]{0});
                l1BLAS.axpy(rwGradients[4].length(), 1.0, dLdwFF, rwGradients[4]);
                INDArray dLdwGG = deltag.dup('f').muli(prevMemCellState).sum(new int[]{0});
                l1BLAS.axpy(rwGradients[6].length(), 1.0, dLdwGG, rwGradients[6]);
            }
            INDArray dLdwOO = deltao.dup('f').muli(currMemCellState).sum(new int[]{0});
            l1BLAS.axpy(rwGradients[5].length(), 1.0, dLdwOO, rwGradients[5]);
            l1BLAS.axpy(bGradients[0].length(), 1.0, deltai.sum(new int[]{0}), bGradients[0]);
            if (t > 0) {
                l1BLAS.axpy(bGradients[1].length(), 1.0, deltaf.sum(new int[]{0}), bGradients[1]);
            }
            l1BLAS.axpy(bGradients[2].length(), 1.0, deltao.sum(new int[]{0}), bGradients[2]);
            l1BLAS.axpy(bGradients[3].length(), 1.0, deltag.sum(new int[]{0}), bGradients[3]);
            INDArray epsilonNextSlice = Nd4j.gemm((INDArray)deltai, (INDArray)wi, (boolean)false, (boolean)true);
            Nd4j.gemm((INDArray)deltao, (INDArray)wo, (INDArray)epsilonNextSlice, (boolean)false, (boolean)true, (double)1.0, (double)1.0);
            Nd4j.gemm((INDArray)deltag, (INDArray)wg, (INDArray)epsilonNextSlice, (boolean)false, (boolean)true, (double)1.0, (double)1.0);
            if (t > 0) {
                Nd4j.gemm((INDArray)deltaf, (INDArray)wf, (INDArray)epsilonNextSlice, (boolean)false, (boolean)true, (double)1.0, (double)1.0);
            }
            epsilonNext.tensorAlongDimension(t, new int[]{1, 0}).assign(epsilonNextSlice);
            deltaiNext = deltai;
            deltafNext = deltaf;
            deltaoNext = deltao;
            deltagNext = deltag;
        }
        INDArray iwGradientsOut = Nd4j.zeros((int)prevLayerSize, (int)(4 * hiddenLayerSize));
        INDArray rwGradientsOut = Nd4j.zeros((int)hiddenLayerSize, (int)(4 * hiddenLayerSize + 3));
        INDArray bGradientsOut = Nd4j.hstack((INDArray[])bGradients);
        iwGradientsOut.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)hiddenLayerSize)}, iwGradients[0]);
        iwGradientsOut.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)hiddenLayerSize, (int)(2 * hiddenLayerSize))}, iwGradients[1]);
        iwGradientsOut.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(2 * hiddenLayerSize), (int)(3 * hiddenLayerSize))}, iwGradients[2]);
        iwGradientsOut.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(3 * hiddenLayerSize), (int)(4 * hiddenLayerSize))}, iwGradients[3]);
        rwGradientsOut.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)hiddenLayerSize)}, rwGradients[0]);
        rwGradientsOut.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)hiddenLayerSize, (int)(2 * hiddenLayerSize))}, rwGradients[1]);
        rwGradientsOut.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(2 * hiddenLayerSize), (int)(3 * hiddenLayerSize))}, rwGradients[2]);
        rwGradientsOut.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(3 * hiddenLayerSize), (int)(4 * hiddenLayerSize))}, rwGradients[3]);
        rwGradientsOut.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.point((int)(4 * hiddenLayerSize))}, rwGradients[4].transpose());
        rwGradientsOut.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.point((int)(4 * hiddenLayerSize + 1))}, rwGradients[5].transpose());
        rwGradientsOut.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.point((int)(4 * hiddenLayerSize + 2))}, rwGradients[6].transpose());
        DefaultGradient retGradient = new DefaultGradient();
        retGradient.gradientForVariable().put("W", iwGradientsOut);
        retGradient.gradientForVariable().put("RW", rwGradientsOut);
        retGradient.gradientForVariable().put("b", bGradientsOut);
        return new Pair<Gradient, INDArray>(retGradient, epsilonNext);
    }

    @Override
    public INDArray preOutput(INDArray x) {
        return this.activate(x, true);
    }

    @Override
    public INDArray preOutput(INDArray x, boolean training) {
        return this.activate(x, training);
    }

    @Override
    public INDArray activate(INDArray input, boolean training) {
        this.setInput(input, training);
        return this.activateHelper(training, null, null, false).fwdPassOutput;
    }

    @Override
    public INDArray activate(INDArray input) {
        this.setInput(input);
        return this.activateHelper(true, null, null, false).fwdPassOutput;
    }

    @Override
    public INDArray activate(boolean training) {
        return this.activateHelper(training, null, null, false).fwdPassOutput;
    }

    @Override
    public INDArray activate() {
        return this.activateHelper(false, null, null, false).fwdPassOutput;
    }

    private FwdPassReturn activateHelper(boolean training, INDArray prevOutputActivations, INDArray prevMemCellState, boolean forBackprop) {
        if (this.input == null || this.input.length() == 0) {
            throw new IllegalArgumentException("Invalid input: not set or 0 length");
        }
        INDArray recurrentWeights = this.getParam("RW");
        INDArray inputWeights = this.getParam("W");
        INDArray biases = this.getParam("b");
        boolean is2dInput = this.input.rank() < 3;
        int timeSeriesLength = is2dInput ? 1 : this.input.size(2);
        int hiddenLayerSize = recurrentWeights.size(0);
        int miniBatchSize = this.input.size(0);
        if (this.conf.isUseDropConnect() && training && this.conf.getLayer().getDropOut() > 0.0) {
            inputWeights = Dropout.applyDropConnect(this, "W");
        }
        INDArray wi = inputWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)hiddenLayerSize)});
        INDArray wI = recurrentWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)hiddenLayerSize)});
        INDArray bi = biases.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)0, (int)hiddenLayerSize)});
        INDArray wf = inputWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)hiddenLayerSize, (int)(2 * hiddenLayerSize))});
        INDArray wF = recurrentWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)hiddenLayerSize, (int)(2 * hiddenLayerSize))});
        INDArray wFFTranspose = recurrentWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(4 * hiddenLayerSize), (int)(4 * hiddenLayerSize + 1))}).transpose();
        INDArray bf = biases.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)hiddenLayerSize, (int)(2 * hiddenLayerSize))});
        INDArray wo = inputWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(2 * hiddenLayerSize), (int)(3 * hiddenLayerSize))});
        INDArray wO = recurrentWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(2 * hiddenLayerSize), (int)(3 * hiddenLayerSize))});
        INDArray wOOTranspose = recurrentWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(4 * hiddenLayerSize + 1), (int)(4 * hiddenLayerSize + 2))}).transpose();
        INDArray bo = biases.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)(2 * hiddenLayerSize), (int)(3 * hiddenLayerSize))});
        INDArray wg = inputWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(3 * hiddenLayerSize), (int)(4 * hiddenLayerSize))});
        INDArray wG = recurrentWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(3 * hiddenLayerSize), (int)(4 * hiddenLayerSize))});
        INDArray wGGTranspose = recurrentWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(4 * hiddenLayerSize + 2), (int)(4 * hiddenLayerSize + 3))}).transpose();
        INDArray bg = biases.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)(3 * hiddenLayerSize), (int)(4 * hiddenLayerSize))});
        if (timeSeriesLength > 1 || forBackprop) {
            wi = Shape.toMmulCompatible((INDArray)wi);
            wI = Shape.toMmulCompatible((INDArray)wI);
            wf = Shape.toMmulCompatible((INDArray)wf);
            wF = Shape.toMmulCompatible((INDArray)wF);
            wFFTranspose = Shape.toMmulCompatible((INDArray)wFFTranspose);
            wo = Shape.toMmulCompatible((INDArray)wo);
            wO = Shape.toMmulCompatible((INDArray)wO);
            wOOTranspose = Shape.toMmulCompatible((INDArray)wOOTranspose);
            wg = Shape.toMmulCompatible((INDArray)wg);
            wG = Shape.toMmulCompatible((INDArray)wG);
            wGGTranspose = Shape.toMmulCompatible((INDArray)wGGTranspose);
            bi = Shape.toMmulCompatible((INDArray)bi);
            bf = Shape.toMmulCompatible((INDArray)bf);
            bo = Shape.toMmulCompatible((INDArray)bo);
            bg = Shape.toMmulCompatible((INDArray)bg);
        }
        INDArray outputActivations = null;
        FwdPassReturn toReturn = new FwdPassReturn();
        if (forBackprop) {
            FwdPassReturn.access$202(toReturn, new INDArray[]{wi, wI, wf, wF, wFFTranspose, wo, wO, wOOTranspose, wg, wG, wGGTranspose});
            FwdPassReturn.access$402(toReturn, new INDArray[timeSeriesLength]);
            FwdPassReturn.access$302(toReturn, new INDArray[timeSeriesLength]);
            FwdPassReturn.access$502(toReturn, new INDArray[timeSeriesLength]);
            FwdPassReturn.access$1002(toReturn, new INDArray[timeSeriesLength]);
            FwdPassReturn.access$902(toReturn, new INDArray[timeSeriesLength]);
            FwdPassReturn.access$702(toReturn, new INDArray[timeSeriesLength]);
            FwdPassReturn.access$602(toReturn, new INDArray[timeSeriesLength]);
            FwdPassReturn.access$802(toReturn, new INDArray[timeSeriesLength]);
        } else {
            outputActivations = Nd4j.zeros((int[])new int[]{miniBatchSize, hiddenLayerSize, timeSeriesLength});
            toReturn.fwdPassOutput = outputActivations;
        }
        Level1 l1BLAS = Nd4j.getBlasWrapper().level1();
        if (prevOutputActivations == null) {
            prevOutputActivations = Nd4j.zeros((int[])new int[]{miniBatchSize, hiddenLayerSize});
        }
        if (prevMemCellState == null) {
            prevMemCellState = Nd4j.zeros((int[])new int[]{miniBatchSize, hiddenLayerSize});
        }
        for (int t = 0; t < timeSeriesLength; ++t) {
            INDArray miniBatchData = is2dInput ? this.input : this.input.tensorAlongDimension(t, new int[]{1, 0});
            miniBatchData = Shape.toMmulCompatible((INDArray)miniBatchData);
            INDArray inputActivations = miniBatchData.mmul(wi);
            Nd4j.gemm((INDArray)prevOutputActivations, (INDArray)wI, (INDArray)inputActivations, (boolean)false, (boolean)false, (double)1.0, (double)1.0);
            inputActivations.addiRowVector(bi);
            if (forBackprop) {
                ((FwdPassReturn)toReturn).iz[t] = inputActivations.dup('f');
            }
            Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), inputActivations));
            if (forBackprop) {
                ((FwdPassReturn)toReturn).ia[t] = inputActivations;
            }
            INDArray forgetGateActivations = miniBatchData.mmul(wf);
            Nd4j.gemm((INDArray)prevOutputActivations, (INDArray)wF, (INDArray)forgetGateActivations, (boolean)false, (boolean)false, (double)1.0, (double)1.0);
            INDArray pmcellWFF = prevMemCellState.dup('f').muliRowVector(wFFTranspose);
            l1BLAS.axpy(pmcellWFF.length(), 1.0, pmcellWFF, forgetGateActivations);
            forgetGateActivations.addiRowVector(bf);
            Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", forgetGateActivations));
            if (forBackprop) {
                ((FwdPassReturn)toReturn).fa[t] = forgetGateActivations;
            }
            INDArray inputModGateActivations = miniBatchData.mmul(wg);
            Nd4j.gemm((INDArray)prevOutputActivations, (INDArray)wG, (INDArray)inputModGateActivations, (boolean)false, (boolean)false, (double)1.0, (double)1.0);
            INDArray pmcellWGG = prevMemCellState.dup('f').muliRowVector(wGGTranspose);
            l1BLAS.axpy(pmcellWGG.length(), 1.0, pmcellWGG, inputModGateActivations);
            inputModGateActivations.addiRowVector(bg);
            Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", inputModGateActivations));
            if (forBackprop) {
                ((FwdPassReturn)toReturn).ga[t] = inputModGateActivations;
            }
            INDArray currentMemoryCellState = forgetGateActivations.dup('f').muli(prevMemCellState);
            INDArray inputModMulInput = inputModGateActivations.dup('f').muli(inputActivations);
            l1BLAS.axpy(currentMemoryCellState.length(), 1.0, inputModMulInput, currentMemoryCellState);
            INDArray outputGateActivations = miniBatchData.mmul(wo);
            Nd4j.gemm((INDArray)prevOutputActivations, (INDArray)wO, (INDArray)outputGateActivations, (boolean)false, (boolean)false, (double)1.0, (double)1.0);
            INDArray pmcellWOO = currentMemoryCellState.dup('f').muliRowVector(wOOTranspose);
            l1BLAS.axpy(pmcellWOO.length(), 1.0, pmcellWOO, outputGateActivations);
            outputGateActivations.addiRowVector(bo);
            Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", outputGateActivations));
            if (forBackprop) {
                ((FwdPassReturn)toReturn).oa[t] = outputGateActivations;
            }
            INDArray currMemoryCellActivation = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), currentMemoryCellState.dup('f')));
            INDArray currHiddenUnitActivations = currMemoryCellActivation.dup('f').muli(outputGateActivations);
            if (forBackprop) {
                ((FwdPassReturn)toReturn).fwdPassOutputAsArrays[t] = currHiddenUnitActivations;
                ((FwdPassReturn)toReturn).memCellState[t] = currentMemoryCellState;
                ((FwdPassReturn)toReturn).memCellActivations[t] = currMemoryCellActivation;
            } else {
                outputActivations.tensorAlongDimension(t, new int[]{1, 0}).assign(currHiddenUnitActivations);
            }
            prevOutputActivations = currHiddenUnitActivations;
            prevMemCellState = currentMemoryCellState;
            toReturn.lastAct = currHiddenUnitActivations;
            toReturn.lastMemCell = currentMemoryCellState;
        }
        return toReturn;
    }

    @Override
    public INDArray activationMean() {
        return this.activate();
    }

    @Override
    public Layer.Type type() {
        return Layer.Type.RECURRENT;
    }

    @Override
    public Layer transpose() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public double calcL2() {
        if (!this.conf.isUseRegularization() || this.conf.getLayer().getL2() <= 0.0) {
            return 0.0;
        }
        double l2 = Transforms.pow((INDArray)this.getParam("RW"), (Number)2).sum(new int[]{Integer.MAX_VALUE}).getDouble(0) + Transforms.pow((INDArray)this.getParam("W"), (Number)2).sum(new int[]{Integer.MAX_VALUE}).getDouble(0);
        return 0.5 * this.conf.getLayer().getL2() * l2;
    }

    @Override
    public double calcL1() {
        if (!this.conf.isUseRegularization() || this.conf.getLayer().getL1() <= 0.0) {
            return 0.0;
        }
        double l1 = Transforms.abs((INDArray)this.getParam("RW")).sum(new int[]{Integer.MAX_VALUE}).getDouble(0) + Transforms.abs((INDArray)this.getParam("W")).sum(new int[]{Integer.MAX_VALUE}).getDouble(0);
        return this.conf.getLayer().getL1() * l1;
    }

    @Override
    public INDArray rnnTimeStep(INDArray input) {
        this.setInput(input);
        FwdPassReturn fwdPass = this.activateHelper(false, (INDArray)this.stateMap.get(STATE_KEY_PREV_ACTIVATION), (INDArray)this.stateMap.get(STATE_KEY_PREV_MEMCELL), false);
        INDArray outAct = fwdPass.fwdPassOutput;
        this.stateMap.put(STATE_KEY_PREV_ACTIVATION, fwdPass.lastAct);
        this.stateMap.put(STATE_KEY_PREV_MEMCELL, fwdPass.lastMemCell);
        return outAct;
    }

    @Override
    public INDArray rnnActivateUsingStoredState(INDArray input, boolean training, boolean storeLastForTBPTT) {
        this.setInput(input);
        FwdPassReturn fwdPass = this.activateHelper(training, (INDArray)this.stateMap.get(STATE_KEY_PREV_ACTIVATION), (INDArray)this.stateMap.get(STATE_KEY_PREV_MEMCELL), false);
        INDArray outAct = fwdPass.fwdPassOutput;
        if (storeLastForTBPTT) {
            this.tBpttStateMap.put(STATE_KEY_PREV_ACTIVATION, fwdPass.lastAct);
            this.tBpttStateMap.put(STATE_KEY_PREV_MEMCELL, fwdPass.lastMemCell);
        }
        return outAct;
    }

    private static class FwdPassReturn {
        private INDArray fwdPassOutput;
        private INDArray[] paramsMmulCompatible;
        private INDArray[] fwdPassOutputAsArrays;
        private INDArray[] memCellState;
        private INDArray[] memCellActivations;
        private INDArray[] iz;
        private INDArray[] ia;
        private INDArray[] fa;
        private INDArray[] oa;
        private INDArray[] ga;
        private INDArray lastAct;
        private INDArray lastMemCell;

        private FwdPassReturn() {
        }

        static /* synthetic */ INDArray[] access$202(FwdPassReturn x0, INDArray[] x1) {
            x0.paramsMmulCompatible = x1;
            return x1;
        }

        static /* synthetic */ INDArray[] access$402(FwdPassReturn x0, INDArray[] x1) {
            x0.fwdPassOutputAsArrays = x1;
            return x1;
        }

        static /* synthetic */ INDArray[] access$302(FwdPassReturn x0, INDArray[] x1) {
            x0.memCellState = x1;
            return x1;
        }

        static /* synthetic */ INDArray[] access$502(FwdPassReturn x0, INDArray[] x1) {
            x0.memCellActivations = x1;
            return x1;
        }

        static /* synthetic */ INDArray[] access$1002(FwdPassReturn x0, INDArray[] x1) {
            x0.iz = x1;
            return x1;
        }

        static /* synthetic */ INDArray[] access$902(FwdPassReturn x0, INDArray[] x1) {
            x0.ia = x1;
            return x1;
        }

        static /* synthetic */ INDArray[] access$702(FwdPassReturn x0, INDArray[] x1) {
            x0.fa = x1;
            return x1;
        }

        static /* synthetic */ INDArray[] access$602(FwdPassReturn x0, INDArray[] x1) {
            x0.oa = x1;
            return x1;
        }

        static /* synthetic */ INDArray[] access$802(FwdPassReturn x0, INDArray[] x1) {
            x0.ga = x1;
            return x1;
        }
    }
}

