/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers;

import com.google.common.base.Function;
import java.io.Serializable;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.Classifier;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.OutputLayerGradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.optimize.api.TrainingEvaluator;
import org.deeplearning4j.optimize.optimizers.OutputLayerOptimizer;
import org.deeplearning4j.optimize.solvers.StochasticHessianFree;
import org.deeplearning4j.optimize.solvers.VectorizedDeepLearningGradientAscent;
import org.deeplearning4j.optimize.solvers.VectorizedNonZeroStoppingConjugateGradient;
import org.nd4j.linalg.api.activation.Activations;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.indexing.functions.Value;
import org.nd4j.linalg.learning.AdaGrad;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.util.FeatureUtil;
import org.nd4j.linalg.util.LinAlgExceptions;

public class OutputLayer
extends BaseLayer
implements Serializable,
Classifier {
    private static final long serialVersionUID = -7065564817460914364L;
    private INDArray labels;
    private AdaGrad adaGrad;
    private AdaGrad biasAdaGrad;

    public OutputLayer(NeuralNetConfiguration conf, INDArray input, INDArray labels) {
        super(conf, null, null, input);
        this.labels = labels;
        this.adaGrad = new AdaGrad(conf.getnIn(), conf.getnOut());
        this.b = Nd4j.zeros((int)1, (int)conf.getnOut());
        this.biasAdaGrad = new AdaGrad(this.b.rows(), this.b.columns());
    }

    public void train(double lr) {
        this.train(this.input, this.labels, lr);
    }

    public void train(INDArray x, double lr) {
        this.adaGrad.setMasterStepSize(lr);
        this.biasAdaGrad.setMasterStepSize(lr);
        LinAlgExceptions.assertRows((INDArray)x, (INDArray)this.labels);
        this.train(x, this.labels, lr);
    }

    public void trainTillConvergence(INDArray x, INDArray y, double learningRate, int epochs) {
        LinAlgExceptions.assertRows((INDArray)x, (INDArray)y);
        this.adaGrad.setMasterStepSize(learningRate);
        this.biasAdaGrad.setMasterStepSize(learningRate);
        this.input = x;
        this.labels = y;
        this.trainTillConvergence(learningRate, epochs);
    }

    public void trainTillConvergence(INDArray labels, double learningRate, int numEpochs, TrainingEvaluator eval) {
        this.labels = labels;
        OutputLayerOptimizer opt = new OutputLayerOptimizer(this, learningRate);
        this.adaGrad.setMasterStepSize(learningRate);
        this.biasAdaGrad.setMasterStepSize(learningRate);
        if (this.conf.getOptimizationAlgo() == NeuralNetwork.OptimizationAlgorithm.CONJUGATE_GRADIENT) {
            VectorizedNonZeroStoppingConjugateGradient g = new VectorizedNonZeroStoppingConjugateGradient(opt);
            g.setTolerance(0.001f);
            g.setTrainingEvaluator(eval);
            g.setMaxIterations(numEpochs);
            g.optimize(numEpochs);
        } else if (this.conf.getOptimizationAlgo() == NeuralNetwork.OptimizationAlgorithm.HESSIAN_FREE) {
            StochasticHessianFree o = new StochasticHessianFree(opt, null);
            o.setTolerance(0.001f);
            o.setTrainingEvaluator(eval);
            o.optimize(numEpochs);
        } else {
            VectorizedDeepLearningGradientAscent g = new VectorizedDeepLearningGradientAscent(opt);
            g.setTolerance(0.001f);
            g.setTrainingEvaluator(eval);
            g.optimize(numEpochs);
        }
    }

    public void trainTillConvergence(double learningRate, int numEpochs, TrainingEvaluator eval) {
        OutputLayerOptimizer opt = new OutputLayerOptimizer(this, learningRate);
        this.adaGrad.setMasterStepSize(learningRate);
        this.biasAdaGrad.setMasterStepSize(learningRate);
        if (this.conf.getOptimizationAlgo() == NeuralNetwork.OptimizationAlgorithm.CONJUGATE_GRADIENT) {
            VectorizedNonZeroStoppingConjugateGradient g = new VectorizedNonZeroStoppingConjugateGradient(opt);
            g.setTolerance(0.001f);
            g.setTrainingEvaluator(eval);
            g.setMaxIterations(numEpochs);
            g.optimize(numEpochs);
        } else {
            VectorizedDeepLearningGradientAscent g = new VectorizedDeepLearningGradientAscent(opt);
            g.setTolerance(0.001f);
            g.setTrainingEvaluator(eval);
            g.optimize(numEpochs);
        }
    }

    public void trainTillConvergence(double learningRate, int numEpochs) {
        this.trainTillConvergence(learningRate, numEpochs, null);
    }

    @Override
    public double score() {
        LinAlgExceptions.assertRows((INDArray)this.input, (INDArray)this.labels);
        INDArray output = this.output(this.input);
        BooleanIndexing.applyWhere((INDArray)output, (Condition)Conditions.isNan(), (Function)new Value((Number)Nd4j.EPS_THRESHOLD));
        assert (!Nd4j.hasInvalidNumber((INDArray)output)) : "Invalid number on output!";
        if (this.conf.getLossFunction() != LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY) {
            return LossFunctions.score((INDArray)this.labels, (LossFunctions.LossFunction)this.conf.getLossFunction(), (INDArray)output, (double)this.conf.getL2(), (boolean)this.conf.isUseRegularization());
        }
        return LossFunctions.score((INDArray)this.labels, (LossFunctions.LossFunction)this.conf.getLossFunction(), (INDArray)output, (double)this.conf.getL2(), (boolean)this.conf.isUseRegularization());
    }

    public void train(INDArray x, INDArray y, double lr) {
        this.adaGrad.setMasterStepSize(lr);
        this.biasAdaGrad.setMasterStepSize(lr);
        LinAlgExceptions.assertRows((INDArray)this.input, (INDArray)this.labels);
        this.input = x;
        this.labels = y;
        OutputLayerGradient gradient = this.getGradient(lr);
        this.W.addi(gradient.getwGradient());
        this.b.addi(gradient.getbGradient());
    }

    @Override
    public Layer clone() {
        OutputLayer reg = new OutputLayer(this.conf, this.W, this.b);
        if (this.labels != null) {
            reg.labels = this.labels.dup();
        }
        reg.biasAdaGrad = this.biasAdaGrad;
        reg.adaGrad = this.adaGrad;
        if (this.input != null) {
            reg.input = this.input.dup();
        }
        return reg;
    }

    public OutputLayerGradient getGradient(double lr) {
        LinAlgExceptions.assertRows((INDArray)this.input, (INDArray)this.labels);
        this.adaGrad.setMasterStepSize(lr);
        this.biasAdaGrad.setMasterStepSize(lr);
        INDArray netOut = this.output(this.input);
        INDArray dy = this.labels.sub(netOut);
        dy.divi((Number)this.input.rows());
        INDArray wGradient = this.getWeightGradient();
        if (this.conf.isUseAdaGrad()) {
            wGradient.muli(this.adaGrad.getLearningRates(wGradient));
        } else {
            wGradient.muli((Number)lr);
        }
        if (this.conf.isUseAdaGrad()) {
            dy.muliRowVector(this.biasAdaGrad.getLearningRates(dy.mean(0)));
        } else {
            dy.muli((Number)lr);
        }
        dy.divi((Number)this.input.rows());
        INDArray bGradient = dy.mean(0);
        if (this.conf.isConstrainGradientToUnitNorm()) {
            wGradient.divi(wGradient.norm2(Integer.MAX_VALUE));
            bGradient.divi(bGradient.norm2(Integer.MAX_VALUE));
        }
        return new OutputLayerGradient(wGradient, bGradient);
    }

    private INDArray getWeightGradient() {
        INDArray z = this.output(this.input);
        switch (this.conf.getLossFunction()) {
            case MCXENT: {
                INDArray preOut = this.preOutput(this.input);
                INDArray p_y_given_x = (INDArray)Activations.softMaxRows().apply((Object)preOut);
                INDArray dy = this.labels.sub(p_y_given_x);
                return this.input.transpose().mmul(dy);
            }
            case XENT: {
                INDArray xEntDiff = z.sub(this.labels);
                return this.input.transpose().mmul(xEntDiff.div(z.mul(z.rsub((Number)1))));
            }
            case MSE: {
                INDArray mseDelta = this.labels.sub(z);
                return this.input.transpose().mmul(mseDelta.neg());
            }
            case EXPLL: {
                return this.input.transpose().mmul(this.labels.rsub((Number)1).divi(z));
            }
            case RMSE_XENT: {
                return this.input.transpose().mmul(Transforms.pow((INDArray)this.labels.sub(z), (Number)2));
            }
            case SQUARED_LOSS: {
                return this.input.transpose().mmul(Transforms.pow((INDArray)this.labels.sub(z), (Number)2));
            }
            case NEGATIVELOGLIKELIHOOD: {
                return this.input.transpose().mmul(Transforms.log((INDArray)z).negi());
            }
        }
        throw new IllegalStateException("Invalid loss function");
    }

    @Override
    public double score(DataSet data) {
        return this.score(data.getFeatureMatrix(), data.getLabels());
    }

    @Override
    public double score(INDArray examples, INDArray labels) {
        Evaluation eval = new Evaluation();
        eval.eval(labels, this.labelProbabilities(examples));
        return eval.f1();
    }

    @Override
    public int numLabels() {
        return this.labels.columns();
    }

    @Override
    public int[] predict(INDArray d) {
        INDArray output = this.output(d);
        int[] ret = new int[d.rows()];
        for (int i = 0; i < ret.length; ++i) {
            ret[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
        }
        return ret;
    }

    @Override
    public INDArray labelProbabilities(INDArray examples) {
        return this.output(examples);
    }

    @Override
    public void fit(INDArray examples, INDArray labels) {
        this.trainTillConvergence(examples, labels, this.conf.getLr(), this.conf.getNumIterations());
    }

    @Override
    public void fit(DataSet data) {
        this.fit(data.getFeatureMatrix(), data.getLabels());
    }

    @Override
    public void fit(INDArray examples, INDArray labels, Object[] params) {
        this.fit(examples, labels);
    }

    @Override
    public void fit(DataSet data, Object[] params) {
        this.fit(data);
    }

    @Override
    public void fit(INDArray examples, int[] labels) {
        INDArray outcomeMatrix = FeatureUtil.toOutcomeMatrix((int[])labels, (int)this.numLabels());
        this.fit(examples, outcomeMatrix);
    }

    @Override
    public void fit(INDArray examples, int[] labels, Object[] params) {
        INDArray labelMatrix = FeatureUtil.toOutcomeMatrix((int[])labels, (int)labels.length);
        this.fit(examples, labelMatrix);
    }

    @Override
    public void iterate(INDArray examples, int[] labels, Object[] params) {
    }

    @Override
    public INDArray transform(INDArray data) {
        return this.preOutput(data);
    }

    @Override
    public INDArray params() {
        return Nd4j.hstack((INDArray[])new INDArray[]{this.W.linearView(), this.b.linearView()});
    }

    @Override
    public int numParams() {
        return this.conf.getnIn() * this.conf.getnOut() + this.conf.getnOut();
    }

    @Override
    public void setParams(INDArray params) {
        INDArray wParams = params.get(new NDArrayIndex[]{NDArrayIndex.interval((int)0, (int)(this.conf.getnIn() * this.conf.getnOut()))});
        INDArray wLinear = this.getW().linearView();
        for (int i = 0; i < wParams.length(); ++i) {
            wLinear.putScalar(i, wParams.getDouble(i));
        }
        this.setB(params.get(new NDArrayIndex[]{NDArrayIndex.interval((int)(this.conf.getnIn() * this.conf.getnOut()), (int)params.length())}));
    }

    @Override
    public void fit(INDArray data, Object[] params) {
        throw new UnsupportedOperationException();
    }

    @Override
    public void fit(INDArray data) {
        throw new UnsupportedOperationException();
    }

    @Override
    public void iterate(INDArray input, Object[] params) {
    }

    @Override
    public String toString() {
        return "OutputLayer{labels=" + this.labels + ", adaGrad=" + this.adaGrad + ", biasAdaGrad=" + this.biasAdaGrad + "} " + super.toString();
    }

    public INDArray output(INDArray x) {
        if (x == null) {
            throw new IllegalArgumentException("No null input allowed");
        }
        this.input = x;
        INDArray preOutput = this.preOutput(x);
        INDArray ret = (INDArray)this.conf.getActivationFunction().apply((Object)preOutput);
        this.applyDropOutIfNecessary(ret);
        return ret;
    }

    public INDArray getLabels() {
        return this.labels;
    }

    public void setLabels(INDArray labels) {
        this.labels = labels;
    }

    public AdaGrad getBiasAdaGrad() {
        return this.biasAdaGrad;
    }

    public AdaGrad getAdaGrad() {
        return this.adaGrad;
    }

    public void setAdaGrad(AdaGrad adaGrad) {
        this.adaGrad = adaGrad;
    }

    public void setBiasAdaGrad(AdaGrad biasAdaGrad) {
        this.biasAdaGrad = biasAdaGrad;
    }

    public static class Builder {
        private INDArray W;
        private OutputLayer ret;
        private NeuralNetConfiguration conf;
        private INDArray b;
        private INDArray input;
        private INDArray labels;

        public Builder configure(NeuralNetConfiguration conf) {
            this.conf = conf;
            return this;
        }

        public Builder input(INDArray input) {
            this.input = input;
            return this;
        }

        public Builder withLabels(INDArray labels) {
            this.labels = labels;
            return this;
        }

        public Builder withWeights(INDArray W) {
            this.W = W;
            return this;
        }

        public Builder withBias(INDArray b) {
            this.b = b;
            return this;
        }

        public OutputLayer build() {
            this.ret = new OutputLayer(this.conf, this.input, this.labels);
            if (this.W != null) {
                this.ret.W = this.W;
            }
            if (this.b != null) {
                this.ret.b = this.b;
            }
            this.ret.conf = this.conf;
            return this.ret;
        }
    }
}

