/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.params;

import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.Distributions;
import org.deeplearning4j.nn.conf.layers.BaseOutputLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.EmbeddingLayer;
import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class DefaultParamInitializer
implements ParamInitializer {
    private static final DefaultParamInitializer INSTANCE = new DefaultParamInitializer();
    public static final String WEIGHT_KEY = "W";
    public static final String BIAS_KEY = "b";

    public static DefaultParamInitializer getInstance() {
        return INSTANCE;
    }

    @Override
    public long numParams(NeuralNetConfiguration conf) {
        return this.numParams(conf.getLayer());
    }

    @Override
    public long numParams(Layer l) {
        FeedForwardLayer layerConf = (FeedForwardLayer)l;
        long nIn = layerConf.getNIn();
        long nOut = layerConf.getNOut();
        return nIn * nOut + (this.hasBias(l) ? nOut : 0L);
    }

    @Override
    public List<String> paramKeys(Layer layer) {
        if (this.hasBias(layer)) {
            return Arrays.asList(WEIGHT_KEY, BIAS_KEY);
        }
        return this.weightKeys(layer);
    }

    @Override
    public List<String> weightKeys(Layer layer) {
        return Collections.singletonList(WEIGHT_KEY);
    }

    @Override
    public List<String> biasKeys(Layer layer) {
        if (this.hasBias(layer)) {
            return Collections.singletonList(BIAS_KEY);
        }
        return Collections.emptyList();
    }

    @Override
    public boolean isWeightParam(Layer layer, String key) {
        return WEIGHT_KEY.equals(key);
    }

    @Override
    public boolean isBiasParam(Layer layer, String key) {
        return BIAS_KEY.equals(key);
    }

    @Override
    public Map<String, INDArray> init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) {
        if (!(conf.getLayer() instanceof FeedForwardLayer)) {
            throw new IllegalArgumentException("unsupported layer type: " + conf.getLayer().getClass().getName());
        }
        Map<String, INDArray> params = Collections.synchronizedMap(new LinkedHashMap());
        long length = this.numParams(conf);
        if (paramsView.length() != length) {
            throw new IllegalStateException("Expected params view of length " + length + ", got length " + paramsView.length());
        }
        FeedForwardLayer layerConf = (FeedForwardLayer)conf.getLayer();
        long nIn = layerConf.getNIn();
        long nOut = layerConf.getNOut();
        long nWeightParams = nIn * nOut;
        INDArray weightView = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((long)0L), NDArrayIndex.interval((long)0L, (long)nWeightParams)});
        params.put(WEIGHT_KEY, this.createWeightMatrix(conf, weightView, initializeParams));
        conf.addVariable(WEIGHT_KEY);
        if (this.hasBias(layerConf)) {
            INDArray biasView = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((long)0L), NDArrayIndex.interval((long)nWeightParams, (long)(nWeightParams + nOut))});
            params.put(BIAS_KEY, this.createBias(conf, biasView, initializeParams));
            conf.addVariable(BIAS_KEY);
        }
        return params;
    }

    @Override
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) {
        FeedForwardLayer layerConf = (FeedForwardLayer)conf.getLayer();
        long nIn = layerConf.getNIn();
        long nOut = layerConf.getNOut();
        long nWeightParams = nIn * nOut;
        INDArray weightGradientView = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((long)0L), NDArrayIndex.interval((long)0L, (long)nWeightParams)}).reshape('f', new long[]{nIn, nOut});
        LinkedHashMap<String, INDArray> out = new LinkedHashMap<String, INDArray>();
        out.put(WEIGHT_KEY, weightGradientView);
        if (this.hasBias(layerConf)) {
            INDArray biasView = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((long)0L), NDArrayIndex.interval((long)nWeightParams, (long)(nWeightParams + nOut))});
            out.put(BIAS_KEY, biasView);
        }
        return out;
    }

    protected INDArray createBias(NeuralNetConfiguration conf, INDArray biasParamView, boolean initializeParameters) {
        FeedForwardLayer layerConf = (FeedForwardLayer)conf.getLayer();
        return this.createBias(layerConf.getNOut(), layerConf.getBiasInit(), biasParamView, initializeParameters);
    }

    protected INDArray createBias(long nOut, double biasInit, INDArray biasParamView, boolean initializeParameters) {
        if (initializeParameters) {
            INDArray ret = Nd4j.valueArrayOf((long[])new long[]{1L, nOut}, (double)biasInit);
            biasParamView.assign(ret);
        }
        return biasParamView;
    }

    protected INDArray createWeightMatrix(NeuralNetConfiguration conf, INDArray weightParamView, boolean initializeParameters) {
        FeedForwardLayer layerConf = (FeedForwardLayer)conf.getLayer();
        if (initializeParameters) {
            Distribution dist = Distributions.createDistribution(layerConf.getDist());
            return this.createWeightMatrix(layerConf.getNIn(), layerConf.getNOut(), layerConf.getWeightInit(), dist, weightParamView, true);
        }
        return this.createWeightMatrix(layerConf.getNIn(), layerConf.getNOut(), null, null, weightParamView, false);
    }

    protected INDArray createWeightMatrix(long nIn, long nOut, WeightInit weightInit, Distribution dist, INDArray weightParamView, boolean initializeParameters) {
        long[] shape = new long[]{nIn, nOut};
        if (initializeParameters) {
            INDArray ret = WeightInitUtil.initWeights((double)nIn, (double)nOut, shape, weightInit, dist, weightParamView);
            return ret;
        }
        return WeightInitUtil.reshapeWeights(shape, weightParamView);
    }

    protected boolean hasBias(Layer layer) {
        if (layer instanceof BaseOutputLayer) {
            return ((BaseOutputLayer)layer).hasBias();
        }
        if (layer instanceof DenseLayer) {
            return ((DenseLayer)layer).hasBias();
        }
        if (layer instanceof EmbeddingLayer) {
            return ((EmbeddingLayer)layer).hasBias();
        }
        if (layer instanceof EmbeddingSequenceLayer) {
            return ((EmbeddingSequenceLayer)layer).hasBias();
        }
        return true;
    }
}

