package org.elasticsearch.xpack.core.ml.dataframe.evaluation.common;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.search.aggregations.metrics.Percentiles;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

/* loaded from: input_file:lib/x-pack-core-7.17.13.jar:org/elasticsearch/xpack/core/ml/dataframe/evaluation/common/AbstractAucRoc.class */
public abstract class AbstractAucRoc implements EvaluationMetric {
    public static final ParseField NAME;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:lib/x-pack-core-7.17.13.jar:org/elasticsearch/xpack/core/ml/dataframe/evaluation/common/AbstractAucRoc$AucRocPoint.class */
    public static final class AucRocPoint implements Comparable<AucRocPoint>, ToXContentObject, Writeable {
        private static final String TPR = "tpr";
        private static final String FPR = "fpr";
        private static final String THRESHOLD = "threshold";
        final double tpr;
        final double fpr;
        final double threshold;

        public AucRocPoint(double d, double d2, double d3) {
            this.tpr = d;
            this.fpr = d2;
            this.threshold = d3;
        }

        private AucRocPoint(StreamInput streamInput) throws IOException {
            this.tpr = streamInput.readDouble();
            this.fpr = streamInput.readDouble();
            this.threshold = streamInput.readDouble();
        }

        @Override // java.lang.Comparable
        public int compareTo(AucRocPoint aucRocPoint) {
            return Comparator.comparingDouble(aucRocPoint2 -> {
                return aucRocPoint2.threshold;
            }).reversed().thenComparing(aucRocPoint3 -> {
                return Double.valueOf(aucRocPoint3.fpr);
            }).thenComparing(aucRocPoint4 -> {
                return Double.valueOf(aucRocPoint4.tpr);
            }).compare(this, aucRocPoint);
        }

        @Override // org.elasticsearch.common.io.stream.Writeable
        public void writeTo(StreamOutput streamOutput) throws IOException {
            streamOutput.writeDouble(this.tpr);
            streamOutput.writeDouble(this.fpr);
            streamOutput.writeDouble(this.threshold);
        }

        @Override // org.elasticsearch.xcontent.ToXContent
        public XContentBuilder toXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException {
            xContentBuilder.startObject();
            xContentBuilder.field(TPR, this.tpr);
            xContentBuilder.field(FPR, this.fpr);
            xContentBuilder.field(THRESHOLD, this.threshold);
            xContentBuilder.endObject();
            return xContentBuilder;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            AucRocPoint aucRocPoint = (AucRocPoint) obj;
            return this.tpr == aucRocPoint.tpr && this.fpr == aucRocPoint.fpr && this.threshold == aucRocPoint.threshold;
        }

        public int hashCode() {
            return Objects.hash(Double.valueOf(this.tpr), Double.valueOf(this.fpr), Double.valueOf(this.threshold));
        }

        public String toString() {
            return Strings.toString(this);
        }
    }

    /* loaded from: input_file:lib/x-pack-core-7.17.13.jar:org/elasticsearch/xpack/core/ml/dataframe/evaluation/common/AbstractAucRoc$RateThresholdCurve.class */
    private static class RateThresholdCurve {
        private final double[] percentiles;
        private final boolean isTp;

        private RateThresholdCurve(double[] dArr, boolean z) {
            this.percentiles = dArr;
            this.isTp = z;
        }

        private double getRate(int i) {
            return 1.0d - (0.01d * (i + 1));
        }

        private double getThreshold(int i) {
            return Math.max(0.0d, this.percentiles[i] - Math.ulp(this.percentiles[i]));
        }

        private double interpolateRate(double d) {
            int binarySearch = Arrays.binarySearch(this.percentiles, d);
            if (binarySearch >= 0) {
                return getRate(binarySearch);
            }
            int i = (binarySearch * (-1)) - 1;
            int i2 = i - 1;
            if (i >= this.percentiles.length) {
                return 0.0d;
            }
            if (i2 < 0) {
                return 1.0d;
            }
            double rate = getRate(i);
            return AbstractAucRoc.interpolate(d, this.percentiles[i2], getRate(i2), this.percentiles[i], rate);
        }

