/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.clustering.kdtree;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.clustering.kdtree.HyperRect;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.impl.accum.distances.EuclideanDistance;
import org.nd4j.linalg.factory.Nd4j;

public class KDTree
implements Serializable {
    private KDNode root;
    private int dims = 100;
    public static final int GREATER = 1;
    public static final int LESS = 0;
    private int size = 0;
    private HyperRect rect;

    public KDTree(int dims) {
        this.dims = dims;
    }

    public void insert(INDArray point) {
        if (!point.isVector() || point.length() != this.dims) {
            throw new IllegalArgumentException("Point must be a vector of length " + this.dims);
        }
        if (this.root == null) {
            this.root = new KDNode(point);
            this.rect = new HyperRect(HyperRect.point(point));
        } else {
            int successor;
            int disc = 0;
            KDNode node = this.root;
            KDNode insert = new KDNode(point);
            while (true) {
                if (node.getPoint().eq(point).sum(Integer.MAX_VALUE).getDouble(0) == 0.0) {
                    return;
                }
                successor = this.successor(this.root, point, disc);
                KDNode child = successor < 1 ? this.root.getLeft() : this.root.getRight();
                if (child == null) break;
                disc = (disc + 1) % this.dims;
                node = child;
            }
            if (successor < 1) {
                node.setLeft(insert);
            } else {
                node.setRight(insert);
            }
            this.rect.enlargeTo(point);
            insert.setParent(node);
            ++this.size;
        }
    }

    public KDNode delete(INDArray point) {
        KDNode node = this.root;
        int _disc = 0;
        while (node != null && node.point != point) {
            int successor = this.successor(node, point, _disc);
            node = successor < 1 ? node.getLeft() : node.getRight();
            _disc = (_disc + 1) % this.dims;
        }
        if (node != null) {
            if (node == this.root) {
                this.root = this.delete(this.root, _disc);
            } else {
                node = this.delete(node, _disc);
            }
            --this.size;
            this.rect = this.size == 1 ? new HyperRect(HyperRect.point(point)) : null;
        }
        return node;
    }

    public List<Pair<Double, INDArray>> knn(INDArray point, double distance) {
        ArrayList<Pair<Double, INDArray>> best = new ArrayList<Pair<Double, INDArray>>();
        this.knn(this.root, point, this.rect, distance, best, 0);
        Collections.sort(best, new Comparator<Pair<Double, INDArray>>(){

            @Override
            public int compare(Pair<Double, INDArray> o1, Pair<Double, INDArray> o2) {
                return Double.compare(o1.getFirst(), o2.getFirst());
            }
        });
        return best;
    }

    private void knn(KDNode node, INDArray point, HyperRect rect, double dist, List<Pair<Double, INDArray>> best, int _disc) {
        if (node == null || rect.minDistance(point) > dist) {
            return;
        }
        int _discNext = (_disc + 1) % this.dims;
        double distance = Nd4j.getExecutioner().execAndReturn((Accumulation)new EuclideanDistance(point)).currentResult().doubleValue();
        if (distance <= dist) {
            best.add(new Pair<Double, INDArray>(distance, node.getPoint()));
        }
        HyperRect lower = rect.getLower(point, _disc);
        HyperRect upper = rect.getUpper(point, _disc);
        this.knn(node.getLeft(), point, lower, dist, best, _discNext);
        this.knn(node.getRight(), point, upper, dist, best, _discNext);
    }

    public Pair<Double, INDArray> nn(INDArray point) {
        return this.nn(this.root, point, this.rect, Double.POSITIVE_INFINITY, null, 0);
    }

    private Pair<Double, INDArray> nn(KDNode node, INDArray point, HyperRect rect, double dist, INDArray best, int _disc) {
        if (node == null || rect.minDistance(point) > dist) {
            return new Pair<Double, Object>(Double.POSITIVE_INFINITY, null);
        }
        int _discNext = (_disc + 1) % this.dims;
        double dist2 = Nd4j.getExecutioner().execAndReturn((Accumulation)new EuclideanDistance(point)).currentResult().doubleValue();
        if (dist2 < dist) {
            best = node.getPoint();
        }
        HyperRect lower = rect.getLower(node.point, _disc);
        HyperRect upper = rect.getUpper(node.point, _disc);
        if (point.getDouble(_disc) < node.point.getDouble(_disc)) {
            Pair<Double, INDArray> left = this.nn(node.getLeft(), point, lower, dist, best, _discNext);
            Pair<Double, INDArray> right = this.nn(node.getRight(), point, upper, dist, best, _discNext);
            if (left.getFirst() < dist) {
                return left;
            }
            if (right.getFirst() < dist) {
                return right;
            }
        } else {
            Pair<Double, INDArray> left = this.nn(node.getRight(), point, upper, dist, best, _discNext);
            Pair<Double, INDArray> right = this.nn(node.getLeft(), point, lower, dist, best, _discNext);
            if (left.getFirst() < dist) {
                return left;
            }
            if (right.getFirst() < dist) {
                return right;
            }
        }
        return new Pair<Double, INDArray>(dist, best);
    }

    private KDNode delete(KDNode delete, int _disc) {
        if (delete.getLeft() != null && delete.getRight() != null) {
            if (delete.getParent() != null) {
                if (delete.getParent().getLeft() == delete) {
                    delete.getParent().setLeft(null);
                } else {
                    delete.getParent().setRight(null);
                }
            }
            return null;
        }
        int disc = _disc;
        _disc = (_disc + 1) % this.dims;
        Pair<KDNode, Integer> qd = null;
        if (delete.getRight() != null) {
            qd = this.min(delete.getRight(), disc, _disc);
        } else if (delete.getLeft() != null) {
            qd = this.max(delete.getLeft(), disc, _disc);
        }
        delete.point = ((KDNode)qd.getFirst()).point;
        KDNode qFather = qd.getFirst().getParent();
        if (qFather.getLeft() == qd.getFirst()) {
            qFather.setLeft(this.delete(qd.getFirst(), disc));
        } else if (qFather.getRight() == qd.getFirst()) {
            qFather.setRight(this.delete(qd.getFirst(), disc));
        }
        return delete;
    }

    private Pair<KDNode, Integer> max(KDNode node, int disc, int _disc) {
        int discNext = (_disc + 1) % this.dims;
        if (_disc == disc) {
            KDNode child = node.getLeft();
            if (child != null) {
                return this.max(child, disc, discNext);
            }
        } else if (node.getLeft() != null || node.getRight() != null) {
            Pair<KDNode, Integer> left = null;
            Pair<KDNode, Integer> right = null;
            if (node.getLeft() != null) {
                left = this.max(node.getLeft(), disc, discNext);
            }
            if (node.getRight() != null) {
                right = this.max(node.getRight(), disc, discNext);
            }
            if (left != null && right != null) {
                double pointRight;
                double pointLeft = left.getFirst().getPoint().getDouble(disc);
                if (pointLeft > (pointRight = right.getFirst().getPoint().getDouble(disc))) {
                    return left;
                }
                return right;
            }
            if (left != null) {
                return left;
            }
            return right;
        }
        return new Pair<KDNode, Integer>(node, _disc);
    }

    private Pair<KDNode, Integer> min(KDNode node, int disc, int _disc) {
        int discNext = (_disc + 1) % this.dims;
        if (_disc == disc) {
            KDNode child = node.getLeft();
            if (child != null) {
                return this.min(child, disc, discNext);
            }
        } else if (node.getLeft() != null || node.getRight() != null) {
            Pair<KDNode, Integer> left = null;
            Pair<KDNode, Integer> right = null;
            if (node.getLeft() != null) {
                left = this.min(node.getLeft(), disc, discNext);
            }
            if (node.getRight() != null) {
                right = this.min(node.getRight(), disc, discNext);
            }
            if (left != null && right != null) {
                double pointRight;
                double pointLeft = left.getFirst().getPoint().getDouble(disc);
                if (pointLeft < (pointRight = right.getFirst().getPoint().getDouble(disc))) {
                    return left;
                }
                return right;
            }
            if (left != null) {
                return left;
            }
            return right;
        }
        return new Pair<KDNode, Integer>(node, _disc);
    }

    public int size() {
        return this.size;
    }

    private int successor(KDNode node, INDArray point, int disc) {
        for (int i = disc; i < this.dims; ++i) {
            double nodePointI;
            double pointI = point.getDouble(i);
            if (pointI < (nodePointI = node.getPoint().getDouble(i))) {
                return 0;
            }
            if (!(pointI > nodePointI)) continue;
            return 1;
        }
        throw new IllegalStateException("Point is equal!");
    }

    public static class KDNode {
        private INDArray point;
        private KDNode left;
        private KDNode right;
        private KDNode parent;

        public KDNode(INDArray point) {
            this.point = point;
        }

        public INDArray getPoint() {
            return this.point;
        }

        public KDNode getLeft() {
            return this.left;
        }

        public void setLeft(KDNode left) {
            this.left = left;
        }

        public KDNode getRight() {
            return this.right;
        }

        public void setRight(KDNode right) {
            this.right = right;
        }

        public KDNode getParent() {
            return this.parent;
        }

        public void setParent(KDNode parent) {
            this.parent = parent;
        }
    }
}

