/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers;

import java.io.Serializable;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.Classifier;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.optimize.Solver;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.util.FeatureUtil;
import org.nd4j.linalg.util.LinAlgExceptions;

public class OutputLayer
extends BaseLayer
implements Serializable,
Classifier {
    private static final long serialVersionUID = -7065564817460914364L;
    private INDArray labels;

    public OutputLayer(NeuralNetConfiguration conf) {
        super(conf);
    }

    public OutputLayer(NeuralNetConfiguration conf, INDArray input) {
        super(conf, input);
    }

    @Override
    public double score() {
        LinAlgExceptions.assertRows((INDArray)this.input, (INDArray)this.labels);
        INDArray output = this.output(this.input);
        if (this.conf.getLossFunction() != LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY) {
            return LossFunctions.score((INDArray)this.labels, (LossFunctions.LossFunction)this.conf.getLossFunction(), (INDArray)output, (double)this.conf.getL2(), (boolean)this.conf.isUseRegularization());
        }
        return -LossFunctions.score((INDArray)this.labels, (LossFunctions.LossFunction)this.conf.getLossFunction(), (INDArray)output, (double)this.conf.getL2(), (boolean)this.conf.isUseRegularization());
    }

    @Override
    public void setScore() {
        LinAlgExceptions.assertRows((INDArray)this.input, (INDArray)this.labels);
        INDArray output = this.output(this.input);
        this.score = LossFunctions.score((INDArray)this.labels, (LossFunctions.LossFunction)this.conf.getLossFunction(), (INDArray)output, (double)this.conf.getL2(), (boolean)this.conf.isUseRegularization());
    }

    @Override
    public Gradient gradient() {
        LinAlgExceptions.assertRows((INDArray)this.input, (INDArray)this.labels);
        INDArray netOut = this.output(this.input);
        INDArray dy = this.labels.sub(netOut);
        INDArray wGradient = this.getWeightGradient();
        INDArray bGradient = dy.mean(0);
        DefaultGradient g = new DefaultGradient();
        g.gradientForVariable().put("W", wGradient);
        g.gradientForVariable().put("b", bGradient);
        return g;
    }

    @Override
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair<Gradient, Double>(this.gradient(), this.score());
    }

    private INDArray getWeightGradient() {
        INDArray z = this.output(this.input);
        switch (this.conf.getLossFunction()) {
            case MCXENT: {
                INDArray preOut = this.preOutput(this.input);
                INDArray pYGivenX = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", preOut), 0);
                INDArray dy = this.labels.sub(pYGivenX);
                return this.input.transpose().mmul(dy);
            }
            case XENT: {
                INDArray xEntDiff = z.sub(this.labels);
                return this.input.transpose().mmul(xEntDiff.div(z.mul(z.rsub((Number)1))));
            }
            case MSE: {
                INDArray mseDelta = this.labels.sub(z);
                return this.input.transpose().mmul(mseDelta.neg());
            }
            case EXPLL: {
                return this.input.transpose().mmul(this.labels.rsub((Number)1).divi(z));
            }
            case RMSE_XENT: {
                return this.input.transpose().mmul(Transforms.pow((INDArray)this.labels.sub(z), (Number)2));
            }
            case SQUARED_LOSS: {
                return this.input.transpose().mmul(Transforms.pow((INDArray)this.labels.sub(z), (Number)2));
            }
            case NEGATIVELOGLIKELIHOOD: {
                return this.input.transpose().mmul(Transforms.log((INDArray)z).negi());
            }
        }
        throw new IllegalStateException("Invalid loss function");
    }

    @Override
    public double score(DataSet data) {
        return this.score(data.getFeatureMatrix(), data.getLabels());
    }

    @Override
    public double score(INDArray examples, INDArray labels) {
        Evaluation eval = new Evaluation();
        eval.eval(labels, this.labelProbabilities(examples));
        return eval.f1();
    }

    @Override
    public int numLabels() {
        return this.labels.columns();
    }

    @Override
    public void fit(DataSetIterator iter) {
        while (iter.hasNext()) {
            this.fit((DataSet)iter.next());
        }
    }

    @Override
    public int[] predict(INDArray d) {
        INDArray output = this.output(d);
        int[] ret = new int[d.rows()];
        for (int i = 0; i < ret.length; ++i) {
            ret[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
        }
        return ret;
    }

    @Override
    public INDArray labelProbabilities(INDArray examples) {
        return this.output(examples);
    }

    @Override
    public void fit(INDArray examples, INDArray labels) {
        this.input = examples;
        this.labels = labels;
        Solver solver = new Solver.Builder().configure(this.conf()).listeners(this.getIterationListeners()).model(this).build();
        solver.optimize();
    }

    @Override
    public void fit(DataSet data) {
        this.fit(data.getFeatureMatrix(), data.getLabels());
    }

    @Override
    public void fit(INDArray examples, int[] labels) {
        INDArray outcomeMatrix = FeatureUtil.toOutcomeMatrix((int[])labels, (int)this.numLabels());
        this.fit(examples, outcomeMatrix);
    }

    @Override
    public void clear() {
        super.clear();
        if (this.labels != null) {
            this.labels.data().destroy();
            this.labels = null;
        }
    }

    @Override
    public INDArray transform(INDArray data) {
        return this.preOutput(data);
    }

    @Override
    public void setParams(INDArray params) {
        INDArray wParams = params.get(new NDArrayIndex[]{NDArrayIndex.interval((int)0, (int)(this.conf.getnIn() * this.conf.getnOut()))});
        INDArray W = this.getParam("W");
        W.assign(wParams);
        INDArray bias = this.getParam("b");
        bias.assign(params.get(new NDArrayIndex[]{NDArrayIndex.interval((int)(this.conf.getnIn() * this.conf.getnOut()), (int)params.length())}).dup());
    }

    @Override
    public void fit(INDArray data) {
    }

    @Override
    public void iterate(INDArray input) {
    }

    public INDArray output(INDArray x) {
        if (x == null) {
            throw new IllegalArgumentException("No null input allowed");
        }
        INDArray preOutput = this.preOutput(x);
        INDArray ret = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", preOutput), 1);
        return ret;
    }

    public INDArray getLabels() {
        return this.labels;
    }

    public void setLabels(INDArray labels) {
        this.labels = labels;
    }
}

