/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.recurrent;

import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.recurrent.BaseRecurrentLayer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.primitives.Quad;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastCopyOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNormBp;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class SimpleRnn
extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn> {
    public static final String STATE_KEY_PREV_ACTIVATION = "prevAct";

    public SimpleRnn(NeuralNetConfiguration conf, DataType dataType) {
        super(conf, dataType);
    }

    @Override
    public INDArray rnnTimeStep(INDArray input, LayerWorkspaceMgr workspaceMgr) {
        this.setInput(input, workspaceMgr);
        INDArray last = (INDArray)this.stateMap.get(STATE_KEY_PREV_ACTIVATION);
        INDArray out = (INDArray)this.activateHelper(last, false, false, workspaceMgr).getFirst();
        try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
            this.stateMap.put(STATE_KEY_PREV_ACTIVATION, out.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)(out.size(2) - 1L))}).dup());
        }
        return out;
    }

    @Override
    public INDArray rnnActivateUsingStoredState(INDArray input, boolean training, boolean storeLastForTBPTT, LayerWorkspaceMgr workspaceMgr) {
        this.setInput(input, workspaceMgr);
        INDArray last = (INDArray)this.tBpttStateMap.get(STATE_KEY_PREV_ACTIVATION);
        INDArray out = (INDArray)this.activateHelper(last, training, false, workspaceMgr).getFirst();
        if (storeLastForTBPTT) {
            try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
                this.tBpttStateMap.put(STATE_KEY_PREV_ACTIVATION, out.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)(out.size(2) - 1L))}).dup());
            }
        }
        return out;
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
        return this.tbpttBackpropGradient(epsilon, -1, workspaceMgr);
    }

    @Override
    public Pair<Gradient, INDArray> tbpttBackpropGradient(INDArray epsilon, int tbpttBackLength, LayerWorkspaceMgr workspaceMgr) {
        this.assertInputSet(true);
        if (epsilon.ordering() != 'f' || !Shape.hasDefaultStridesForShape((INDArray)epsilon)) {
            epsilon = epsilon.dup('f');
        }
        long nOut = ((org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn)this.layerConf()).getNOut();
        INDArray input = this.input.castTo(this.dataType);
        input = this.permuteIfNWC(input);
        Quad<INDArray, INDArray, INDArray, INDArray> p = this.activateHelper(null, true, true, workspaceMgr);
        INDArray w = this.getParamWithNoise("W", true, workspaceMgr);
        INDArray rw = this.getParamWithNoise("RW", true, workspaceMgr);
        INDArray b = this.getParamWithNoise("b", true, workspaceMgr);
        INDArray g = this.hasLayerNorm() ? this.getParamWithNoise("g", true, workspaceMgr) : null;
        INDArray gx = g != null ? g.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)0L, (long)nOut)}) : null;
        INDArray gr = g != null ? g.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)nOut, (long)(nOut * 2L))}) : null;
        INDArray wg = (INDArray)this.gradientViews.get("W");
        INDArray rwg = (INDArray)this.gradientViews.get("RW");
        INDArray bg = (INDArray)this.gradientViews.get("b");
        INDArray gg = this.hasLayerNorm() ? (INDArray)this.gradientViews.get("g") : null;
        INDArray gxg = gg != null ? gg.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)0L, (long)nOut)}) : null;
        INDArray grg = gg != null ? gg.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)nOut, (long)(nOut * 2L))}) : null;
        this.gradientsFlattened.assign((Number)0);
        IActivation a = ((org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn)this.layerConf()).getActivationFn();
        long tsLength = input.size(2);
        INDArray epsOut = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape(), 'f');
        INDArray dldzNext = null;
        long end = tbpttBackLength > 0 ? Math.max(0L, tsLength - (long)tbpttBackLength) : 0L;
        epsilon = this.permuteIfNWC(epsilon);
        for (long i = tsLength - 1L; i >= end; --i) {
            INDArray ggCur;
            INDArray dldnCurrent;
            INDArray dldaCurrent = epsilon.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)i)}).dup();
            INDArray aCurrent = ((INDArray)p.getFirst()).get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)i)});
            INDArray zCurrent = ((INDArray)p.getSecond()).get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)i)});
            INDArray nCurrent = this.hasLayerNorm() ? ((INDArray)p.getThird()).get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)i)}) : null;
            INDArray rCurrent = this.hasLayerNorm() ? ((INDArray)p.getFourth()).get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)i)}) : null;
            INDArray inCurrent = input.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)i)});
            INDArray epsOutCurrent = epsOut.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)i)});
            if (dldzNext != null) {
                Nd4j.gemm(dldzNext, (INDArray)rw, (INDArray)dldaCurrent, (boolean)false, (boolean)true, (double)1.0, (double)1.0);
                Nd4j.gemm((INDArray)aCurrent, dldzNext, (INDArray)rwg, (boolean)true, (boolean)false, (double)1.0, (double)1.0);
            }
            INDArray dldzCurrent = (INDArray)a.backprop(zCurrent.dup(), dldaCurrent).getFirst();
            INDArray maskCol = null;
            if (this.maskArray != null) {
                maskCol = this.maskArray.getColumn(i, true).castTo(this.dataType);
                dldzCurrent.muliColumnVector(maskCol);
            }
            if (this.hasLayerNorm()) {
                dldnCurrent = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, dldzCurrent.dataType(), dldzCurrent.shape());
                ggCur = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, gg.dataType(), gxg.shape());
                INDArray bgCur = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, bg.dataType(), bg.shape());
                Nd4j.getExecutioner().exec((CustomOp)new LayerNormBp(nCurrent, gx, b, dldzCurrent, dldnCurrent, ggCur, bgCur, true, new int[]{1}));
                gxg.addi(ggCur);
                bg.addi(bgCur);
            } else {
                dldnCurrent = dldzCurrent;
                bg.addi(dldzCurrent.sum(new int[]{0}));
            }
            Nd4j.gemm((INDArray)inCurrent, (INDArray)dldnCurrent, (INDArray)wg, (boolean)true, (boolean)false, (double)1.0, (double)1.0);
            Nd4j.gemm((INDArray)dldnCurrent, (INDArray)w, (INDArray)epsOutCurrent, (boolean)false, (boolean)true, (double)1.0, (double)0.0);
            if (this.hasLayerNorm() && i > end) {
                dldzNext = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, dldzCurrent.dataType(), dldzCurrent.shape());
                ggCur = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, gg.dataType(), grg.shape());
                Nd4j.getExecutioner().exec((CustomOp)new LayerNormBp(rCurrent, gr, dldzCurrent, dldzNext, ggCur, true, new int[]{1}));
                grg.addi(ggCur);
            } else {
                dldzNext = dldzCurrent;
            }
            if (this.maskArray == null) continue;
            epsOutCurrent.muliColumnVector(maskCol);
        }
        this.weightNoiseParams.clear();
        DefaultGradient grad = new DefaultGradient(this.gradientsFlattened);
        grad.gradientForVariable().put("W", wg);
        grad.gradientForVariable().put("RW", rwg);
        grad.gradientForVariable().put("b", bg);
        if (this.hasLayerNorm()) {
            grad.gradientForVariable().put("g", gg);
        }
        epsOut = this.backpropDropOutIfPresent(epsOut);
        epsOut = this.permuteIfNWC(epsOut);
        return new Pair((Object)grad, (Object)epsOut);
    }

    @Override
    public boolean isPretrainLayer() {
        return false;
    }

    @Override
    public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
        return (INDArray)this.activateHelper(null, training, false, workspaceMgr).getFirst();
    }

    private Quad<INDArray, INDArray, INDArray, INDArray> activateHelper(INDArray prevStepOut, boolean training, boolean forBackprop, LayerWorkspaceMgr workspaceMgr) {
        INDArray recPreNorm;
        this.assertInputSet(false);
        Preconditions.checkState((this.input.rank() == 3 ? 1 : 0) != 0, (String)("3D input expected to RNN layer expected, got " + this.input.rank()));
        Preconditions.checkState((prevStepOut == null || prevStepOut.size(0) == this.input.size(0) ? 1 : 0) != 0, (String)"Invalid RNN previous state (last time step activations/initialization): rnnTimeStep with different minibatch size, or forgot to call rnnClearPreviousState between batches? Previous step output = [batch, nIn] = %ndShape, current input = [batch, nIn, seqLength] = %ndShape", (Object)prevStepOut, (Object)this.input);
        this.applyDropOutIfNecessary(training, workspaceMgr);
        INDArray input = this.input.castTo(this.dataType);
        input = this.permuteIfNWC(input);
        long m = input.size(0);
        long tsLength = input.size(2);
        long nOut = ((org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn)this.layerConf()).getNOut();
        INDArray w = this.getParamWithNoise("W", training, workspaceMgr);
        INDArray rw = this.getParamWithNoise("RW", training, workspaceMgr);
        INDArray b = this.getParamWithNoise("b", training, workspaceMgr);
        INDArray g = this.hasLayerNorm() ? this.getParamWithNoise("g", training, workspaceMgr) : null;
        INDArray gx = g != null ? g.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)0L, (long)nOut)}) : null;
        INDArray gr = g != null ? g.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)nOut, (long)(nOut * 2L))}) : null;
        INDArray out = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, w.dataType(), new long[]{m, nOut, tsLength}, 'f');
        INDArray outZ = forBackprop ? workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, w.dataType(), out.shape()) : null;
        INDArray outPreNorm = forBackprop && this.hasLayerNorm() ? workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, w.dataType(), out.shape(), 'f') : null;
        INDArray iNDArray = recPreNorm = forBackprop && this.hasLayerNorm() ? workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, w.dataType(), out.shape(), 'f') : null;
        if (input.ordering() != 'f' || Shape.strideDescendingCAscendingF((INDArray)input)) {
            input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'f');
        }
        if (!this.hasLayerNorm()) {
            Nd4j.getExecutioner().exec((BroadcastOp)new BroadcastCopyOp(out, b, out, new int[]{1}));
        }
        IActivation a = ((org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn)this.layerConf()).getActivationFn();
        int i = 0;
        while ((long)i < tsLength) {
            INDArray currOut = out.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)i)});
            INDArray currIn = input.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)i)});
            if (this.hasLayerNorm()) {
                INDArray currOutPreNorm = (forBackprop ? outPreNorm : out).get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)i)});
                Nd4j.gemm((INDArray)currIn, (INDArray)w, (INDArray)currOutPreNorm, (boolean)false, (boolean)false, (double)1.0, (double)0.0);
                Nd4j.getExecutioner().exec((CustomOp)new LayerNorm(currOutPreNorm, gx, b, currOut, true, new int[]{1}));
            } else {
                Nd4j.gemm((INDArray)currIn, (INDArray)w, (INDArray)currOut, (boolean)false, (boolean)false, (double)1.0, (double)1.0);
            }
            if (i > 0 || prevStepOut != null) {
                if (this.hasLayerNorm()) {
                    INDArray currRecPreNorm = forBackprop ? recPreNorm.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)i)}) : workspaceMgr.createUninitialized(ArrayType.FF_WORKING_MEM, currOut.dataType(), currOut.shape(), 'f');
                    Nd4j.gemm((INDArray)prevStepOut, (INDArray)rw, (INDArray)currRecPreNorm, (boolean)false, (boolean)false, (double)1.0, (double)0.0);
                    INDArray recNorm = workspaceMgr.createUninitialized(ArrayType.FF_WORKING_MEM, currOut.dataType(), currOut.shape(), 'f');
                    Nd4j.getExecutioner().exec((CustomOp)new LayerNorm(currRecPreNorm, gr, recNorm, true, new int[]{1}));
                    currOut.addi(recNorm);
                } else {
                    Nd4j.gemm((INDArray)prevStepOut, (INDArray)rw, (INDArray)currOut, (boolean)false, (boolean)false, (double)1.0, (double)1.0);
                }
            }
            if (forBackprop) {
                outZ.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)i)}).assign(currOut);
            }
            a.getActivation(currOut, training);
            if (this.maskArray != null) {
                INDArray maskCol = this.maskArray.getColumn((long)i, true).castTo(this.dataType);
                currOut.muliColumnVector(maskCol);
            }
            prevStepOut = currOut;
            ++i;
        }
        if (this.maskArray != null) {
            INDArray mask = this.maskArray.castTo(this.dataType);
            Nd4j.getExecutioner().exec((BroadcastOp)new BroadcastMulOp(out, mask, out, new int[]{0, 2}));
            if (forBackprop) {
                Nd4j.getExecutioner().exec((BroadcastOp)new BroadcastMulOp(outZ, mask, outZ, new int[]{0, 2}));
            }
        }
        if (!forBackprop) {
            out = this.permuteIfNWC(out);
            outZ = this.permuteIfNWC(outZ);
            outPreNorm = this.permuteIfNWC(outPreNorm);
            recPreNorm = this.permuteIfNWC(recPreNorm);
        }
        return new Quad((Object)out, (Object)outZ, (Object)outPreNorm, (Object)recPreNorm);
    }

    @Override
    public boolean hasLayerNorm() {
        return ((org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn)this.layerConf()).hasLayerNorm();
    }
}

