/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize.solvers;

import org.deeplearning4j.exception.InvalidStepException;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix;
import org.deeplearning4j.optimize.api.TrainingEvaluator;
import org.deeplearning4j.optimize.solvers.VectorizedBackTrackLineSearch;
import org.deeplearning4j.util.OptimizerMatrix;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class VectorizedDeepLearningGradientAscent
implements OptimizerMatrix {
    private IterationListener listener;
    boolean converged = false;
    OptimizableByGradientValueMatrix optimizable;
    private double maxStep = 1.0;
    static final double initialStepSize = (double)0.2f;
    double tolerance = 1.0E-5f;
    int maxIterations = 200;
    VectorizedBackTrackLineSearch lineMaximizer;
    double stpmax = 100.0;
    private static Logger logger = LoggerFactory.getLogger(VectorizedDeepLearningGradientAscent.class);
    final double eps = 1.0E-10f;
    double step = 0.2f;
    TrainingEvaluator eval;

    public VectorizedDeepLearningGradientAscent(OptimizableByGradientValueMatrix function, double initialStepSize) {
        this.optimizable = function;
        this.lineMaximizer = new VectorizedBackTrackLineSearch(function);
        this.lineMaximizer.setAbsTolx(this.tolerance);
    }

    public VectorizedDeepLearningGradientAscent(OptimizableByGradientValueMatrix function, IterationListener listener) {
        this(function, 0.01f);
        this.listener = listener;
    }

    public VectorizedDeepLearningGradientAscent(OptimizableByGradientValueMatrix function, double initialStepSize, IterationListener listener) {
        this(function, initialStepSize);
        this.listener = listener;
    }

    public VectorizedDeepLearningGradientAscent(OptimizableByGradientValueMatrix function) {
        this(function, 0.01f);
    }

    @Override
    public void setMaxIterations(int maxIterations) {
        this.maxIterations = maxIterations;
    }

    public OptimizableByGradientValueMatrix getOptimizable() {
        return this.optimizable;
    }

    @Override
    public boolean isConverged() {
        return this.converged;
    }

    public VectorizedBackTrackLineSearch getLineMaximizer() {
        return this.lineMaximizer;
    }

    @Override
    public void setTolerance(double tolerance) {
        this.tolerance = tolerance;
    }

    public double getInitialStepSize() {
        return 0.2f;
    }

    public void setInitialStepSize(double initialStepSize) {
        this.step = initialStepSize;
    }

    public double getStpmax() {
        return this.stpmax;
    }

    public void setStpmax(double stpmax) {
        this.stpmax = stpmax;
    }

    @Override
    public boolean optimize() {
        return this.optimize(this.maxIterations);
    }

    @Override
    public boolean optimize(int numIterations) {
        double fp = this.optimizable.getValue();
        INDArray xi = this.optimizable.getValueGradient(0);
        for (int iterations = 0; iterations < numIterations; ++iterations) {
            logger.info("At iteration " + iterations + ", cost = " + fp + ", scaled = " + this.maxStep + " step = " + this.step + ", gradient infty-norm = " + xi.normmax(Integer.MAX_VALUE));
            boolean calledEpochDone = false;
            this.optimizable.setCurrentIteration(iterations);
            double sum = xi.norm2(Integer.MAX_VALUE).getDouble(0);
            if (sum > this.stpmax) {
                logger.info("*** Step 2-norm " + sum + " greater than max " + this.stpmax + "  Scaling...");
                xi.muli((Number)(this.stpmax / sum));
            }
            try {
                this.step = this.lineMaximizer.optimize(xi, iterations, this.step);
            }
            catch (InvalidStepException e) {
                logger.warn("Error during computation", (Throwable)e);
                continue;
            }
            double fret = this.optimizable.getValue();
            if (2.0 * Math.abs(fret - fp) <= this.tolerance * (Math.abs(fret) + Math.abs(fp) + (double)1.0E-10f)) {
                logger.info("Gradient Ascent: Value difference " + Math.abs(fret - fp) + " below " + "tolerance; saying converged.");
                this.converged = true;
                if (this.listener != null) {
                    this.listener.iterationDone(iterations);
                    calledEpochDone = true;
                }
                return true;
            }
            fp = fret;
            xi = this.optimizable.getValueGradient(iterations);
            if (this.listener != null && !calledEpochDone) {
                this.listener.iterationDone(iterations);
            }
            if (this.eval == null || !this.eval.shouldStop(iterations)) continue;
            return true;
        }
        return false;
    }

    public void setMaxStepSize(double v) {
        this.maxStep = v;
    }

    @Override
    public void setTrainingEvaluator(TrainingEvaluator eval) {
        this.eval = eval;
    }
}

