/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.graph;

import io.github.jbellis.jvector.graph.NodeArray;
import io.github.jbellis.jvector.graph.NodesIterator;
import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider;
import io.github.jbellis.jvector.util.BitSet;
import io.github.jbellis.jvector.util.Bits;
import io.github.jbellis.jvector.util.FixedBitSet;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.IntFunction;

public class ConcurrentNeighborSet {
    private final int nodeId;
    private final AtomicReference<Neighbors> neighborsRef;
    private final float alpha;
    private final BuildScoreProvider scoreProvider;
    private final int maxConnections;
    private float shortEdges = Float.NaN;

    public ConcurrentNeighborSet(int nodeId, int maxConnections, BuildScoreProvider scoreProvider) {
        this(nodeId, maxConnections, scoreProvider, 1.0f);
    }

    public ConcurrentNeighborSet(int nodeId, int maxConnections, BuildScoreProvider scoreProvider, float alpha) {
        this(nodeId, maxConnections, scoreProvider, alpha, new NodeArray(maxConnections));
    }

    ConcurrentNeighborSet(int nodeId, int maxConnections, BuildScoreProvider scoreProvider, float alpha, NodeArray nodes) {
        this.nodeId = nodeId;
        this.maxConnections = maxConnections;
        this.scoreProvider = scoreProvider;
        this.alpha = alpha;
        this.neighborsRef = new AtomicReference<Neighbors>(new Neighbors(nodes, 0));
    }

    public float getShortEdges() {
        return this.shortEdges;
    }

    public NodesIterator iterator() {
        return new NeighborIterator(this.neighborsRef.get().nodes);
    }

    public void backlink(IntFunction<ConcurrentNeighborSet> neighborhoodOf, float overflow) {
        NodeArray neighbors = this.neighborsRef.get().nodes;
        for (int i = 0; i < neighbors.size(); ++i) {
            int nbr = neighbors.node[i];
            float nbrScore = neighbors.score[i];
            ConcurrentNeighborSet nbrNbr = neighborhoodOf.apply(nbr);
            assert (nbrNbr != null) : "Node " + nbr + " not found";
            nbrNbr.insert(this.nodeId, nbrScore, overflow);
        }
    }

    public void enforceDegree() {
        this.neighborsRef.getAndUpdate(old -> {
            NodeArray nodes = this.removeAllNonDiverse(old.nodes, old.diverseBefore);
            return new Neighbors(nodes, nodes.size);
        });
    }

    public void replaceDeletedNeighbors(Bits deletedNodes, NodeArray candidates) {
        this.neighborsRef.getAndUpdate(old -> {
            NodeArray liveNeighbors = new NodeArray(old.nodes.size);
            for (int i = 0; i < old.nodes.size(); ++i) {
                int node = old.nodes.node[i];
                if (deletedNodes.get(node)) continue;
                liveNeighbors.addInOrder(node, old.nodes.score[i]);
            }
            NodeArray merged = this.rescoreAndMerge(liveNeighbors, candidates);
            this.retainDiverse(merged, 0, this.scoreProvider.isExact());
            return new Neighbors(merged, merged.size);
        });
    }

    public int size() {
        return this.neighborsRef.get().nodes.size();
    }

    public void insertDiverse(NodeArray toMerge) {
        if (toMerge.size() == 0) {
            return;
        }
        this.neighborsRef.getAndUpdate(old -> {
            NodeArray merged = old.nodes.size > 0 ? this.rescoreAndMerge(old.nodes, toMerge) : toMerge.copy();
            this.retainDiverse(merged, 0, this.scoreProvider.isExact());
            return new Neighbors(merged, merged.size);
        });
    }

    private NodeArray rescoreAndMerge(NodeArray old, NodeArray toMerge) {
        NodeArray merged;
        if (this.scoreProvider.isExact()) {
            merged = NodeArray.merge(old, toMerge);
        } else {
            NodeArray approximatedOld = this.computeApproximatelyScored(old);
            merged = NodeArray.merge(approximatedOld, toMerge);
        }
        return merged;
    }

    private NodeArray computeApproximatelyScored(NodeArray exact) {
        NodeArray approximated = new NodeArray(exact.size);
        ScoreFunction sf = this.scoreProvider.diversityProvider().createFor(this.nodeId).scoreFunction();
        assert (!sf.isExact());
        for (int i = 0; i < exact.size; ++i) {
            approximated.insertSorted(exact.node[i], sf.similarityTo(exact.node[i]));
        }
        return approximated;
    }

    void insertNotDiverse(int node, float score) {
        this.neighborsRef.getAndUpdate(old -> {
            NodeArray nextNodes = old.nodes.copy();
            nextNodes.size = Math.min(nextNodes.size, this.maxConnections - 1);
            int insertedAt = nextNodes.insertSorted(node, score);
            if (insertedAt == -1) {
                return old;
            }
            return new Neighbors(nextNodes, Math.min(insertedAt, old.diverseBefore));
        });
    }

