/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.pq;

import io.github.jbellis.jvector.disk.Io;
import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.pq.BQVectors;
import io.github.jbellis.jvector.pq.CompressedVectors;
import io.github.jbellis.jvector.pq.KMeansPlusPlusClusterer;
import io.github.jbellis.jvector.pq.VectorCompressor;
import io.github.jbellis.jvector.util.PoolingSupport;
import io.github.jbellis.jvector.vector.VectorUtil;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

public class BinaryQuantization
implements VectorCompressor<long[]> {
    private final float[] globalCentroid;

    public BinaryQuantization(float[] globalCentroid) {
        this.globalCentroid = globalCentroid;
    }

    public static BinaryQuantization compute(RandomAccessVectorValues<float[]> ravv) {
        return BinaryQuantization.compute(ravv, ForkJoinPool.commonPool());
    }

    public static BinaryQuantization compute(RandomAccessVectorValues<float[]> ravv, ForkJoinPool parallelExecutor) {
        float P = Math.min(1.0f, 128000.0f / (float)ravv.size());
        PoolingSupport<RandomAccessVectorValues<float[]>> ravvCopy = ravv.isValueShared() ? PoolingSupport.newThreadBased(ravv::copy) : PoolingSupport.newNoPooling(ravv);
        List vectors = (List)((ForkJoinTask)parallelExecutor.submit(() -> IntStream.range(0, ravv.size()).parallel().filter(i -> ThreadLocalRandom.current().nextFloat() < P).mapToObj(targetOrd -> {
            try (PoolingSupport.Pooled pooledRavv = ravvCopy.get();){
                RandomAccessVectorValues localRavv = (RandomAccessVectorValues)pooledRavv.get();
                float[] v = (float[])localRavv.vectorValue(targetOrd);
                float[] fArray = localRavv.isValueShared() ? Arrays.copyOf(v, v.length) : v;
                return fArray;
            }
        }).collect(Collectors.toList()))).join();
        float[] globalCentroid = KMeansPlusPlusClusterer.centroidOf(vectors);
        return new BinaryQuantization(globalCentroid);
    }

    @Override
    public CompressedVectors createCompressedVectors(Object[] compressedVectors) {
        return new BQVectors(this, (long[][])compressedVectors);
    }

    public long[][] encodeAll(List<float[]> vectors, ForkJoinPool simdExecutor) {
        return (long[][])((ForkJoinTask)simdExecutor.submit(() -> (long[][])((Stream)vectors.stream().parallel()).map(this::encode).toArray(x$0 -> new long[x$0][]))).join();
    }

    @Override
    public long[] encode(float[] v) {
        float[] centered = VectorUtil.sub(v, this.globalCentroid);
        int M = (int)Math.ceil((double)centered.length / 64.0);
        long[] encoded = new long[M];
        for (int i = 0; i < M; ++i) {
            int idx;
            long bits = 0L;
            for (int j = 0; j < 64 && (idx = i * 64 + j) < centered.length; ++j) {
                if (!(centered[idx] > 0.0f)) continue;
                bits |= 1L << j;
            }
            encoded[i] = bits;
        }
        return encoded;
    }

    @Override
    public void write(DataOutput out) throws IOException {
        out.writeInt(this.globalCentroid.length);
        Io.writeFloats(out, this.globalCentroid);
    }

    public int getOriginalDimension() {
        return this.globalCentroid.length;
    }

    public static BinaryQuantization load(RandomAccessReader in) throws IOException {
        int length = in.readInt();
        float[] centroid = new float[length];
        in.readFully(centroid);
        return new BinaryQuantization(centroid);
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        BinaryQuantization that = (BinaryQuantization)o;
        return Arrays.equals(this.globalCentroid, that.globalCentroid);
    }

    public int hashCode() {
        return Arrays.hashCode(this.globalCentroid);
    }

    public String toString() {
        return "BinaryQuantization";
    }
}

