/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.graph;

import io.github.jbellis.jvector.annotations.VisibleForTesting;
import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.ConcurrentNeighborSet;
import io.github.jbellis.jvector.graph.GraphSearcher;
import io.github.jbellis.jvector.graph.NodeArray;
import io.github.jbellis.jvector.graph.NodesIterator;
import io.github.jbellis.jvector.graph.OnHeapGraphIndex;
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.graph.SearchResult;
import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider;
import io.github.jbellis.jvector.util.AtomicFixedBitSet;
import io.github.jbellis.jvector.util.Bits;
import io.github.jbellis.jvector.util.ExplicitThreadLocal;
import io.github.jbellis.jvector.util.PhysicalCoreExecutor;
import io.github.jbellis.jvector.util.ThreadSafeGrowableBitSet;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorUtil;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.agrona.collections.IntArrayList;
import org.agrona.collections.IntArrayQueue;

public class GraphIndexBuilder
implements AutoCloseable {
    private final int beamWidth;
    private final ExplicitThreadLocal<NodeArray> naturalScratch;
    private final ExplicitThreadLocal<NodeArray> concurrentScratch;
    private final int dimension;
    private final float neighborOverflow;
    private final float alpha;
    @VisibleForTesting
    final OnHeapGraphIndex graph;
    private final ConcurrentSkipListSet<Integer> insertionsInProgress = new ConcurrentSkipListSet();
    private BuildScoreProvider scoreProvider;
    private final ForkJoinPool simdExecutor;
    private final ForkJoinPool parallelExecutor;
    private final ExplicitThreadLocal<GraphSearcher> searchers;
    private final AtomicInteger updateEntryNodeIn = new AtomicInteger(10000);

    public GraphIndexBuilder(RandomAccessVectorValues vectorValues, VectorSimilarityFunction similarityFunction, int M, int beamWidth, float neighborOverflow, float alpha) {
        this(BuildScoreProvider.randomAccessScoreProvider(vectorValues, similarityFunction), vectorValues.dimension(), M, beamWidth, neighborOverflow, alpha, PhysicalCoreExecutor.pool(), ForkJoinPool.commonPool());
    }

    public GraphIndexBuilder(BuildScoreProvider scoreProvider, int dimension, int M, int beamWidth, float neighborOverflow, float alpha, ForkJoinPool simdExecutor, ForkJoinPool parallelExecutor) {
        this.scoreProvider = scoreProvider;
        this.dimension = dimension;
        this.neighborOverflow = neighborOverflow;
        this.alpha = alpha;
        if (M <= 0) {
            throw new IllegalArgumentException("maxConn must be positive");
        }
        if (beamWidth <= 0) {
            throw new IllegalArgumentException("beamWidth must be positive");
        }
        this.beamWidth = beamWidth;
        this.simdExecutor = simdExecutor;
        this.parallelExecutor = parallelExecutor;
        this.graph = new OnHeapGraphIndex(M, (node, m) -> new ConcurrentNeighborSet((int)node, (int)m, this.scoreProvider, alpha));
        this.searchers = ExplicitThreadLocal.withInitial(() -> new GraphSearcher(this.graph));
        this.naturalScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(Math.max(beamWidth, M + 1)));
        this.concurrentScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(Math.max(beamWidth, M + 1)));
    }

    public void setBuildScoreProvider(BuildScoreProvider bsp) {
        this.scoreProvider = bsp;
    }

    public OnHeapGraphIndex build(RandomAccessVectorValues ravv) {
        Supplier<RandomAccessVectorValues> vv = ravv.threadLocalSupplier();
        int size = ravv.size();
        ((ForkJoinTask)this.simdExecutor.submit(() -> IntStream.range(0, size).parallel().forEach(arg_0 -> this.lambda$build$4((Supplier)vv, arg_0)))).join();
        this.cleanup();
        return this.graph;
    }

    public void cleanup() {
        if (this.graph.size() == 0) {
            return;
        }
        this.graph.validateEntryNode();
        this.removeDeletedNodes();
        if (this.graph.size() == 0) {
            return;
        }
        ((ForkJoinTask)this.parallelExecutor.submit(() -> IntStream.range(0, this.graph.getIdUpperBound()).parallel().forEach(i -> {
            ConcurrentNeighborSet neighbors = this.graph.getNeighbors(i);
            if (neighbors != null) {
                neighbors.enforceDegree();
            }
        }))).join();
        this.reconnectOrphanedNodes();
        this.updateEntryPoint();
    }

    private void reconnectOrphanedNodes() {
        ConcurrentHashMap searchPathNeighbors = new ConcurrentHashMap();
        for (int i = 0; i < 5; ++i) {
            AtomicFixedBitSet connectedNodes = new AtomicFixedBitSet(this.graph.getIdUpperBound());
            connectedNodes.set(this.graph.entry());
            NodeArray entryNeighbors = this.graph.getNeighbors(this.graph.entry()).getCurrent();
            ((ForkJoinTask)this.parallelExecutor.submit(() -> IntStream.range(0, entryNeighbors.size).parallel().forEach(node -> this.findConnected(connectedNodes, entryNeighbors.node[node])))).join();
            AtomicInteger nReconnected = new AtomicInteger();
            ConcurrentHashMap.KeySetView connectionTargets = ConcurrentHashMap.newKeySet();
            ((ForkJoinTask)this.simdExecutor.submit(() -> IntStream.range(0, this.graph.getIdUpperBound()).parallel().forEach(node -> {
                if (connectedNodes.get(node) || !this.graph.containsNode(node)) {
                    return;
                }
                nReconnected.incrementAndGet();
                NodeArray neighbors = this.graph.getNeighbors(node).getCurrent();
                if (this.connectToClosestNeighbor(node, neighbors, connectionTargets)) {
                    return;
                }
                neighbors = (NodeArray)searchPathNeighbors.get(node);
                if (neighbors == null) {
                    SearchResult result;
                    try (GraphSearcher gs = this.searchers.get();){
                        Bits notSelfBits = GraphIndexBuilder.createNotSelfBits(node);
                        SearchScoreProvider ssp = this.scoreProvider.searchProviderFor(node);
                        int ep = this.graph.entry();
                        result = gs.searchInternal(ssp, this.beamWidth, 0.0f, 0.0f, ep, notSelfBits);
                    }
                    catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                    neighbors = new NodeArray(result.getNodes().length);
                    GraphIndexBuilder.toScratchCandidates(result.getNodes(), neighbors);
                    searchPathNeighbors.put(node, neighbors);
                }
                this.connectToClosestNeighbor(node, neighbors, connectionTargets);
            }))).join();
            if (nReconnected.get() == 0) break;
        }
    }

    private boolean connectToClosestNeighbor(int node, NodeArray neighbors, Set<Integer> connectionTargets) {
        for (int i = 0; i < neighbors.size; ++i) {
            int neighborNode = neighbors.node[i];
            float neighborScore = neighbors.score[i];
            if (!connectionTargets.add(neighborNode)) continue;
            this.graph.getNeighbors(neighborNode).insertNotDiverse(node, neighborScore);
            return true;
        }
        return false;
    }

    private void findConnected(AtomicFixedBitSet connectedNodes, int start) {
        IntArrayQueue queue = new IntArrayQueue();
        queue.add(Integer.valueOf(start));
        try (OnHeapGraphIndex.ConcurrentGraphIndexView view = this.graph.getView();){
            while (!queue.isEmpty()) {
                int next = queue.pollInt();
                if (connectedNodes.getAndSet(next)) continue;
                NodesIterator it = view.getNeighborsIterator(next);
                while (it.hasNext()) {
                    queue.addInt(it.nextInt());
                }
            }
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public OnHeapGraphIndex getGraph() {
        return this.graph;
    }

    public int insertsInProgress() {
        return this.insertionsInProgress.size();
    }

    @Deprecated
    public long addGraphNode(int node, RandomAccessVectorValues ravv) {
        return this.addGraphNode(node, ravv.getVector(node));
    }

    public long addGraphNode(int node, VectorFloat<?> vector) {
        ConcurrentNeighborSet newNodeNeighbors = this.graph.addNode(node);
        this.insertionsInProgress.add(node);
        Object inProgressBefore = this.insertionsInProgress.clone();
        try (GraphSearcher gs = this.searchers.get();){
            NodeArray naturalScratchPooled = this.naturalScratch.get();
            NodeArray concurrentScratchPooled = this.concurrentScratch.get();
            int ep = this.graph.entry();
            ExcludingBits bits = new ExcludingBits(node);
            SearchScoreProvider ssp = this.scoreProvider.searchProviderFor(vector);
            SearchResult result = gs.searchInternal(ssp, this.beamWidth, 0.0f, 0.0f, ep, bits);
            NodeArray natural = GraphIndexBuilder.toScratchCandidates(result.getNodes(), naturalScratchPooled);
            NodeArray concurrent = this.getConcurrentCandidates(node, (Set<Integer>)inProgressBefore, concurrentScratchPooled, ssp.scoreFunction());
            this.updateNeighbors(newNodeNeighbors, natural, concurrent);
            this.maybeUpdateEntryPoint(node);
            this.maybeImproveOlderNode();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        finally {
            this.insertionsInProgress.remove(node);
        }
        return this.graph.ramBytesUsedOneNode();
    }

    private void maybeImproveOlderNode() {
        if (this.dimension <= 3 && this.graph.size() > 20000) {
            for (int i = 0; i < 3; ++i) {
                int olderNode = ThreadLocalRandom.current().nextInt(this.graph.size());
                if (!this.graph.containsNode(olderNode) || this.graph.getDeletedNodes().get(olderNode)) continue;
                this.improveConnections(olderNode);
                break;
            }
        }
    }

    private void maybeUpdateEntryPoint(int node) {
        this.graph.maybeSetInitialEntryNode(node);
        if (this.updateEntryNodeIn.decrementAndGet() == 0) {
            this.updateEntryPoint();
        }
    }

    private void updateEntryPoint() {
        int newEntryNode = this.approximateMedioid();
        this.graph.updateEntryNode(newEntryNode);
        if (newEntryNode >= 0) {
            this.improveConnections(newEntryNode);
            this.updateEntryNodeIn.addAndGet(this.graph.size());
        } else {
            this.updateEntryNodeIn.addAndGet(10000);
        }
    }

    public void improveConnections(int node) {
        SearchResult result;
        NodeArray naturalScratchPooled;
        try (GraphSearcher gs = this.searchers.get();){
            naturalScratchPooled = this.naturalScratch.get();
            int ep = this.graph.entry();
            ExcludingBits bits = new ExcludingBits(node);
            SearchScoreProvider ssp = this.scoreProvider.searchProviderFor(node);
            result = gs.searchInternal(ssp, this.beamWidth, 0.0f, 0.0f, ep, bits);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        NodeArray natural = GraphIndexBuilder.toScratchCandidates(result.getNodes(), naturalScratchPooled);
        ConcurrentNeighborSet neighbors = this.graph.getNeighbors(node);
        neighbors.insertDiverse(natural);
        neighbors.backlink(this.graph::getNeighbors, 1.0f);
    }

    public void markNodeDeleted(int node) {
        this.graph.markDeleted(node);
    }

    public synchronized long removeDeletedNodes() {
        ThreadSafeGrowableBitSet toDelete = this.graph.getDeletedNodes().copy();
        int nRemoved = toDelete.cardinality();
        if (nRemoved == 0) {
            return 0L;
        }
        IntArrayList liveNodes = new IntArrayList();
        for (int i = 0; i < this.graph.getIdUpperBound(); ++i) {
            if (!this.graph.containsNode(i) || toDelete.get(i)) continue;
            liveNodes.add(Integer.valueOf(i));
        }
        ConcurrentHashMap newEdges = new ConcurrentHashMap();
        ((ForkJoinTask)this.parallelExecutor.submit(() -> IntStream.range(0, this.graph.getIdUpperBound()).parallel().forEach(i -> {
            ConcurrentNeighborSet neighbors = this.graph.getNeighbors(i);
            if (neighbors == null || toDelete.get(i)) {
                return;
            }
            NodesIterator it = neighbors.iterator();
            while (it.hasNext()) {
                int j = it.nextInt();
                if (!toDelete.get(j)) continue;
                Set newEdgesForI = newEdges.computeIfAbsent(i, __ -> ConcurrentHashMap.newKeySet());
                NodesIterator jt = this.graph.getNeighbors(j).iterator();
                while (jt.hasNext()) {
                    int k = jt.nextInt();
                    if (i == k || toDelete.get(k)) continue;
                    newEdgesForI.add(k);
                }
            }
        }))).join();
        ((ForkJoinTask)this.simdExecutor.submit(() -> ((Stream)newEdges.entrySet().stream().parallel()).forEach(e -> {
            int node = (Integer)e.getKey();
            ScoreFunction sf = this.scoreProvider.searchProviderFor(node).scoreFunction();
            ConcurrentNeighborSet neighbors = this.graph.getNeighbors(node);
            NodeArray candidates = new NodeArray(this.graph.maxDegree);
            for (Integer k : (Set)e.getValue()) {
                candidates.insertSorted(k, sf.similarityTo(k));
            }
            if (candidates.size() == 0) {
                ThreadLocalRandom R = ThreadLocalRandom.current();
                for (int i = 0; i < 2 * this.graph.maxDegree(); ++i) {
                    int randomNode = liveNodes.get(R.nextInt(liveNodes.size()));
                    if (randomNode != node && !candidates.contains(randomNode)) {
                        float score = sf.similarityTo(randomNode);
                        candidates.insertSorted(randomNode, score);
                    }
                    if (candidates.size == this.graph.maxDegree) break;
                }
            }
            neighbors.replaceDeletedNeighbors(toDelete, candidates);
        }))).join();
        if (toDelete.get(this.graph.entry())) {
            this.updateEntryPoint();
        }
        assert (toDelete.cardinality() == nRemoved) : "cardinality changed";
        int i = toDelete.nextSetBit(0);
        while (i != Integer.MAX_VALUE) {
            this.graph.removeNode(i);
            i = toDelete.nextSetBit(i + 1);
        }
        return (long)nRemoved * this.graph.ramBytesUsedOneNode();
    }

    private static Bits createNotSelfBits(final int node) {
        return new Bits(){

            @Override
            public boolean get(int index) {
                return index != node;
            }
        };
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    private int approximateMedioid() {
        if (this.graph.size() == 0) {
            return -1;
        }
        VectorFloat<?> centroid = this.scoreProvider.approximateCentroid();
        if ((double)VectorUtil.dotProduct(centroid, centroid) < 1.0E-6) {
            return this.randomLiveNode();
        }
        int ep = this.graph.entry();
        SearchScoreProvider ssp = this.scoreProvider.searchProviderFor(centroid);
        try (GraphSearcher gs = this.searchers.get();){
            SearchResult result = gs.searchInternal(ssp, this.beamWidth, 0.0f, 0.0f, ep, Bits.ALL);
            if (result.getNodes().length == 0) {
                int n2 = -1;
                return n2;
            }
            int n = result.getNodes()[0].node;
            return n;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private void updateNeighbors(ConcurrentNeighborSet neighbors, NodeArray natural, NodeArray concurrent) {
        NodeArray toMerge = concurrent.size == 0 ? natural : (natural.size == 0 ? concurrent : NodeArray.merge(natural, concurrent));
        neighbors.insertDiverse(toMerge);
        neighbors.backlink(this.graph::getNeighbors, this.neighborOverflow);
    }

    private static NodeArray toScratchCandidates(SearchResult.NodeScore[] candidates, NodeArray scratch) {
        scratch.clear();
        for (SearchResult.NodeScore candidate : candidates) {
            scratch.addInOrder(candidate.node, candidate.score);
        }
        return scratch;
    }

    private NodeArray getConcurrentCandidates(int newNode, Set<Integer> inProgress, NodeArray scratch, ScoreFunction scoreFunction) {
        scratch.clear();
        for (Integer n : inProgress) {
            if (n == newNode) continue;
            scratch.insertSorted(n, scoreFunction.similarityTo(n));
        }
        return scratch;
    }

    @Override
    public void close() throws Exception {
        this.searchers.close();
    }

    @VisibleForTesting
    int randomLiveNode() {
        ThreadLocalRandom R = ThreadLocalRandom.current();
        for (int i = 0; i < 3; ++i) {
            int idUpperBound = this.graph.getIdUpperBound();
            if (idUpperBound == 0) {
                return -1;
            }
            int n = R.nextInt(idUpperBound);
            if (!this.graph.containsNode(n) || this.graph.getDeletedNodes().get(n)) continue;
            return n;
        }
        ArrayList<Integer> L = new ArrayList<Integer>();
        for (int i = 0; i < this.graph.getIdUpperBound(); ++i) {
            if (!this.graph.containsNode(i) || this.graph.getDeletedNodes().get(i)) continue;
            L.add(i);
        }
        if (L.isEmpty()) {
            return -1;
        }
        return (Integer)L.get(R.nextInt(L.size()));
    }

    @VisibleForTesting
    void validateAllNodesLive() {
        assert (this.graph.getDeletedNodes().cardinality() == 0);
        for (int i = 0; i < this.graph.getIdUpperBound(); ++i) {
            if (!this.graph.containsNode(i)) continue;
            ConcurrentNeighborSet neighbors = this.graph.getNeighbors(i);
            NodesIterator it = neighbors.iterator();
            while (it.hasNext()) {
                int j = it.nextInt();
                assert (this.graph.containsNode(j)) : String.format("Edge %d -> %d is invalid", i, j);
            }
        }
    }

    public void load(RandomAccessReader in) throws IOException {
        if (this.graph.size() != 0) {
            throw new IllegalStateException("Cannot load into a non-empty graph");
        }
        int size = in.readInt();
        int entryNode = in.readInt();
        int maxDegree = in.readInt();
        for (int i = 0; i < size; ++i) {
            int node = in.readInt();
            int nNeighbors = in.readInt();
            ScoreFunction.ExactScoreFunction sf = this.scoreProvider.searchProviderFor(node).exactScoreFunction();
            NodeArray ca = new NodeArray(maxDegree);
            for (int j = 0; j < nNeighbors; ++j) {
                int neighbor = in.readInt();
                ca.addInOrder(neighbor, sf.similarityTo(neighbor));
            }
            this.graph.addNode(node, new ConcurrentNeighborSet(node, maxDegree, this.scoreProvider, this.alpha, ca));
        }
        this.graph.updateEntryNode(entryNode);
    }

    private /* synthetic */ void lambda$build$4(Supplier vv, int node) {
        this.addGraphNode(node, ((RandomAccessVectorValues)vv.get()).getVector(node));
    }

    private static class ExcludingBits
    implements Bits {
        private final int excluded;

        public ExcludingBits(int excluded) {
            this.excluded = excluded;
        }

        @Override
        public boolean get(int index) {
            return index != this.excluded;
        }
    }
}

