/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.plot.iterationlistener;

import java.util.ArrayList;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.plot.NeuralNetPlotter;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;

public class AccuracyPlotterIterationListener
implements IterationListener {
    private int epochs = 1;
    private INDArray input;
    private MultiLayerNetwork network;
    private INDArray labels;
    private NeuralNetPlotter plotter = new NeuralNetPlotter();
    private boolean renderFirst = false;
    private ArrayList<Double> accuracy = new ArrayList();
    private boolean invoked = false;

    @Override
    public boolean invoked() {
        return this.invoked;
    }

    @Override
    public void invoke() {
        this.invoked = true;
    }

    public AccuracyPlotterIterationListener(int epochs, boolean renderFirst) {
        this.epochs = epochs;
        this.renderFirst = renderFirst;
    }

    public AccuracyPlotterIterationListener(int epochs, NeuralNetPlotter plotter) {
        this.epochs = epochs;
        this.plotter = plotter;
    }

    public AccuracyPlotterIterationListener(int epochs, NeuralNetPlotter plotter, boolean renderFirst) {
        this.epochs = epochs;
        this.plotter = plotter;
        this.renderFirst = renderFirst;
    }

    public AccuracyPlotterIterationListener(int epochs, MultiLayerNetwork network, DataSet data) {
        this.epochs = epochs;
        this.network = network;
        this.input = data.getFeatures();
        this.labels = data.getLabels();
    }

    public AccuracyPlotterIterationListener(int epochs, MultiLayerNetwork network, DataSet data, boolean renderFirst) {
        this.epochs = epochs;
        this.network = network;
        this.input = data.getFeatures();
        this.labels = data.getLabels();
        this.renderFirst = renderFirst;
    }

    public AccuracyPlotterIterationListener(int epochs, MultiLayerNetwork network, INDArray input, INDArray labels) {
        this.epochs = epochs;
        this.network = network;
        this.input = input;
        this.labels = labels;
    }

    public AccuracyPlotterIterationListener(int iterations) {
        this.epochs = this.epochs;
    }

    private double calculateAccuracy() {
        Evaluation eval = new Evaluation();
        INDArray output = this.network.output(this.input);
        eval.eval(this.labels, output);
        return eval.accuracy();
    }

    @Override
    public void iterationDone(Model model, int epochs) {
        double iterationAccuracy = this.calculateAccuracy();
        this.accuracy.add(iterationAccuracy);
        if (epochs == 0 && this.renderFirst || epochs > 0 && epochs % this.epochs == 0) {
            this.invoke();
            String dataFilePath = this.plotter.writeArray(this.accuracy);
            this.plotter.renderGraph("accuracy", dataFilePath, this.plotter.getLayerGraphFilePath() + "accuracy.png");
        }
    }
}

