package edu.stanford.nlp.stats;

import edu.stanford.nlp.classify.Classifier;
import edu.stanford.nlp.classify.GeneralDataset;
import edu.stanford.nlp.classify.ProbabilisticClassifier;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Triple;
import java.text.NumberFormat;
import java.util.ArrayList;

/* loaded from: input_file:edu/stanford/nlp/stats/MultiClassPrecisionRecallStats.class */
public class MultiClassPrecisionRecallStats<L> implements Scorer<L> {
    protected int[] tpCount;
    protected int[] fpCount;
    protected int[] fnCount;
    protected Index<L> labelIndex;
    protected L negLabel;
    protected int negIndex = -1;

    public <F> MultiClassPrecisionRecallStats(Classifier<L, F> classifier, GeneralDataset<L, F> generalDataset, L l) {
        this.negLabel = l;
        score(classifier, generalDataset);
    }

    public MultiClassPrecisionRecallStats(L l) {
        this.negLabel = l;
    }

    public L getNegLabel() {
        return this.negLabel;
    }

    @Override // edu.stanford.nlp.stats.Scorer
    public <F> double score(ProbabilisticClassifier<L, F> probabilisticClassifier, GeneralDataset<L, F> generalDataset) {
        return score((Classifier) probabilisticClassifier, (GeneralDataset) generalDataset);
    }

