/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize.solvers;

import java.util.Collection;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.api.TerminationCondition;
import org.deeplearning4j.optimize.solvers.BaseOptimizer;
import org.nd4j.linalg.api.ndarray.INDArray;

public class LineGradientDescent
extends BaseOptimizer {
    private static final long serialVersionUID = 6336124657542062284L;

    public LineGradientDescent(NeuralNetConfiguration conf, StepFunction stepFunction, Collection<IterationListener> iterationListeners, Model model) {
        super(conf, stepFunction, iterationListeners, model);
    }

    public LineGradientDescent(NeuralNetConfiguration conf, StepFunction stepFunction, Collection<IterationListener> iterationListeners, Collection<TerminationCondition> terminationConditions, Model model) {
        super(conf, stepFunction, iterationListeners, terminationConditions, model);
    }

    @Override
    public void preProcessLine() {
        INDArray gradient = (INDArray)this.searchState.get("g");
        this.searchState.put("searchDirection", gradient.dup());
    }

    @Override
    public void postStep(INDArray gradient) {
        int[] nArray = new int[]{Integer.MAX_VALUE};
        double norm2 = gradient.norm2(nArray).getDouble(0);
        if (norm2 > this.stepMax) {
            this.searchState.put("searchDirection", gradient.dup().muli((Number)(this.stepMax / norm2)));
        } else {
            this.searchState.put("searchDirection", gradient.dup());
        }
        this.searchState.put("g", gradient.dup());
    }
}

