/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize.solvers;

import java.util.Collection;
import java.util.LinkedList;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.api.TerminationCondition;
import org.deeplearning4j.optimize.solvers.BaseOptimizer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

public class LBFGS
extends BaseOptimizer {
    private int m = 4;

    public LBFGS(NeuralNetConfiguration conf, StepFunction stepFunction, Collection<IterationListener> iterationListeners, Model model) {
        super(conf, stepFunction, iterationListeners, model);
    }

    public LBFGS(NeuralNetConfiguration conf, StepFunction stepFunction, Collection<IterationListener> iterationListeners, Collection<TerminationCondition> terminationConditions, Model model) {
        super(conf, stepFunction, iterationListeners, terminationConditions, model);
    }

    @Override
    protected boolean preFirstStepProcess(INDArray gradient) {
        this.searchState.put("g", gradient.mul((Number)Nd4j.norm2((INDArray)gradient).rdivi((Number)1.0).getDouble(0)));
        return true;
    }

    @Override
    public void setupSearchState(Pair<Gradient, Double> pair) {
        super.setupSearchState(pair);
        INDArray gradient = (INDArray)this.searchState.get("g");
        INDArray params = (INDArray)this.searchState.get("params");
        this.searchState.put("s", new LinkedList());
        this.searchState.put("y", new LinkedList());
        this.searchState.put("rho", new LinkedList());
        this.searchState.put("alpha", Nd4j.create((int)this.m));
        this.searchState.put("oldparams", params.dup());
        this.searchState.put("oldgradient", gradient.dup());
    }

    @Override
    protected void postFirstStep(INDArray gradient) {
        super.postFirstStep(gradient);
        if (this.step == 0.0) {
            log.info("Unable to step in that direction...resetting");
            this.setupSearchState(this.model.gradientAndScore());
            this.step = 1.0;
        }
    }

    @Override
    public void preProcessLine(INDArray line) {
        int i;
        INDArray oldParameters = (INDArray)this.searchState.get("oldparams");
        INDArray params = (INDArray)this.searchState.get("params");
        oldParameters.assign(params.sub(oldParameters));
        INDArray oldGradient = (INDArray)this.searchState.get("oldgradient");
        INDArray gradient = (INDArray)this.searchState.get("g");
        oldGradient.subi(gradient);
        double sy = Nd4j.getBlasWrapper().dot(oldParameters, oldGradient) + Nd4j.EPS_THRESHOLD;
        double yy = Transforms.pow((INDArray)oldGradient, (Number)2).sum(Integer.MAX_VALUE).getDouble(0) + Nd4j.EPS_THRESHOLD;
        double gamma = sy / yy;
        LinkedList rho = (LinkedList)this.searchState.get("rho");
        rho.add(1.0 / sy);
        LinkedList s = (LinkedList)this.searchState.get("s");
        s.add(oldParameters);
        LinkedList y = (LinkedList)this.searchState.get("y");
        y.add(oldGradient);
        if (s.size() != y.size()) {
            throw new IllegalStateException("S and y mis matched sizes");
        }
        INDArray alpha = (INDArray)this.searchState.get("alpha");
        for (i = s.size() - 1; i >= 0; --i) {
            if (((INDArray)s.get(i)).length() != gradient.length()) {
                throw new IllegalStateException("Gradient and s length not equal");
            }
            if (i >= alpha.length()) break;
            if (i > rho.size()) {
                throw new IllegalStateException("I > rho size");
            }
            alpha.putScalar(i, (Double)rho.get(i) * Nd4j.getBlasWrapper().dot(gradient, (INDArray)s.get(i)));
            Nd4j.getBlasWrapper().axpy(-1.0 * alpha.getDouble(i), gradient, (INDArray)y.get(i));
        }
        gradient.muli((Number)gamma);
        for (i = 0; i < y.size() && i < alpha.length(); ++i) {
            double beta = (Double)rho.get(i) * Nd4j.getBlasWrapper().dot((INDArray)y.get(i), gradient);
            Nd4j.getBlasWrapper().axpy(alpha.getDouble(i) * beta, gradient, (INDArray)s.get(i));
        }
        oldParameters.assign(params);
        oldGradient.assign(gradient);
        gradient.muli((Number)-1);
    }

    @Override
    public void postStep() {
    }
}