    public <F> double score(Classifier<L, F> classifier, GeneralDataset<L, F> generalDataset) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < generalDataset.size(); i++) {
            arrayList.add(classifier.classOf(generalDataset.getRVFDatum(i)));
        }
        int[] labelsArray = generalDataset.getLabelsArray();
        this.labelIndex = generalDataset.labelIndex;
        for (int i2 = 0; i2 < generalDataset.size(); i2++) {
            arrayList2.add(this.labelIndex.get(labelsArray[i2]));
        }
        this.labelIndex = new HashIndex();
        this.labelIndex.addAll(generalDataset.labelIndex().objectsList());
        this.labelIndex.addAll(classifier.labels());
        int size = this.labelIndex.size();
        this.tpCount = new int[size];
        this.fpCount = new int[size];
        this.fnCount = new int[size];
        this.negIndex = this.labelIndex.indexOf(this.negLabel);
        for (int i3 = 0; i3 < arrayList.size(); i3++) {
            int indexOf = this.labelIndex.indexOf(arrayList.get(i3));
            int indexOf2 = this.labelIndex.indexOf(arrayList2.get(i3));
            if (indexOf != indexOf2) {
                if (indexOf != this.negIndex) {
                    int[] iArr = this.fpCount;
                    iArr[indexOf] = iArr[indexOf] + 1;
                }
                if (indexOf2 != this.negIndex) {
                    int[] iArr2 = this.fnCount;
                    iArr2[indexOf2] = iArr2[indexOf2] + 1;
                }
            } else if (indexOf != this.negIndex) {
                int[] iArr3 = this.tpCount;
                iArr3[indexOf] = iArr3[indexOf] + 1;
            }
        }
        return getFMeasure();
    }

    public Triple<Double, Integer, Integer> getPrecisionInfo(L l) {
        int indexOf = this.labelIndex.indexOf(l);
        return (this.tpCount[indexOf] == 0 && this.fpCount[indexOf] == 0) ? new Triple<>(Double.valueOf(1.0d), Integer.valueOf(this.tpCount[indexOf]), Integer.valueOf(this.fpCount[indexOf])) : new Triple<>(Double.valueOf(this.tpCount[indexOf] / (this.tpCount[indexOf] + this.fpCount[indexOf])), Integer.valueOf(this.tpCount[indexOf]), Integer.valueOf(this.fpCount[indexOf]));
    }

    public double getPrecision(L l) {
        return getPrecisionInfo(l).first().doubleValue();
    }

    public Triple<Double, Integer, Integer> getPrecisionInfo() {
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < this.labelIndex.size(); i3++) {
            if (i3 != this.negIndex) {
                i += this.tpCount[i3];
                i2 += this.fpCount[i3];
            }
        }
        return new Triple<>(Double.valueOf(i / (i + i2)), Integer.valueOf(i), Integer.valueOf(i2));
    }

    public double getPrecision() {
        return getPrecisionInfo().first().doubleValue();
    }

    public String getPrecisionDescription(int i) {
        NumberFormat numberInstance = NumberFormat.getNumberInstance();
        numberInstance.setMaximumFractionDigits(i);
        Triple<Double, Integer, Integer> precisionInfo = getPrecisionInfo();
        return numberInstance.format(precisionInfo.first()) + "  (" + precisionInfo.second() + "/" + (precisionInfo.second().intValue() + precisionInfo.third().intValue()) + ")";
    }

    public String getPrecisionDescription(int i, L l) {
        NumberFormat numberInstance = NumberFormat.getNumberInstance();
        numberInstance.setMaximumFractionDigits(i);
        Triple<Double, Integer, Integer> precisionInfo = getPrecisionInfo(l);
        return numberInstance.format(precisionInfo.first()) + "  (" + precisionInfo.second() + "/" + (precisionInfo.second().intValue() + precisionInfo.third().intValue()) + ")";
    }

    public Triple<Double, Integer, Integer> getRecallInfo(L l) {
        int indexOf = this.labelIndex.indexOf(l);
        return (this.tpCount[indexOf] == 0 && this.fnCount[indexOf] == 0) ? new Triple<>(Double.valueOf(1.0d), Integer.valueOf(this.tpCount[indexOf]), Integer.valueOf(this.fnCount[indexOf])) : new Triple<>(Double.valueOf(this.tpCount[indexOf] / (this.tpCount[indexOf] + this.fnCount[indexOf])), Integer.valueOf(this.tpCount[indexOf]), Integer.valueOf(this.fnCount[indexOf]));
    }

    public double getRecall(L l) {
        return getRecallInfo(l).first().doubleValue();
    }

    public Triple<Double, Integer, Integer> getRecallInfo() {
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < this.labelIndex.size(); i3++) {
            if (i3 != this.negIndex) {
                i += this.tpCount[i3];
                i2 += this.fnCount[i3];
            }
        }
        return new Triple<>(Double.valueOf(i / (i + i2)), Integer.valueOf(i), Integer.valueOf(i2));
    }

    public double getRecall() {
        return getRecallInfo().first().doubleValue();
    }

    public String getRecallDescription(int i) {
        NumberFormat numberInstance = NumberFormat.getNumberInstance();
        numberInstance.setMaximumFractionDigits(i);
        Triple<Double, Integer, Integer> recallInfo = getRecallInfo();
        return numberInstance.format(recallInfo.first()) + "  (" + recallInfo.second() + "/" + (recallInfo.second().intValue() + recallInfo.third().intValue()) + ")";
    }

    public String getRecallDescription(int i, L l) {
        NumberFormat numberInstance = NumberFormat.getNumberInstance();
        numberInstance.setMaximumFractionDigits(i);
        Triple<Double, Integer, Integer> recallInfo = getRecallInfo(l);
        return numberInstance.format(recallInfo.first()) + "  (" + recallInfo.second() + "/" + (recallInfo.second().intValue() + recallInfo.third().intValue()) + ")";
    }

    public double getFMeasure(L l) {
        double precision = getPrecision(l);
        double recall = getRecall(l);
        return ((2.0d * precision) * recall) / (precision + recall);
    }

    public double getFMeasure() {
        double precision = getPrecision();
        double recall = getRecall();
        return ((2.0d * precision) * recall) / (precision + recall);
    }

    public String getF1Description(int i) {
        NumberFormat numberInstance = NumberFormat.getNumberInstance();
        numberInstance.setMaximumFractionDigits(i);
        return numberInstance.format(getFMeasure());
    }

    public String getF1Description(int i, L l) {
        NumberFormat numberInstance = NumberFormat.getNumberInstance();
        numberInstance.setMaximumFractionDigits(i);
        return numberInstance.format(getFMeasure(l));
    }

    @Override // edu.stanford.nlp.stats.Scorer
    public String getDescription(int i) {
        StringBuilder sb = new StringBuilder();
        sb.append("--- PR Stats ---").append("\n");
        for (L l : this.labelIndex) {
            if (l != null && !l.equals(this.negLabel)) {
                sb.append("** ").append(l.toString()).append(" **\n");
                sb.append("\tPrec:   ").append(getPrecisionDescription(i, l)).append("\n");
                sb.append("\tRecall: ").append(getRecallDescription(i, l)).append("\n");
                sb.append("\tF1:     ").append(getF1Description(i, l)).append("\n");
            }
        }
        sb.append("** Overall **\n");
        sb.append("\tPrec:   ").append(getPrecisionDescription(i)).append("\n");
        sb.append("\tRecall: ").append(getRecallDescription(i)).append("\n");
        sb.append("\tF1:     ").append(getF1Description(i));
        return sb.toString();
    }
}
