package com.facebook.presto.operator.aggregation;

import com.facebook.presto.operator.aggregation.fixedhistogram.FixedDoubleHistogram;
import com.facebook.presto.operator.aggregation.state.PrecisionRecallState;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.StandardErrorCode;
import com.facebook.presto.spi.function.AggregationState;
import com.facebook.presto.spi.function.CombineFunction;
import com.facebook.presto.spi.function.InputFunction;
import com.facebook.presto.spi.function.SqlType;
import com.google.common.collect.Streams;
import java.util.Collections;
import java.util.Iterator;
import java.util.NoSuchElementException;

/* loaded from: input_file:com/facebook/presto/operator/aggregation/PrecisionRecallAggregation.class */
public abstract class PrecisionRecallAggregation {
    private static final double DEFAULT_WEIGHT = 1.0d;
    private static final double MIN_PREDICTION_VALUE = 0.0d;
    private static final double MAX_PREDICTION_VALUE = 1.0d;
    private static final double MAX_PREDICTION_VALUE_FOR_HISTOGRAM = 0.99999999999d;
    private static final String ILLEGAL_PREDICTION_VALUE_MESSAGE = String.format("Prediction value must be between %s and %s", Double.valueOf(0.0d), Double.valueOf(1.0d));
    private static final String NEGATIVE_WEIGHT_MESSAGE = "Weights must be non-negative";
    private static final String INCONSISTENT_BUCKET_COUNT_MESSAGE = "Bucket count must be constant";

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:com/facebook/presto/operator/aggregation/PrecisionRecallAggregation$BucketResult.class */
    public static class BucketResult {
        private final double threshold;
        private final double positive;
        private final double negative;
        private final double truePositive;
        private final double trueNegative;
        private final double falsePositive;
        private final double falseNegative;

        public double getThreshold() {
            return this.threshold;
        }

        public double getPositive() {
            return this.positive;
        }

        public double getNegative() {
            return this.negative;
        }

        public double getTruePositive() {
            return this.truePositive;
        }

        public double getTrueNegative() {
            return this.trueNegative;
        }

        public double getFalsePositive() {
            return this.falsePositive;
        }

        public double getFalseNegative() {
            return this.falseNegative;
        }

        public BucketResult(double d, double d2, double d3, double d4, double d5, double d6, double d7) {
            this.threshold = d;
            this.positive = d2;
            this.negative = d3;
            this.truePositive = d4;
            this.trueNegative = d5;
            this.falsePositive = d6;
            this.falseNegative = d7;
        }
    }

    @InputFunction
    public static void input(@AggregationState PrecisionRecallState precisionRecallState, @SqlType("bigint") long j, @SqlType("boolean") boolean z, @SqlType("double") double d, @SqlType("double") double d2) {
        if (precisionRecallState.getTrueWeights() == null) {
            precisionRecallState.setTrueWeights(new FixedDoubleHistogram((int) j, 0.0d, 1.0d));
            precisionRecallState.setFalseWeights(new FixedDoubleHistogram((int) j, 0.0d, 1.0d));
        }
        if (d < 0.0d || d > 1.0d) {
            throw new PrestoException(StandardErrorCode.INVALID_FUNCTION_ARGUMENT, ILLEGAL_PREDICTION_VALUE_MESSAGE);
        }
        double min = Math.min(d, MAX_PREDICTION_VALUE_FOR_HISTOGRAM);
        if (d2 < 0.0d) {
            throw new PrestoException(StandardErrorCode.INVALID_FUNCTION_ARGUMENT, NEGATIVE_WEIGHT_MESSAGE);
        }
        if (j != precisionRecallState.getTrueWeights().getBucketCount()) {
            throw new PrestoException(StandardErrorCode.INVALID_FUNCTION_ARGUMENT, INCONSISTENT_BUCKET_COUNT_MESSAGE);
        }
        if (z) {
            precisionRecallState.getTrueWeights().add(min, d2);
        } else {
            precisionRecallState.getFalseWeights().add(min, d2);
        }
    }

    @InputFunction
    public static void input(@AggregationState PrecisionRecallState precisionRecallState, @SqlType("bigint") long j, @SqlType("boolean") boolean z, @SqlType("double") double d) {
        input(precisionRecallState, j, z, d, 1.0d);
    }

    @CombineFunction
    public static void combine(@AggregationState PrecisionRecallState precisionRecallState, @AggregationState PrecisionRecallState precisionRecallState2) {
        if (precisionRecallState.getTrueWeights() == null && precisionRecallState2.getTrueWeights() != null) {
            precisionRecallState.setTrueWeights(precisionRecallState2.getTrueWeights().m4359clone());
            precisionRecallState.setFalseWeights(precisionRecallState2.getFalseWeights().m4359clone());
        } else {
            if (precisionRecallState.getTrueWeights() == null || precisionRecallState2.getTrueWeights() == null) {
                return;
            }
            precisionRecallState.getTrueWeights().mergeWith(precisionRecallState2.getTrueWeights());
            precisionRecallState.getFalseWeights().mergeWith(precisionRecallState2.getFalseWeights());
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static Iterator<BucketResult> getResultsIterator(@AggregationState final PrecisionRecallState precisionRecallState) {
        if (precisionRecallState.getTrueWeights() == null) {
            return Collections.emptyList().iterator();
        }
        final double sum = Streams.stream(precisionRecallState.getTrueWeights().iterator()).mapToDouble((v0) -> {
            return v0.getWeight();
        }).sum();
        final double sum2 = Streams.stream(precisionRecallState.getFalseWeights().iterator()).mapToDouble((v0) -> {
            return v0.getWeight();
        }).sum();
        return new Iterator<BucketResult>() { // from class: com.facebook.presto.operator.aggregation.PrecisionRecallAggregation.1
            Iterator<FixedDoubleHistogram.Bucket> trueIterator;
            Iterator<FixedDoubleHistogram.Bucket> falseIterator;
            double runningFalseWeight;
            double runningTrueWeight;

            {
                this.trueIterator = PrecisionRecallState.this.getTrueWeights().iterator();
                this.falseIterator = PrecisionRecallState.this.getFalseWeights().iterator();
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return this.trueIterator.hasNext() && sum > this.runningTrueWeight;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public BucketResult next() {
                if (!this.trueIterator.hasNext() || !this.falseIterator.hasNext()) {
                    throw new NoSuchElementException();
                }
                FixedDoubleHistogram.Bucket next = this.trueIterator.next();
                FixedDoubleHistogram.Bucket next2 = this.falseIterator.next();
                BucketResult bucketResult = new BucketResult(next.getLeft(), sum, sum2, sum - this.runningTrueWeight, this.runningFalseWeight, sum2 - this.runningFalseWeight, this.runningTrueWeight);
                this.runningTrueWeight += next.getWeight();
                this.runningFalseWeight += next2.getWeight();
                return bucketResult;
            }

            @Override // java.util.Iterator
            public void remove() {
                throw new UnsupportedOperationException();
            }
        };
    }
}
