/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize.solvers;

import java.util.Collection;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.api.TerminationCondition;
import org.deeplearning4j.optimize.solvers.BaseOptimizer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ConjugateGradient
extends BaseOptimizer {
    private static final long serialVersionUID = -1269296013474864091L;
    private static final Logger logger = LoggerFactory.getLogger(ConjugateGradient.class);

    public ConjugateGradient(NeuralNetConfiguration conf, StepFunction stepFunction, Collection<IterationListener> iterationListeners, Model model) {
        super(conf, stepFunction, iterationListeners, model);
    }

    public ConjugateGradient(NeuralNetConfiguration conf, StepFunction stepFunction, Collection<IterationListener> iterationListeners, Collection<TerminationCondition> terminationConditions, Model model) {
        super(conf, stepFunction, iterationListeners, terminationConditions, model);
    }

    @Override
    public void preProcessLine() {
        INDArray gradient = (INDArray)this.searchState.get("g");
        INDArray searchDir = (INDArray)this.searchState.get("searchDirection");
        if (searchDir == null) {
            this.searchState.put("searchDirection", gradient);
        } else {
            searchDir.assign(gradient);
        }
    }

    @Override
    public void postStep(INDArray gradient) {
        INDArray gLast = (INDArray)this.searchState.get("g");
        INDArray searchDirLast = (INDArray)this.searchState.get("searchDirection");
        double dgg = Nd4j.getBlasWrapper().dot(gradient.sub(gLast), gradient);
        double gg = Nd4j.getBlasWrapper().dot(gLast, gLast);
        double gamma = Math.max(dgg / gg, 0.0);
        if (dgg <= 0.0) {
            logger.debug("Polak-Ribiere gamma <= 0.0; using gamma=0.0 -> SGD line search. dgg={}, gg={}", (Object)dgg, (Object)gg);
        }
        INDArray searchDir = searchDirLast.muli((Number)gamma).addi(gradient);
        this.searchState.put("g", gradient);
        this.searchState.put("searchDirection", searchDir);
    }
}

