/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.pq;

import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.NeighborSimilarity;
import io.github.jbellis.jvector.pq.CompressedDecoder;
import io.github.jbellis.jvector.pq.ProductQuantization;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.IntStream;

public class CompressedVectors {
    final ProductQuantization pq;
    private final List<byte[]> compressedVectors;
    private final ThreadLocal<float[][]> partialSums;
    private final ThreadLocal<float[][]> partialMagnitudes;

    public CompressedVectors(ProductQuantization pq, List<byte[]> compressedVectors) {
        this.pq = pq;
        this.compressedVectors = compressedVectors;
        this.partialSums = ThreadLocal.withInitial(() -> CompressedVectors.initFloatFragments(pq));
        this.partialMagnitudes = ThreadLocal.withInitial(() -> CompressedVectors.initFloatFragments(pq));
    }

    private static float[][] initFloatFragments(ProductQuantization pq) {
        float[][] a = new float[pq.getSubspaceCount()][];
        for (int i = 0; i < a.length; ++i) {
            a[i] = new float[256];
        }
        return a;
    }

    public void write(DataOutput out) throws IOException {
        this.pq.write(out);
        out.writeInt(this.compressedVectors.size());
        out.writeInt(this.pq.getSubspaceCount());
        for (byte[] v : this.compressedVectors) {
            out.write(v);
        }
    }

    public static CompressedVectors load(RandomAccessReader in, long offset) throws IOException {
        in.seek(offset);
        ProductQuantization pq = ProductQuantization.load(in);
        int size = in.readInt();
        ArrayList<byte[]> compressedVectors = new ArrayList<byte[]>(size);
        int compressedDimension = in.readInt();
        for (int i = 0; i < size; ++i) {
            byte[] vector = new byte[compressedDimension];
            in.readFully(vector);
            compressedVectors.add(vector);
        }
        return new CompressedVectors(pq, compressedVectors);
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        CompressedVectors that = (CompressedVectors)o;
        if (!Objects.equals(this.pq, that.pq)) {
            return false;
        }
        if (this.compressedVectors.size() != that.compressedVectors.size()) {
            return false;
        }
        return IntStream.range(0, this.compressedVectors.size()).allMatch(i -> Arrays.equals(this.compressedVectors.get(i), that.compressedVectors.get(i)));
    }

    public int hashCode() {
        return Objects.hash(this.pq, this.compressedVectors);
    }

    public NeighborSimilarity.ApproximateScoreFunction approximateScoreFunctionFor(float[] q, VectorSimilarityFunction similarityFunction) {
        switch (similarityFunction) {
            case DOT_PRODUCT: {
                return new CompressedDecoder.DotProductDecoder(this, q);
            }
            case EUCLIDEAN: {
                return new CompressedDecoder.EuclideanDecoder(this, q);
            }
            case COSINE: {
                return new CompressedDecoder.CosineDecoder(this, q);
            }
        }
        throw new IllegalArgumentException("Unsupported similarity function " + String.valueOf((Object)similarityFunction));
    }

    public byte[] get(int ordinal) {
        return this.compressedVectors.get(ordinal);
    }

    float[][] reusablePartialSums() {
        return this.partialSums.get();
    }

    float[][] reusablePartialMagnitudes() {
        return this.partialMagnitudes.get();
    }
}

