/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.ml.linear.experiments.sinabill;

import gnu.trove.set.hash.TIntHashSet;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import org.openimaj.io.IOUtils;
import org.openimaj.io.WriteableBinary;
import org.openimaj.math.matrix.CFMatrixUtils;
import org.openimaj.ml.linear.data.BillMatlabFileDataGenerator;
import org.openimaj.ml.linear.evaluation.BilinearEvaluator;
import org.openimaj.ml.linear.evaluation.RootMeanSumLossEvaluator;
import org.openimaj.ml.linear.experiments.sinabill.BilinearExperiment;
import org.openimaj.ml.linear.learner.BilinearLearnerParameters;
import org.openimaj.ml.linear.learner.BilinearSparseOnlineLearner;
import org.openimaj.ml.linear.learner.init.SingleValueInitStrat;
import org.openimaj.ml.linear.learner.init.SparseZerosInitStrategy;
import org.openimaj.ml.linear.learner.loss.MatSquareLossFunction;
import org.openimaj.util.pair.Pair;

public class BillAustrianExperiments
extends BilinearExperiment {
    public static void main(String[] args) throws IOException {
        BillAustrianExperiments exp = new BillAustrianExperiments();
        exp.performExperiment();
    }

    @Override
    public void performExperiment() throws IOException {
        BilinearLearnerParameters params = new BilinearLearnerParameters();
        params.put("eta0u", 0.02);
        params.put("eta0w", 0.02);
        params.put("lambda", 0.001);
        params.put("biconvex_tol", 0.01);
        params.put("biconvex_maxiter", 10);
        params.put("bias", true);
        params.put("biaseta0", 0.5);
        params.put("winitstrat", new SingleValueInitStrat(0.1));
        params.put("uinitstrat", new SparseZerosInitStrategy());
        params.put("loss", new MatSquareLossFunction());
        BillMatlabFileDataGenerator bmfdg = new BillMatlabFileDataGenerator(new File(this.MATLAB_DATA()), 98, true);
        this.prepareExperimentLog(params);
        BilinearSparseOnlineLearner learner = new BilinearSparseOnlineLearner(params);
        learner.reinitParams();
        TIntHashSet seenTraining = new TIntHashSet();
        for (int i = 0; i < bmfdg.nFolds(); ++i) {
            Pair<Matrix> next;
            this.logger.debug((Object)("Fold: " + i));
            bmfdg.setFold(i, BillMatlabFileDataGenerator.Mode.TEST);
            ArrayList<Pair<Matrix>> testpairs = new ArrayList<Pair<Matrix>>();
            while ((next = bmfdg.generate()) != null) {
                testpairs.add(next);
            }
            int j = 0;
            this.logger.debug((Object)"...training");
            bmfdg.setFold(i, BillMatlabFileDataGenerator.Mode.TRAINING);
            while (true) {
                Pair<Matrix> next2 = bmfdg.generate();
                if (seenTraining.contains(j)) {
                    this.logger.debug((Object)("...skipping item " + j));
                    ++j;
                    continue;
                }
                seenTraining.add(j);
                if (next2 == null) break;
                this.logger.debug((Object)("...trying item " + j));
                learner.process((Matrix)next2.firstObject(), (Matrix)next2.secondObject());
                this.logger.debug((Object)("...done processing item " + j));
                ++j;
            }
            Matrix u = learner.getU();
            Matrix w = learner.getW();
            Matrix bias = MatrixFactory.getDenseDefault().copyMatrix(learner.getBias());
            RootMeanSumLossEvaluator eval = new RootMeanSumLossEvaluator();
            eval.setLearner(learner);
            double loss = ((BilinearEvaluator)eval).evaluate(testpairs);
            this.logger.debug((Object)String.format("Saving learner, Fold %d, Item %d", i, j));
            File learnerOut = new File(this.FOLD_ROOT(i), String.format("learner_%d", j));
            IOUtils.writeBinary((File)learnerOut, (WriteableBinary)learner);
            this.logger.debug((Object)("W row sparcity: " + CFMatrixUtils.rowSparsity((Matrix)w)));
            this.logger.debug((Object)("U row sparcity: " + CFMatrixUtils.rowSparsity((Matrix)u)));
            Boolean biasMode = (Boolean)learner.getParams().getTyped("bias");
            if (biasMode.booleanValue()) {
                this.logger.debug((Object)("Bias: " + CFMatrixUtils.diag((Matrix)bias)));
            }
            this.logger.debug((Object)String.format("... loss: %f", loss));
        }
    }
}

