/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize.solvers;

import com.google.common.base.Function;
import org.apache.commons.math3.util.FastMath;
import org.deeplearning4j.exception.InvalidStepException;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.stepfunctions.NegativeGradientStepFunction;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.LineOptimizer;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.stepfunctions.NegativeDefaultStepFunction;
import org.nd4j.linalg.api.blas.Level1;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarSetValue;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.indexing.functions.Value;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BackTrackLineSearch
implements LineOptimizer {
    private static final Logger log = LoggerFactory.getLogger(BackTrackLineSearch.class);
    private Model layer;
    private StepFunction stepFunction;
    private ConvexOptimizer optimizer;
    private int maxIterations;
    double stepMax = 100.0;
    private boolean minObjectiveFunction = true;
    private double relTolx = 1.0E-7f;
    private double absTolx = 1.0E-4f;
    protected final double ALF = 1.0E-4f;

    public BackTrackLineSearch(Model layer, StepFunction stepFunction, ConvexOptimizer optimizer) {
        this.layer = layer;
        this.stepFunction = stepFunction;
        this.optimizer = optimizer;
        this.maxIterations = layer.conf().getMaxNumLineSearchIterations();
    }

    public BackTrackLineSearch(Model optimizable, ConvexOptimizer optimizer) {
        this(optimizable, new NegativeDefaultStepFunction(), optimizer);
        log.debug("Objective function automatically set to minimize. Set stepFunction in neural net configuration to change default settings.");
    }

    public void setStepMax(double stepMax) {
        this.stepMax = stepMax;
    }

    public double getStepMax() {
        return this.stepMax;
    }

    public void setRelTolx(double tolx) {
        this.relTolx = tolx;
    }

    public void setAbsTolx(double tolx) {
        this.absTolx = tolx;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setMaxIterations(int maxIterations) {
        this.maxIterations = maxIterations;
    }

    public double setScoreFor(INDArray parameters) {
        if (Nd4j.ENFORCE_NUMERICAL_STABILITY) {
            BooleanIndexing.applyWhere((INDArray)parameters, (Condition)Conditions.isNan(), (Function)new Value((Number)Nd4j.EPS_THRESHOLD));
        }
        this.layer.setParams(parameters);
        this.layer.computeGradientAndScore();
        return this.layer.score();
    }

    @Override
    public double optimize(INDArray parameters, INDArray gradients, INDArray searchDirection) throws InvalidStepException {
        double score;
        double scoreAtStart;
        this.minObjectiveFunction = this.stepFunction instanceof NegativeDefaultStepFunction || this.stepFunction instanceof NegativeGradientStepFunction;
        Level1 l1Blas = Nd4j.getBlasWrapper().level1();
        double sum = l1Blas.nrm2(searchDirection);
        double slope = -1.0 * Nd4j.getBlasWrapper().dot(searchDirection, gradients);
        log.debug("slope = {}", (Object)slope);
        INDArray maxOldParams = Transforms.abs((INDArray)parameters);
        Nd4j.getExecutioner().exec((Op)new ScalarSetValue(maxOldParams, (Number)1));
        INDArray testMatrix = Transforms.abs((INDArray)gradients).divi(maxOldParams);
        double test = testMatrix.max(new int[]{Integer.MAX_VALUE}).getDouble(0);
        double step = 1.0;
        double stepMin = this.relTolx / test;
        double oldStep = 0.0;
        double step2 = 0.0;
        double score2 = scoreAtStart = this.layer.score();
        double bestScore = score = scoreAtStart;
        double bestStepSize = 1.0;
        if (log.isTraceEnabled()) {
            double norm1 = l1Blas.asum(searchDirection);
            int infNormIdx = l1Blas.iamax(searchDirection);
            double infNorm = FastMath.max((double)Double.NEGATIVE_INFINITY, (double)searchDirection.getDouble(infNormIdx));
            log.trace("ENTERING BACKTRACK\n");
            log.trace("Entering BackTrackLineSearch, value = " + scoreAtStart + ",\ndirection.oneNorm:" + norm1 + "  direction.infNorm:" + infNorm);
        }
        if (sum > this.stepMax) {
            log.warn("Attempted step too big. scaling: sum= {}, stepMax= {}", (Object)sum, (Object)this.stepMax);
            searchDirection.muli((Number)(this.stepMax / sum));
        }
        INDArray candidateParameters = null;
        for (int iteration = 0; iteration < this.maxIterations; ++iteration) {
            double disc;
            double b;
            double a;
            double step2Squared;
            double rhs2;
            double rhs1;
            double tmpStep;
            if (log.isTraceEnabled()) {
                log.trace("BackTrack loop iteration {} : step={}, oldStep={}", new Object[]{iteration, step, oldStep});
                log.trace("before step, x.1norm: {} \nstep: {} \noldStep: {}", new Object[]{parameters.norm1(new int[]{Integer.MAX_VALUE}), step, oldStep});
            }
            if (step == oldStep) {
                throw new IllegalArgumentException("Current step == oldStep");
            }
            candidateParameters = parameters.dup('f');
            this.stepFunction.step(candidateParameters, searchDirection, step);
            oldStep = step;
            if (log.isTraceEnabled()) {
                double norm1 = l1Blas.asum(candidateParameters);
                log.trace("after step, x.1norm: " + norm1);
            }
            if (step < stepMin || Nd4j.getExecutioner().execAndReturn((TransformOp)new Eps(parameters, candidateParameters, Shape.toOffsetZeroCopy((INDArray)candidateParameters, (char)'f'), (long)candidateParameters.length())).sum(new int[]{Integer.MAX_VALUE}).getDouble(0) == (double)candidateParameters.length()) {
                score = this.setScoreFor(parameters);
                log.debug("EXITING BACKTRACK: Jump too small (stepMin = {}). Exiting and using original params. Score = {}", (Object)stepMin, (Object)score);
                return 0.0;
            }
            score = this.setScoreFor(candidateParameters);
            log.debug("Model score after step = {}", (Object)score);
            if (this.minObjectiveFunction && score < bestScore) {
                bestScore = score;
                bestStepSize = step;
            } else if (!this.minObjectiveFunction && score > bestScore) {
                bestScore = score;
                bestStepSize = step;
            }
            if (this.minObjectiveFunction && score <= scoreAtStart + (double)1.0E-4f * step * slope) {
                log.debug("Sufficient decrease (Wolfe cond.), exiting backtrack on iter {}: score={}, scoreAtStart={}", new Object[]{iteration, score, scoreAtStart});
                if (score > scoreAtStart) {
                    throw new IllegalStateException("Function did not decrease: score = " + score + " > " + scoreAtStart + " = oldScore");
                }
                return step;
            }
            if (!this.minObjectiveFunction && score >= scoreAtStart + (double)1.0E-4f * step * slope) {
                log.debug("Sufficient increase (Wolfe cond.), exiting backtrack on iter {}: score={}, bestScore={}", new Object[]{iteration, score, scoreAtStart});
                if (score < scoreAtStart) {
                    throw new IllegalStateException("Function did not increase: score = " + score + " < " + scoreAtStart + " = scoreAtStart");
                }
                return step;
            }
            if (Double.isInfinite(score) || Double.isInfinite(score2) || Double.isNaN(score) || Double.isNaN(score2)) {
                log.warn("Value is infinite after jump. oldStep={}. score={}, score2={}. Scaling back step size...", new Object[]{oldStep, score, score2});
                tmpStep = 0.2 * step;
                if (step < stepMin) {
                    score = this.setScoreFor(parameters);
                    log.warn("EXITING BACKTRACK: Jump too small (step={} < stepMin={}). Exiting and using previous parameters. Value={}", new Object[]{step, stepMin, score});
                    return 0.0;
                }
            } else if (this.minObjectiveFunction) {
                if (step == 1.0) {
                    tmpStep = -slope / (2.0 * (score - scoreAtStart - slope));
                } else {
                    rhs1 = score - scoreAtStart - step * slope;
                    rhs2 = score2 - scoreAtStart - step2 * slope;
                    if (step == step2) {
                        throw new IllegalStateException("FAILURE: dividing by step-step2 which equals 0. step=" + step);
                    }
                    double stepSquared = step * step;
                    step2Squared = step2 * step2;
                    a = (rhs1 / stepSquared - rhs2 / step2Squared) / (step - step2);
                    b = (-step2 * rhs1 / stepSquared + step * rhs2 / step2Squared) / (step - step2);
                    tmpStep = a == 0.0 ? -slope / (2.0 * b) : ((disc = b * b - 3.0 * a * slope) < 0.0 ? 0.5 * step : (b <= 0.0 ? (-b + FastMath.sqrt((double)disc)) / (3.0 * a) : -slope / (b + FastMath.sqrt((double)disc))));
                    if (tmpStep > 0.5 * step) {
                        tmpStep = 0.5 * step;
                    }
                }
            } else if (step == 1.0) {
                tmpStep = -slope / (2.0 * (scoreAtStart - score - slope));
            } else {
                rhs1 = scoreAtStart - score - step * slope;
                rhs2 = scoreAtStart - score2 - step2 * slope;
                if (step == step2) {
                    throw new IllegalStateException("FAILURE: dividing by step-step2 which equals 0. step=" + step);
                }
                double stepSquared = step * step;
                step2Squared = step2 * step2;
                a = (rhs1 / stepSquared - rhs2 / step2Squared) / (step - step2);
                b = (-step2 * rhs1 / stepSquared + step * rhs2 / step2Squared) / (step - step2);
                tmpStep = a == 0.0 ? -slope / (2.0 * b) : ((disc = b * b - 3.0 * a * slope) < 0.0 ? 0.5 * step : (b <= 0.0 ? (-b + FastMath.sqrt((double)disc)) / (3.0 * a) : -slope / (b + FastMath.sqrt((double)disc))));
                if (tmpStep > 0.5 * step) {
                    tmpStep = 0.5 * step;
                }
            }
            step2 = step;
            score2 = score;
            log.debug("tmpStep: {}", (Object)tmpStep);
            step = Math.max(tmpStep, (double)0.1f * step);
        }
        if (this.minObjectiveFunction && bestScore < scoreAtStart) {
            log.debug("Exited line search after maxIterations termination condition; bestStepSize={}, bestScore={}, scoreAtStart={}", new Object[]{bestStepSize, bestScore, scoreAtStart});
            return bestStepSize;
        }
        if (!this.minObjectiveFunction && bestScore > scoreAtStart) {
            log.debug("Exited line search after maxIterations termination condition; bestStepSize={}, bestScore={}, scoreAtStart={}", new Object[]{bestStepSize, bestScore, scoreAtStart});
            return bestStepSize;
        }
        log.debug("Exited line search after maxIterations termination condition; score did not improve (bestScore={}, scoreAtStart={}). Resetting parameters", (Object)bestScore, (Object)scoreAtStart);
        this.setScoreFor(parameters);
        return 0.0;
    }
}

