/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.ui.weights;

import com.fasterxml.jackson.jaxrs.json.JacksonJsonProvider;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import javax.ws.rs.client.Client;
import javax.ws.rs.client.ClientBuilder;
import javax.ws.rs.client.Entity;
import javax.ws.rs.client.WebTarget;
import javax.ws.rs.core.Response;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.ui.UiServer;
import org.deeplearning4j.ui.UiUtils;
import org.deeplearning4j.ui.providers.ObjectMapperProvider;
import org.deeplearning4j.ui.weights.ModelAndGradient;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class HistogramIterationListener
implements IterationListener {
    private static final Logger log = LoggerFactory.getLogger(HistogramIterationListener.class);
    private Client client = (Client)((Client)ClientBuilder.newClient().register(JacksonJsonProvider.class)).register((Object)new ObjectMapperProvider());
    private WebTarget target;
    private int iterations = 1;
    private ArrayList<Double> scoreHistory = new ArrayList();
    private List<Map<String, List<Double>>> meanMagHistoryParams = new ArrayList<Map<String, List<Double>>>();
    private List<Map<String, List<Double>>> meanMagHistoryUpdates = new ArrayList<Map<String, List<Double>>>();
    private boolean openBrowser;
    private boolean firstIteration = true;
    private String path;
    private String subPath;

    public HistogramIterationListener(int iterations) {
        this(iterations, true, "weights");
    }

    public HistogramIterationListener(int iterations, boolean openBrowser, String subPath) {
        int port = -1;
        try {
            UiServer server = UiServer.getInstance();
            port = server.getPort();
        }
        catch (Exception e) {
            log.error("Error initializing UI server", (Throwable)e);
            throw new RuntimeException(e);
        }
        this.iterations = iterations;
        this.target = this.client.target("http://localhost:" + port).path(subPath).path("update");
        this.openBrowser = openBrowser;
        this.path = "http://localhost:" + port + "/" + subPath;
        this.subPath = subPath;
        System.out.println("UI Histogram: " + this.path);
    }

    public boolean invoked() {
        return false;
    }

    public void invoke() {
    }

    public void iterationDone(Model model, int iteration) {
        if (iteration % this.iterations == 0) {
            Map grad = model.gradient().gradientForVariable();
            if (this.meanMagHistoryParams.size() == 0) {
                int maxLayerIdx = -1;
                for (String s : grad.keySet()) {
                    maxLayerIdx = Math.max(maxLayerIdx, HistogramIterationListener.indexFromString(s));
                }
                if (maxLayerIdx == -1) {
                    maxLayerIdx = 0;
                }
                for (int i = 0; i <= maxLayerIdx; ++i) {
                    this.meanMagHistoryParams.add(new LinkedHashMap());
                    this.meanMagHistoryUpdates.add(new LinkedHashMap());
                }
            }
            LinkedHashMap<String, INDArray> newGrad = new LinkedHashMap<String, INDArray>();
            for (Map.Entry entry : grad.entrySet()) {
                String param = (String)entry.getKey();
                String string = "param_" + (String)param;
                newGrad.put(string, ((INDArray)entry.getValue()).dup());
                Map<String, List<Double>> map = this.meanMagHistoryUpdates.get(HistogramIterationListener.indexFromString(param));
                List<Double> list = map.get(string);
                if (list == null) {
                    list = new ArrayList<Double>();
                    map.put(string, list);
                }
                double meanMag = ((INDArray)entry.getValue()).norm1Number().doubleValue() / (double)((INDArray)entry.getValue()).length();
                list.add(meanMag);
            }
            Map params = model.paramTable();
            LinkedHashMap<String, INDArray> newParams = new LinkedHashMap<String, INDArray>();
            for (Map.Entry entry : params.entrySet()) {
                String param = (String)entry.getKey();
                String newName = "param_" + param;
                newParams.put(newName, ((INDArray)entry.getValue()).dup());
                Map<String, List<Double>> map = this.meanMagHistoryParams.get(HistogramIterationListener.indexFromString(param));
                List<Double> list = map.get(newName);
                if (list == null) {
                    list = new ArrayList<Double>();
                    map.put(newName, list);
                }
                double meanMag = ((INDArray)entry.getValue()).norm1Number().doubleValue() / (double)((INDArray)entry.getValue()).length();
                list.add(meanMag);
            }
            double score = model.score();
            this.scoreHistory.add(score);
            ModelAndGradient g = new ModelAndGradient();
            g.setGradients(newGrad);
            g.setParameters(newParams);
            g.setScore(score);
            g.setScores(this.scoreHistory);
            g.setPath(this.subPath);
            g.setUpdateMagnitudes(this.meanMagHistoryUpdates);
            g.setParamMagnitudes(this.meanMagHistoryParams);
            g.setLastUpdateTime(System.currentTimeMillis());
            Response resp = this.target.request(new String[]{"application/json"}).accept(new String[]{"application/json"}).post(Entity.entity((Object)g, (String)"application/json"));
            log.debug("{}", (Object)resp);
            if (this.openBrowser && this.firstIteration) {
                UiUtils.tryOpenBrowser(this.path, log);
                this.firstIteration = false;
            }
        }
    }

    private static int indexFromString(String str) {
        int underscore = str.indexOf("_");
        if (underscore == -1) {
            return -1;
        }
        String subStr = str.substring(0, underscore);
        try {
            return Integer.parseInt(subStr);
        }
        catch (NumberFormatException e) {
            return -1;
        }
    }
}

