/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.eval;

import com.google.common.collect.HashMultiset;
import com.google.common.collect.Multiset;
import java.io.Serializable;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class ConfusionMatrix<T extends Comparable<? super T>>
implements Serializable {
    private Map<T, Multiset<T>> matrix;
    private List<T> classes;

    public ConfusionMatrix(List<T> classes) {
        this.matrix = new HashMap<T, Multiset<T>>();
        this.classes = classes;
    }

    public ConfusionMatrix() {
    }

    public ConfusionMatrix(ConfusionMatrix<T> other) {
        this(other.getClasses());
        this.add(other);
    }

    public void add(T actual, T predicted) {
        this.add(actual, predicted, 1);
    }

    public void add(T actual, T predicted, int count) {
        if (this.matrix.containsKey(actual)) {
            this.matrix.get(actual).add(predicted, count);
        } else {
            HashMultiset counts = HashMultiset.create();
            counts.add(predicted, count);
            this.matrix.put(actual, (Multiset<T>)counts);
        }
    }

    public void add(ConfusionMatrix<T> other) {
        for (Comparable actual : other.matrix.keySet()) {
            Multiset<T> counts = other.matrix.get(actual);
            for (Comparable predicted : counts.elementSet()) {
                int count = counts.count((Object)predicted);
                this.add(actual, predicted, count);
            }
        }
    }

    public List<T> getClasses() {
        return this.classes;
    }

    public int getCount(T actual, T predicted) {
        if (!this.matrix.containsKey(actual)) {
            return 0;
        }
        return this.matrix.get(actual).count(predicted);
    }

    public int getPredictedTotal(T predicted) {
        int total = 0;
        for (Comparable actual : this.classes) {
            total += this.getCount(actual, predicted);
        }
        return total;
    }

    public int getActualTotal(T actual) {
        if (!this.matrix.containsKey(actual)) {
            return 0;
        }
        int total = 0;
        for (Comparable elem : this.matrix.get(actual).elementSet()) {
            total += this.matrix.get(actual).count((Object)elem);
        }
        return total;
    }

    public String toString() {
        return this.matrix.toString();
    }

    public String toCSV() {
        StringBuilder builder = new StringBuilder();
        builder.append(",,Predicted Class,\n");
        builder.append(",,");
        for (Comparable predicted : this.classes) {
            builder.append(String.format("%s,", predicted));
        }
        builder.append("Total\n");
        String firstColumnLabel = "Actual Class,";
        for (Comparable actual : this.classes) {
            builder.append(firstColumnLabel);
            firstColumnLabel = ",";
            builder.append(String.format("%s,", actual));
            for (Comparable predicted : this.classes) {
                builder.append(this.getCount(actual, predicted));
                builder.append(",");
            }
            builder.append(this.getActualTotal(actual));
            builder.append("\n");
        }
        builder.append(",Total,");
        for (Comparable predicted : this.classes) {
            builder.append(this.getPredictedTotal(predicted));
            builder.append(",");
        }
        builder.append("\n");
        return builder.toString();
    }

    public String toHTML() {
        StringBuilder builder = new StringBuilder();
        int numClasses = this.classes.size();
        builder.append("<table>\n");
        builder.append("<tr><th class=\"empty-space\" colspan=\"2\" rowspan=\"2\">");
        builder.append(String.format("<th class=\"predicted-class-header\" colspan=\"%d\">Predicted Class</th></tr>\n", numClasses + 1));
        builder.append("<tr>");
        for (Comparable predicted : this.classes) {
            builder.append("<th class=\"predicted-class-header\">");
            builder.append(predicted);
            builder.append("</th>");
        }
        builder.append("<th class=\"predicted-class-header\">Total</th>");
        builder.append("</tr>\n");
        String firstColumnLabel = String.format("<tr><th class=\"actual-class-header\" rowspan=\"%d\">Actual Class</th>", numClasses + 1);
        for (Comparable actual : this.classes) {
            builder.append(firstColumnLabel);
            firstColumnLabel = "<tr>";
            builder.append(String.format("<th class=\"actual-class-header\" >%s</th>", actual));
            for (Comparable predicted : this.classes) {
                builder.append("<td class=\"count-element\">");
                builder.append(this.getCount(actual, predicted));
                builder.append("</td>");
            }
            builder.append("<td class=\"count-element\">");
            builder.append(this.getActualTotal(actual));
            builder.append("</td>");
            builder.append("</tr>\n");
        }
        builder.append("<tr><th class=\"actual-class-header\">Total</th>");
        for (Comparable predicted : this.classes) {
            builder.append("<td class=\"count-element\">");
            builder.append(this.getPredictedTotal(predicted));
            builder.append("</td>");
        }
        builder.append("<td class=\"empty-space\"></td>\n");
        builder.append("</tr>\n");
        builder.append("</table>\n");
        return builder.toString();
    }

    public static void main(String[] args) {
        ConfusionMatrix<String> confusionMatrix = new ConfusionMatrix<String>(Arrays.asList("a", "b", "c"));
        confusionMatrix.add("a", "a", 88);
        confusionMatrix.add("a", "b", 10);
        confusionMatrix.add("b", "a", 14);
        confusionMatrix.add("b", "b", 40);
        confusionMatrix.add("b", "c", 6);
        confusionMatrix.add("c", "a", 18);
        confusionMatrix.add("c", "b", 10);
        confusionMatrix.add("c", "c", 12);
        ConfusionMatrix<String> confusionMatrix2 = new ConfusionMatrix<String>(confusionMatrix);
        confusionMatrix2.add(confusionMatrix);
        System.out.println(confusionMatrix2.toHTML());
        System.out.println(confusionMatrix2.toCSV());
    }
}

