/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.ml.linear.learner.matlib.loss;

import ch.akuhn.matrix.Matrix;
import ch.akuhn.matrix.Vector;
import org.apache.log4j.Logger;
import org.openimaj.math.matrix.MatlibMatrixUtils;
import org.openimaj.ml.linear.learner.matlib.loss.LossFunction;

public class MatSquareLossFunction
extends LossFunction {
    Logger logger = Logger.getLogger(MatSquareLossFunction.class);

    @Override
    public Matrix gradient(Matrix W) {
        Matrix ret = W.newInstance();
        Matrix resid = MatlibMatrixUtils.dotProduct((Matrix)this.X, (Matrix)W);
        if (this.bias != null) {
            MatlibMatrixUtils.plusInplace((Matrix)resid, (Matrix)this.bias);
        }
        MatlibMatrixUtils.minusInplace((Matrix)resid, (Matrix)this.Y);
        for (int t = 0; t < resid.columnCount(); ++t) {
            Vector row = this.X.row(t);
            row.times(resid.get(t, t));
            MatlibMatrixUtils.setSubMatrixCol((Matrix)ret, (int)0, (int)t, (Vector)row);
        }
        return ret;
    }

    @Override
    public double eval(Matrix W) {
        Matrix resid = null;
        resid = W == null ? this.X : MatlibMatrixUtils.dotProduct((Matrix)this.X, (Matrix)W);
        Matrix vnobias = MatlibMatrixUtils.copy((Matrix)this.X);
        if (this.bias != null) {
            MatlibMatrixUtils.plusInplace((Matrix)resid, (Matrix)this.bias);
        }
        Matrix v = MatlibMatrixUtils.copy((Matrix)resid);
        MatlibMatrixUtils.minusInplace((Matrix)resid, (Matrix)this.Y);
        double retval = 0.0;
        for (int t = 0; t < resid.columnCount(); ++t) {
            double loss = resid.get(t, t);
            retval += loss * loss;
            this.logger.debug((Object)String.format("yr=%d,y=%3.2f,v=%3.2f,v(no bias)=%2.5f,error=%2.5f,serror=%2.5f", t, this.Y.get(t, t), v.get(t, t), vnobias.get(t, t), loss, loss * loss));
        }
        return retval;
    }

    @Override
    public boolean isMatrixLoss() {
        return true;
    }
}

