package org.deeplearning4j.clustering.randomprojection;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.clustering.kdtree.KDTree;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity;
import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.guava.primitives.Doubles;

/* loaded from: input_file:org/deeplearning4j/clustering/randomprojection/RPUtils.class */
public class RPUtils {
    private static ThreadLocal<Map<String, DifferentialFunction>> functionInstances = new ThreadLocal<>();

    public static <T extends DifferentialFunction> DifferentialFunction getOp(String str, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        Map<String, DifferentialFunction> map = functionInstances.get();
        if (map == null) {
            map = new HashMap();
            functionInstances.set(map);
        }
        boolean z = iNDArray.length() != iNDArray2.length();
        boolean z2 = -1;
        switch (str.hashCode()) {
            case -1837355236:
                if (str.equals("jaccard")) {
                    z2 = 3;
                    break;
                }
                break;
            case -1062197092:
                if (str.equals("cosinesimilarity")) {
                    z2 = true;
                    break;
                }
                break;
            case -772843538:
                if (str.equals("cosinedistance")) {
                    z2 = false;
                    break;
                }
                break;
            case -278389504:
                if (str.equals("manhattan")) {
                    z2 = 2;
                    break;
                }
                break;
            case 692145385:
                if (str.equals("hamming")) {
                    z2 = 4;
                    break;
                }
                break;
        }
        switch (z2) {
            case KDTree.LESS /* 0 */:
                if (map.containsKey(str) && map.get(str).isComplexAccumulation() == z) {
                    return map.get(str);
                }
                CosineDistance cosineDistance = new CosineDistance(iNDArray, iNDArray2, iNDArray3, z, new int[0]);
                map.put(str, cosineDistance);
                return cosineDistance;
            case KDTree.GREATER /* 1 */:
                if (!map.containsKey(str) || map.get(str).isComplexAccumulation() != z) {
                    CosineSimilarity cosineSimilarity = new CosineSimilarity(iNDArray, iNDArray2, iNDArray3, z, new int[0]);
                    map.put(str, cosineSimilarity);
                    return cosineSimilarity;
                }
                CosineSimilarity cosineSimilarity2 = map.get(str);
                cosineSimilarity2.setX(iNDArray);
                cosineSimilarity2.setY(iNDArray2);
                cosineSimilarity2.setZ(iNDArray3);
                return cosineSimilarity2;
            case true:
                if (!map.containsKey(str) || map.get(str).isComplexAccumulation() != z) {
                    ManhattanDistance manhattanDistance = new ManhattanDistance(iNDArray, iNDArray2, iNDArray3, z, new int[0]);
                    map.put(str, manhattanDistance);
                    return manhattanDistance;
                }
                ManhattanDistance manhattanDistance2 = map.get(str);
                manhattanDistance2.setX(iNDArray);
                manhattanDistance2.setY(iNDArray2);
                manhattanDistance2.setZ(iNDArray3);
                return manhattanDistance2;
            case true:
                if (!map.containsKey(str) || map.get(str).isComplexAccumulation() != z) {
                    JaccardDistance jaccardDistance = new JaccardDistance(iNDArray, iNDArray2, iNDArray3, z, new int[0]);
                    map.put(str, jaccardDistance);
                    return jaccardDistance;
                }
                JaccardDistance jaccardDistance2 = map.get(str);
                jaccardDistance2.setX(iNDArray);
                jaccardDistance2.setY(iNDArray2);
                jaccardDistance2.setZ(iNDArray3);
                return jaccardDistance2;
            case true:
                if (!map.containsKey(str) || map.get(str).isComplexAccumulation() != z) {
                    HammingDistance hammingDistance = new HammingDistance(iNDArray, iNDArray2, iNDArray3, z, new int[0]);
                    map.put(str, hammingDistance);
                    return hammingDistance;
                }
                HammingDistance hammingDistance2 = map.get(str);
                hammingDistance2.setX(iNDArray);
                hammingDistance2.setY(iNDArray2);
                hammingDistance2.setZ(iNDArray3);
                return hammingDistance2;
            default:
                if (!map.containsKey(str) || map.get(str).isComplexAccumulation() != z) {
                    EuclideanDistance euclideanDistance = new EuclideanDistance(iNDArray, iNDArray2, iNDArray3, z, new int[0]);
                    map.put(str, euclideanDistance);
                    return euclideanDistance;
                }
                EuclideanDistance euclideanDistance2 = map.get(str);
                euclideanDistance2.setX(iNDArray);
                euclideanDistance2.setY(iNDArray2);
                euclideanDistance2.setZ(iNDArray3);
                return euclideanDistance2;
        }
    }

