/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.weights;

import org.apache.commons.math3.util.FastMath;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.factory.Nd4j;

public class WeightInitUtil {
    public static final char DEFAULT_WEIGHT_INIT_ORDER = 'f';

    public static INDArray uniformBasedOnInAndOut(int[] shape, int nIn, int nOut) {
        double min = -4.0 * Math.sqrt(6.0 / (double)(nOut + nIn));
        double max = 4.0 * Math.sqrt(6.0 / (double)(nOut + nIn));
        return Nd4j.rand((int[])shape, (Distribution)Nd4j.getDistributions().createUniform(min, max));
    }

    public static INDArray initWeights(int[] shape, float min, float max) {
        return Nd4j.rand((int[])shape, (double)min, (double)max, (Random)Nd4j.getRandom());
    }

    public static INDArray initWeights(int[] shape, WeightInit initScheme, Distribution dist, INDArray paramView) {
        return WeightInitUtil.initWeights(shape, initScheme, dist, 'f', paramView);
    }

    public static INDArray initWeights(int[] shape, WeightInit initScheme, Distribution dist, char order, INDArray paramView) {
        INDArray ret;
        switch (initScheme) {
            case DISTRIBUTION: {
                ret = dist.sample(shape);
                break;
            }
            case NORMALIZED: {
                ret = Nd4j.rand((char)order, (int[])shape);
                ret.subi((Number)0.5).divi((Number)shape[0]);
                break;
            }
            case RELU: {
                ret = Nd4j.randn((char)order, (int[])shape).muli((Number)FastMath.sqrt((double)(2.0 / (double)shape[0])));
                break;
            }
            case SIZE: {
                ret = WeightInitUtil.uniformBasedOnInAndOut(shape, shape[0], shape[1]);
                break;
            }
            case UNIFORM: {
                double a = 1.0 / (double)shape[0];
                ret = Nd4j.rand((char)order, (int[])shape).muli((Number)(2.0 * a)).subi((Number)a);
                break;
            }
            case VI: {
                ret = Nd4j.rand((char)order, (int[])shape);
                int len = 0;
                for (int aShape : shape) {
                    len += aShape;
                }
                double r = Math.sqrt(6.0) / Math.sqrt(len + 1);
                ret.muli((Number)(2.0 * r)).subi((Number)r);
                break;
            }
            case XAVIER: {
                ret = Nd4j.randn((char)order, (int[])shape).divi((Number)FastMath.sqrt((double)(shape[0] + shape[1])));
                break;
            }
            case ZERO: {
                ret = Nd4j.create((int[])shape, (char)order);
                break;
            }
            default: {
                throw new IllegalStateException("Illegal weight init value: " + (Object)((Object)initScheme));
            }
        }
        INDArray flat = Nd4j.toFlattened((char)order, (INDArray[])new INDArray[]{ret});
        if (flat.length() != paramView.length()) {
            throw new RuntimeException("ParamView length does not match initialized weights length");
        }
        paramView.assign(flat);
        return paramView.reshape(order, shape);
    }

    public static INDArray initWeights(int nIn, int nOut, WeightInit initScheme, Distribution dist, INDArray paramView) {
        return WeightInitUtil.initWeights(new int[]{nIn, nOut}, initScheme, dist, paramView);
    }
}

