/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.graph;

import io.github.jbellis.jvector.annotations.VisibleForTesting;
import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.ConcurrentNeighborSet;
import io.github.jbellis.jvector.graph.GraphIndex;
import io.github.jbellis.jvector.graph.GraphSearcher;
import io.github.jbellis.jvector.graph.NodeArray;
import io.github.jbellis.jvector.graph.NodeSimilarity;
import io.github.jbellis.jvector.graph.NodesIterator;
import io.github.jbellis.jvector.graph.OnHeapGraphIndex;
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.graph.SearchResult;
import io.github.jbellis.jvector.util.AtomicFixedBitSet;
import io.github.jbellis.jvector.util.BitSet;
import io.github.jbellis.jvector.util.Bits;
import io.github.jbellis.jvector.util.PhysicalCoreExecutor;
import io.github.jbellis.jvector.util.PoolingSupport;
import io.github.jbellis.jvector.vector.VectorEncoding;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorUtil;
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.HashSet;
import java.util.Objects;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.IntStream;

public class GraphIndexBuilder<T> {
    private final int beamWidth;
    private final PoolingSupport<NodeArray> naturalScratch;
    private final PoolingSupport<NodeArray> concurrentScratch;
    private final VectorSimilarityFunction similarityFunction;
    private final float neighborOverflow;
    private final float alpha;
    private final VectorEncoding vectorEncoding;
    private final PoolingSupport<GraphSearcher<?>> graphSearcher;
    @VisibleForTesting
    final OnHeapGraphIndex<T> graph;
    private final ConcurrentSkipListSet<Integer> insertionsInProgress = new ConcurrentSkipListSet();
    private final PoolingSupport<RandomAccessVectorValues<T>> vectors;
    private final PoolingSupport<RandomAccessVectorValues<T>> vectorsCopy;
    private final int dimension;
    private final NodeSimilarity similarity;
    private final AtomicInteger updateEntryNodeIn = new AtomicInteger(10000);

