/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.gradientcheck;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseOutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.UpdaterCreator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GradientCheckUtil {
    private static Logger log = LoggerFactory.getLogger(GradientCheckUtil.class);

    public static boolean checkGradients(MultiLayerNetwork mln, double epsilon, double maxRelError, boolean print, boolean exitOnFirstError, INDArray input, INDArray labels, boolean useUpdater) {
        if (epsilon <= 0.0 || epsilon > 0.1) {
            throw new IllegalArgumentException("Invalid epsilon: expect epsilon in range (0,0.1], usually 1e-4 or so");
        }
        if (maxRelError <= 0.0 || maxRelError > 0.25) {
            throw new IllegalArgumentException("Invalid maxRelativeError: " + maxRelError);
        }
        if (!(mln.getOutputLayer() instanceof BaseOutputLayer)) {
            throw new IllegalArgumentException("Cannot check backprop gradients without OutputLayer");
        }
        mln.setInput(input);
        mln.setLabels(labels);
        mln.computeGradientAndScore();
        Pair<Gradient, Double> gradAndScore = mln.gradientAndScore();
        if (useUpdater) {
            Updater updater = UpdaterCreator.getUpdater(mln);
            updater.update(mln, gradAndScore.getFirst(), 0);
        }
        INDArray gradientToCheck = gradAndScore.getFirst().gradient();
        INDArray originalParams = mln.params();
        int nParams = mln.numParams();
        int totalNFailures = 0;
        double maxError = 0.0;
        for (int i = 0; i < nParams; ++i) {
            INDArray params = originalParams.dup();
            params.putScalar(i, params.getDouble(i) + epsilon);
            mln.setParameters(params);
            mln.computeGradientAndScore();
            double scorePlus = mln.score();
            params.putScalar(i, params.getDouble(i) - 2.0 * epsilon);
            mln.setParameters(params);
            mln.computeGradientAndScore();
            double scoreMinus = mln.score();
            double scoreDelta = scorePlus - scoreMinus;
            double numericalGradient = scoreDelta / (2.0 * epsilon);
            if (Double.isNaN(numericalGradient)) {
                throw new IllegalStateException("Numerical gradient was NaN for parameter " + i + " of " + nParams);
            }
            double backpropGradient = gradientToCheck.getDouble(i);
            double relError = Math.abs(backpropGradient - numericalGradient) / (Math.abs(numericalGradient) + Math.abs(backpropGradient));
            if (backpropGradient == 0.0 && numericalGradient == 0.0) {
                relError = 0.0;
            }
            if (relError > maxError) {
                maxError = relError;
            }
            if (relError > maxRelError || Double.isNaN(relError)) {
                if (print) {
                    log.info("Param " + i + " FAILED: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError + ", scorePlus=" + scorePlus + ", scoreMinus= " + scoreMinus);
                }
                if (exitOnFirstError) {
                    return false;
                }
                ++totalNFailures;
                continue;
            }
            if (!print) continue;
            log.info("Param " + i + " passed: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError);
        }
        if (print) {
            int nPass = nParams - totalNFailures;
            log.info("GradientCheckUtil.checkGradients(): " + nParams + " params checked, " + nPass + " passed, " + totalNFailures + " failed. Largest relative error = " + maxError);
        }
        return totalNFailures == 0;
    }
}

