/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize.solvers;

import java.util.Collection;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.api.TerminationCondition;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.solvers.BaseOptimizer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class LineGradientDescent
extends BaseOptimizer {
    private static final long serialVersionUID = 6336124657542062284L;

    public LineGradientDescent(NeuralNetConfiguration conf, StepFunction stepFunction, Collection<TrainingListener> trainingListeners, Model model) {
        super(conf, stepFunction, trainingListeners, model);
    }

    public LineGradientDescent(NeuralNetConfiguration conf, StepFunction stepFunction, Collection<TrainingListener> trainingListeners, Collection<TerminationCondition> terminationConditions, Model model) {
        super(conf, stepFunction, trainingListeners, terminationConditions, model);
    }

    @Override
    public void preProcessLine() {
        INDArray gradient = (INDArray)this.searchState.get("g");
        this.searchState.put("searchDirection", gradient.dup());
    }

    @Override
    public void postStep(INDArray gradient) {
        double norm2 = Nd4j.getBlasWrapper().level1().nrm2(gradient);
        if (norm2 > this.stepMax) {
            this.searchState.put("searchDirection", gradient.dup().muli((Number)(this.stepMax / norm2)));
        } else {
            this.searchState.put("searchDirection", gradient.dup());
        }
        this.searchState.put("g", gradient.dup());
    }
}

