package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationFeatureImportance;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionFeatureImportance;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

/* loaded from: input_file:lib/x-pack-core-7.17.13.jar:org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.class */
public final class InferenceHelpers {
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:lib/x-pack-core-7.17.13.jar:org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers$TopClassificationValue.class */
    public static class TopClassificationValue {
        private final int value;
        private final double probability;
        private final double score;

        TopClassificationValue(int i, double d, double d2) {
            this.value = i;
            this.probability = d;
            this.score = d2;
        }

        public int getValue() {
            return this.value;
        }

        public double getProbability() {
            return this.probability;
        }

        public double getScore() {
            return this.score;
        }
    }

    private InferenceHelpers() {
    }

    public static Tuple<TopClassificationValue, List<TopClassEntry>> topClasses(double[] dArr, List<String> list, @Nullable double[] dArr2, int i, PredictionFieldType predictionFieldType) {
        if (list != null && dArr.length != list.size()) {
            throw ExceptionsHelper.serverError("model returned classification probabilities of size [{}] which is not equal to classification labels size [{}]", null, Integer.valueOf(dArr.length), Integer.valueOf(list.size()));
        }
        double[] array = dArr2 == null ? dArr : IntStream.range(0, dArr.length).mapToDouble(i2 -> {
            return dArr[i2] * dArr2[i2];
        }).toArray();
        int[] array2 = IntStream.range(0, array.length).boxed().sorted(Comparator.comparing(obj -> {
            return Double.valueOf(array[((Integer) obj).intValue()]);
        }).reversed()).mapToInt(num -> {
            return num.intValue();
        }).toArray();
        TopClassificationValue topClassificationValue = new TopClassificationValue(array2[0], dArr[array2[0]], array[array2[0]]);
        if (i == 0) {
            return Tuple.tuple(topClassificationValue, Collections.emptyList());
        }
        List<String> list2 = list == null ? (List) IntStream.range(0, dArr.length).boxed().map((v0) -> {
            return String.valueOf(v0);
        }).collect(Collectors.toList()) : list;
        int length = i < 0 ? dArr.length : Math.min(i, dArr.length);
        ArrayList arrayList = new ArrayList(length);
        for (int i3 = 0; i3 < length; i3++) {
            int i4 = array2[i3];
            arrayList.add(new TopClassEntry(predictionFieldType.transformPredictedValue(Double.valueOf(i4), list2.get(i4)), dArr[i4], array[i4]));
        }
        return Tuple.tuple(topClassificationValue, arrayList);
    }

    public static String classificationLabel(Integer num, @Nullable List<String> list) {
        if (list == null) {
            return String.valueOf(num);
        }
        if (num.intValue() < 0 || num.intValue() >= list.size()) {
            throw ExceptionsHelper.serverError("model returned classification value of [{}] which is not a valid index in classification labels [{}]", null, num, list);
        }
        return list.get(num.intValue());
    }

    public static Double toDouble(Object obj) {
        if (obj instanceof Number) {
            return Double.valueOf(((Number) obj).doubleValue());
        }
        if (obj instanceof String) {
            return stringToDouble((String) obj);
        }
        return null;
    }

    private static Double stringToDouble(String str) {
        if (str.isEmpty()) {
            return null;
        }
        try {
            return Double.valueOf(str);
        } catch (NumberFormatException e) {
            if ($assertionsDisabled) {
                return null;
            }
            throw new AssertionError("value is not properly formatted double [" + str + "]");
        }
    }

    public static Map<String, double[]> decodeFeatureImportances(Map<String, String> map, Map<String, double[]> map2) {
        if (map == null || map.isEmpty()) {
            return map2;
        }
        HashMap hashMap = new HashMap();
        map2.forEach((str, dArr) -> {
            hashMap.compute((String) map.getOrDefault(str, str), (str, dArr) -> {
                return dArr == null ? dArr : sumDoubleArrays(dArr, dArr);
            });
        });
        return hashMap;
    }

    public static List<RegressionFeatureImportance> transformFeatureImportanceRegression(Map<String, double[]> map) {
        ArrayList arrayList = new ArrayList(map.size());
        map.forEach((str, dArr) -> {
            arrayList.add(new RegressionFeatureImportance(str, dArr[0]));
        });
        return arrayList;
    }

    public static List<ClassificationFeatureImportance> transformFeatureImportanceClassification(Map<String, double[]> map, @Nullable List<String> list, @Nullable PredictionFieldType predictionFieldType) {
        ArrayList arrayList = new ArrayList(map.size());
        PredictionFieldType predictionFieldType2 = predictionFieldType == null ? PredictionFieldType.STRING : predictionFieldType;
        map.forEach((str, dArr) -> {
            if (dArr.length == 1) {
                arrayList.add(new ClassificationFeatureImportance(str, Arrays.asList(new ClassificationFeatureImportance.ClassImportance(predictionFieldType2.transformPredictedValue(Double.valueOf(0.0d), list == null ? null : (String) list.get(0)), -dArr[0]), new ClassificationFeatureImportance.ClassImportance(predictionFieldType2.transformPredictedValue(Double.valueOf(1.0d), list == null ? null : (String) list.get(1)), dArr[0]))));
                return;
            }
            ArrayList arrayList2 = new ArrayList(dArr.length);
            if (!$assertionsDisabled && list != null && list.size() != dArr.length) {
                throw new AssertionError();
            }
            for (int i = 0; i < dArr.length; i++) {
                arrayList2.add(new ClassificationFeatureImportance.ClassImportance(predictionFieldType2.transformPredictedValue(Double.valueOf(i), list == null ? null : (String) list.get(i)), dArr[i]));
            }
            arrayList.add(new ClassificationFeatureImportance(str, arrayList2));
        });
        return arrayList;
    }

    public static double[] sumDoubleArrays(double[] dArr, double[] dArr2) {
        return sumDoubleArrays(dArr, dArr2, 1);
    }

    public static double[] sumDoubleArrays(double[] dArr, double[] dArr2, int i) {
        if (!$assertionsDisabled && (dArr == null || dArr2 == null || dArr.length != dArr2.length)) {
            throw new AssertionError();
        }
        for (int i2 = 0; i2 < dArr2.length; i2++) {
            int i3 = i2;
            dArr[i3] = dArr[i3] + (dArr2[i2] * i);
        }
        return dArr;
    }

    public static void divMut(double[] dArr, int i) {
        if (dArr.length == 0) {
            return;
        }
        if (i == 0) {
            throw new IllegalArgumentException("unable to divide by [" + i + "] as it results in undefined behavior");
        }
        for (int i2 = 0; i2 < dArr.length; i2++) {
            int i3 = i2;
            dArr[i3] = dArr[i3] / i;
        }
    }

    static {
        $assertionsDisabled = !InferenceHelpers.class.desiredAssertionStatus();
    }
}
