/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.params;

import java.util.LinkedHashMap;
import java.util.Map;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.Distributions;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class GravesLSTMParamInitializer
implements ParamInitializer {
    public static final String RECURRENT_WEIGHT_KEY = "RW";
    public static final String BIAS_KEY = "b";
    public static final String INPUT_WEIGHT_KEY = "W";

    @Override
    public int numParams(NeuralNetConfiguration conf, boolean backprop) {
        GravesLSTM layerConf = (GravesLSTM)conf.getLayer();
        int nL = layerConf.getNOut();
        int nLast = layerConf.getNIn();
        int nParams = nLast * (4 * nL) + nL * (4 * nL + 3) + 4 * nL;
        return nParams;
    }

    @Override
    public void init(Map<String, INDArray> params, NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) {
        GravesLSTM layerConf = (GravesLSTM)conf.getLayer();
        double forgetGateInit = layerConf.getForgetGateBiasInit();
        Distribution dist = Distributions.createDistribution(layerConf.getDist());
        int nL = layerConf.getNOut();
        int nLast = layerConf.getNIn();
        conf.addVariable(INPUT_WEIGHT_KEY);
        conf.addVariable(RECURRENT_WEIGHT_KEY);
        conf.addVariable(BIAS_KEY);
        int length = this.numParams(conf, true);
        if (paramsView.length() != length) {
            throw new IllegalStateException("Expected params view of length " + length + ", got length " + paramsView.length());
        }
        int nParamsIn = nLast * (4 * nL);
        int nParamsRecurrent = nL * (4 * nL + 3);
        int nBias = 4 * nL;
        INDArray inputWeightView = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)0, (int)nParamsIn)});
        INDArray recurrentWeightView = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)nParamsIn, (int)(nParamsIn + nParamsRecurrent))});
        INDArray biasView = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)(nParamsIn + nParamsRecurrent), (int)(nParamsIn + nParamsRecurrent + nBias))});
        if (initializeParams) {
            params.put(INPUT_WEIGHT_KEY, WeightInitUtil.initWeights(nLast, 4 * nL, layerConf.getWeightInit(), dist, inputWeightView));
            params.put(RECURRENT_WEIGHT_KEY, WeightInitUtil.initWeights(nL, 4 * nL + 3, layerConf.getWeightInit(), dist, recurrentWeightView));
            biasView.put(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)nL, (int)(2 * nL))}, Nd4j.ones((int)1, (int)nL).muli((Number)forgetGateInit));
            params.put(BIAS_KEY, biasView);
        } else {
            params.put(INPUT_WEIGHT_KEY, WeightInitUtil.reshapeWeights(new int[]{nLast, 4 * nL}, inputWeightView));
            params.put(RECURRENT_WEIGHT_KEY, WeightInitUtil.reshapeWeights(new int[]{nL, 4 * nL + 3}, recurrentWeightView));
            params.put(BIAS_KEY, biasView);
        }
    }

    @Override
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) {
        GravesLSTM layerConf = (GravesLSTM)conf.getLayer();
        int nL = layerConf.getNOut();
        int nLast = layerConf.getNIn();
        int length = this.numParams(conf, true);
        if (gradientView.length() != length) {
            throw new IllegalStateException("Expected gradient view of length " + length + ", got length " + gradientView.length());
        }
        int nParamsIn = nLast * (4 * nL);
        int nParamsRecurrent = nL * (4 * nL + 3);
        int nBias = 4 * nL;
        INDArray inputWeightGradView = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)0, (int)nParamsIn)}).reshape('f', nLast, 4 * nL);
        INDArray recurrentWeightGradView = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)nParamsIn, (int)(nParamsIn + nParamsRecurrent))}).reshape('f', nL, 4 * nL + 3);
        INDArray biasGradView = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)(nParamsIn + nParamsRecurrent), (int)(nParamsIn + nParamsRecurrent + nBias))});
        LinkedHashMap<String, INDArray> out = new LinkedHashMap<String, INDArray>();
        out.put(INPUT_WEIGHT_KEY, inputWeightGradView);
        out.put(RECURRENT_WEIGHT_KEY, recurrentWeightGradView);
        out.put(BIAS_KEY, biasGradView);
        return out;
    }
}