    public static List<Pair<Double, Integer>> queryAllWithDistances(INDArray iNDArray, INDArray iNDArray2, List<RPTree> list, int i, String str) {
        if (list.isEmpty()) {
            throw new ND4JIllegalArgumentException("Trees is empty!");
        }
        List<Pair<Double, Integer>> sortCandidates = sortCandidates(iNDArray, iNDArray2, getCandidates(iNDArray, list, str), str);
        int min = Math.min(i, sortCandidates.size());
        ArrayList arrayList = new ArrayList(min);
        for (int i2 = 0; i2 < min; i2++) {
            arrayList.add(sortCandidates.get(i2));
        }
        return arrayList;
    }

    public static INDArray queryAll(INDArray iNDArray, INDArray iNDArray2, List<RPTree> list, int i, String str) {
        if (list.isEmpty()) {
            throw new ND4JIllegalArgumentException("Trees is empty!");
        }
        List<Pair<Double, Integer>> sortCandidates = sortCandidates(iNDArray, iNDArray2, getCandidates(iNDArray, list, str), str);
        int min = Math.min(i, sortCandidates.size());
        INDArray create = Nd4j.create(min);
        for (int i2 = 0; i2 < min; i2++) {
            create.putScalar(i2, ((Integer) sortCandidates.get(i2).getSecond()).intValue());
        }
        return create;
    }

