/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize;

import cc.mallet.optimize.Optimizable;
import org.deeplearning4j.nn.LogisticRegression;
import org.deeplearning4j.nn.gradient.LogisticRegressionGradient;
import org.deeplearning4j.optimize.OptimizableByGradientValueMatrix;
import org.jblas.DoubleMatrix;

public class LogisticRegressionOptimizer
implements Optimizable.ByGradientValue,
OptimizableByGradientValueMatrix {
    private LogisticRegression logReg;
    private double lr;

    public LogisticRegressionOptimizer(LogisticRegression logReg, double lr) {
        this.logReg = logReg;
        this.lr = lr;
    }

    @Override
    public int getNumParameters() {
        return this.logReg.getW().length + this.logReg.getB().length;
    }

    public void getParameters(double[] buffer) {
        for (int i = 0; i < buffer.length; ++i) {
            buffer[i] = this.getParameter(i);
        }
    }

    @Override
    public double getParameter(int index) {
        if (index >= this.logReg.getW().length) {
            return this.logReg.getB().get(index - this.logReg.getW().length);
        }
        return this.logReg.getW().get(index);
    }

    public void setParameters(double[] params) {
        for (int i = 0; i < params.length; ++i) {
            this.setParameter(i, params[i]);
        }
    }

    @Override
    public void setParameter(int index, double value) {
        if (index >= this.logReg.getW().length) {
            this.logReg.getB().put(index - this.logReg.getW().length, value);
        } else {
            this.logReg.getW().put(index, value);
        }
    }

    public void getValueGradient(double[] buffer) {
        LogisticRegressionGradient grad = this.logReg.getGradient(this.lr);
        for (int i = 0; i < buffer.length; ++i) {
            buffer[i] = i < this.logReg.getW().length ? grad.getwGradient().get(i) : grad.getbGradient().get(i - this.logReg.getW().length);
        }
    }

    @Override
    public double getValue() {
        return -this.logReg.negativeLogLikelihood();
    }

    @Override
    public DoubleMatrix getParameters() {
        DoubleMatrix params = new DoubleMatrix(this.getNumParameters());
        for (int i = 0; i < params.length; ++i) {
            params.put(i, this.getParameter(i));
        }
        return params;
    }

    @Override
    public void setParameters(DoubleMatrix params) {
        this.setParameters(params.toArray());
    }

    @Override
    public DoubleMatrix getValueGradient() {
        LogisticRegressionGradient grad = this.logReg.getGradient(this.lr);
        DoubleMatrix ret = new DoubleMatrix(this.getNumParameters());
        for (int i = 0; i < ret.length; ++i) {
            if (i < this.logReg.getW().length) {
                ret.put(i, grad.getwGradient().get(i));
                continue;
            }
            ret.put(i, grad.getbGradient().get(i - this.logReg.getW().length));
        }
        return ret;
    }
}

