/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.ml.linear.experiments.sinabill;

import com.google.common.primitives.Doubles;
import com.jmatio.io.MatFileWriter;
import com.jmatio.types.MLArray;
import gov.sandia.cognition.math.matrix.Matrix;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.apache.log4j.Appender;
import org.apache.log4j.ConsoleAppender;
import org.apache.log4j.FileAppender;
import org.apache.log4j.Layout;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.log4j.PatternLayout;
import org.apache.log4j.Priority;
import org.openimaj.io.IOUtils;
import org.openimaj.io.WriteableBinary;
import org.openimaj.math.matrix.CFMatrixUtils;
import org.openimaj.ml.linear.data.BillMatlabFileDataGenerator;
import org.openimaj.ml.linear.evaluation.BilinearEvaluator;
import org.openimaj.ml.linear.evaluation.RootMeanSumLossEvaluator;
import org.openimaj.ml.linear.experiments.sinabill.BilinearLearnerParametersLineSearch;
import org.openimaj.ml.linear.learner.BilinearLearnerParameters;
import org.openimaj.ml.linear.learner.BilinearSparseOnlineLearner;
import org.openimaj.ml.linear.learner.init.SparseZerosInitStrategy;
import org.openimaj.ml.linear.learner.loss.MatSquareLossFunction;
import org.openimaj.util.pair.Pair;

public class LambdaSearchAustrian {
    private static final int NFOLDS = 1;
    private static final String ROOT = "/Users/ss/Experiments/bilinear/austrian/";
    private static final String OUTPUT_ROOT = "/Users/ss/Dropbox/TrendMiner/Collaboration/StreamingBilinear2014/experiments";
    private final Logger logger = Logger.getLogger(this.getClass());
    private long expStartTime = System.currentTimeMillis();

    public static void main(String[] args) throws IOException {
        LambdaSearchAustrian exp = new LambdaSearchAustrian();
        exp.performExperiment();
    }

    public void performExperiment() throws IOException {
        List<BillMatlabFileDataGenerator.Fold> folds = this.prepareFolds();
        BillMatlabFileDataGenerator bmfdg = new BillMatlabFileDataGenerator(new File(LambdaSearchAustrian.dataFromRoot("normalised.mat")), "user_vsr_for_polls_SINA", new File(LambdaSearchAustrian.dataFromRoot("unnormalised.mat")), 98, false, folds);
        this.prepareExperimentLog();
        RootMeanSumLossEvaluator eval = new RootMeanSumLossEvaluator();
        for (int i = 0; i < bmfdg.nFolds(); ++i) {
            this.logger.info((Object)("Starting Fold: " + i));
            BilinearSparseOnlineLearner best = this.lineSearchParams(i, bmfdg);
            this.logger.debug((Object)"Best params found! Starting test...");
            bmfdg.setFold(i, BillMatlabFileDataGenerator.Mode.TEST);
            eval.setLearner(best);
            double ev = ((BilinearEvaluator)eval).evaluate(bmfdg.generateAll());
            this.logger.debug((Object)("Test RMSE: " + ev));
        }
    }

    private BilinearSparseOnlineLearner lineSearchParams(int fold, BillMatlabFileDataGenerator source) {
        BilinearSparseOnlineLearner best = null;
        double bestScore = Double.MAX_VALUE;
        RootMeanSumLossEvaluator eval = new RootMeanSumLossEvaluator();
        int j = 0;
        List<BilinearLearnerParameters> parameterLineSearch = this.parameterLineSearch();
        this.logger.info((Object)("Optimising params, searching: " + parameterLineSearch.size()));
        for (BilinearLearnerParameters next : parameterLineSearch) {
            this.logger.info((Object)String.format("Optimising params %d/%d", j + 1, parameterLineSearch.size()));
            this.logger.debug((Object)("Current Params:\n" + next.toString()));
            BilinearSparseOnlineLearner learner = new BilinearSparseOnlineLearner(next);
            source.setFold(fold, BillMatlabFileDataGenerator.Mode.TRAINING);
            Pair<Matrix> pair = null;
            this.logger.debug((Object)"Training...");
            while ((pair = source.generate()) != null) {
                learner.process((Matrix)pair.firstObject(), (Matrix)pair.secondObject());
            }
            this.logger.debug((Object)"Generating score of validation set");
            source.setFold(fold, BillMatlabFileDataGenerator.Mode.VALIDATION);
            eval.setLearner(learner);
            double loss = ((BilinearEvaluator)eval).evaluate(source.generateAll());
            this.logger.debug((Object)("Total RMSE: " + loss));
            this.logger.debug((Object)("U sparcity: " + CFMatrixUtils.sparsity((Matrix)learner.getU())));
            this.logger.debug((Object)("W sparcity: " + CFMatrixUtils.sparsity((Matrix)learner.getW())));
            if (loss < bestScore) {
                this.logger.info((Object)"New best score detected!");
                bestScore = loss;
                best = learner;
                this.logger.info((Object)("New Best Config:\n" + best.getParams()));
                this.logger.info((Object)("New Best Loss:" + loss));
                this.saveFoldParameterLearner(fold, j, learner);
            }
            ++j;
        }
        return best;
    }