    public static List<Pair<Double, Integer>> sortCandidates(INDArray iNDArray, INDArray iNDArray2, List<Integer> list, String str) {
        int i = -1;
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < list.size(); i2++) {
            if (list.get(i2).intValue() != i) {
                arrayList.add(Pair.of(Double.valueOf(computeDistance(str, iNDArray2.slice(list.get(i2).intValue()), iNDArray)), list.get(i2)));
            }
            i = i2;
        }
        Collections.sort(arrayList, new Comparator<Pair<Double, Integer>>() { // from class: org.deeplearning4j.clustering.randomprojection.RPUtils.1
            @Override // java.util.Comparator
            public int compare(Pair<Double, Integer> pair, Pair<Double, Integer> pair2) {
                return Doubles.compare(((Double) pair.getFirst()).doubleValue(), ((Double) pair2.getFirst()).doubleValue());
            }
        });
        return arrayList;
    }

    public static INDArray getAllCandidates(INDArray iNDArray, List<RPTree> list, String str) {
        List<Integer> candidates = getCandidates(iNDArray, list, str);
        Collections.sort(candidates);
        int i = -1;
        int i2 = 0;
        ArrayList arrayList = new ArrayList();
        for (int i3 = 0; i3 < candidates.size(); i3++) {
            if (candidates.get(i3).intValue() == i) {
                i2++;
            } else if (i != -1) {
                arrayList.add(Pair.of(Integer.valueOf(i2), Integer.valueOf(i)));
                i2 = 1;
            }
            i = i3;
        }
        arrayList.add(Pair.of(Integer.valueOf(i2), Integer.valueOf(i)));
        INDArray create = Nd4j.create(arrayList.size());
        for (int i4 = 0; i4 < arrayList.size(); i4++) {
            create.putScalar(i4, ((Integer) ((Pair) arrayList.get(i4)).getSecond()).intValue());
        }
        return create;
    }

    public static List<Integer> getCandidates(INDArray iNDArray, List<RPTree> list, String str) {
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        for (RPTree rPTree : list) {
            linkedHashSet.addAll(query(rPTree.getRoot(), rPTree.getRpHyperPlanes(), iNDArray, str).getIndices());
        }
        return new ArrayList(linkedHashSet);
    }

    public static RPNode query(RPNode rPNode, RPHyperPlanes rPHyperPlanes, INDArray iNDArray, String str) {
        return (rPNode.getLeft() == null && rPNode.getRight() == null) ? rPNode : computeDistance(str, iNDArray, rPHyperPlanes.getHyperPlaneAt(rPNode.getDepth())) <= rPNode.getMedian() ? query(rPNode.getLeft(), rPHyperPlanes, iNDArray, str) : query(rPNode.getRight(), rPHyperPlanes, iNDArray, str);
    }

    public static INDArray computeDistanceMulti(String str, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        ReduceOp op = getOp(str, iNDArray, iNDArray2, iNDArray3);
        op.setDimensions(new int[]{1});
        Nd4j.getExecutioner().exec(op);
        return op.z();
    }

    public static double computeDistance(String str, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        ReduceOp op = getOp(str, iNDArray, iNDArray2, iNDArray3);
        Nd4j.getExecutioner().exec(op);
        return op.z().getDouble(0L);
    }

    public static double computeDistance(String str, INDArray iNDArray, INDArray iNDArray2) {
        return computeDistance(str, iNDArray, iNDArray2, Nd4j.scalar(0.0d));
    }

    public static void buildTree(RPTree rPTree, RPNode rPNode, RPHyperPlanes rPHyperPlanes, INDArray iNDArray, int i, int i2, String str) {
        if (rPNode.getIndices().size() <= i) {
            slimNode(rPNode);
            return;
        }
        ArrayList arrayList = new ArrayList();
        RPNode rPNode2 = new RPNode(rPTree, i2 + 1);
        RPNode rPNode3 = new RPNode(rPTree, i2 + 1);
        if (rPHyperPlanes.getWholeHyperPlane() == null || i2 >= rPHyperPlanes.getWholeHyperPlane().rows()) {
            rPHyperPlanes.addRandomHyperPlane();
        }
        INDArray hyperPlaneAt = rPHyperPlanes.getHyperPlaneAt(i2);
        for (int i3 = 0; i3 < rPNode.getIndices().size(); i3++) {
            arrayList.add(Double.valueOf(computeDistance(str, hyperPlaneAt, iNDArray.slice(rPNode.getIndices().get(i3).intValue()))));
        }
        Collections.sort(arrayList);
        rPNode.setMedian(((Double) arrayList.get(arrayList.size() / 2)).doubleValue());
        for (int i4 = 0; i4 < rPNode.getIndices().size(); i4++) {
            if (computeDistance(str, hyperPlaneAt, iNDArray.slice(rPNode.getIndices().get(i4).intValue())) <= rPNode.getMedian()) {
                rPNode2.getIndices().add(rPNode.getIndices().get(i4));
            } else {
                rPNode3.getIndices().add(rPNode.getIndices().get(i4));
            }
        }
        if (rPNode2.getIndices().isEmpty() || rPNode3.getIndices().isEmpty()) {
            slimNode(rPNode);
            return;
        }
        rPNode.setLeft(rPNode2);
        rPNode.setRight(rPNode3);
        slimNode(rPNode);
        buildTree(rPTree, rPNode2, rPHyperPlanes, iNDArray, i, i2 + 1, str);
        buildTree(rPTree, rPNode3, rPHyperPlanes, iNDArray, i, i2 + 1, str);
    }

    public static void scanForLeaves(List<RPNode> list, RPTree rPTree) {
        scanForLeaves(list, rPTree.getRoot());
    }

    public static void scanForLeaves(List<RPNode> list, RPNode rPNode) {
        if (rPNode.getLeft() == null && rPNode.getRight() == null) {
            list.add(rPNode);
        }
        if (rPNode.getLeft() != null) {
            scanForLeaves(list, rPNode.getLeft());
        }
        if (rPNode.getRight() != null) {
            scanForLeaves(list, rPNode.getRight());
        }
    }

    public static void slimNode(RPNode rPNode) {
        if (rPNode.getRight() == null || rPNode.getLeft() == null) {
            return;
        }
        rPNode.getIndices().clear();
    }
}
