/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.pq;

import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.pq.BQVectors;
import io.github.jbellis.jvector.pq.CompressedVectors;
import io.github.jbellis.jvector.pq.KMeansPlusPlusClusterer;
import io.github.jbellis.jvector.pq.ProductQuantization;
import io.github.jbellis.jvector.pq.VectorCompressor;
import io.github.jbellis.jvector.vector.VectorUtil;
import io.github.jbellis.jvector.vector.VectorizationProvider;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
import java.io.DataOutput;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.stream.IntStream;

public class BinaryQuantization
implements VectorCompressor<long[]> {
    private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport();
    private final VectorFloat<?> globalCentroid;

    public BinaryQuantization(VectorFloat<?> globalCentroid) {
        this.globalCentroid = globalCentroid;
    }

    public static BinaryQuantization compute(RandomAccessVectorValues ravv) {
        return BinaryQuantization.compute(ravv, ForkJoinPool.commonPool());
    }

    public static BinaryQuantization compute(RandomAccessVectorValues ravv, ForkJoinPool parallelExecutor) {
        List<VectorFloat<?>> vectors = ProductQuantization.extractTrainingVectors(ravv, parallelExecutor);
        VectorFloat<?> globalCentroid = KMeansPlusPlusClusterer.centroidOf(vectors);
        return new BinaryQuantization(globalCentroid);
    }

    @Override
    public CompressedVectors createCompressedVectors(Object[] compressedVectors) {
        return new BQVectors(this, (long[][])compressedVectors);
    }

    public long[][] encodeAll(RandomAccessVectorValues ravv, ForkJoinPool simdExecutor) {
        return (long[][])((ForkJoinTask)simdExecutor.submit(() -> (long[][])IntStream.range(0, ravv.size()).parallel().mapToObj(i -> this.encode(ravv.getVector(i))).toArray(x$0 -> new long[x$0][]))).join();
    }

    @Override
    public long[] encode(VectorFloat<?> v) {
        VectorFloat<?> centered = VectorUtil.sub(v, this.globalCentroid);
        int M = (int)Math.ceil((double)centered.length() / 64.0);
        long[] encoded = new long[M];
        for (int i = 0; i < M; ++i) {
            int idx;
            long bits = 0L;
            for (int j = 0; j < 64 && (idx = i * 64 + j) < centered.length(); ++j) {
                if (!(centered.get(idx) > 0.0f)) continue;
                bits |= 1L << j;
            }
            encoded[i] = bits;
        }
        return encoded;
    }

    @Override
    public int compressorSize() {
        return 4 + this.globalCentroid.length() * 4;
    }

    @Override
    public int compressedVectorSize() {
        int M = (int)Math.ceil((double)this.globalCentroid.length() / 64.0);
        return 8 * M;
    }

    @Override
    public void write(DataOutput out, int version) throws IOException {
        out.writeInt(this.globalCentroid.length());
        vectorTypeSupport.writeFloatVector(out, this.globalCentroid);
    }

    public int getOriginalDimension() {
        return this.globalCentroid.length();
    }

    public static BinaryQuantization load(RandomAccessReader in) throws IOException {
        int length = in.readInt();
        VectorFloat<?> centroid = vectorTypeSupport.readFloatVector(in, length);
        return new BinaryQuantization(centroid);
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        BinaryQuantization that = (BinaryQuantization)o;
        return Objects.equals(this.globalCentroid, that.globalCentroid);
    }

    public int hashCode() {
        return Objects.hashCode(this.globalCentroid);
    }

    public String toString() {
        return "BinaryQuantization";
    }
}

