/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn;

import java.io.Serializable;
import org.deeplearning4j.nn.gradient.LogisticRegressionGradient;
import org.deeplearning4j.optimize.LogisticRegressionOptimizer;
import org.deeplearning4j.util.MatrixUtil;
import org.deeplearning4j.util.NonZeroStoppingConjugateGradient;
import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;

public class LogisticRegression
implements Serializable {
    private static final long serialVersionUID = -7065564817460914364L;
    private int nIn;
    private int nOut;
    private DoubleMatrix input;
    private DoubleMatrix labels;
    private DoubleMatrix W;
    private DoubleMatrix b;
    private double l2 = 0.01;
    private boolean useRegularization = true;

    private LogisticRegression() {
    }

    public LogisticRegression(DoubleMatrix input, DoubleMatrix labels, int nIn, int nOut) {
        this.input = input;
        this.labels = labels;
        this.nIn = nIn;
        this.nOut = nOut;
        this.W = DoubleMatrix.zeros((int)nIn, (int)nOut);
        this.b = DoubleMatrix.zeros((int)nOut);
    }

    public LogisticRegression(DoubleMatrix input, int nIn, int nOut) {
        this(input, null, nIn, nOut);
    }

    public LogisticRegression(int nIn, int nOut) {
        this(null, null, nIn, nOut);
    }

    public synchronized void train(double lr) {
        this.train(this.input, this.labels, lr);
    }

    public synchronized void train(DoubleMatrix x, double lr) {
        MatrixUtil.complainAboutMissMatchedMatrices(x, this.labels);
        this.train(x, this.labels, lr);
    }

    public synchronized void trainTillConvergence(DoubleMatrix x, DoubleMatrix y, double learningRate, int epochs) {
        MatrixUtil.complainAboutMissMatchedMatrices(x, y);
        this.input = x;
        this.labels = y;
        this.trainTillConvergence(learningRate, epochs);
    }

    public synchronized void trainTillConvergence(double learningRate, int numEpochs) {
        LogisticRegressionOptimizer opt = new LogisticRegressionOptimizer(this, learningRate);
        NonZeroStoppingConjugateGradient g = new NonZeroStoppingConjugateGradient(opt);
        g.optimize(numEpochs);
    }

    public synchronized void merge(LogisticRegression l, int batchSize) {
        this.W.addi(l.W.subi(this.W).div((double)batchSize));
        this.b.addi(l.b.subi(this.b).div((double)batchSize));
    }

    public synchronized double negativeLogLikelihood() {
        MatrixUtil.complainAboutMissMatchedMatrices(this.input, this.labels);
        DoubleMatrix sigAct = MatrixUtil.softmax(this.input.mmul(this.W).addRowVector(this.b));
        if (this.useRegularization) {
            double reg = 2.0 / this.l2 * MatrixFunctions.pow((DoubleMatrix)this.W, (double)2.0).sum();
            return -this.labels.mul(MatrixUtil.log(sigAct)).add(MatrixUtil.oneMinus(this.labels).mul(MatrixUtil.log(MatrixUtil.oneMinus(sigAct)))).columnSums().mean() + reg;
        }
        return -this.labels.mul(MatrixUtil.log(sigAct)).add(MatrixUtil.oneMinus(this.labels).mul(MatrixUtil.log(MatrixUtil.oneMinus(sigAct)))).columnSums().mean();
    }

    public synchronized void train(DoubleMatrix x, DoubleMatrix y, double lr) {
        MatrixUtil.complainAboutMissMatchedMatrices(x, y);
        this.input = x;
        this.labels = y;
        LogisticRegressionGradient gradient = this.getGradient(lr);
        this.W.addi(gradient.getwGradient());
        this.b.addi(gradient.getbGradient());
    }

    protected LogisticRegression clone() {
        LogisticRegression reg = new LogisticRegression();
        reg.b = this.b.dup();
        reg.W = this.W.dup();
        reg.l2 = this.l2;
        if (this.labels != null) {
            reg.labels = this.labels.dup();
        }
        reg.nIn = this.nIn;
        reg.nOut = this.nOut;
        reg.useRegularization = this.useRegularization;
        if (this.input != null) {
            reg.input = this.input.dup();
        }
        return reg;
    }

    public synchronized LogisticRegressionGradient getGradient(double lr) {
        MatrixUtil.complainAboutMissMatchedMatrices(this.input, this.labels);
        DoubleMatrix p_y_given_x = MatrixUtil.sigmoid(this.input.mmul(this.W).addRowVector(this.b));
        DoubleMatrix dy = this.labels.sub(p_y_given_x);
        if (this.useRegularization) {
            dy.divi((double)this.input.rows);
        }
        DoubleMatrix wGradient = this.input.transpose().mmul(dy).mul(lr);
        DoubleMatrix bGradient = dy;
        return new LogisticRegressionGradient(wGradient, bGradient);
    }

    public synchronized DoubleMatrix predict(DoubleMatrix x) {
        return MatrixUtil.softmax(x.mmul(this.W).addRowVector(this.b));
    }

    public synchronized int getnIn() {
        return this.nIn;
    }

    public synchronized void setnIn(int nIn) {
        this.nIn = nIn;
    }

    public synchronized int getnOut() {
        return this.nOut;
    }

    public synchronized void setnOut(int nOut) {
        this.nOut = nOut;
    }

    public synchronized DoubleMatrix getInput() {
        return this.input;
    }

    public synchronized void setInput(DoubleMatrix input) {
        this.input = input;
    }

    public synchronized DoubleMatrix getLabels() {
        return this.labels;
    }

    public synchronized void setLabels(DoubleMatrix labels) {
        this.labels = labels;
    }

    public synchronized DoubleMatrix getW() {
        return this.W;
    }

    public synchronized void setW(DoubleMatrix w) {
        this.W = w;
    }

    public synchronized DoubleMatrix getB() {
        return this.b;
    }

    public synchronized void setB(DoubleMatrix b) {
        this.b = b;
    }

    public synchronized double getL2() {
        return this.l2;
    }

    public synchronized void setL2(double l2) {
        this.l2 = l2;
    }

    public synchronized boolean isUseRegularization() {
        return this.useRegularization;
    }

    public synchronized void setUseRegularization(boolean useRegularization) {
        this.useRegularization = useRegularization;
    }

    public static class Builder {
        private DoubleMatrix W;
        private LogisticRegression ret;
        private DoubleMatrix b;
        private double l2;
        private int nIn;
        private int nOut;
        private DoubleMatrix input;
        private boolean useRegualarization;

        public Builder withL2(double l2) {
            this.l2 = l2;
            return this;
        }

        public Builder useRegularization(boolean regularize) {
            this.useRegualarization = regularize;
            return this;
        }

        public Builder withWeights(DoubleMatrix W) {
            this.W = W;
            return this;
        }

        public Builder withBias(DoubleMatrix b) {
            this.b = b;
            return this;
        }

        public Builder numberOfInputs(int nIn) {
            this.nIn = nIn;
            return this;
        }

        public Builder numberOfOutputs(int nOut) {
            this.nOut = nOut;
            return this;
        }

        public LogisticRegression build() {
            this.ret = new LogisticRegression(this.input, this.nIn, this.nOut);
            if (this.W != null) {
                this.ret.W = this.W;
            }
            if (this.b != null) {
                this.ret.b = this.b;
            }
            this.ret.useRegularization = this.useRegualarization;
            this.ret.l2 = this.l2;
            return this.ret;
        }
    }
}

