/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.ml.linear.learner.loss;

import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.mtj.SparseMatrix;
import gov.sandia.cognition.math.matrix.mtj.SparseMatrixFactoryMTJ;
import org.openimaj.ml.linear.learner.loss.LossFunction;

public class MatLossFunction
extends LossFunction {
    private LossFunction f;
    private SparseMatrixFactoryMTJ spf;

    public MatLossFunction(LossFunction f) {
        this.f = f;
        this.spf = SparseMatrixFactoryMTJ.INSTANCE;
    }

    @Override
    public void setX(Matrix X) {
        super.setX(X);
        this.f.setX(X);
    }

    @Override
    public void setY(Matrix Y) {
        super.setY(Y);
        this.f.setY(Y);
    }

    @Override
    public void setBias(Matrix bias) {
        super.setBias(bias);
        this.f.setBias(bias);
    }

    @Override
    public Matrix gradient(Matrix W) {
        SparseMatrix ret = this.spf.createMatrix(W.getNumRows(), W.getNumColumns());
        int allRowsY = this.Y.getNumRows() - 1;
        int allRowsW = W.getNumRows() - 1;
        for (int i = 0; i < this.Y.getNumColumns(); ++i) {
            this.f.setY(this.Y.getSubMatrix(0, allRowsY, i, i));
            if (this.bias != null) {
                this.f.setBias(this.bias.getSubMatrix(0, allRowsY, i, i));
            }
            Matrix w = W.getSubMatrix(0, allRowsW, i, i);
            Matrix submatrix = this.f.gradient(w);
            ret.setSubMatrix(0, i, submatrix);
        }
        return ret;
    }

    @Override
    public double eval(Matrix W) {
        double total = 0.0;
        this.f.setBias(this.bias);
        return total += this.f.eval(W);
    }

    @Override
    public boolean isMatrixLoss() {
        return true;
    }
}

