/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.learning;

import java.io.Serializable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

public class AdaGrad
implements Serializable {
    protected static final long serialVersionUID = -4754127927704099888L;
    protected double masterStepSize = 0.1;
    public INDArray historicalGradient;
    public INDArray adjustedGradient;
    public double fudgeFactor = 1.0E-6;
    public INDArray gradient;
    public int[] shape;
    protected int numIterations = 0;
    protected double lrDecay = 0.95;
    protected boolean decayLr;
    protected double minLearningRate = 1.0E-4;

    public AdaGrad(int rows, int cols, double gamma) {
        this.shape = new int[]{rows, cols};
        this.createHistoricalGradient();
        this.createAdjustedGradient();
        this.masterStepSize = gamma;
        this.decayLr = false;
    }

    public AdaGrad(int[] shape) {
        this.shape = shape;
        this.createHistoricalGradient();
        this.createAdjustedGradient();
        this.masterStepSize = 0.1;
        this.decayLr = false;
    }

    public AdaGrad(int rows, int cols) {
        this(rows, cols, 0.1);
    }

    protected void createHistoricalGradient() {
        this.historicalGradient = Nd4j.create(this.shape);
    }

    protected void createAdjustedGradient() {
        this.adjustedGradient = Nd4j.create(this.shape);
    }

    public INDArray getLearningRates(INDArray gradient) {
        this.gradient = gradient;
        INDArray squaredGradient = Transforms.pow(this.gradient, 2);
        if (this.historicalGradient == null || this.historicalGradient.length() != this.gradient.length()) {
            this.historicalGradient = Nd4j.zeros(this.gradient.rows(), this.gradient.columns());
        }
        this.historicalGradient.addi(squaredGradient);
        ++this.numIterations;
        INDArray sqrtGradient = Transforms.sqrt(this.historicalGradient).addi(this.fudgeFactor);
        INDArray div = Transforms.abs(gradient).divi(sqrtGradient);
        this.adjustedGradient = div.muli(this.masterStepSize);
        return this.adjustedGradient;
    }

    public double getMasterStepSize() {
        return this.masterStepSize;
    }

    public void setMasterStepSize(double masterStepSize) {
        this.masterStepSize = masterStepSize;
    }

    public synchronized boolean isDecayLr() {
        return this.decayLr;
    }

    public synchronized void setDecayLr(boolean decayLr) {
        this.decayLr = decayLr;
    }
}

