/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.params;

import com.google.common.primitives.Ints;
import java.util.Map;
import org.canova.api.conf.Configuration;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.Distributions;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.factory.Nd4j;

public class ConvolutionParamInitializer
implements ParamInitializer {
    public static final String WEIGHT_KEY = "W";
    public static final String BIAS_KEY = "b";

    @Override
    public void init(Map<String, INDArray> params, NeuralNetConfiguration conf) {
        if (((ConvolutionLayer)conf.getLayer()).getKernelSize().length < 2) {
            throw new IllegalArgumentException("Filter size must be == 2");
        }
        params.put(BIAS_KEY, this.createBias(conf));
        params.put(WEIGHT_KEY, this.createWeightMatrix(conf));
        conf.addVariable(WEIGHT_KEY);
        conf.addVariable(BIAS_KEY);
    }

    @Override
    public void init(Map<String, INDArray> params, NeuralNetConfiguration conf, Configuration extraConf) {
        this.init(params, conf);
    }

    protected INDArray createBias(NeuralNetConfiguration conf) {
        ConvolutionLayer layerConf = (ConvolutionLayer)conf.getLayer();
        return Nd4j.valueArrayOf((int)layerConf.getNOut(), (double)layerConf.getBiasInit());
    }

    protected INDArray createWeightMatrix(NeuralNetConfiguration conf) {
        ConvolutionLayer layerConf = (ConvolutionLayer)conf.getLayer();
        Distribution dist = Distributions.createDistribution(conf.getLayer().getDist());
        return WeightInitUtil.initWeights(Ints.concat((int[][])new int[][]{{layerConf.getNOut(), layerConf.getNIn()}, layerConf.getKernelSize()}), layerConf.getWeightInit(), dist);
    }
}