        /* JADX INFO: Access modifiers changed from: private */
        public List<AucRocPoint> scanPoints(RateThresholdCurve rateThresholdCurve) {
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < this.percentiles.length; i++) {
                double rate = getRate(i);
                double threshold = getThreshold(i);
                double interpolateRate = rateThresholdCurve.interpolateRate(threshold);
                arrayList.add(this.isTp ? new AucRocPoint(rate, interpolateRate, threshold) : new AucRocPoint(interpolateRate, rate, threshold));
            }
            return arrayList;
        }
    }

    /* loaded from: input_file:lib/x-pack-core-7.17.13.jar:org/elasticsearch/xpack/core/ml/dataframe/evaluation/common/AbstractAucRoc$Result.class */
    public static class Result implements EvaluationMetricResult {
        public static final String NAME = "auc_roc_result";
        private static final String VALUE = "value";
        private static final String CURVE = "curve";
        private final double value;
        private final List<AucRocPoint> curve;

        public Result(double d, List<AucRocPoint> list) {
            this.value = d;
            this.curve = (List) Objects.requireNonNull(list);
        }

        public Result(StreamInput streamInput) throws IOException {
            this.value = streamInput.readDouble();
            this.curve = streamInput.readList(streamInput2 -> {
                return new AucRocPoint(streamInput2);
            });
        }

        public double getValue() {
            return this.value;
        }

        public List<AucRocPoint> getCurve() {
            return Collections.unmodifiableList(this.curve);
        }

        @Override // org.elasticsearch.common.io.stream.NamedWriteable
        public String getWriteableName() {
            return NAME;
        }

        @Override // org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult
        public String getMetricName() {
            return AbstractAucRoc.NAME.getPreferredName();
        }

        @Override // org.elasticsearch.common.io.stream.Writeable
        public void writeTo(StreamOutput streamOutput) throws IOException {
            streamOutput.writeDouble(this.value);
            streamOutput.writeList(this.curve);
        }

        @Override // org.elasticsearch.xcontent.ToXContent
        public XContentBuilder toXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException {
            xContentBuilder.startObject();
            xContentBuilder.field("value", this.value);
            if (!this.curve.isEmpty()) {
                xContentBuilder.field(CURVE, (Iterable<?>) this.curve);
            }
            xContentBuilder.endObject();
            return xContentBuilder;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            Result result = (Result) obj;
            return this.value == result.value && Objects.equals(this.curve, result.curve);
        }

        public int hashCode() {
            return Objects.hash(Double.valueOf(this.value), this.curve);
        }
    }

    @Override // org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric
    public String getName() {
        return NAME.getPreferredName();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static double[] percentilesArray(Percentiles percentiles) {
        double[] dArr = new double[99];
        percentiles.forEach(percentile -> {
            if (Double.isNaN(percentile.getValue())) {
                throw ExceptionsHelper.badRequestException("[{}] requires at all the percentiles values to be finite numbers", NAME.getPreferredName());
            }
            dArr[((int) percentile.getPercent()) - 1] = percentile.getValue();
        });
        return dArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static List<AucRocPoint> buildAucRocCurve(double[] dArr, double[] dArr2) {
        if (!$assertionsDisabled && dArr.length != dArr2.length) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && dArr.length != 99) {
            throw new AssertionError();
        }
        ArrayList arrayList = new ArrayList(dArr.length + dArr2.length);
        RateThresholdCurve rateThresholdCurve = new RateThresholdCurve(dArr, true);
        RateThresholdCurve rateThresholdCurve2 = new RateThresholdCurve(dArr2, false);
        arrayList.addAll(rateThresholdCurve.scanPoints(rateThresholdCurve2));
        arrayList.addAll(rateThresholdCurve2.scanPoints(rateThresholdCurve));
        Collections.sort(arrayList);
        List<AucRocPoint> collapseEqualThresholdPoints = collapseEqualThresholdPoints(arrayList);
        ArrayList arrayList2 = new ArrayList(collapseEqualThresholdPoints.size() + 2);
        arrayList2.add(new AucRocPoint(0.0d, 0.0d, 1.0d));
        arrayList2.addAll(collapseEqualThresholdPoints);
        arrayList2.add(new AucRocPoint(1.0d, 1.0d, 0.0d));
        return arrayList2;
    }

    static List<AucRocPoint> collapseEqualThresholdPoints(List<AucRocPoint> list) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (AucRocPoint aucRocPoint : list) {
            if (!arrayList2.isEmpty() && ((AucRocPoint) arrayList2.get(0)).threshold != aucRocPoint.threshold) {
                arrayList.add(calculateAveragePoint(arrayList2));
                arrayList2 = new ArrayList();
            }
            arrayList2.add(aucRocPoint);
        }
        if (!arrayList2.isEmpty()) {
            arrayList.add(calculateAveragePoint(arrayList2));
        }
        return arrayList;
    }

    private static AucRocPoint calculateAveragePoint(List<AucRocPoint> list) {
        if (list.isEmpty()) {
            throw new IllegalArgumentException("points must not be empty");
        }
        if (list.size() == 1) {
            return list.get(0);
        }
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (AucRocPoint aucRocPoint : list) {
            d += aucRocPoint.tpr;
            d2 += aucRocPoint.fpr;
            d3 += aucRocPoint.threshold;
        }
        int size = list.size();
        return new AucRocPoint(d / size, d2 / size, d3 / size);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static double calculateAucScore(List<AucRocPoint> list) {
        double d = 0.0d;
        for (int i = 1; i < list.size(); i++) {
            AucRocPoint aucRocPoint = list.get(i - 1);
            AucRocPoint aucRocPoint2 = list.get(i);
            d += ((aucRocPoint2.fpr - aucRocPoint.fpr) * (aucRocPoint2.tpr + aucRocPoint.tpr)) / 2.0d;
        }
        return d;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double interpolate(double d, double d2, double d3, double d4, double d5) {
        return d3 + (((d - d2) * (d5 - d3)) / (d4 - d2));
    }

    static {
        $assertionsDisabled = !AbstractAucRoc.class.desiredAssertionStatus();
        NAME = new ParseField("auc_roc", new String[0]);
    }
}
