/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.eval;

import java.io.Serializable;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.eval.ConfusionMatrix;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Evaluation
implements Serializable {
    private Counter<Integer> truePositives = new Counter();
    private Counter<Integer> falsePositives = new Counter();
    private Counter<Integer> trueNegatives = new Counter();
    private Counter<Integer> falseNegatives = new Counter();
    private ConfusionMatrix<Integer> confusion;
    private int numRowCounter;
    private List<Integer> classLabels = new ArrayList<Integer>();
    private static Logger log = LoggerFactory.getLogger(Evaluation.class);

    public Evaluation() {
    }

    public Evaluation(int numClasses) {
        for (int i = 0; i < numClasses; ++i) {
            this.classLabels.add(i);
        }
        this.confusion = new ConfusionMatrix<Integer>(this.classLabels);
        this.numRowCounter = 0;
    }

    public void eval(INDArray realOutcomes, INDArray guesses) {
        this.numRowCounter += realOutcomes.shape()[0];
        if (this.confusion == null) {
            log.warn("Creating confusion matrix based on classes passed in . Will assume the label distribution passed in is indicative of the overall dataset");
            HashSet<Integer> classes = new HashSet<Integer>();
            for (int i = 0; i < realOutcomes.rows(); ++i) {
                classes.add(Nd4j.getBlasWrapper().iamax(realOutcomes.slice(i)));
            }
            this.confusion = new ConfusionMatrix(new ArrayList(classes));
        }
        if (realOutcomes.length() != guesses.length()) {
            throw new IllegalArgumentException("Unable to evaluate. Outcome matrices not same length");
        }
        for (int i = 0; i < realOutcomes.rows(); ++i) {
            INDArray currRow = realOutcomes.getRow(i);
            INDArray guessRow = guesses.getRow(i);
            double max = currRow.getDouble(0);
            int currMax = 0;
            for (int col = 1; col < currRow.columns(); ++col) {
                if (!(currRow.getDouble(col) > max)) continue;
                max = currRow.getDouble(col);
                currMax = col;
            }
            double max2 = guessRow.getDouble(0);
            int guessMax = 0;
            for (int col = 1; col < guessRow.columns(); ++col) {
                if (!(guessRow.getDouble(col) > max2)) continue;
                max2 = guessRow.getDouble(col);
                guessMax = col;
            }
            this.addToConfusion(currMax, guessMax);
            if (currMax == guessMax) {
                this.incrementTruePositives(guessMax);
                for (Integer clazz : this.confusion.getClasses()) {
                    if (clazz == guessMax) continue;
                    this.trueNegatives.incrementCount(clazz, 1.0);
                }
                continue;
            }
            this.incrementFalseNegatives(currMax);
            this.incrementFalsePositives(guessMax);
            for (Integer clazz : this.confusion.getClasses()) {
                if (clazz == guessMax || clazz == currMax) continue;
                this.trueNegatives.incrementCount(clazz, 1.0);
            }
        }
    }

    public String stats() {
        StringBuilder builder = new StringBuilder().append("\n");
        List<Integer> classes = this.confusion.getClasses();
        for (Integer clazz : classes) {
            for (Integer clazz2 : classes) {
                int count = this.confusion.getCount(clazz, clazz2);
                if (count == 0) continue;
                builder.append("\nActual Class " + clazz + " was predicted with Predicted " + clazz2 + " with count " + count + " times\n");
            }
        }
        DecimalFormat df = new DecimalFormat("#.####");
        builder.append("\n==========================Scores========================================");
        builder.append("\n Accuracy:  " + df.format(this.accuracy()));
        builder.append("\n Precision: " + df.format(this.precision()));
        builder.append("\n Recall:    " + df.format(this.recall()));
        builder.append("\n F1 Score:  " + this.f1());
        builder.append("\n===========================================================================");
        return builder.toString();
    }

    public double precision(Integer classLabel) {
        double tpCount = this.truePositives.getCount(classLabel);
        double fpCount = this.falsePositives.getCount(classLabel);
        if (tpCount == 0.0) {
            return 0.0;
        }
        return tpCount / (tpCount + fpCount);
    }

    public double precision() {
        double precisionAcc = 0.0;
        double classCount = 0.0;
        for (Integer classLabel : this.confusion.getClasses()) {
            precisionAcc += this.precision(classLabel);
            if (!(this.truePositives.getCount(classLabel) > 0.0)) continue;
            classCount += 1.0;
        }
        return precisionAcc / classCount;
    }

    public double recall(Integer classLabel) {
        double tpCount = this.truePositives.getCount(classLabel);
        double fnCount = this.falseNegatives.getCount(classLabel);
        if (tpCount == 0.0) {
            return 0.0;
        }
        return tpCount / (tpCount + fnCount);
    }

    public double recall() {
        double recallAcc = 0.0;
        double classCount = 0.0;
        for (Integer classLabel : this.confusion.getClasses()) {
            recallAcc += this.recall(classLabel);
            if (!(this.truePositives.getCount(classLabel) > 0.0)) continue;
            classCount += 1.0;
        }
        return recallAcc / classCount;
    }

    public double f1(Integer classLabel) {
        double precision = this.precision(classLabel);
        double recall = this.recall();
        if (precision == 0.0 || recall == 0.0) {
            return 0.0;
        }
        return 2.0 * (precision * recall / (precision + recall));
    }

    public double f1() {
        double precision = this.precision();
        double recall = this.recall();
        if (precision == 0.0 || recall == 0.0) {
            return 0.0;
        }
        return 2.0 * (precision * recall / (precision + recall));
    }

    public double accuracy() {
        return this.truePositives() / this.getNumRowCounter();
    }

    public double truePositives() {
        return this.truePositives.totalCount();
    }

    public double trueNegatives() {
        return this.trueNegatives.totalCount();
    }

    public double falsePositives() {
        return this.falsePositives.totalCount();
    }

    public double falseNegatives() {
        return this.falseNegatives.totalCount();
    }

    public double negative() {
        return this.trueNegatives() + this.falsePositives();
    }

    public double positive() {
        return this.truePositives() + this.falseNegatives();
    }

    public void incrementTruePositives(Integer classLabel) {
        this.truePositives.incrementCount(classLabel, 1.0);
    }

    public void incrementTrueNegatives(Integer classLabel) {
        this.truePositives.incrementCount(classLabel, 1.0);
    }

    public void incrementFalseNegatives(Integer classLabel) {
        this.falseNegatives.incrementCount(classLabel, 1.0);
    }

    public void incrementFalsePositives(Integer classLabel) {
        this.falsePositives.incrementCount(classLabel, 1.0);
    }

    public void addToConfusion(Integer real, Integer guess) {
        this.confusion.add(real, guess);
    }

    public int classCount(Integer clazz) {
        return this.confusion.getActualTotal(clazz);
    }

    public double getNumRowCounter() {
        return this.numRowCounter;
    }
}

