/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.weights;

import org.apache.commons.math3.util.FastMath;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.factory.Nd4j;

public class WeightInitUtil {
    public static INDArray normalized(int[] shape, int nIn) {
        return Nd4j.rand((int[])shape).subi((Number)0.5).divi((Number)nIn);
    }

    public static INDArray uniformBasedOnInAndOut(int[] shape, int nIn, int nOut) {
        double min = -4.0 * Math.sqrt(6.0 / (double)(nOut + nIn));
        double max = 4.0 * Math.sqrt(6.0 / (double)(nOut + nIn));
        return Nd4j.rand((int[])shape, (Distribution)Nd4j.getDistributions().createUniform(min, max));
    }

    public static INDArray initWeights(int[] shape, float min, float max) {
        return Nd4j.rand((int[])shape, (double)min, (double)max, (Random)Nd4j.getRandom());
    }

    public static INDArray initWeights(int[] shape, WeightInit initScheme, Distribution dist) {
        switch (initScheme) {
            case NORMALIZED: {
                INDArray ret = Nd4j.rand((int[])shape, (Random)Nd4j.getRandom());
                return ret.subi((Number)0.5).divi((Number)shape[0]);
            }
            case XAVIER: {
                INDArray ret = Nd4j.randn((int[])shape).divi((Number)FastMath.sqrt((double)(shape[0] + shape[1])));
                return ret;
            }
            case UNIFORM: {
                double a = 1.0 / (double)shape[0];
                return Nd4j.rand((int[])shape, (double)(-a), (double)a, (Random)Nd4j.getRandom());
            }
            case VI: {
                INDArray ret = Nd4j.rand((int[])shape, (Random)Nd4j.getRandom());
                int len = 0;
                for (int aShape : shape) {
                    len += aShape;
                }
                double r = Math.sqrt(6.0) / Math.sqrt(len + 1);
                ret.muli((Number)2).muli((Number)r).subi((Number)r);
                return ret;
            }
            case DISTRIBUTION: {
                INDArray ret = dist.sample(shape);
                return ret;
            }
            case SIZE: {
                return WeightInitUtil.uniformBasedOnInAndOut(shape, shape[0], shape[1]);
            }
            case ZERO: {
                return Nd4j.create((int[])shape);
            }
        }
        throw new IllegalStateException("Illegal weight init value");
    }

    public static INDArray initWeights(int nIn, int nOut, WeightInit initScheme, Distribution dist) {
        return WeightInitUtil.initWeights(new int[]{nIn, nOut}, initScheme, dist);
    }
}

