/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.rntn;

import java.util.List;
import java.util.SortedSet;
import org.deeplearning4j.eval.ConfusionMatrix;
import org.deeplearning4j.models.rntn.RNTN;
import org.deeplearning4j.nn.layers.feedforward.autoencoder.recursive.Tree;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RNTNEval {
    private ConfusionMatrix<Integer> cf = new ConfusionMatrix();
    private static final Logger log = LoggerFactory.getLogger(RNTNEval.class);

    public void eval(RNTN rntn, List<Tree> trees) {
        for (Tree t : trees) {
            rntn.forwardPropagateTree(t);
            this.count(t);
        }
    }

    private void count(Tree tree) {
        if (tree.isLeaf()) {
            return;
        }
        if (tree.prediction() == null) {
            return;
        }
        for (Tree t : tree.children()) {
            this.count(t);
        }
        int treeGoldLabel = tree.goldLabel();
        int predictionLabel = Nd4j.getBlasWrapper().iamax(tree.prediction());
        this.cf.add((Comparable)Integer.valueOf(treeGoldLabel), (Comparable)Integer.valueOf(predictionLabel));
    }

    public String stats() {
        StringBuilder builder = new StringBuilder().append("\n");
        SortedSet classes = this.cf.getClasses();
        for (Integer clazz : classes) {
            for (Integer clazz2 : classes) {
                int count = this.cf.getCount((Comparable)clazz, (Comparable)clazz2);
                if (count == 0) continue;
                builder.append("\nActual Class " + clazz + " was predicted with Predicted " + clazz2 + " with count " + count + " times\n");
            }
        }
        return builder.toString();
    }
}