    private void saveFoldParameterLearner(int fold, int j, BilinearSparseOnlineLearner learner) {
        File learnerOut = new File(String.format("%s/fold_%d", this.currentOutputRoot(), fold), String.format("learner_%d", j));
        File learnerOutMat = new File(String.format("%s/fold_%d", this.currentOutputRoot(), fold), String.format("learner_%d.mat", j));
        learnerOut.getParentFile().mkdirs();
        try {
            IOUtils.writeBinary((File)learnerOut, (WriteableBinary)learner);
            ArrayList<MLArray> data = new ArrayList<MLArray>();
            data.add(CFMatrixUtils.toMLArray((String)"u", (Matrix)learner.getU()));
            data.add(CFMatrixUtils.toMLArray((String)"w", (Matrix)learner.getW()));
            if (learner.getBias() != null) {
                data.add(CFMatrixUtils.toMLArray((String)"b", (Matrix)learner.getBias()));
            }
            MatFileWriter writer = new MatFileWriter(learnerOutMat, data);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private List<BilinearLearnerParameters> parameterLineSearch() {
        BilinearLearnerParameters params = this.prepareParams();
        BilinearLearnerParametersLineSearch iter = new BilinearLearnerParametersLineSearch(params);
        iter.addIteration("eta0u", Doubles.asList((double[])new double[]{1.0E-4}));
        iter.addIteration("eta0w", Doubles.asList((double[])new double[]{0.005}));
        iter.addIteration("biaseta0", Doubles.asList((double[])new double[]{50.0}));
        iter.addIteration("lambda_u", Doubles.asList((double[])new double[]{1.0E-5}));
        iter.addIteration("lambda_w", Doubles.asList((double[])new double[]{1.0E-5}));
        ArrayList<BilinearLearnerParameters> ret = new ArrayList<BilinearLearnerParameters>();
        for (BilinearLearnerParameters param : iter) {
            ret.add(param);
        }
        return ret;
    }

    private List<BillMatlabFileDataGenerator.Fold> prepareFolds() {
        ArrayList<BillMatlabFileDataGenerator.Fold> set_fold = new ArrayList<BillMatlabFileDataGenerator.Fold>();
        int step = 5;
        int t_size = 48;
        int v_size = 8;
        for (int i = 0; i < 1; ++i) {
            int k;
            int total = i * 5 + 48;
            int[] training = new int[total - 8];
            int[] test = new int[5];
            int[] validation = new int[8];
            int j = 0;
            int traini = 0;
            int tt = (int)Math.round((double)total / 2.0) - 1;
            while (j < tt - 4) {
                training[traini] = j++;
                ++traini;
            }
            for (k = 0; k < validation.length; ++k) {
                validation[k] = j++;
            }
            while (j < total) {
                training[traini] = j++;
                ++traini;
            }
            for (k = 0; k < test.length; ++k) {
                test[k] = j++;
            }
            BillMatlabFileDataGenerator.Fold foldi = new BillMatlabFileDataGenerator.Fold(training, test, validation);
            set_fold.add(foldi);
        }
        return set_fold;
    }

    private BilinearLearnerParameters prepareParams() {
        BilinearLearnerParameters params = new BilinearLearnerParameters();
        params.put("eta0u", null);
        params.put("eta0w", null);
        params.put("lambda_u", null);
        params.put("lambda_w", null);
        params.put("biaseta0", null);
        params.put("biconvex_tol", 0.01);
        params.put("biconvex_maxiter", 10);
        params.put("bias", true);
        params.put("winitstrat", new SparseZerosInitStrategy());
        params.put("uinitstrat", new SparseZerosInitStrategy());
        params.put("loss", new MatSquareLossFunction());
        return params;
    }

    public static String dataFromRoot(String data) {
        return String.format("%s/%s", ROOT, data);
    }

    protected void prepareExperimentLog() throws IOException {
        ConsoleAppender console = new ConsoleAppender();
        String PATTERN = "[%p->%C{1}] %m%n";
        console.setLayout((Layout)new PatternLayout("[%p->%C{1}] %m%n"));
        console.setThreshold((Priority)Level.INFO);
        console.activateOptions();
        Logger.getRootLogger().addAppender((Appender)console);
        File expRoot = this.prepareExperimentRoot();
        File logFile = new File(expRoot, "log");
        if (logFile.exists()) {
            logFile.delete();
        }
        String TIMED_PATTERN = "[%d{HH:mm:ss} %p->%C{1}] %m%n";
        FileAppender file = new FileAppender((Layout)new PatternLayout("[%d{HH:mm:ss} %p->%C{1}] %m%n"), logFile.getAbsolutePath());
        file.setThreshold((Priority)Level.DEBUG);
        file.activateOptions();
        Logger.getRootLogger().addAppender((Appender)file);
        this.logger.info((Object)("Experiment root: " + expRoot));
    }

    public File prepareExperimentRoot() throws IOException {
        String experimentRoot = this.currentOutputRoot();
        File expRoot = new File(experimentRoot);
        if (expRoot.exists() && expRoot.isDirectory()) {
            return expRoot;
        }
        this.logger.debug((Object)("Experiment root: " + expRoot));
        if (!expRoot.mkdirs()) {
            throw new IOException("Couldn't prepare experiment output");
        }
        return expRoot;
    }

    private String currentOutputRoot() {
        return String.format("%s/%s/%s", OUTPUT_ROOT, this.getExperimentSetName(), "" + this.currentExperimentTime());
    }

    private long currentExperimentTime() {
        return this.expStartTime;
    }

    private String getExperimentSetName() {
        return "streamingBilinear/optimiselambda";
    }
}

