/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.eval;

import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.base.DeepLearningTest;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.dbn.CDBN;
import org.deeplearning4j.dbn.DBN;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.BaseMultiLayerNetwork;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DataSetTester
extends DeepLearningTest {
    private static int[] layers = new int[]{200, 200, 200};
    private String dataset;
    private String algorithm;
    private Integer numExamples;
    private static Logger log = LoggerFactory.getLogger(DataSetTester.class);

    public DataSetTester(String dataset, String algorithm, Integer numExamples) {
        this.dataset = dataset;
        this.algorithm = algorithm;
        this.numExamples = numExamples;
    }

    public DataSetTester(String dataset, String algorithm) {
        this.dataset = dataset;
        this.algorithm = algorithm;
    }

    public static void main(String[] args) throws Exception {
        String algorithm = args[0];
        String dataset = args[1];
        if (args.length > 2) {
            int num = Integer.parseInt(args[2]);
            DataSetTester test = new DataSetTester(dataset, algorithm, num);
            test.run();
        } else {
            DataSetTester test = new DataSetTester(dataset, algorithm);
            test.run();
        }
    }

    public void run() throws Exception {
        List<Pair<DoubleMatrix, DoubleMatrix>> dataset = null;
        dataset = this.numExamples != null ? this.loadDataset(this.numExamples) : this.loadDataset();
        BaseMultiLayerNetwork neuralNet = this.getNeuralNet(dataset);
        long start = System.currentTimeMillis();
        Evaluation e = new Evaluation();
        for (Pair<DoubleMatrix, DoubleMatrix> pair : dataset) {
            neuralNet.trainNetwork(pair.getFirst(), pair.getSecond(), this.getOtherParams());
            DoubleMatrix predicted = neuralNet.predict(pair.getFirst());
            e.eval(pair.getSecond(), predicted);
        }
        long end = System.currentTimeMillis();
        long diff = end - start;
        log.info("Ended in " + TimeUnit.MILLISECONDS.toSeconds(diff) + " seconds");
        log.info(e.stats());
    }

    private Object[] getOtherParams() {
        if (this.algorithm.equals("sda")) {
            return new Object[]{0.1, 0.3, 500, 0.1, 200};
        }
        if (this.algorithm.equals("dbn") || this.algorithm.equals("cdbn")) {
            return new Object[]{1, 0.1, 500, 0.1, 200};
        }
        return null;
    }

    private BaseMultiLayerNetwork getNeuralNet(List<Pair<DoubleMatrix, DoubleMatrix>> dataset) {
        Pair<Integer, Integer> params = this.numInputsOutcomes(dataset);
        Object ret = new BaseMultiLayerNetwork.Builder().hiddenLayerSizes(layers).numberOfInputs(params.getFirst()).numberOfOutPuts(params.getSecond()).withRng((RandomGenerator)new MersenneTwister(123)).withClazz(this.algorithmForClass()).build();
        return ret;
    }

    private Class<? extends BaseMultiLayerNetwork> algorithmForClass() {
        if (this.algorithm.equals("sda")) {
            return BaseMultiLayerNetwork.class;
        }
        if (this.algorithm.equals("cdbn")) {
            return CDBN.class;
        }
        if (this.algorithm.equals("dbn")) {
            return DBN.class;
        }
        throw new IllegalStateException("No algorithm found");
    }

    private Pair<Integer, Integer> numInputsOutcomes(List<Pair<DoubleMatrix, DoubleMatrix>> list) {
        return this.numInputsOutcomes(list.get(0));
    }

    private Pair<Integer, Integer> numInputsOutcomes(Pair<DoubleMatrix, DoubleMatrix> pair) {
        int numInputs = pair.getFirst().columns;
        int numOutcomes = pair.getSecond().columns;
        return new Pair<Integer, Integer>(numInputs, numOutcomes);
    }

    private List<Pair<DoubleMatrix, DoubleMatrix>> loadDataset(int numExamples) throws Exception {
        if (this.dataset.equals("lfw")) {
            return DataSetTester.getFirstFaces(numExamples);
        }
        if (this.dataset.equals("iris")) {
            return Collections.singletonList(DataSetTester.getIris());
        }
        if (this.dataset.equals("mnist")) {
            return this.getMnistExampleBatches(1, numExamples);
        }
        return null;
    }

    private List<Pair<DoubleMatrix, DoubleMatrix>> loadDataset() throws Exception {
        if (this.dataset.equals("lfw")) {
            return this.getFaces();
        }
        if (this.dataset.equals("iris")) {
            return Collections.singletonList(DataSetTester.getIris());
        }
        if (this.dataset.equals("mnist")) {
            return this.getMnistExampleBatches(10, 6000);
        }
        return null;
    }
}

