/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.learning;

import java.io.Serializable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.GradientUpdater;

public class Nesterovs
implements Serializable,
GradientUpdater {
    private double momentum = 0.5;
    private INDArray v;
    private double lr;

    public Nesterovs(double momentum, double lr) {
        this.momentum = momentum;
        this.lr = lr;
    }

    public Nesterovs(double momentum) {
        this(momentum, 0.1);
    }

    public double getMomentum() {
        return this.momentum;
    }

    public void setMomentum(double momentum) {
        this.momentum = momentum;
    }

    public double getLr() {
        return this.lr;
    }

    public void setLr(double lr) {
        this.lr = lr;
    }

    @Override
    public INDArray getGradient(INDArray gradient, int iteration) {
        if (this.v == null) {
            this.v = Nd4j.zeros(gradient.shape());
        }
        INDArray vPrev = this.v;
        this.v = vPrev.mul(this.momentum).subi(gradient.mul(this.lr));
        INDArray ret = vPrev.muli(this.momentum).addi(this.v.mul(-this.momentum - 1.0));
        return ret;
    }
}

