/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn;

import java.io.Serializable;
import org.deeplearning4j.nn.NeuralNetwork;
import org.deeplearning4j.nn.gradient.LogisticRegressionGradient;
import org.deeplearning4j.nn.learning.AdaGrad;
import org.deeplearning4j.optimize.LogisticRegressionOptimizer;
import org.deeplearning4j.optimize.VectorizedDeepLearningGradientAscent;
import org.deeplearning4j.optimize.VectorizedNonZeroStoppingConjugateGradient;
import org.deeplearning4j.util.MatrixUtil;
import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;

public class LogisticRegression
implements Serializable {
    private static final long serialVersionUID = -7065564817460914364L;
    private int nIn;
    private int nOut;
    private DoubleMatrix input;
    private DoubleMatrix labels;
    private DoubleMatrix W;
    private DoubleMatrix b;
    private double l2 = 0.01;
    private boolean useRegularization = true;
    private boolean useAdaGrad = false;
    private AdaGrad adaGrad;
    private boolean firstTimeThrough = false;
    private boolean normalizeByInputRows = false;
    private NeuralNetwork.OptimizationAlgorithm optimizationAlgorithm;

    private LogisticRegression() {
    }

    public LogisticRegression(DoubleMatrix input, DoubleMatrix labels, int nIn, int nOut) {
        this.input = input;
        this.labels = labels;
        this.nIn = nIn;
        this.nOut = nOut;
        this.W = DoubleMatrix.zeros((int)nIn, (int)nOut);
        this.adaGrad = new AdaGrad(nIn, nOut);
        this.b = DoubleMatrix.zeros((int)nOut);
    }

    public LogisticRegression(DoubleMatrix input, int nIn, int nOut) {
        this(input, null, nIn, nOut);
    }

    public LogisticRegression(int nIn, int nOut) {
        this(null, null, nIn, nOut);
    }

    public void train(double lr) {
        this.train(this.input, this.labels, lr);
    }

    public void resetAdaGrad(double lr) {
        if (!this.firstTimeThrough) {
            this.adaGrad = new AdaGrad(this.nIn, this.nOut, lr);
            this.firstTimeThrough = false;
        }
    }

    public void train(DoubleMatrix x, double lr) {
        MatrixUtil.complainAboutMissMatchedMatrices(x, this.labels);
        this.train(x, this.labels, lr);
    }

    public void trainTillConvergence(DoubleMatrix x, DoubleMatrix y, double learningRate, int epochs) {
        MatrixUtil.complainAboutMissMatchedMatrices(x, y);
        this.input = x;
        this.labels = y;
        this.trainTillConvergence(learningRate, epochs);
    }

    public void trainTillConvergence(double learningRate, int numEpochs) {
        LogisticRegressionOptimizer opt = new LogisticRegressionOptimizer(this, learningRate);
        if (this.optimizationAlgorithm == NeuralNetwork.OptimizationAlgorithm.CONJUGATE_GRADIENT) {
            VectorizedNonZeroStoppingConjugateGradient g = new VectorizedNonZeroStoppingConjugateGradient(opt);
            g.setTolerance(1.0E-5);
            g.optimize(numEpochs);
        } else {
            VectorizedDeepLearningGradientAscent g = new VectorizedDeepLearningGradientAscent(opt);
            g.setTolerance(1.0E-5);
            g.optimize(numEpochs);
        }
    }

    public void merge(LogisticRegression l, int batchSize) {
        if (this.useRegularization) {
            this.W.addi(l.W.subi(this.W).div((double)batchSize));
            this.b.addi(l.b.subi(this.b).div((double)batchSize));
        } else {
            this.W.addi(l.W.subi(this.W));
            this.b.addi(l.b.subi(this.b));
        }
    }

    public double negativeLogLikelihood() {
        MatrixUtil.complainAboutMissMatchedMatrices(this.input, this.labels);
        DoubleMatrix z = this.predict(this.input);
        if (this.useRegularization) {
            double reg = 2.0 / this.l2 * MatrixFunctions.pow((DoubleMatrix)this.W, (double)2.0).sum();
            return -this.labels.mul(MatrixUtil.log(z)).add(MatrixUtil.oneMinus(this.labels).mul(MatrixUtil.log(MatrixUtil.oneMinus(z)))).columnSums().mean() + reg;
        }
        return -this.labels.mul(MatrixUtil.log(z)).add(MatrixUtil.oneMinus(this.labels).mul(MatrixUtil.log(MatrixUtil.oneMinus(z)))).columnSums().mean();
    }

    public void train(DoubleMatrix x, DoubleMatrix y, double lr) {
        MatrixUtil.complainAboutMissMatchedMatrices(x, y);
        this.input = x;
        this.labels = y;
        LogisticRegressionGradient gradient = this.getGradient(lr);
        this.W.addi(gradient.getwGradient());
        this.b.addi(gradient.getbGradient());
    }

    protected LogisticRegression clone() {
        LogisticRegression reg = new LogisticRegression();
        reg.b = this.b.dup();
        reg.W = this.W.dup();
        reg.l2 = this.l2;
        if (this.labels != null) {
            reg.labels = this.labels.dup();
        }
        reg.nIn = this.nIn;
        reg.nOut = this.nOut;
        reg.useRegularization = this.useRegularization;
        reg.normalizeByInputRows = this.normalizeByInputRows;
        if (this.input != null) {
            reg.input = this.input.dup();
        }
        return reg;
    }

    public LogisticRegressionGradient getGradient(double lr) {
        MatrixUtil.complainAboutMissMatchedMatrices(this.input, this.labels);
        DoubleMatrix p_y_given_x = MatrixUtil.sigmoid(this.input.mmul(this.W).addRowVector(this.b));
        DoubleMatrix dy = this.labels.sub(p_y_given_x);
        if (this.normalizeByInputRows) {
            dy.divi((double)this.input.rows);
        }
        DoubleMatrix wGradient = this.input.transpose().mmul(dy);
        if (this.useAdaGrad) {
            wGradient.muli(this.adaGrad.getLearningRates(wGradient));
        } else {
            wGradient.muli(lr);
        }
        DoubleMatrix bGradient = dy;
        return new LogisticRegressionGradient(wGradient, bGradient);
    }

    public DoubleMatrix predict(DoubleMatrix x) {
        this.input = x;
        return MatrixUtil.softmax(x.mmul(this.W).addRowVector(this.b));
    }

    public int getnIn() {
        return this.nIn;
    }

    public void setnIn(int nIn) {
        this.nIn = nIn;
    }

    public int getnOut() {
        return this.nOut;
    }

    public void setnOut(int nOut) {
        this.nOut = nOut;
    }

    public DoubleMatrix getInput() {
        return this.input;
    }

    public void setInput(DoubleMatrix input) {
        this.input = input;
    }

    public DoubleMatrix getLabels() {
        return this.labels;
    }

    public void setLabels(DoubleMatrix labels) {
        this.labels = labels;
    }

    public DoubleMatrix getW() {
        return this.W;
    }

    public void setW(DoubleMatrix w) {
        this.W = w;
    }

    public DoubleMatrix getB() {
        return this.b;
    }

    public void setB(DoubleMatrix b) {
        this.b = b;
    }

    public double getL2() {
        return this.l2;
    }

    public void setL2(double l2) {
        this.l2 = l2;
    }

    public boolean isUseRegularization() {
        return this.useRegularization;
    }

    public void setUseRegularization(boolean useRegularization) {
        this.useRegularization = useRegularization;
    }

    public synchronized boolean isNormalizeByInputRows() {
        return this.normalizeByInputRows;
    }

    public synchronized void setNormalizeByInputRows(boolean normalizeByInputRows) {
        this.normalizeByInputRows = normalizeByInputRows;
    }

    public boolean isUseAdaGrad() {
        return this.useAdaGrad;
    }

    public void setUseAdaGrad(boolean useAdaGrad) {
        this.useAdaGrad = useAdaGrad;
    }

    public NeuralNetwork.OptimizationAlgorithm getOptimizationAlgorithm() {
        return this.optimizationAlgorithm;
    }

    public void setOptimizationAlgorithm(NeuralNetwork.OptimizationAlgorithm optimizationAlgorithm) {
        this.optimizationAlgorithm = optimizationAlgorithm;
    }

    public static class Builder {
        private DoubleMatrix W;
        private LogisticRegression ret;
        private DoubleMatrix b;
        private double l2;
        private int nIn;
        private int nOut;
        private DoubleMatrix input;
        private boolean useRegualarization;
        private boolean useAdaGrad = false;
        private boolean normalizeByInputRows = false;
        private NeuralNetwork.OptimizationAlgorithm optimizationAlgorithm;

        public Builder optimizeBy(NeuralNetwork.OptimizationAlgorithm optimizationAlgorithm) {
            this.optimizationAlgorithm = optimizationAlgorithm;
            return this;
        }

        public Builder normalizeByInputRows(boolean normalizeByInputRows) {
            this.normalizeByInputRows = normalizeByInputRows;
            return this;
        }

        public Builder useAdaGrad(boolean useAdaGrad) {
            this.useAdaGrad = useAdaGrad;
            return this;
        }

        public Builder withL2(double l2) {
            this.l2 = l2;
            return this;
        }

        public Builder useRegularization(boolean regularize) {
            this.useRegualarization = regularize;
            return this;
        }

        public Builder withWeights(DoubleMatrix W) {
            this.W = W;
            return this;
        }

        public Builder withBias(DoubleMatrix b) {
            this.b = b;
            return this;
        }

        public Builder numberOfInputs(int nIn) {
            this.nIn = nIn;
            return this;
        }

        public Builder numberOfOutputs(int nOut) {
            this.nOut = nOut;
            return this;
        }

        public LogisticRegression build() {
            this.ret = new LogisticRegression(this.input, this.nIn, this.nOut);
            if (this.W != null) {
                this.ret.W = this.W;
            }
            if (this.b != null) {
                this.ret.b = this.b;
            }
            this.ret.optimizationAlgorithm = this.optimizationAlgorithm;
            this.ret.normalizeByInputRows = this.normalizeByInputRows;
            this.ret.useRegularization = this.useRegualarization;
            this.ret.l2 = this.l2;
            this.ret.useAdaGrad = this.useAdaGrad;
            return this.ret;
        }
    }
}

