/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.graph;

import io.github.jbellis.jvector.graph.NeighborArray;
import io.github.jbellis.jvector.graph.NeighborSimilarity;
import io.github.jbellis.jvector.graph.NodesIterator;
import io.github.jbellis.jvector.util.BitSet;
import io.github.jbellis.jvector.util.FixedBitSet;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;

public class ConcurrentNeighborSet {
    private final int nodeId;
    private final AtomicReference<ConcurrentNeighborArray> neighborsRef;
    private final float alpha;
    private final NeighborSimilarity similarity;
    private final int maxConnections;
    private float shortEdges = Float.NaN;

    public ConcurrentNeighborSet(int nodeId, int maxConnections, NeighborSimilarity similarity, float alpha) {
        this.nodeId = nodeId;
        this.maxConnections = maxConnections;
        this.similarity = similarity;
        this.neighborsRef = new AtomicReference<ConcurrentNeighborArray>(new ConcurrentNeighborArray(maxConnections, true));
        this.alpha = alpha;
    }

    public ConcurrentNeighborSet(int nodeId, int maxConnections, NeighborSimilarity similarity) {
        this(nodeId, maxConnections, similarity, 1.0f);
    }

    private ConcurrentNeighborSet(ConcurrentNeighborSet old) {
        this.nodeId = old.nodeId;
        this.maxConnections = old.maxConnections;
        this.similarity = old.similarity;
        this.alpha = old.alpha;
        this.neighborsRef = new AtomicReference<ConcurrentNeighborArray>(old.neighborsRef.get());
    }

    public float getShortEdges() {
        return this.shortEdges;
    }

    public NodesIterator iterator() {
        return new NeighborIterator(this.neighborsRef.get());
    }

    public void backlink(Function<Integer, ConcurrentNeighborSet> neighborhoodOf, float overflow) {
        NeighborArray neighbors = this.neighborsRef.get();
        for (int i = 0; i < neighbors.size(); ++i) {
            int nbr = neighbors.node[i];
            float nbrScore = neighbors.score[i];
            ConcurrentNeighborSet nbrNbr = neighborhoodOf.apply(nbr);
            nbrNbr.insert(this.nodeId, nbrScore, overflow);
        }
    }

    public void cleanup() {
        this.neighborsRef.getAndUpdate(this::removeAllNonDiverse);
    }

    public int size() {
        return this.neighborsRef.get().size();
    }

    public int arrayLength() {
        return this.neighborsRef.get().node.length;
    }

    public void insertDiverse(NeighborArray natural, NeighborArray concurrent) {
        if (natural.size() == 0 && concurrent.size() == 0) {
            return;
        }
        assert (natural.scoresDescOrder);
        assert (concurrent.scoresDescOrder);
        this.neighborsRef.getAndUpdate(current -> {
            ConcurrentNeighborArray merged = ConcurrentNeighborSet.mergeNeighbors(ConcurrentNeighborSet.mergeNeighbors(natural, current), concurrent);
            BitSet selected = this.selectDiverse(merged);
            merged.retain(selected);
            return merged;
        });
    }

    private ConcurrentNeighborArray copyDiverse(NeighborArray merged, BitSet selected) {
        ConcurrentNeighborArray next = new ConcurrentNeighborArray(this.maxConnections, true);
        for (int i = 0; i < merged.size(); ++i) {
            if (!selected.get(i)) continue;
            int node = merged.node()[i];
            assert (node != this.nodeId) : "can't add self as neighbor at node " + this.nodeId;
            float score = merged.score()[i];
            next.addInOrder(node, score);
        }
        assert (next.size <= this.maxConnections);
        return next;
    }

    private BitSet selectDiverse(NeighborArray neighbors) {
        FixedBitSet selected = new FixedBitSet(neighbors.size());
        int nSelected = 0;
        float a = 1.0f;
        while ((double)a <= (double)this.alpha + 1.0E-6 && nSelected < this.maxConnections) {
            for (int i = 0; i < neighbors.size() && nSelected < this.maxConnections; ++i) {
                float cScore;
                int cNode;
                if (selected.get(i) || !this.isDiverse(cNode = neighbors.node()[i], cScore = neighbors.score()[i], neighbors, selected, a)) continue;
                ((BitSet)selected).set(i);
                ++nSelected;
            }
            if (a == 1.0f) {
                this.shortEdges = (float)nSelected / (float)this.maxConnections;
            }
            a += 0.2f;
        }
        return selected;
    }

    public ConcurrentNeighborArray getCurrent() {
        return this.neighborsRef.get();
    }