    public GraphIndexBuilder(RandomAccessVectorValues<T> vectorValues, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction, int M, int beamWidth, float neighborOverflow, float alpha) {
        PoolingSupport<RandomAccessVectorValues<RandomAccessVectorValues<T>>> poolingSupport = vectorValues.isValueShared() ? PoolingSupport.newThreadBased(vectorValues::copy) : (this.vectors = PoolingSupport.newNoPooling(vectorValues));
        this.vectorsCopy = vectorValues.isValueShared() ? PoolingSupport.newThreadBased(vectorValues::copy) : PoolingSupport.newNoPooling(vectorValues);
        this.dimension = vectorValues.dimension();
        this.vectorEncoding = Objects.requireNonNull(vectorEncoding);
        this.similarityFunction = Objects.requireNonNull(similarityFunction);
        this.neighborOverflow = neighborOverflow;
        this.alpha = alpha;
        if (M <= 0) {
            throw new IllegalArgumentException("maxConn must be positive");
        }
        if (beamWidth <= 0) {
            throw new IllegalArgumentException("beamWidth must be positive");
        }
        this.beamWidth = beamWidth;
        this.similarity = node1 -> {
            try (PoolingSupport.Pooled<RandomAccessVectorValues<T>> v = this.vectors.get();){
                PoolingSupport.Pooled<RandomAccessVectorValues<T>> vc = this.vectorsCopy.get();
                try {
                    T v1 = v.get().vectorValue(node1);
                    NodeSimilarity.ExactScoreFunction exactScoreFunction = node2 -> this.scoreBetween(v1, ((RandomAccessVectorValues)vc.get()).vectorValue(node2));
                    if (vc != null) {
                        vc.close();
                    }
                    return exactScoreFunction;
                }
                catch (Throwable throwable) {
                    if (vc != null) {
                        try {
                            vc.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
            }
        };
        this.graph = new OnHeapGraphIndex(M, (node, m) -> new ConcurrentNeighborSet((int)node, (int)m, this.similarity, alpha));
        this.graphSearcher = PoolingSupport.newThreadBased(() -> new GraphSearcher.Builder<T>(this.graph.getView()).withConcurrentUpdates().build());
        this.naturalScratch = PoolingSupport.newThreadBased(() -> new NodeArray(Math.max(beamWidth, M + 1)));
        this.concurrentScratch = PoolingSupport.newThreadBased(() -> new NodeArray(Math.max(beamWidth, M + 1)));
    }

    public OnHeapGraphIndex<T> build() {
        int size;
        try (PoolingSupport.Pooled<RandomAccessVectorValues<T>> v = this.vectors.get();){
            size = v.get().size();
        }
        PhysicalCoreExecutor.instance.execute(() -> IntStream.range(0, size).parallel().forEach(i -> {
            try (PoolingSupport.Pooled<RandomAccessVectorValues<T>> v1 = this.vectors.get();){
                this.addGraphNode(i, v1.get());
            }
        }));
        this.cleanup();
        return this.graph;
    }

    public void cleanup() {
        if (this.graph.size() == 0) {
            return;
        }
        this.graph.validateEntryNode();
        this.removeDeletedNodes();
        IntStream.range(0, this.graph.getIdUpperBound()).parallel().forEach(i -> {
            ConcurrentNeighborSet neighbors = this.graph.getNeighbors(i);
            if (neighbors != null) {
                neighbors.cleanup();
            }
        });
        this.reconnectOrphanedNodes();
        this.graph.updateEntryNode(this.approximateMedioid());
        this.updateEntryNodeIn.set(this.graph.size());
    }

    private void reconnectOrphanedNodes() {
        for (int i = 0; i < 3; ++i) {
            AtomicFixedBitSet connectedNodes = new AtomicFixedBitSet(this.graph.getIdUpperBound());
            connectedNodes.set(this.graph.entry());
            NodeArray entryNeighbors = this.graph.getNeighbors(this.graph.entry()).getCurrent();
            IntStream.range(0, entryNeighbors.size).parallel().forEach(node -> this.findConnected(connectedNodes, entryNeighbors.node[node]));
            AtomicInteger nReconnected = new AtomicInteger();
            try (PoolingSupport.Pooled<GraphSearcher<?>> gs = this.graphSearcher.get();
                 PoolingSupport.Pooled<RandomAccessVectorValues<T>> v1 = this.vectors.get();
                 PoolingSupport.Pooled<RandomAccessVectorValues<T>> v2 = this.vectorsCopy.get();){
                HashSet<Integer> connectionTargets = new HashSet<Integer>();
                for (int node2 = 0; node2 < this.graph.getIdUpperBound(); ++node2) {
                    SearchResult.NodeScore[] result;
                    if (connectedNodes.get(node2) || !this.graph.containsNode(node2)) continue;
                    Bits notSelfBits = GraphIndexBuilder.createNotSelfBits(node2);
                    T value = v1.get().vectorValue(node2);
                    NodeSimilarity.ExactScoreFunction scoreFunction = i1 -> this.scoreBetween(((RandomAccessVectorValues)v2.get()).vectorValue(i1), value);
                    for (SearchResult.NodeScore ns : result = gs.get().searchInternal(scoreFunction, null, this.beamWidth, 0.0f, this.graph.entry(), notSelfBits).getNodes()) {
                        if (!connectionTargets.add(ns.node)) continue;
                        this.graph.getNeighbors(ns.node).insertNotDiverse(node2, ns.score, true);
                        break;
                    }
                    nReconnected.incrementAndGet();
                }
            }
            if (nReconnected.get() == 0) break;
        }
    }

    private void findConnected(AtomicFixedBitSet connectedNodes, int start) {
        ArrayDeque<Integer> queue = new ArrayDeque<Integer>();
        queue.add(start);
        GraphIndex.View<T> view = this.graph.getView();
        while (!queue.isEmpty()) {
            int next = (Integer)queue.pop();
            if (connectedNodes.getAndSet(next)) continue;
            NodesIterator it = view.getNeighborsIterator(next);
            while (it.hasNext()) {
                queue.add(it.nextInt());
            }
        }
    }

    public OnHeapGraphIndex<T> getGraph() {
        return this.graph;
    }

    public int insertsInProgress() {
        return this.insertionsInProgress.size();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public long addGraphNode(int node, RandomAccessVectorValues<T> vectors) {
        T value = vectors.vectorValue(node);
        ConcurrentNeighborSet newNodeNeighbors = this.graph.addNode(node);
        this.insertionsInProgress.add(node);
        Object inProgressBefore = this.insertionsInProgress.clone();
        try (PoolingSupport.Pooled<GraphSearcher<?>> gs = this.graphSearcher.get();
             PoolingSupport.Pooled<RandomAccessVectorValues<T>> vc = this.vectorsCopy.get();
             PoolingSupport.Pooled<NodeArray> naturalScratchPooled = this.naturalScratch.get();
             PoolingSupport.Pooled<NodeArray> concurrentScratchPooled = this.concurrentScratch.get();){
            int ep = this.graph.entry();
            NodeSimilarity.ExactScoreFunction scoreFunction = i -> this.scoreBetween(((RandomAccessVectorValues)vc.get()).vectorValue(i), value);
            ExcludingBits bits = new ExcludingBits(node);
            SearchResult result = gs.get().searchInternal(scoreFunction, null, this.beamWidth, 0.0f, ep, bits);
            NodeArray natural = this.toScratchCandidates(result.getNodes(), result.getNodes().length, naturalScratchPooled.get());
            NodeArray concurrent = this.getConcurrentCandidates(node, (Set<Integer>)inProgressBefore, concurrentScratchPooled.get(), vectors, vc.get());
            this.updateNeighbors(newNodeNeighbors, natural, concurrent);
            this.maybeUpdateEntryPoint(node);
            this.maybeImproveOlderNode();
        }
        finally {
            this.insertionsInProgress.remove(node);
        }
        return this.graph.ramBytesUsedOneNode(0);
    }

    private void maybeImproveOlderNode() {
        if (this.dimension <= 3 && this.graph.size() > 20000) {
            for (int i = 0; i < 3; ++i) {
                int olderNode = ThreadLocalRandom.current().nextInt(this.graph.size());
                if (!this.graph.containsNode(olderNode)) continue;
                this.improveConnections(olderNode);
                break;
            }
        }
    }

    private void maybeUpdateEntryPoint(int node) {
        this.graph.maybeSetInitialEntryNode(node);
        if (this.updateEntryNodeIn.decrementAndGet() == 0) {
            int newEntryNode = this.approximateMedioid();
            this.graph.updateEntryNode(newEntryNode);
            this.improveConnections(newEntryNode);
            this.updateEntryNodeIn.addAndGet(this.graph.size());
        }
    }

    public void improveConnections(int node) {
        try (PoolingSupport.Pooled<RandomAccessVectorValues<T>> pv = this.vectors.get();
             PoolingSupport.Pooled<GraphSearcher<?>> gs = this.graphSearcher.get();
             PoolingSupport.Pooled<RandomAccessVectorValues<T>> vc = this.vectorsCopy.get();
             PoolingSupport.Pooled<NodeArray> naturalScratchPooled = this.naturalScratch.get();){
            T value = pv.get().vectorValue(node);
            int ep = this.graph.entry();
            NodeSimilarity.ExactScoreFunction scoreFunction = i -> this.scoreBetween(((RandomAccessVectorValues)vc.get()).vectorValue(i), value);
            ExcludingBits bits = new ExcludingBits(node);
            SearchResult result = gs.get().searchInternal(scoreFunction, null, this.beamWidth, 0.0f, ep, bits);
            NodeArray natural = this.toScratchCandidates(result.getNodes(), result.getNodes().length, naturalScratchPooled.get());
            this.updateNeighbors(this.graph.getNeighbors(node), natural, NodeArray.EMPTY);
        }
    }

    public void markNodeDeleted(int node) {
        this.graph.markDeleted(node);
    }

    private long removeDeletedNodes() {
        BitSet deletedNodes = this.graph.getDeletedNodes();
        int nRemoved = deletedNodes.cardinality();
        if (nRemoved == 0) {
            return 0L;
        }
        int i = deletedNodes.nextSetBit(0);
        while (i != Integer.MAX_VALUE) {
            boolean success = this.graph.removeNode(i);
            assert (success) : String.format("Node %d marked deleted but not present", i);
            i = deletedNodes.nextSetBit(i + 1);
        }
        int[] liveNodes = this.graph.rawNodes();
        HashSet<Integer> affectedLiveNodes = new HashSet<Integer>();
        Random R = new Random();
        try (PoolingSupport.Pooled<RandomAccessVectorValues<T>> v1 = this.vectors.get();
             PoolingSupport.Pooled<RandomAccessVectorValues<T>> v2 = this.vectorsCopy.get();){
            for (int node : liveNodes) {
                assert (!deletedNodes.get(node));
                ConcurrentNeighborSet neighbors = this.graph.getNeighbors(node);
                if (neighbors.removeDeletedNeighbors(deletedNodes)) {
                    affectedLiveNodes.add(node);
                }
                int minConnections = 1 + this.graph.maxDegree() / 2;
                if (neighbors.size() >= minConnections) continue;
                NodeArray randomConnections = new NodeArray(this.graph.maxDegree() - neighbors.size());
                for (int i2 = 0; i2 < 2 * this.graph.maxDegree(); ++i2) {
                    int randomNode = liveNodes[R.nextInt(liveNodes.length)];
                    if (randomNode != node && !randomConnections.contains(randomNode)) {
                        float score = this.scoreBetween(v1.get().vectorValue(node), v2.get().vectorValue(randomNode));
                        randomConnections.insertSorted(randomNode, score);
                    }
                    if (randomConnections.size == randomConnections.node.length) break;
                }
                neighbors.padWithRandom(randomConnections);
            }
        }
        if (deletedNodes.get(this.graph.entry())) {
            if (this.graph.size() > 0) {
                this.graph.updateEntryNode(this.graph.getNodes().nextInt());
            } else {
                this.graph.updateEntryNode(-1);
            }
        }
        for (Integer node : affectedLiveNodes) {
            this.addNNDescentConnections(node);
        }
        deletedNodes.clear();
        return (long)nRemoved * this.graph.ramBytesUsedOneNode(0);
    }

    private void addNNDescentConnections(int node) {
        Bits notSelfBits = GraphIndexBuilder.createNotSelfBits(node);
        try (PoolingSupport.Pooled<GraphSearcher<?>> gs = this.graphSearcher.get();
             PoolingSupport.Pooled<RandomAccessVectorValues<T>> v1 = this.vectors.get();
             PoolingSupport.Pooled<RandomAccessVectorValues<T>> v2 = this.vectorsCopy.get();
             PoolingSupport.Pooled<NodeArray> scratch = this.naturalScratch.get();){
            T value = v1.get().vectorValue(node);
            NodeSimilarity.ExactScoreFunction scoreFunction = i -> this.scoreBetween(((RandomAccessVectorValues)v2.get()).vectorValue(i), value);
            SearchResult result = gs.get().searchInternal(scoreFunction, null, this.beamWidth, 0.0f, this.graph.entry(), notSelfBits);
            NodeArray candidates = this.toScratchCandidates(result.getNodes(), result.getNodes().length, scratch.get());
            this.updateNeighbors(this.graph.getNeighbors(node), candidates, NodeArray.EMPTY);
        }
    }

    private static Bits createNotSelfBits(final int node) {
        return new Bits(){

            @Override
            public boolean get(int index) {
                return index != node;
            }

            @Override
            public int length() {
                throw new UnsupportedOperationException();
            }
        };
    }

    private int approximateMedioid() {
        assert (this.graph.size() > 0);
        if (this.vectorEncoding != VectorEncoding.FLOAT32) {
            return this.graph.entry();
        }
        try (PoolingSupport.Pooled<GraphSearcher<?>> gs = this.graphSearcher.get();){
            PoolingSupport.Pooled<RandomAccessVectorValues<T>> vc = this.vectorsCopy.get();
            try {
                float[] centroid = new float[this.dimension];
                NodesIterator it = this.graph.getNodes();
                while (it.hasNext()) {
                    int node = it.nextInt();
                    VectorUtil.addInPlace(centroid, (float[])vc.get().vectorValue(node));
                }
                VectorUtil.divInPlace(centroid, this.graph.size());
                NodeSimilarity.ExactScoreFunction scoreFunction = i -> this.scoreBetween(((RandomAccessVectorValues)vc.get()).vectorValue(i), centroid);
                SearchResult result = gs.get().searchInternal(scoreFunction, null, this.beamWidth, 0.0f, this.graph.entry(), Bits.ALL);
                int n = result.getNodes()[0].node;
                if (vc != null) {
                    vc.close();
                }
                return n;
            }
            catch (Throwable throwable) {
                if (vc != null) {
                    try {
                        vc.close();
                    }
                    catch (Throwable throwable2) {
                        throwable.addSuppressed(throwable2);
                    }
                }
                throw throwable;
            }
        }
    }

    private void updateNeighbors(ConcurrentNeighborSet neighbors, NodeArray natural, NodeArray concurrent) {
        neighbors.insertDiverse(natural, concurrent);
        neighbors.backlink(this.graph::getNeighbors, this.neighborOverflow);
    }

    private NodeArray toScratchCandidates(SearchResult.NodeScore[] candidates, int count, NodeArray scratch) {
        scratch.clear();
        for (int i = 0; i < count; ++i) {
            SearchResult.NodeScore candidate = candidates[i];
            scratch.addInOrder(candidate.node, candidate.score);
        }
        return scratch;
    }

    private NodeArray getConcurrentCandidates(int newNode, Set<Integer> inProgress, NodeArray scratch, RandomAccessVectorValues<T> values, RandomAccessVectorValues<T> valuesCopy) {
        scratch.clear();
        for (Integer n : inProgress) {
            if (n == newNode) continue;
            scratch.insertSorted(n, this.scoreBetween(values.vectorValue(newNode), valuesCopy.vectorValue(n)));
        }
        return scratch;
    }

    protected float scoreBetween(T v1, T v2) {
        return GraphIndexBuilder.scoreBetween(this.vectorEncoding, this.similarityFunction, v1, v2);
    }

    static <T> float scoreBetween(VectorEncoding encoding, VectorSimilarityFunction similarityFunction, T v1, T v2) {
        switch (encoding) {
            case BYTE: {
                return similarityFunction.compare((byte[])v1, (byte[])v2);
            }
            case FLOAT32: {
                return similarityFunction.compare((float[])v1, (float[])v2);
            }
        }
        throw new IllegalArgumentException();
    }

    public void load(RandomAccessReader in) throws IOException {
        if (this.graph.size() != 0) {
            throw new IllegalStateException("Cannot load into a non-empty graph");
        }
        int size = in.readInt();
        int entryNode = in.readInt();
        int maxDegree = in.readInt();
        for (int i = 0; i < size; ++i) {
            int node = in.readInt();
            int nNeighbors = in.readInt();
            NodeArray ca = new NodeArray(maxDegree);
            for (int j = 0; j < nNeighbors; ++j) {
                int neighbor = in.readInt();
                ca.addInOrder(neighbor, this.similarity.score(node, neighbor));
            }
            this.graph.addNode(node, new ConcurrentNeighborSet(node, maxDegree, this.similarity, this.alpha, ca));
        }
        this.graph.updateEntryNode(entryNode);
    }

    private static class ExcludingBits
    implements Bits {
        private final int excluded;

        public ExcludingBits(int excluded) {
            this.excluded = excluded;
        }

        @Override
        public boolean get(int index) {
            return index != this.excluded;
        }

        @Override
        public int length() {
            throw new UnsupportedOperationException();
        }
    }
}

