/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize.solvers;

import com.google.common.base.Function;
import java.util.Collection;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.api.TerminationCondition;
import org.deeplearning4j.optimize.solvers.BaseOptimizer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.indexing.functions.Value;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.util.LinAlgExceptions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ConjugateGradient
extends BaseOptimizer {
    private static Logger logger = LoggerFactory.getLogger(ConjugateGradient.class);

    public ConjugateGradient(NeuralNetConfiguration conf, StepFunction stepFunction, Collection<IterationListener> iterationListeners, Model model) {
        super(conf, stepFunction, iterationListeners, model);
    }

    public ConjugateGradient(NeuralNetConfiguration conf, StepFunction stepFunction, Collection<IterationListener> iterationListeners, Collection<TerminationCondition> terminationConditions, Model model) {
        super(conf, stepFunction, iterationListeners, terminationConditions, model);
    }

    @Override
    public void preProcessLine(INDArray line) {
    }

    @Override
    public void postStep() {
        INDArray g = (INDArray)this.searchState.get("g");
        INDArray xi = (INDArray)this.searchState.get("xi");
        INDArray h = (INDArray)this.searchState.get("h");
        this.searchState.put("gg", Transforms.pow((INDArray)g, (Number)2).sum(Integer.MAX_VALUE).getDouble(0));
        this.searchState.put("dgg", xi.mul(xi.sub(g)).sum(Integer.MAX_VALUE).getDouble(0));
        double dgg = (Double)this.searchState.get("dgg");
        double gg = (Double)this.searchState.get("gg");
        double gam = dgg / gg;
        this.searchState.put("gam", gam);
        if (h == null) {
            h = g;
        }
        g.assign(xi);
        h.assign(h.mul((Number)gam).addi(xi));
        BooleanIndexing.applyWhere((INDArray)h, (Condition)Conditions.isNan(), (Function)new Value((Number)Nd4j.EPS_THRESHOLD));
        LinAlgExceptions.assertValidNum((INDArray)h);
        if (Nd4j.getBlasWrapper().dot(xi, h) > 0.0) {
            xi.assign(h);
        } else {
            logger.warn("Reverting back to GA");
            h.assign(xi);
        }
        this.searchState.put("g", g);
        this.searchState.put("xi", xi);
        this.searchState.put("h", xi.add(h.mul((Number)gam)));
    }

    @Override
    public void setupSearchState(Pair<Gradient, Double> pair) {
        super.setupSearchState(pair);
        INDArray gradient = (INDArray)this.searchState.get("g");
        this.searchState.put("h", gradient.dup());
        this.searchState.put("xi", gradient.dup());
        this.searchState.put("gg", 0.0);
        this.searchState.put("gam", 0.0);
        this.searchState.put("dgg", 0.0);
    }
}

