/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.eval;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.deeplearning4j.eval.BaseEvaluation;
import org.deeplearning4j.eval.EvaluationUtils;
import org.deeplearning4j.eval.ROCBinary;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastGreaterThan;
import org.nd4j.linalg.api.ops.impl.transforms.Not;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.serde.RowVectorDeserializer;
import org.nd4j.linalg.lossfunctions.serde.RowVectorSerializer;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;

public class EvaluationBinary
extends BaseEvaluation<EvaluationBinary> {
    public static final int DEFAULT_PRECISION = 4;
    public static final double DEFAULT_EDGE_VALUE = 0.0;
    private int[] countTruePositive;
    private int[] countFalsePositive;
    private int[] countTrueNegative;
    private int[] countFalseNegative;
    private ROCBinary rocBinary;
    private List<String> labels;
    @JsonSerialize(using=RowVectorSerializer.class)
    @JsonDeserialize(using=RowVectorDeserializer.class)
    private INDArray decisionThreshold;

    public EvaluationBinary(INDArray decisionThreshold) {
        if (decisionThreshold != null) {
            if (!decisionThreshold.isRowVector()) {
                throw new IllegalArgumentException("Decision threshold array must be a row vector; got array with shape " + Arrays.toString(decisionThreshold.shape()));
            }
            if (decisionThreshold.minNumber().doubleValue() < 0.0) {
                throw new IllegalArgumentException("Invalid decision threshold array: minimum value is less than 0");
            }
            if (decisionThreshold.maxNumber().doubleValue() > 1.0) {
                throw new IllegalArgumentException("invalid decision threshold array: maximum value is greater than 1.0");
            }
            this.decisionThreshold = decisionThreshold;
        }
    }

    public EvaluationBinary(int size, Integer rocBinarySteps) {
        this.countTruePositive = new int[size];
        this.countFalsePositive = new int[size];
        this.countTrueNegative = new int[size];
        this.countFalseNegative = new int[size];
        if (rocBinarySteps != null) {
            this.rocBinary = new ROCBinary(rocBinarySteps);
        }
    }

    @Override
    public void eval(INDArray labels, INDArray networkPredictions) {
        this.eval(labels, networkPredictions, (INDArray)null);
    }

    @Override
    public void evalTimeSeries(INDArray labels, INDArray predictions, INDArray labelsMask) {
        if (labelsMask == null || labelsMask.rank() == 2) {
            super.evalTimeSeries(labels, predictions, labelsMask);
            return;
        }
        if (labelsMask.rank() != 3) {
            throw new IllegalArgumentException("Labels must: must be rank 2 or 3. Got: " + labelsMask.rank());
        }
        INDArray l2d = EvaluationUtils.reshapeTimeSeriesTo2d(labels);
        INDArray p2d = EvaluationUtils.reshapeTimeSeriesTo2d(predictions);
        INDArray m2d = EvaluationUtils.reshapeTimeSeriesTo2d(labelsMask);
        this.eval(l2d, p2d, m2d);
    }

    @Override
    public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray) {
        INDArray classPredictions;
        if (this.countTruePositive != null && this.countTruePositive.length != labels.size(1)) {
            throw new IllegalStateException("Labels array does not match stored state size. Expected labels array with size " + this.countTruePositive.length + ", got labels array with size " + labels.size(1));
        }
        if (labels.rank() == 3) {
            this.evalTimeSeries(labels, networkPredictions, maskArray);
            return;
        }
        if (this.decisionThreshold != null) {
            classPredictions = Nd4j.createUninitialized((int[])networkPredictions.shape());
            Nd4j.getExecutioner().exec((Op)new BroadcastGreaterThan(networkPredictions, this.decisionThreshold, classPredictions, new int[]{1}));
        } else {
            classPredictions = networkPredictions.gt((Number)0.5);
        }
        INDArray notLabels = Nd4j.getExecutioner().execAndReturn((TransformOp)new Not(labels.dup()));
        INDArray notClassPredictions = Nd4j.getExecutioner().execAndReturn((TransformOp)new Not(classPredictions.dup()));
        INDArray truePositives = classPredictions.mul(labels);
        INDArray trueNegatives = notClassPredictions.mul(notLabels);
        INDArray falsePositives = classPredictions.mul(notLabels);
        INDArray falseNegatives = notClassPredictions.mul(labels);
        if (maskArray != null) {
            truePositives.muli(maskArray);
            trueNegatives.muli(maskArray);
            falsePositives.muli(maskArray);
            falseNegatives.muli(maskArray);
        }
        int[] tpCount = truePositives.sum(new int[]{0}).data().asInt();
        int[] tnCount = trueNegatives.sum(new int[]{0}).data().asInt();
        int[] fpCount = falsePositives.sum(new int[]{0}).data().asInt();
        int[] fnCount = falseNegatives.sum(new int[]{0}).data().asInt();
        if (this.countTruePositive == null) {
            int l = tpCount.length;
            this.countTruePositive = new int[l];
            this.countFalsePositive = new int[l];
            this.countTrueNegative = new int[l];
            this.countFalseNegative = new int[l];
        }
        EvaluationBinary.addInPlace(this.countTruePositive, tpCount);
        EvaluationBinary.addInPlace(this.countFalsePositive, fpCount);
        EvaluationBinary.addInPlace(this.countTrueNegative, tnCount);
        EvaluationBinary.addInPlace(this.countFalseNegative, fnCount);
        if (this.rocBinary != null) {
            this.rocBinary.eval(labels, networkPredictions, maskArray);
        }
    }

    @Override
    public void merge(EvaluationBinary other) {
        if (other.countTruePositive == null) {
            return;
        }
        if (this.countTruePositive == null) {
            this.countTruePositive = other.countTruePositive;
            this.countFalsePositive = other.countFalsePositive;
            this.countTrueNegative = other.countTrueNegative;
            this.countFalseNegative = other.countFalseNegative;
            this.rocBinary = other.rocBinary;
        } else {
            if (this.countTruePositive.length != other.countTruePositive.length) {
                throw new IllegalStateException("Cannot merge EvaluationBinary instances with different sizes. This size: " + this.countTruePositive.length + ", other size: " + other.countTruePositive.length);
            }
            EvaluationBinary.addInPlace(this.countTruePositive, other.countTruePositive);
            EvaluationBinary.addInPlace(this.countTrueNegative, other.countTrueNegative);
            EvaluationBinary.addInPlace(this.countFalsePositive, other.countFalsePositive);
            EvaluationBinary.addInPlace(this.countFalseNegative, other.countFalseNegative);
            if (this.rocBinary != null) {
                this.rocBinary.merge(other.rocBinary);
            }
        }
    }

    @Override
    public void reset() {
        this.countTruePositive = null;
    }

    private static void addInPlace(int[] addTo, int[] toAdd) {
        for (int i = 0; i < addTo.length; ++i) {
            int n = i;
            addTo[n] = addTo[n] + toAdd[i];
        }
    }

    public int numLabels() {
        if (this.countTruePositive == null) {
            return -1;
        }
        return this.countTruePositive.length;
    }

    public void setLabelNames(List<String> labels) {
        if (labels == null) {
            this.labels = null;
            return;
        }
        this.labels = new ArrayList<String>(labels);
    }

    public int totalCount(int outputNum) {
        this.assertIndex(outputNum);
        return this.countTruePositive[outputNum] + this.countTrueNegative[outputNum] + this.countFalseNegative[outputNum] + this.countFalsePositive[outputNum];
    }

    public int truePositives(int outputNum) {
        this.assertIndex(outputNum);
        return this.countTruePositive[outputNum];
    }

    public int trueNegatives(int outputNum) {
        this.assertIndex(outputNum);
        return this.countTrueNegative[outputNum];
    }

    public int falsePositives(int outputNum) {
        this.assertIndex(outputNum);
        return this.countFalsePositive[outputNum];
    }

    public int falseNegatives(int outputNum) {
        this.assertIndex(outputNum);
        return this.countFalseNegative[outputNum];
    }

    public double averageAccuracy() {
        double ret = 0.0;
        for (int i = 0; i < this.numLabels(); ++i) {
            ret += this.accuracy(i);
        }
        return ret /= (double)this.numLabels();
    }

    public double accuracy(int outputNum) {
        this.assertIndex(outputNum);
        return (double)(this.countTruePositive[outputNum] + this.countTrueNegative[outputNum]) / (double)this.totalCount(outputNum);
    }

    public double averagePrecision() {
        double ret = 0.0;
        for (int i = 0; i < this.numLabels(); ++i) {
            ret += this.precision(i);
        }
        return ret /= (double)this.numLabels();
    }

    public double precision(int outputNum) {
        this.assertIndex(outputNum);
        return (double)this.countTruePositive[outputNum] / (double)(this.countTruePositive[outputNum] + this.countFalsePositive[outputNum]);
    }

    public double averageRecall() {
        double ret = 0.0;
        for (int i = 0; i < this.numLabels(); ++i) {
            ret += this.recall(i);
        }
        return ret /= (double)this.numLabels();
    }

    public double recall(int outputNum) {
        this.assertIndex(outputNum);
        return (double)this.countTruePositive[outputNum] / (double)(this.countTruePositive[outputNum] + this.countFalseNegative[outputNum]);
    }

    public double averageF1() {
        double ret = 0.0;
        for (int i = 0; i < this.numLabels(); ++i) {
            ret += this.f1(i);
        }
        return ret /= (double)this.numLabels();
    }

    public double fBeta(double beta, int outputNum) {
        this.assertIndex(outputNum);
        double precision = this.precision(outputNum);
        double recall = this.recall(outputNum);
        return EvaluationUtils.fBeta(beta, precision, recall);
    }

    public double f1(int outputNum) {
        return this.fBeta(1.0, outputNum);
    }

    public double matthewsCorrelation(int outputNum) {
        this.assertIndex(outputNum);
        return EvaluationUtils.matthewsCorrelation(this.truePositives(outputNum), this.falsePositives(outputNum), this.falseNegatives(outputNum), this.trueNegatives(outputNum));
    }

    public double gMeasure(int output) {
        double precision = this.precision(output);
        double recall = this.recall(output);
        return EvaluationUtils.gMeasure(precision, recall);
    }

    public double falsePositiveRate(int classLabel) {
        return this.recall(classLabel);
    }

    public double falsePositiveRate(int classLabel, double edgeCase) {
        double fpCount = this.falsePositives(classLabel);
        double tnCount = this.trueNegatives(classLabel);
        return EvaluationUtils.falsePositiveRate((long)fpCount, (long)tnCount, edgeCase);
    }

    public double falseNegativeRate(Integer classLabel) {
        return this.falseNegativeRate(classLabel, 0.0);
    }

    public double falseNegativeRate(Integer classLabel, double edgeCase) {
        double fnCount = this.falseNegatives(classLabel);
        double tpCount = this.truePositives(classLabel);
        return EvaluationUtils.falseNegativeRate((long)fnCount, (long)tpCount, edgeCase);
    }

    public ROCBinary getROCBinary() {
        return this.rocBinary;
    }

    private void assertIndex(int outputNum) {
        if (this.countTruePositive == null) {
            throw new UnsupportedOperationException("EvaluationBinary does not have any stats: eval must be called first");
        }
        if (outputNum < 0 || outputNum >= this.countTruePositive.length) {
            throw new IllegalArgumentException("Invalid input: output number must be between 0 and " + (outputNum - 1) + ". Got index: " + outputNum);
        }
    }

    @Override
    public String stats() {
        return this.stats(4);
    }

    public String stats(int printPrecision) {
        StringBuilder sb = new StringBuilder();
        int maxLabelsLength = 15;
        if (this.labels != null) {
            for (String s : this.labels) {
                maxLabelsLength = Math.max(s.length(), maxLabelsLength);
            }
        }
        String subPattern = "%-12." + printPrecision + "f";
        String pattern = "%-" + (maxLabelsLength + 5) + "s" + subPattern + subPattern + subPattern + subPattern + "%-8d%-7d%-7d%-7d%-7d";
        String patternHeader = "%-" + (maxLabelsLength + 5) + "s%-12s%-12s%-12s%-12s%-8s%-7s%-7s%-7s%-7s";
        List<String> headerNames = Arrays.asList("Label", "Accuracy", "F1", "Precision", "Recall", "Total", "TP", "TN", "FP", "FN");
        if (this.rocBinary != null) {
            patternHeader = patternHeader + "%-12s";
            pattern = pattern + subPattern;
            headerNames = new ArrayList<String>(headerNames);
            headerNames.add("AUC");
        }
        String header = String.format(patternHeader, headerNames.toArray());
        sb.append(header);
        if (this.countTrueNegative != null) {
            for (int i = 0; i < this.countTrueNegative.length; ++i) {
                int totalCount = this.totalCount(i);
                double acc = this.accuracy(i);
                double f1 = this.f1(i);
                double precision = this.precision(i);
                double recall = this.recall(i);
                String label = this.labels == null ? String.valueOf(i) : this.labels.get(i);
                List<Object> args = Arrays.asList(label, acc, f1, precision, recall, totalCount, this.truePositives(i), this.trueNegatives(i), this.falsePositives(i), this.falseNegatives(i));
                if (this.rocBinary != null) {
                    args = new ArrayList<Object>(args);
                    args.add(this.rocBinary.calculateAUC(i));
                }
                sb.append("\n").append(String.format(pattern, args.toArray()));
            }
            if (this.decisionThreshold != null) {
                sb.append("\nPer-output decision thresholds: ").append(Arrays.toString(this.decisionThreshold.dup().data().asFloat()));
            }
        } else {
            sb.append("\n-- No Data --\n");
        }
        return sb.toString();
    }

    public static EvaluationBinary fromJson(String json) {
        return EvaluationBinary.fromJson(json, EvaluationBinary.class);
    }

    public static EvaluationBinary fromYaml(String yaml) {
        return EvaluationBinary.fromYaml(yaml, EvaluationBinary.class);
    }

    public EvaluationBinary() {
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof EvaluationBinary)) {
            return false;
        }
        EvaluationBinary other = (EvaluationBinary)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        if (!Arrays.equals(this.getCountTruePositive(), other.getCountTruePositive())) {
            return false;
        }
        if (!Arrays.equals(this.getCountFalsePositive(), other.getCountFalsePositive())) {
            return false;
        }
        if (!Arrays.equals(this.getCountTrueNegative(), other.getCountTrueNegative())) {
            return false;
        }
        if (!Arrays.equals(this.getCountFalseNegative(), other.getCountFalseNegative())) {
            return false;
        }
        ROCBinary this$rocBinary = this.getROCBinary();
        ROCBinary other$rocBinary = other.getROCBinary();
        if (this$rocBinary == null ? other$rocBinary != null : !((Object)this$rocBinary).equals(other$rocBinary)) {
            return false;
        }
        List<String> this$labels = this.getLabels();
        List<String> other$labels = other.getLabels();
        if (this$labels == null ? other$labels != null : !((Object)this$labels).equals(other$labels)) {
            return false;
        }
        INDArray this$decisionThreshold = this.getDecisionThreshold();
        INDArray other$decisionThreshold = other.getDecisionThreshold();
        return !(this$decisionThreshold == null ? other$decisionThreshold != null : !this$decisionThreshold.equals(other$decisionThreshold));
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof EvaluationBinary;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + super.hashCode();
        result = result * 59 + Arrays.hashCode(this.getCountTruePositive());
        result = result * 59 + Arrays.hashCode(this.getCountFalsePositive());
        result = result * 59 + Arrays.hashCode(this.getCountTrueNegative());
        result = result * 59 + Arrays.hashCode(this.getCountFalseNegative());
        ROCBinary $rocBinary = this.getROCBinary();
        result = result * 59 + ($rocBinary == null ? 43 : ((Object)$rocBinary).hashCode());
        List<String> $labels = this.getLabels();
        result = result * 59 + ($labels == null ? 43 : ((Object)$labels).hashCode());
        INDArray $decisionThreshold = this.getDecisionThreshold();
        result = result * 59 + ($decisionThreshold == null ? 43 : $decisionThreshold.hashCode());
        return result;
    }

    public int[] getCountTruePositive() {
        return this.countTruePositive;
    }

    public int[] getCountFalsePositive() {
        return this.countFalsePositive;
    }

    public int[] getCountTrueNegative() {
        return this.countTrueNegative;
    }

    public int[] getCountFalseNegative() {
        return this.countFalseNegative;
    }

    public List<String> getLabels() {
        return this.labels;
    }

    public INDArray getDecisionThreshold() {
        return this.decisionThreshold;
    }

    public void setCountTruePositive(int[] countTruePositive) {
        this.countTruePositive = countTruePositive;
    }

    public void setCountFalsePositive(int[] countFalsePositive) {
        this.countFalsePositive = countFalsePositive;
    }

    public void setCountTrueNegative(int[] countTrueNegative) {
        this.countTrueNegative = countTrueNegative;
    }

    public void setCountFalseNegative(int[] countFalseNegative) {
        this.countFalseNegative = countFalseNegative;
    }

    public void setRocBinary(ROCBinary rocBinary) {
        this.rocBinary = rocBinary;
    }

    public void setLabels(List<String> labels) {
        this.labels = labels;
    }

    public void setDecisionThreshold(INDArray decisionThreshold) {
        this.decisionThreshold = decisionThreshold;
    }

    @Override
    public String toString() {
        return "EvaluationBinary(countTruePositive=" + Arrays.toString(this.getCountTruePositive()) + ", countFalsePositive=" + Arrays.toString(this.getCountFalsePositive()) + ", countTrueNegative=" + Arrays.toString(this.getCountTrueNegative()) + ", countFalseNegative=" + Arrays.toString(this.getCountFalseNegative()) + ", rocBinary=" + this.getROCBinary() + ", labels=" + this.getLabels() + ", decisionThreshold=" + this.getDecisionThreshold() + ")";
    }
}

