/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize.optimizers;

import java.io.Serializable;
import org.deeplearning4j.nn.BaseMultiLayerNetwork;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix;
import org.deeplearning4j.optimize.api.TrainingEvaluator;
import org.deeplearning4j.optimize.solvers.StochasticHessianFree;
import org.deeplearning4j.optimize.solvers.VectorizedDeepLearningGradientAscent;
import org.deeplearning4j.optimize.solvers.VectorizedNonZeroStoppingConjugateGradient;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BackPropOptimizer
implements Serializable,
OptimizableByGradientValueMatrix {
    private BaseMultiLayerNetwork network;
    private int length = -1;
    private double lr = 0.1f;
    private int epochs = 1000;
    private static Logger log = LoggerFactory.getLogger(BackPropOptimizer.class);
    private int currentIteration = -1;

    public BackPropOptimizer(BaseMultiLayerNetwork network, double lr, int epochs) {
        this.network = network;
        this.lr = lr;
        this.epochs = epochs;
    }

    @Override
    public void setCurrentIteration(int value) {
        this.currentIteration = value;
    }

    public void optimize(TrainingEvaluator eval, int numEpochs, boolean lineSearch) {
        if (!lineSearch) {
            log.info("BEGIN BACKPROP WITH SCORE OF " + this.network.score());
            double lastEntropy = this.network.score();
            BaseMultiLayerNetwork revert = this.network.clone();
            if (this.network.isForceNumEpochs()) {
                for (int i = 0; i < this.epochs; ++i) {
                    if (i % this.network.getDefaultConfiguration().getResetAdaGradIterations() == 0) {
                        this.network.getOutputLayer().getAdaGrad().historicalGradient = null;
                    }
                    this.network.backPropStep();
                    log.info("Iteration " + i + " error " + this.network.score());
                }
            } else {
                boolean train = true;
                int count = 0;
                double changeTolerance = 1.0E-6f;
                int backPropIterations = 0;
                while (train) {
                    if (backPropIterations >= this.epochs) {
                        log.info("Backprop number of iterations max hit; converging");
                        break;
                    }
                    ++count;
                    this.network.backPropStep();
                    double entropy = this.network.score();
                    if (entropy < lastEntropy) {
                        double diff = Math.abs(entropy - lastEntropy);
                        if (diff < changeTolerance) {
                            log.info("Not enough of a change on back prop...breaking");
                            break;
                        }
                        lastEntropy = entropy;
                        log.info("New score " + lastEntropy);
                        revert = this.network.clone();
                    } else if (count >= this.epochs) {
                        log.info("Hit max number of epochs...breaking");
                        train = false;
                    } else if (entropy >= lastEntropy) {
                        train = false;
                        this.network.update(revert);
                        log.info("Reverting to best score " + lastEntropy);
                    }
                    ++backPropIterations;
                }
            }
        } else {
            NeuralNetwork.OptimizationAlgorithm optimizationAlgorithm = this.network.getDefaultConfiguration().getOptimizationAlgo();
            if (optimizationAlgorithm == NeuralNetwork.OptimizationAlgorithm.CONJUGATE_GRADIENT) {
                VectorizedNonZeroStoppingConjugateGradient g = new VectorizedNonZeroStoppingConjugateGradient(this);
                g.setTrainingEvaluator(eval);
                g.setMaxIterations(numEpochs);
                g.optimize(numEpochs);
            } else if (optimizationAlgorithm == NeuralNetwork.OptimizationAlgorithm.HESSIAN_FREE) {
                StochasticHessianFree s = new StochasticHessianFree(this, this.network);
                s.setTrainingEvaluator(eval);
                s.setMaxIterations(numEpochs);
                s.optimize(numEpochs);
            } else {
                VectorizedDeepLearningGradientAscent g = new VectorizedDeepLearningGradientAscent(this);
                g.setTrainingEvaluator(eval);
                g.optimize(numEpochs);
            }
        }
    }

    @Override
    public double getValue() {
        return -this.network.score();
    }

    @Override
    public int getNumParameters() {
        if (this.length < 0) {
            this.length = this.getParameters().length();
        }
        return this.length;
    }

    @Override
    public void setParameter(int index, double value) {
    }

    @Override
    public INDArray getParameters() {
        return this.network.params();
    }

    @Override
    public double getParameter(int index) {
        return 0.0;
    }

    @Override
    public void setParameters(INDArray params) {
        this.network.setParameters(params);
        this.network.getOutputLayer().trainTillConvergence(this.lr, this.epochs);
    }

    @Override
    public INDArray getValueGradient(int iteration) {
        return this.network.getBackPropGradient2().getFirst();
    }
}

