/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.ml.linear.evaluation;

import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.mtj.SparseMatrix;
import java.util.List;
import org.apache.log4j.Logger;
import org.openimaj.math.matrix.CFMatrixUtils;
import org.openimaj.ml.linear.evaluation.BilinearEvaluator;
import org.openimaj.ml.linear.learner.BilinearLearnerParameters;
import org.openimaj.ml.linear.learner.BilinearSparseOnlineLearner;
import org.openimaj.ml.linear.learner.loss.LossFunction;
import org.openimaj.ml.linear.learner.loss.MatLossFunction;
import org.openimaj.util.pair.Pair;

public class RootMeanSumLossEvaluator
extends BilinearEvaluator {
    Logger logger = Logger.getLogger(RootMeanSumLossEvaluator.class);

    @Override
    public double evaluate(List<Pair<Matrix>> data) {
        Matrix u = this.learner.getU();
        Matrix w = this.learner.getW();
        Matrix bias = this.learner.getBias();
        double sumloss = this.sumLoss(data, u, w, bias, this.learner.getParams());
        return sumloss;
    }

    public double sumLoss(List<Pair<Matrix>> pairs, Matrix u, Matrix w, Matrix bias, BilinearLearnerParameters params) {
        LossFunction loss = (LossFunction)params.getTyped("loss");
        if (!loss.isMatrixLoss()) {
            loss = new MatLossFunction(loss);
        }
        double total = 0.0;
        int i = 0;
        int ntasks = 0;
        boolean forceSparcity = (Boolean)this.learner.getParams().getTyped("forcesparcity");
        if (forceSparcity) {
            u = CFMatrixUtils.asSparseColumn((Matrix)u);
            w = CFMatrixUtils.asSparseColumn((Matrix)w);
        }
        for (Pair<Matrix> pair : pairs) {
            Matrix X = (Matrix)pair.firstObject();
            Matrix Y = (Matrix)pair.secondObject();
            SparseMatrix Yexp = BilinearSparseOnlineLearner.expandY(Y);
            Matrix xt = X.transpose();
            Matrix ut = u.transpose();
            Matrix expectedAll = CFMatrixUtils.fastdot((Matrix)CFMatrixUtils.fastdot((Matrix)ut, (Matrix)xt), (Matrix)w);
            loss.setY((Matrix)Yexp);
            loss.setX(expectedAll);
            if (bias != null) {
                loss.setBias(bias);
            }
            this.logger.debug((Object)("Testing pair: " + i));
            total += loss.eval(null);
            ++i;
            ntasks += Y.getNumColumns();
        }
        return Math.sqrt(total /= (double)ntasks);
    }
}

