/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize.optimizers;

import org.deeplearning4j.nn.gradient.OutputLayerGradient;
import org.deeplearning4j.nn.layers.OutputLayer;
import org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class OutputLayerOptimizer
implements OptimizableByGradientValueMatrix {
    private OutputLayer logReg;
    private double lr;
    private int currIteration = -1;

    public OutputLayerOptimizer(OutputLayer logReg, double lr) {
        this.logReg = logReg;
        this.lr = lr;
    }

    @Override
    public void setCurrentIteration(int value) {
        this.currIteration = value;
    }

    @Override
    public int getNumParameters() {
        return this.logReg.getW().length() + this.logReg.getB().length();
    }

    public void getParameters(double[] buffer) {
        for (int i = 0; i < buffer.length; ++i) {
            buffer[i] = this.getParameter(i);
        }
    }

    @Override
    public double getParameter(int index) {
        if (index >= this.logReg.getW().length()) {
            return (Double)this.logReg.getB().getScalar(index - this.logReg.getW().length()).element();
        }
        return (Double)this.logReg.getW().getScalar(index).element();
    }

    public void setParameters(double[] params) {
        for (int i = 0; i < params.length; ++i) {
            this.setParameter(i, params[i]);
        }
    }

    @Override
    public void setParameter(int index, double value) {
        if (index >= this.logReg.getW().length()) {
            this.logReg.getB().putScalar(index - this.logReg.getW().length(), value);
        } else {
            this.logReg.getW().putScalar(index, value);
        }
    }

    public void getValueGradient(double[] buffer) {
        OutputLayerGradient grad = this.logReg.getGradient(this.lr);
        for (int i = 0; i < buffer.length; ++i) {
            buffer[i] = i < this.logReg.getW().length() ? ((Double)grad.getwGradient().getScalar(i).element()).doubleValue() : ((Double)grad.getbGradient().getScalar(i - this.logReg.getW().length()).element()).doubleValue();
        }
    }

    @Override
    public double getValue() {
        return -this.logReg.score();
    }

    @Override
    public INDArray getParameters() {
        return this.logReg.params();
    }

    @Override
    public void setParameters(INDArray params) {
        if (this.logReg.conf().isConstrainGradientToUnitNorm()) {
            params.divi(params.normmax(Integer.MAX_VALUE));
        }
        this.logReg.setParams(params);
    }

    @Override
    public INDArray getValueGradient(int currIteration) {
        this.currIteration = currIteration;
        OutputLayerGradient grad = this.logReg.getGradient(this.lr);
        if (this.logReg.getW().length() != grad.getwGradient().length()) {
            throw new IllegalStateException("Illegal length for gradient");
        }
        if (this.logReg.getB().length() != grad.getbGradient().length()) {
            throw new IllegalStateException("Illegal length for gradient");
        }
        return Nd4j.toFlattened((INDArray[])new INDArray[]{grad.getwGradient(), grad.getbGradient()});
    }
}

