/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.listeners.impl;

import com.google.flatbuffers.Table;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import lombok.NonNull;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.ListenerResponse;
import org.nd4j.autodiff.listeners.Loss;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.listeners.records.LossCurve;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.graph.UIGraphStructure;
import org.nd4j.graph.UIStaticInfoRecord;
import org.nd4j.graph.ui.LogFileWriter;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.primitives.Pair;

public class UIListener
extends BaseListener {
    private FileMode fileMode;
    private File logFile;
    private int lossPlotFreq;
    private int performanceStatsFrequency;
    private int updateRatioFrequency;
    private UpdateRatio updateRatioType;
    private int histogramFrequency;
    private HistogramType[] histogramTypes;
    private int opProfileFrequency;
    private Map<Pair<String, Integer>, List<Evaluation.Metric>> trainEvalMetrics;
    private int trainEvalFrequency;
    private TestEvaluation testEvaluation;
    private int learningRateFrequency;
    private MultiDataSet currentIterDataSet;
    private LogFileWriter writer;
    private boolean wroteLossNames;
    private boolean wroteLearningRateName;
    private Set<String> relevantOpsForEval;
    private Map<Pair<String, Integer>, Evaluation> epochTrainEval;
    private boolean wroteEvalNames;
    private boolean wroteEvalNamesIter;
    private int firstUpdateRatioIter = -1;
    private boolean checkStructureForRestore;

    private UIListener(Builder b) {
        this.fileMode = b.fileMode;
        this.logFile = b.logFile;
        this.lossPlotFreq = b.lossPlotFreq;
        this.performanceStatsFrequency = b.performanceStatsFrequency;
        this.updateRatioFrequency = b.updateRatioFrequency;
        this.updateRatioType = b.updateRatioType;
        this.histogramFrequency = b.histogramFrequency;
        this.histogramTypes = b.histogramTypes;
        this.opProfileFrequency = b.opProfileFrequency;
        this.trainEvalMetrics = b.trainEvalMetrics;
        this.trainEvalFrequency = b.trainEvalFrequency;
        this.testEvaluation = b.testEvaluation;
        this.learningRateFrequency = b.learningRateFrequency;
        switch (this.fileMode) {
            case CREATE: {
                Preconditions.checkState((!this.logFile.exists() ? 1 : 0) != 0, (String)"Log file already exists and fileMode is set to CREATE: %s\nEither delete the existing file, specify a path that doesn't exist, or set the UIListener to another mode such as CREATE_OR_APPEND", (Object)this.logFile);
                break;
            }
            case APPEND: {
                Preconditions.checkState((boolean)this.logFile.exists(), (String)"Log file does not exist and fileMode is set to APPEND: %s\nEither specify a path to an existing log file for this model, or set the UIListener to another mode such as CREATE_OR_APPEND", (Object)this.logFile);
            }
        }
        if (this.logFile.exists()) {
            this.restoreLogFile();
        }
    }

    protected void restoreLogFile() {
        if (this.logFile.length() == 0L && this.fileMode == FileMode.CREATE_OR_APPEND || this.fileMode == FileMode.APPEND) {
            this.logFile.delete();
            return;
        }
        try {
            this.writer = new LogFileWriter(this.logFile);
        }
        catch (IOException e) {
            throw new RuntimeException("Error restoring existing log file at path: " + this.logFile.getAbsolutePath(), e);
        }
        if (this.fileMode == FileMode.APPEND || this.fileMode == FileMode.CREATE_OR_APPEND) {
            LogFileWriter.StaticInfo si;
            try {
                si = this.writer.readStatic();
            }
            catch (IOException e) {
                throw new RuntimeException("Error restoring existing log file, static info at path: " + this.logFile.getAbsolutePath(), e);
            }
            List<Pair<UIStaticInfoRecord, Table>> staticList = si.getData();
            if (si != null) {
                for (int i = 0; i < staticList.size(); ++i) {
                    UIStaticInfoRecord r = (UIStaticInfoRecord)((Object)staticList.get(i).getFirst());
                    if (r.infoType() != 0) continue;
                    this.checkStructureForRestore = true;
                }
            }
        }
    }

    protected void checkStructureForRestore(SameDiff sd) {
        LogFileWriter.StaticInfo si;
        try {
            si = this.writer.readStatic();
        }
        catch (IOException e) {
            throw new RuntimeException("Error restoring existing log file, static info at path: " + this.logFile.getAbsolutePath(), e);
        }
        List<Pair<UIStaticInfoRecord, Table>> staticList = si.getData();
        if (si != null) {
            UIGraphStructure structure = null;
            for (int i = 0; i < staticList.size(); ++i) {
                UIStaticInfoRecord r = (UIStaticInfoRecord)((Object)staticList.get(i).getFirst());
                if (r.infoType() != 0) continue;
                structure = (UIGraphStructure)((Object)staticList.get(i).getSecond());
                break;
            }
            if (structure != null) {
                int nInFile = structure.inputsLength();
                ArrayList<String> phs = new ArrayList<String>(nInFile);
                for (int i = 0; i < nInFile; ++i) {
                    phs.add(structure.inputs(i));
                }
                List<String> actPhs = sd.inputs();
                if (actPhs.size() != phs.size() || !actPhs.containsAll(phs)) {
                    throw new IllegalStateException("Error continuing collection of UI stats in existing model file " + this.logFile.getAbsolutePath() + ": Model structure differs. Existing (file) model placeholders: " + phs + " vs. current model placeholders: " + actPhs + ". To disable this check, use FileMode.CREATE_APPEND_NOCHECK though this may result issues when rendering data via UI");
                }
                int nVarsFile = structure.variablesLength();
                ArrayList<String> vars = new ArrayList<String>(nVarsFile);
                for (int i = 0; i < nVarsFile; ++i) {
                    vars.add(structure.variables(i).name());
                }
                List<SDVariable> sdVars = sd.variables();
                ArrayList<String> varNames = new ArrayList<String>(sdVars.size());
                for (SDVariable v : sdVars) {
                    varNames.add(v.name());
                }
                if (varNames.size() != vars.size() || !varNames.containsAll(vars)) {
                    int countDifferent = 0;
                    ArrayList<String> different = new ArrayList<String>();
                    for (String s : varNames) {
                        if (vars.contains(s)) continue;
                        ++countDifferent;
                        if (different.size() >= 10) continue;
                        different.add(s);
                    }
                    StringBuilder msg = new StringBuilder();
                    msg.append("Error continuing collection of UI stats in existing model file ").append(this.logFile.getAbsolutePath()).append(": Current model structure differs vs. model structure in file - ").append(countDifferent).append(" variable names differ.");
                    if (different.size() == countDifferent) {
                        msg.append("\nVariables in new model not present in existing (file) model: ").append(different);
                    } else {
                        msg.append("\nFirst 10 variables in new model not present in existing (file) model: ").append(different);
                    }
                    msg.append("\nTo disable this check, use FileMode.CREATE_APPEND_NOCHECK though this may result issues when rendering data via UI");
                    throw new IllegalStateException(msg.toString());
                }
            }
        }
        this.checkStructureForRestore = false;
    }

    protected void initalizeWriter(SameDiff sd) {
        try {
            this.initializeHelper(sd);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    protected void initializeHelper(SameDiff sd) throws IOException {
        this.writer = new LogFileWriter(this.logFile);
        this.writer.writeGraphStructure(sd);
        this.writer.writeFinishStaticMarker();
    }

    @Override
    public boolean isActive(Operation operation) {
        return operation == Operation.TRAINING;
    }

    @Override
    public void epochStart(SameDiff sd, At at) {
        this.epochTrainEval = null;
    }

    @Override
    public ListenerResponse epochEnd(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis) {
        if (this.epochTrainEval != null) {
            long time = System.currentTimeMillis();
            for (Map.Entry<Pair<String, Integer>, Evaluation> e : this.epochTrainEval.entrySet()) {
                String n = "evaluation/" + (String)e.getKey().getFirst();
                List<Evaluation.Metric> l = this.trainEvalMetrics.get(e.getKey());
                for (Evaluation.Metric m : l) {
                    String mName = n + "/train/" + m.toString().toLowerCase();
                    if (!this.wroteEvalNames && !this.writer.registeredEventName(mName)) {
                        this.writer.registerEventNameQuiet(mName);
                    }
                    double score = e.getValue().scoreForMetric(m);
                    try {
                        this.writer.writeScalarEvent(mName, LogFileWriter.EventSubtype.EVALUATION, time, at.iteration(), at.epoch(), score);
                    }
                    catch (IOException ex) {
                        throw new RuntimeException("Error writing to log file", ex);
                    }
                }
                this.wroteEvalNames = true;
            }
        }
        this.epochTrainEval = null;
        return ListenerResponse.CONTINUE;
    }

    @Override
    public void iterationStart(SameDiff sd, At at, MultiDataSet data, long etlMs) {
        if (this.writer == null) {
            this.initalizeWriter(sd);
        }
        if (this.checkStructureForRestore) {
            this.checkStructureForRestore(sd);
        }
        this.currentIterDataSet = data;
    }

    @Override
    public void iterationDone(SameDiff sd, At at, MultiDataSet dataSet, Loss loss) {
        long time = System.currentTimeMillis();
        if (!this.wroteLossNames) {
            String n;
            for (String s : loss.getLossNames()) {
                String n2 = "losses/" + s;
                if (this.writer.registeredEventName(n2)) continue;
                this.writer.registerEventNameQuiet(n2);
            }
            if (loss.numLosses() > 1 && !this.writer.registeredEventName(n = "losses/totalLoss")) {
                this.writer.registerEventNameQuiet(n);
            }
            this.wroteLossNames = true;
        }
        List<String> lossNames = loss.getLossNames();
        double[] lossVals = loss.getLosses();
        for (int i = 0; i < lossVals.length; ++i) {
            try {
                String eventName = "losses/" + lossNames.get(i);
                this.writer.writeScalarEvent(eventName, LogFileWriter.EventSubtype.LOSS, time, at.iteration(), at.epoch(), lossVals[i]);
                continue;
            }
            catch (IOException e) {
                throw new RuntimeException("Error writing to log file", e);
            }
        }
        if (lossVals.length > 1) {
            double total = loss.totalLoss();
            try {
                String eventName = "losses/totalLoss";
                this.writer.writeScalarEvent(eventName, LogFileWriter.EventSubtype.LOSS, time, at.iteration(), at.epoch(), total);
            }
            catch (IOException e) {
                throw new RuntimeException("Error writing to log file", e);
            }
        }
        this.currentIterDataSet = null;
        if (this.learningRateFrequency > 0) {
            IUpdater u;
            if (!this.wroteLearningRateName) {
                String name = "learningRate";
                if (!this.writer.registeredEventName(name)) {
                    this.writer.registerEventNameQuiet(name);
                }
                this.wroteLearningRateName = true;
            }
            if (at.iteration() % this.learningRateFrequency == 0 && (u = sd.getTrainingConfig().getUpdater()).hasLearningRate()) {
                double lr = u.getLearningRate(at.iteration(), at.epoch());
                try {
                    this.writer.writeScalarEvent("learningRate", LogFileWriter.EventSubtype.LEARNING_RATE, time, at.iteration(), at.epoch(), lr);
                }
                catch (IOException e) {
                    throw new RuntimeException("Error writing to log file");
                }
            }
        }
    }

    @Override
    public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
        if (at.operation() == Operation.TRAINING && this.trainEvalMetrics != null && this.trainEvalMetrics.size() > 0) {
            long time = System.currentTimeMillis();
            if (this.relevantOpsForEval == null) {
                this.relevantOpsForEval = new HashSet<String>();
                for (Pair<String, Integer> p : this.trainEvalMetrics.keySet()) {
                    Variable v = sd.getVariables().get(p.getFirst());
                    String opName = v.getOutputOfOp();
                    Preconditions.checkState((opName != null ? 1 : 0) != 0, (String)"Cannot evaluate on variable of type %s - variable name: \"%s\"", (Object)((Object)v.getVariable().getVariableType()), (Object)opName);
                    this.relevantOpsForEval.add(v.getOutputOfOp());
                }
            }
            if (!this.relevantOpsForEval.contains(op.getName())) {
                return;
            }
            if (this.epochTrainEval == null) {
                this.epochTrainEval = new HashMap<Pair<String, Integer>, Evaluation>();
                for (Pair<String, Integer> p : this.trainEvalMetrics.keySet()) {
                    this.epochTrainEval.put(p, new Evaluation());
                }
            }
            boolean wrote = false;
            for (Pair<String, Integer> p : this.trainEvalMetrics.keySet()) {
                int idx = op.getOutputsOfOp().indexOf(p.getFirst());
                INDArray out = outputs[idx];
                INDArray label = this.currentIterDataSet.getLabels((Integer)p.getSecond());
                INDArray mask = this.currentIterDataSet.getLabelsMaskArray((Integer)p.getSecond());
                this.epochTrainEval.get(p).eval(label, out, mask);
                if (this.trainEvalFrequency <= 0 || at.iteration() <= 0 || at.iteration() % this.trainEvalFrequency != 0) continue;
                for (Evaluation.Metric m : this.trainEvalMetrics.get(p)) {
                    String n = "evaluation/train_iter/" + (String)p.getKey() + "/" + m.toString().toLowerCase();
                    if (!this.wroteEvalNamesIter) {
                        if (!this.writer.registeredEventName(n)) {
                            this.writer.registerEventNameQuiet(n);
                        }
                        wrote = true;
                    }
                    double score = this.epochTrainEval.get(p).scoreForMetric(m);
                    try {
                        this.writer.writeScalarEvent(n, LogFileWriter.EventSubtype.EVALUATION, time, at.iteration(), at.epoch(), score);
                    }
                    catch (IOException e) {
                        throw new RuntimeException("Error writing to log file");
                    }
                }
            }
            this.wroteEvalNamesIter = wrote;
        }
    }

    @Override
    public void preUpdate(SameDiff sd, At at, Variable v, INDArray update) {
        if (this.writer == null) {
            this.initalizeWriter(sd);
        }
        if (this.updateRatioFrequency > 0 && at.iteration() % this.updateRatioFrequency == 0) {
            double updates;
            double params;
            String name;
            if (this.firstUpdateRatioIter < 0) {
                this.firstUpdateRatioIter = at.iteration();
            }
            if (this.firstUpdateRatioIter == at.iteration() && !this.writer.registeredEventName(name = "logUpdateRatio/" + v.getName())) {
                this.writer.registerEventNameQuiet(name);
            }
            if (this.updateRatioType == UpdateRatio.L2) {
                params = v.getVariable().getArr().norm2Number().doubleValue();
                updates = update.norm2Number().doubleValue();
            } else {
                params = v.getVariable().getArr().norm1Number().doubleValue();
                updates = update.norm1Number().doubleValue();
            }
            double ratio = updates / params;
            ratio = params == 0.0 ? 0.0 : Math.max(-10.0, Math.log10(ratio));
            try {
                String name2 = "logUpdateRatio/" + v.getName();
                this.writer.writeScalarEvent(name2, LogFileWriter.EventSubtype.LOSS, System.currentTimeMillis(), at.iteration(), at.epoch(), ratio);
            }
            catch (IOException e) {
                throw new RuntimeException("Error writing to log file", e);
            }
        }
    }

    public static Builder builder(File logFile) {
        return new Builder(logFile);
    }

    public static class TestEvaluation {
    }

    public static class Builder {
        private FileMode fileMode = FileMode.CREATE_OR_APPEND;
        private File logFile;
        private int lossPlotFreq = 1;
        private int performanceStatsFrequency = -1;
        private int updateRatioFrequency = -1;
        private UpdateRatio updateRatioType = UpdateRatio.MEAN_MAGNITUDE;
        private int histogramFrequency = -1;
        private HistogramType[] histogramTypes;
        private int opProfileFrequency = -1;
        private Map<Pair<String, Integer>, List<Evaluation.Metric>> trainEvalMetrics;
        private int trainEvalFrequency = 10;
        private TestEvaluation testEvaluation = null;
        private int learningRateFrequency = 10;

        public Builder(@NonNull File logFile) {
            if (logFile == null) {
                throw new NullPointerException("logFile is marked @NonNull but is null");
            }
            this.logFile = logFile;
        }

        public Builder fileMode(FileMode fileMode) {
            this.fileMode = fileMode;
            return this;
        }

        public Builder plotLosses(int frequency) {
            this.lossPlotFreq = frequency;
            return this;
        }

        public Builder performanceStats(int frequency) {
            this.performanceStatsFrequency = frequency;
            return this;
        }

        public Builder trainEvaluationMetrics(String name, int labelIdx, Evaluation.Metric ... metrics) {
            Pair p;
            if (this.trainEvalMetrics == null) {
                this.trainEvalMetrics = new LinkedHashMap<Pair<String, Integer>, List<Evaluation.Metric>>();
            }
            if (!this.trainEvalMetrics.containsKey(p = new Pair((Object)name, (Object)labelIdx))) {
                this.trainEvalMetrics.put((Pair<String, Integer>)p, new ArrayList());
            }
            List<Evaluation.Metric> l = this.trainEvalMetrics.get(p);
            for (Evaluation.Metric m : metrics) {
                if (l.contains(m)) continue;
                l.add(m);
            }
            return this;
        }

        public Builder trainAccuracy(String name, int labelIdx) {
            return this.trainEvaluationMetrics(name, labelIdx, Evaluation.Metric.ACCURACY);
        }

        public Builder trainF1(String name, int labelIdx) {
            return this.trainEvaluationMetrics(name, labelIdx, Evaluation.Metric.F1);
        }

        public Builder trainEvalFrequency(int trainEvalFrequency) {
            this.trainEvalFrequency = trainEvalFrequency;
            return this;
        }

        public Builder updateRatios(int frequency) {
            return this.updateRatios(frequency, UpdateRatio.MEAN_MAGNITUDE);
        }

        public Builder updateRatios(int frequency, UpdateRatio ratioType) {
            this.updateRatioFrequency = frequency;
            this.updateRatioType = ratioType;
            return this;
        }

        public Builder histograms(int frequency, HistogramType ... types) {
            this.histogramFrequency = frequency;
            this.histogramTypes = types;
            return this;
        }

        public Builder profileOps(int frequency) {
            this.opProfileFrequency = frequency;
            return this;
        }

        public Builder testEvaluation(TestEvaluation testEvalConfig) {
            this.testEvaluation = testEvalConfig;
            return this;
        }

        public Builder learningRate(int frequency) {
            this.learningRateFrequency = frequency;
            return this;
        }

        public UIListener build() {
            return new UIListener(this);
        }
    }

    public static enum HistogramType {
        PARAMETERS,
        PARAMETER_GRADIENTS,
        PARAMETER_UPDATES,
        ACTIVATIONS,
        ACTIVATION_GRADIENTS;

    }

    public static enum UpdateRatio {
        L2,
        MEAN_MAGNITUDE;

    }

    public static enum FileMode {
        CREATE,
        APPEND,
        CREATE_OR_APPEND,
        CREATE_APPEND_NOCHECK;

    }
}