    static ConcurrentNeighborArray mergeNeighbors(NeighborArray a1, NeighborArray a2) {
        assert (a1.scoresDescOrder);
        assert (a2.scoresDescOrder);
        ConcurrentNeighborArray merged = new ConcurrentNeighborArray(a1.size() + a2.size(), true);
        int i = 0;
        int j = 0;
        while (i < a1.size() && j < a2.size()) {
            if (a1.score()[i] < a2.score[j]) {
                merged.addInOrder(a2.node[j], a2.score[j]);
                ++j;
                continue;
            }
            if (a1.score()[i] > a2.score[j]) {
                merged.addInOrder(a1.node()[i], a1.score()[i]);
                ++i;
                continue;
            }
            merged.addInOrder(a1.node()[i], a1.score()[i]);
            if (a2.node[j] != a1.node()[i]) {
                merged.addInOrder(a2.node[j], a2.score[j]);
            }
            ++i;
            ++j;
        }
        while (i < a1.size()) {
            if (j > 0 && i < a1.size() && a1.node()[i] == a2.node[j - 1]) {
                ++i;
                continue;
            }
            merged.addInOrder(a1.node()[i], a1.score()[i]);
            ++i;
        }
        while (j < a2.size()) {
            if (i > 0 && j < a2.size() && a2.node[j] == a1.node()[i - 1]) {
                ++j;
                continue;
            }
            merged.addInOrder(a2.node[j], a2.score[j]);
            ++j;
        }
        return merged;
    }

    public void insert(int neighborId, float score, float overflow) {
        assert (neighborId != this.nodeId) : "can't add self as neighbor at node " + this.nodeId;
        this.neighborsRef.getAndUpdate(current -> {
            ConcurrentNeighborArray next = current.copy();
            next.insertSorted(neighborId, score);
            float hardMax = overflow * (float)this.maxConnections;
            if ((float)next.size > hardMax) {
                next = this.removeAllNonDiverse(next);
            }
            return next;
        });
    }

    public void insert(int neighborId, float score) {
        this.insert(neighborId, score, 1.0f);
    }

    private boolean isDiverse(int node, float score, NeighborArray others, BitSet selected, float alpha) {
        int otherNode;
        if (others.size() == 0) {
            return true;
        }
        NeighborSimilarity.ScoreFunction scoreProvider = this.similarity.scoreProvider(node);
        int i = selected.nextSetBit(0);
        while (i != Integer.MAX_VALUE && node != (otherNode = others.node()[i])) {
            if (scoreProvider.similarityTo(otherNode) > score * alpha) {
                return false;
            }
            if (i + 1 >= selected.length()) break;
            i = selected.nextSetBit(i + 1);
        }
        return true;
    }

    private ConcurrentNeighborArray removeAllNonDiverse(ConcurrentNeighborArray neighbors) {
        if (neighbors.size <= this.maxConnections) {
            return neighbors;
        }
        BitSet selected = this.selectDiverse(neighbors);
        return this.copyDiverse(neighbors, selected);
    }

    public ConcurrentNeighborSet copy() {
        return new ConcurrentNeighborSet(this);
    }

    boolean contains(int i) {
        NodesIterator it = this.iterator();
        while (it.hasNext()) {
            if (it.nextInt() != i) continue;
            return true;
        }
        return false;
    }

    static class ConcurrentNeighborArray
    extends NeighborArray {
        public ConcurrentNeighborArray(int maxSize, boolean descOrder) {
            super(maxSize, descOrder);
        }

        @Override
        public void insertSorted(int newNode, float newScore) {
            int insertionPoint;
            if (this.size == this.node.length) {
                this.growArrays();
            }
            int n = insertionPoint = this.scoresDescOrder ? this.descSortFindRightMostInsertionPoint(newScore) : this.ascSortFindRightMostInsertionPoint(newScore);
            if (!this.duplicateExistsNear(insertionPoint, newNode, newScore)) {
                System.arraycopy(this.node, insertionPoint, this.node, insertionPoint + 1, this.size - insertionPoint);
                System.arraycopy(this.score, insertionPoint, this.score, insertionPoint + 1, this.size - insertionPoint);
                this.node[insertionPoint] = newNode;
                this.score[insertionPoint] = newScore;
                ++this.size;
            }
        }

        private boolean duplicateExistsNear(int insertionPoint, int newNode, float newScore) {
            int i;
            for (i = insertionPoint - 1; i >= 0 && this.score[i] == newScore; --i) {
                if (this.node[i] != newNode) continue;
                return true;
            }
            for (i = insertionPoint; i < this.size && this.score[i] == newScore; ++i) {
                if (this.node[i] != newNode) continue;
                return true;
            }
            return false;
        }

        public void retain(BitSet selected) {
            int writeIdx = 0;
            for (int readIdx = 0; readIdx < this.size; ++readIdx) {
                if (!selected.get(readIdx)) continue;
                if (writeIdx != readIdx) {
                    this.node[writeIdx] = this.node[readIdx];
                    this.score[writeIdx] = this.score[readIdx];
                }
                ++writeIdx;
            }
            this.size = writeIdx;
        }

        public ConcurrentNeighborArray copy() {
            ConcurrentNeighborArray copy = new ConcurrentNeighborArray(this.node.length, this.scoresDescOrder);
            copy.size = this.size;
            System.arraycopy(this.node, 0, copy.node, 0, this.size);
            System.arraycopy(this.score, 0, copy.score, 0, this.size);
            return copy;
        }
    }

    private static class NeighborIterator
    extends NodesIterator {
        private final NeighborArray neighbors;
        private int i;

        private NeighborIterator(NeighborArray neighbors) {
            super(neighbors.size());
            this.neighbors = neighbors;
            this.i = 0;
        }

        @Override
        public boolean hasNext() {
            return this.i < this.neighbors.size();
        }

        @Override
        public int nextInt() {
            return this.neighbors.node[this.i++];
        }
    }
}

