/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.plot;

import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.TreeSet;
import java.util.UUID;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.StringUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.plot.FilterRenderer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.ClassPathResource;

public class NeuralNetPlotter
implements Serializable {
    private static ClassPathResource script = new ClassPathResource("scripts" + File.separator + "plot.py");
    private static final Logger log = LoggerFactory.getLogger(NeuralNetPlotter.class);
    private static String ID_FOR_SESSION = UUID.randomUUID().toString();
    private static String localPath = System.getProperty("java.io.tmpdir") + File.separator;
    private static String dataFilePath = localPath + "data" + File.separator;
    private static String graphPath = localPath + "graphs" + File.separator;
    private static String graphFilePath = graphPath + ID_FOR_SESSION + File.separator;
    private static String localPlotPath = NeuralNetPlotter.loadIntoTmp();
    private static String layerGraphFilePath = graphFilePath;

    public String getLayerGraphFilePath() {
        return layerGraphFilePath;
    }

    public void setLayerGraphFilePath(String newPath) {
        layerGraphFilePath = newPath;
    }

    public static void printDataFilePath() {
        log.info("Data stored at " + dataFilePath);
    }

    public static void printGraphFilePath() {
        log.warn("Graphs stored at " + graphFilePath + ". " + "Warning: You must manually delete the folder when you are done.");
    }

    private static String loadIntoTmp() {
        NeuralNetPlotter.setupDirectory(dataFilePath);
        NeuralNetPlotter.setupDirectory(graphFilePath);
        NeuralNetPlotter.printDataFilePath();
        NeuralNetPlotter.printGraphFilePath();
        File plotPath = new File(graphPath, "plot.py");
        plotPath.deleteOnExit();
        if (!plotPath.exists()) {
            try {
                List lines = IOUtils.readLines((InputStream)script.getInputStream());
                FileUtils.writeLines((File)plotPath, (Collection)lines);
            }
            catch (IOException e) {
                throw new IllegalStateException("Unable to load python file");
            }
        }
        return plotPath.getAbsolutePath();
    }

    protected static void setupDirectory(String path) {
        File newPath = new File(path);
        if (!newPath.isDirectory()) {
            newPath.mkdir();
        }
    }

    public void updateGraphDirectory(Layer layer) {
        String layerType = layer.getClass().toString();
        String[] layerPath = layerType.split("\\.");
        String layerName = Integer.toString(layer.getIndex()) + layerPath[layerPath.length - 1];
        String newPath = graphFilePath + File.separator + layerName + File.separator;
        if (!new File(newPath).exists()) {
            NeuralNetPlotter.setupDirectory(newPath);
            this.setLayerGraphFilePath(newPath);
        }
    }

    protected String writeMatrix(INDArray matrix) {
        try {
            String tmpFilePath = dataFilePath + UUID.randomUUID().toString();
            File write = new File(tmpFilePath);
            BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(write, true));
            write.deleteOnExit();
            for (int i = 0; i < matrix.rows(); ++i) {
                INDArray row = matrix.getRow(i);
                StringBuilder sb = new StringBuilder();
                for (int j = 0; j < row.length(); ++j) {
                    sb.append(String.format("%.10f", row.getDouble(j)));
                    if (j >= row.length() - 1) continue;
                    sb.append(",");
                }
                sb.append("\n");
                String line = sb.toString();
                bos.write(line.getBytes());
                bos.flush();
            }
            bos.close();
            return tmpFilePath;
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public String writeArray(ArrayList data) {
        try {
            String tmpFilePath = dataFilePath + UUID.randomUUID().toString();
            File write = new File(tmpFilePath);
            BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(write, true));
            write.deleteOnExit();
            StringBuilder sb = new StringBuilder();
            for (Object value : data) {
                sb.append(String.format("%.10f", (Double)value));
                sb.append(",");
            }
            String line = sb.toString();
            line = line.substring(0, line.length() - 1);
            bos.write(line.getBytes());
            bos.flush();
            bos.close();
            return tmpFilePath;
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public void renderGraph(String action, String dataPath, String saveFilePath) {
        try {
            log.info("Rendering " + action + " graphs for data analysis...");
            Process is = Runtime.getRuntime().exec("python " + localPlotPath + " " + action + " " + dataPath + " " + saveFilePath);
            log.info("Std out " + IOUtils.readLines((InputStream)is.getInputStream()).toString());
            log.error("Std error " + IOUtils.readLines((InputStream)is.getErrorStream()).toString());
        }
        catch (IOException e) {
            log.warn("Image closed");
            throw new RuntimeException(e);
        }
    }

    public void renderGraph(String action, String dataPath, String saveFilePath, int feature_width, int feature_height) {
        try {
            log.info("Rendering " + action + " graphs for data analysis...");
            Process is = Runtime.getRuntime().exec("python " + localPlotPath + " " + action + " " + dataPath + " " + saveFilePath + " " + feature_width + " " + feature_height);
            log.info("Std out " + IOUtils.readLines((InputStream)is.getInputStream()).toString());
            log.error("Std error " + IOUtils.readLines((InputStream)is.getErrorStream()).toString());
        }
        catch (IOException e) {
            log.warn("Image closed");
            throw new RuntimeException(e);
        }
    }

    public void graphPlotType(String plotType, List<String> titles, INDArray[] matrices, String saveFilePath) {
        Object[] path = new String[matrices.length * 2];
        if (titles.size() != matrices.length) {
            throw new IllegalArgumentException("Titles and matrix lengths must be equal");
        }
        for (int i = 0; i < path.length - 1; i += 2) {
            path[i] = this.writeMatrix(matrices[i / 2].ravel());
            path[i + 1] = titles.get(i / 2);
        }
        String dataPath = StringUtils.join((Object[])path, (String)",");
        this.renderGraph(plotType, dataPath, saveFilePath);
    }

    public void plotWeightHistograms(Layer network, Gradient gradient) {
        String variable;
        int i;
        TreeSet<String> vars = new TreeSet<String>(gradient.gradientForVariable().keySet());
        ArrayList<String> titles = new ArrayList<String>(vars);
        for (String s : vars) {
            titles.add(s + "-gradient");
        }
        INDArray[] variablesAndGradients = new INDArray[network.conf().variables().size() * 2];
        int count = 0;
        for (i = 0; i < network.conf().variables().size(); ++i) {
            variable = network.conf().variables().get(i);
            variablesAndGradients[count++] = network.getParam(variable);
        }
        for (i = 0; i < network.conf().variables().size(); ++i) {
            variable = network.conf().variables().get(i);
            variablesAndGradients[count++] = gradient.getGradientFor(variable);
        }
        this.graphPlotType("histogram", titles, variablesAndGradients, layerGraphFilePath + "weightHistograms.png");
    }

    public void plotWeightHistograms(Layer network) {
        this.plotWeightHistograms(network, network.gradient());
    }

    public void plotActivations(Layer layer) {
        if (layer.input() == null) {
            throw new IllegalStateException("Unable to plot; missing input");
        }
        INDArray hbiasMean = layer.activationMean();
        String dataPath = this.writeMatrix(hbiasMean);
        this.renderGraph("activations", dataPath, layerGraphFilePath + "activationPlot.png");
    }

    public void renderFilter(Layer layer, int patchesPerRow) {
        INDArray weight = layer.getParam("W");
        INDArray w = weight.dup();
        FilterRenderer render = new FilterRenderer();
        try {
            if (w.shape().length > 2) {
                INDArray render2 = w.transpose();
                render.renderFilters(render2, layerGraphFilePath + "renderFilter.png", w.columns(), w.rows(), w.slices());
            } else {
                render.renderFilters(w, layerGraphFilePath + "renderFilter.png", (int)Math.sqrt(w.rows()), (int)Math.sqrt(w.columns()), patchesPerRow);
            }
        }
        catch (Exception e) {
            log.error("Unable to plot filter, continuing...", (Throwable)e);
            e.printStackTrace();
        }
    }

    public void plotNetworkGradient(Layer layer, Gradient gradient) {
        this.plotWeightHistograms(layer, gradient);
        this.plotActivations(layer);
    }

    public void plotNetworkGradient(Layer layer, INDArray gradient) {
        this.graphPlotType("histogram", Arrays.asList("W", "w-gradient"), new INDArray[]{layer.getParam("W"), gradient}, layerGraphFilePath + "weightHistograms.png");
        this.plotActivations(layer);
    }
}

