/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn;

import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.distributions.Distributions;
import org.deeplearning4j.nn.WeightInit;
import org.nd4j.linalg.api.activation.ActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class WeightInitUtil {
    public static INDArray uniformBasedOnInAndOut(int[] shape, int nIn, int nOut) {
        double min = -4.0 * Math.sqrt(6.0 / (double)(nOut + nIn));
        double max = 4.0 * Math.sqrt(6.0 / (double)(nOut + nIn));
        return Nd4j.rand((int[])shape, (RealDistribution)Distributions.uniform((RandomGenerator)new MersenneTwister(123), min, max));
    }

    public static INDArray initWeights(int[] shape, float min, float max) {
        return Nd4j.rand((int[])shape, (double)min, (double)max, (RandomGenerator)new MersenneTwister(123));
    }

    public static INDArray initWeights(int nIn, int nOut, WeightInit initScheme, ActivationFunction act, RealDistribution dist) {
        INDArray ret = Nd4j.randn((int)nIn, (int)nOut);
        switch (initScheme) {
            case VI: {
                double r = Math.sqrt(6.0) / Math.sqrt(nIn + nOut + 1);
                ret.muli((Number)2).muli((Number)r).subi((Number)r);
                return ret;
            }
            case DISTRIBUTION: {
                for (int i = 0; i < ret.rows(); ++i) {
                    ret.putRow(i, Nd4j.create((double[])dist.sample(ret.columns())));
                }
                return ret;
            }
            case SIZE: {
                return WeightInitUtil.uniformBasedOnInAndOut(new int[]{nIn, nOut}, nIn, nOut);
            }
            case ZERO: {
                return Nd4j.create((int[])new int[]{nIn, nOut});
            }
        }
        throw new IllegalStateException("Illegal weight init value");
    }
}

