/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.util;

import io.github.jbellis.jvector.util.BitSet;
import io.github.jbellis.jvector.util.RamUsageEstimator;
import java.util.concurrent.atomic.AtomicLongArray;

public class AtomicFixedBitSet
extends BitSet {
    private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(AtomicFixedBitSet.class);
    private final AtomicLongArray storage;

    public AtomicFixedBitSet(int numBits) {
        int numLongs = numBits + 63 >>> 6;
        this.storage = new AtomicLongArray(numLongs);
    }

    private static int index(int bit) {
        return bit >> 6;
    }

    private static long mask(int bit) {
        return 1L << bit;
    }

    @Override
    public int length() {
        return this.storage.length() << 6;
    }

    @Override
    public void set(int i) {
        int idx = AtomicFixedBitSet.index(i);
        long mask = AtomicFixedBitSet.mask(i);
        this.storage.getAndAccumulate(idx, mask, (prev, m) -> prev | m);
    }

    @Override
    public boolean get(int i) {
        if (i >= this.length()) {
            return false;
        }
        int idx = AtomicFixedBitSet.index(i);
        long mask = AtomicFixedBitSet.mask(i);
        long value = this.storage.get(idx);
        return (value & mask) != 0L;
    }

    @Override
    public boolean getAndSet(int i) {
        long mask;
        int idx = AtomicFixedBitSet.index(i);
        long prev = this.storage.getAndAccumulate(idx, mask = AtomicFixedBitSet.mask(i), (p, m) -> p | m);
        return (prev & mask) != 0L;
    }

    @Override
    public void clear() {
        for (int i = 0; i < this.storage.length(); ++i) {
            this.storage.set(i, 0L);
        }
    }

    @Override
    public void clear(int i) {
        if (i >= this.length()) {
            return;
        }
        int idx = AtomicFixedBitSet.index(i);
        long mask = AtomicFixedBitSet.mask(i);
        this.storage.getAndAccumulate(idx, mask, (prev, m) -> prev & (m ^ 0xFFFFFFFFFFFFFFFFL));
    }

    @Override
    public void clear(int startIndex, int endIndex) {
        if (endIndex <= startIndex) {
            return;
        }
        int startIdx = AtomicFixedBitSet.index(startIndex);
        int endIdx = AtomicFixedBitSet.index(endIndex - 1);
        long startMask = -1L << (startIndex & 0x3F);
        long endMask = -1L >>> -(endIndex & 0x3F);
        startMask ^= 0xFFFFFFFFFFFFFFFFL;
        endMask ^= 0xFFFFFFFFFFFFFFFFL;
        if (startIdx == endIdx) {
            this.storage.getAndAccumulate(startIdx, startMask | endMask, (prev, m) -> prev & m);
            return;
        }
        this.storage.getAndAccumulate(startIdx, startMask, (prev, m) -> prev & m);
        for (int i = startIdx + 1; i < endIdx; ++i) {
            this.storage.set(i, 0L);
        }
        this.storage.getAndAccumulate(endIdx, endMask, (prev, m) -> prev & m);
    }

    @Override
    public int cardinality() {
        int count = 0;
        for (int i = 0; i < this.storage.length(); ++i) {
            count += Long.bitCount(this.storage.get(i));
        }
        return count;
    }

    @Override
    public int approximateCardinality() {
        return this.cardinality();
    }

    @Override
    public int prevSetBit(int index) {
        assert (index >= 0 && index < this.length()) : "index=" + index + " length=" + this.length();
        int i = AtomicFixedBitSet.index(index);
        int subIndex = index & 0x3F;
        long word = this.storage.get(i) << 63 - subIndex;
        if (word != 0L) {
            return (i << 6) + subIndex - Long.numberOfLeadingZeros(word);
        }
        while (--i >= 0) {
            word = this.storage.get(i);
            if (word == 0L) continue;
            return (i << 6) + 63 - Long.numberOfLeadingZeros(word);
        }
        return -1;
    }

    @Override
    public int nextSetBit(int index) {
        assert (index >= 0 && index < this.length()) : "index=" + index + ", length=" + this.length();
        int i = AtomicFixedBitSet.index(index);
        if (i >= this.storage.length()) {
            return Integer.MAX_VALUE;
        }
        long word = this.storage.get(i) & -1L << (index & 0x3F);
        while (word == 0L) {
            if (++i >= this.storage.length()) {
                return Integer.MAX_VALUE;
            }
            word = this.storage.get(i);
        }
        return (i << 6) + Long.numberOfTrailingZeros(word);
    }

    @Override
    public long ramBytesUsed() {
        int longSizeInBytes = 8;
        int arrayOverhead = 16;
        long storageSize = (long)this.storage.length() * 8L + 16L;
        return BASE_RAM_BYTES_USED + storageSize;
    }
}

