/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.gradient.multilayer;

import java.util.ArrayList;
import java.util.List;
import org.deeplearning4j.gradient.multilayer.MultiLayerGradientListener;
import org.deeplearning4j.nn.gradient.MultiLayerGradient;
import org.deeplearning4j.plot.NeuralNetPlotter;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class WeightPlotListener
implements MultiLayerGradientListener {
    private static final long serialVersionUID = -2476819215506562426L;
    private List<MultiLayerGradient> gradients = new ArrayList<MultiLayerGradient>();
    private static Logger log = LoggerFactory.getLogger(WeightPlotListener.class);

    @Override
    public void onMultiLayerGradient(MultiLayerGradient gradient) {
        this.gradients.add(gradient);
        if (this.gradients.size() >= 6) {
            this.gradients.remove(0);
        }
        this.plot();
    }

    public void plot() {
        DoubleMatrix[] d = new DoubleMatrix[this.gradients.size()];
        String[] names = new String[this.gradients.size()];
        log.info("Plotting " + this.gradients.size() + " matrices");
        for (int i = 0; i < this.gradients.size(); ++i) {
            names[i] = String.valueOf(i);
            d[i] = this.gradients.get(i).getGradients().get(0).getwGradient();
        }
        NeuralNetPlotter plotter = new NeuralNetPlotter();
        plotter.plotMatrices(names, d);
    }
}

