/*
 * Decompiled with CFR 0.152.
 */
package tagbio.umap;

import java.util.List;
import java.util.Random;
import tagbio.umap.FlatTree;
import tagbio.umap.Heap;
import tagbio.umap.Matrix;
import tagbio.umap.NearestNeighborSearch;
import tagbio.umap.UmapProgress;
import tagbio.umap.Utils;
import tagbio.umap.metric.Metric;

class NearestNeighborDescent {
    final Metric mMetric;
    boolean mVerbose;

    NearestNeighborDescent(Metric metric) {
        this.mMetric = metric;
    }

    void setVerbose(boolean flag) {
        this.mVerbose = flag;
    }

    Heap descent(Matrix data, int nNeighbors, Random random, int maxCandidates, boolean rpTreeInit, int nIters, List<FlatTree> forest) {
        return this.descent(data, nNeighbors, random, maxCandidates, rpTreeInit, nIters, forest, 0.001f, 0.5f);
    }

    Heap descent(Matrix data, int nNeighbors, Random random, int maxCandidates, boolean rpTreeInit, int nIters, List<FlatTree> forest, float delta, float rho) {
        float d;
        int nVertices = data.rows();
        Heap currentGraph = new Heap(data.rows(), nNeighbors);
        for (int i = 0; i < data.rows(); ++i) {
            float[] iRow = data.row(i);
            for (int index : Utils.rejectionSample(nNeighbors, data.rows(), random)) {
                float d2 = this.mMetric.distance(iRow, data.row(index));
                currentGraph.push(i, d2, index, true);
                currentGraph.push(index, d2, i, true);
            }
        }
        UmapProgress.update();
        if (rpTreeInit) {
            for (FlatTree tree : forest) {
                for (int[] leaf : tree.getIndices()) {
                    for (int i = 0; i < leaf.length; ++i) {
                        float[] iRow = data.row(leaf[i]);
                        for (int j = i + 1; j < leaf.length; ++j) {
                            d = this.mMetric.distance(iRow, data.row(leaf[j]));
                            currentGraph.push(leaf[i], d, leaf[j], true);
                            currentGraph.push(leaf[j], d, leaf[i], true);
                        }
                    }
                }
            }
        }
        UmapProgress.update();
        boolean[] rejectStatus = new boolean[maxCandidates];
        for (int n = 0; n < nIters; ++n) {
            if (this.mVerbose) {
                Utils.message("NearestNeighborDescent: " + (n + 1) + " / " + nIters);
            }
            Heap heap = currentGraph.buildCandidates(nVertices, nNeighbors, maxCandidates, random);
            int c = 0;
            for (int i = 0; i < nVertices; ++i) {
                int j;
                for (j = 0; j < maxCandidates; ++j) {
                    rejectStatus[j] = random.nextFloat() < rho;
                }
                for (j = 0; j < maxCandidates; ++j) {
                    int p = heap.index(i, j);
                    if (p < 0) continue;
                    for (int k = 0; k <= j; ++k) {
                        int q = heap.index(i, k);
                        if (q < 0 || rejectStatus[j] && rejectStatus[k] || !heap.isNew(i, j) && !heap.isNew(i, k)) continue;
                        d = this.mMetric.distance(data.row(p), data.row(q));
                        if (currentGraph.push(p, d, q, true)) {
                            ++c;
                        }
                        if (!currentGraph.push(q, d, p, true)) continue;
                        ++c;
                    }
                }
            }
            if ((float)c <= delta * (float)nNeighbors * (float)data.rows()) {
                UmapProgress.update(nIters - n);
                break;
            }
            UmapProgress.update();
        }
        return currentGraph.deheapSort();
    }

    static Heap initialiseSearch(List<FlatTree> forest, Matrix data, Matrix queryPoints, int nNeighbors, NearestNeighborSearch nn, Random random) {
        Heap results = new Heap(queryPoints.rows(), nNeighbors);
        nn.randomInit(nNeighbors, data, queryPoints, results, random);
        if (forest != null) {
            for (FlatTree tree : forest) {
                nn.treeInit(tree, data, queryPoints, results, random);
            }
        }
        return results;
    }
}