    private void retainDiverse(NodeArray neighbors, int diverseBefore, boolean isExactScored) {
        FixedBitSet selected = new FixedBitSet(neighbors.size());
        for (int i = 0; i < Math.min(diverseBefore, this.maxConnections); ++i) {
            ((BitSet)selected).set(i);
        }
        SearchScoreProvider.Factory dp = this.scoreProvider.diversityProvider();
        if (isExactScored) {
            this.retainDiverseInternal(neighbors, this.maxConnections, diverseBefore, selected, node1 -> dp.createFor(node1).exactScoreFunction());
            neighbors.retain(selected);
        } else {
            assert (!this.scoreProvider.isExact());
            assert (diverseBefore == 0);
            this.retainDiverseInternal(neighbors, this.maxConnections, 0, selected, node1 -> dp.createFor(node1).scoreFunction());
            int[] neighborNodes = Arrays.copyOf(neighbors.node, neighbors.size);
            ScoreFunction.ExactScoreFunction sf = dp.createFor(this.nodeId).exactScoreFunction();
            neighbors.clear();
            int i = ((BitSet)selected).nextSetBit(0);
            while (i != Integer.MAX_VALUE) {
                int neighborId = neighborNodes[i];
                float score = sf.similarityTo(neighborId);
                neighbors.insertSorted(neighborId, score);
                i = ((BitSet)selected).nextSetBit(i + 1);
            }
        }
    }

    private void retainDiverseInternal(NodeArray neighbors, int max, int diverseBefore, BitSet selected, ScoreFunction.Provider scoreProvider) {
        int nSelected = diverseBefore;
        float a = 1.0f;
        while ((double)a <= (double)this.alpha + 1.0E-6 && nSelected < max) {
            for (int i = diverseBefore; i < neighbors.size() && nSelected < max; ++i) {
                ScoreFunction sf;
                float cScore;
                int cNode;
                if (selected.get(i) || !this.isDiverse(cNode = neighbors.node()[i], cScore = neighbors.score()[i], neighbors, sf = scoreProvider.scoreFunctionFor(cNode), selected, a)) continue;
                selected.set(i);
                ++nSelected;
            }
            if (a == 1.0f && max == this.maxConnections) {
                this.shortEdges = (float)nSelected / (float)this.maxConnections;
            }
            a += 0.2f;
        }
    }

    private boolean isDiverse(int node, float score, NodeArray others, ScoreFunction sf, BitSet selected, float alpha) {
        int otherNode;
        assert (others.size > 0);
        int i = selected.nextSetBit(0);
        while (i != Integer.MAX_VALUE && node != (otherNode = others.node()[i])) {
            if (sf.similarityTo(otherNode) > score * alpha) {
                return false;
            }
            i = selected.nextSetBit(i + 1);
        }
        return true;
    }

    private NodeArray removeAllNonDiverse(NodeArray neighbors, int diverseBefore) {
        if (neighbors.size <= this.maxConnections) {
            return neighbors;
        }
        NodeArray copy = neighbors.copy();
        this.retainDiverse(copy, diverseBefore, true);
        return copy;
    }

    NodeArray getCurrent() {
        return this.neighborsRef.get().nodes;
    }

    public void insert(int neighborId, float score, float overflow) {
        assert (neighborId != this.nodeId) : "can't add self as neighbor at node " + this.nodeId;
        this.neighborsRef.getAndUpdate(old -> {
            NodeArray nextNodes = old.nodes.copy();
            int insertionPoint = nextNodes.insertSorted(neighborId, score);
            if (insertionPoint == -1) {
                return old;
            }
            int nextDiverseBefore = Math.min(insertionPoint, old.diverseBefore);
            float hardMax = overflow * (float)this.maxConnections;
            if ((float)nextNodes.size > hardMax) {
                nextNodes = this.removeAllNonDiverse(nextNodes, nextDiverseBefore);
                nextDiverseBefore = nextNodes.size;
            }
            return new Neighbors(nextNodes, nextDiverseBefore);
        });
    }

    boolean contains(int i) {
        NodesIterator it = this.iterator();
        while (it.hasNext()) {
            if (it.nextInt() != i) continue;
            return true;
        }
        return false;
    }

    private static class Neighbors {
        public final NodeArray nodes;
        public final int diverseBefore;

        private Neighbors(NodeArray nodes, int diverseBefore) {
            this.nodes = nodes;
            this.diverseBefore = diverseBefore;
        }
    }

    private static class NeighborIterator
    extends NodesIterator {
        private final NodeArray neighbors;
        private int i;

        private NeighborIterator(NodeArray neighbors) {
            super(neighbors.size());
            this.neighbors = neighbors;
            this.i = 0;
        }

        @Override
        public boolean hasNext() {
            return this.i < this.neighbors.size();
        }

        @Override
        public int nextInt() {
            return this.neighbors.node[this.i++];
        }
    }
}

