/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.quantization;

import io.github.jbellis.jvector.quantization.NVQuantization;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorUtil;
import io.github.jbellis.jvector.vector.types.VectorFloat;

public class NVQScorer {
    final NVQuantization nvq;

    public NVQScorer(NVQuantization nvq) {
        this.nvq = nvq;
    }

    public NVQScoreFunction scoreFunctionFor(VectorFloat<?> query, VectorSimilarityFunction similarityFunction) {
        switch (similarityFunction) {
            case DOT_PRODUCT: {
                return this.dotProductScoreFunctionFor(query);
            }
            case EUCLIDEAN: {
                return this.euclideanScoreFunctionFor(query);
            }
            case COSINE: {
                return this.cosineScoreFunctionFor(query);
            }
        }
        throw new IllegalArgumentException("Unsupported similarity function " + String.valueOf((Object)similarityFunction));
    }

    private NVQScoreFunction dotProductScoreFunctionFor(VectorFloat<?> query) {
        float queryGlobalBias = VectorUtil.dotProduct(query, this.nvq.globalMean);
        VectorFloat[] querySubVectors = this.nvq.getSubVectors(query);
        switch (this.nvq.bitsPerDimension) {
            case EIGHT: {
                for (VectorFloat<?> vectorFloat : querySubVectors) {
                    VectorUtil.nvqShuffleQueryInPlace8bit(vectorFloat);
                }
                return vector2 -> {
                    float nvqDot = 0.0f;
                    for (int i = 0; i < querySubVectors.length; ++i) {
                        NVQuantization.QuantizedSubVector svDB = vector2.subVectors[i];
                        nvqDot += VectorUtil.nvqDotProduct8bit(querySubVectors[i], svDB.bytes, svDB.growthRate, svDB.midpoint, svDB.minValue, svDB.maxValue);
                    }
                    return (1.0f + nvqDot + queryGlobalBias) / 2.0f;
                };
            }
        }
        throw new IllegalArgumentException("Unsupported bits per dimension " + String.valueOf((Object)this.nvq.bitsPerDimension));
    }

    private NVQScoreFunction euclideanScoreFunctionFor(VectorFloat<?> query) {
        VectorFloat<?> shiftedQuery = VectorUtil.sub(query, this.nvq.globalMean);
        VectorFloat[] querySubVectors = this.nvq.getSubVectors(shiftedQuery);
        switch (this.nvq.bitsPerDimension) {
            case EIGHT: {
                for (VectorFloat querySubVector : querySubVectors) {
                    VectorUtil.nvqShuffleQueryInPlace8bit(querySubVector);
                }
                return vector2 -> {
                    float dist = 0.0f;
                    for (int i = 0; i < querySubVectors.length; ++i) {
                        NVQuantization.QuantizedSubVector svDB = vector2.subVectors[i];
                        dist += VectorUtil.nvqSquareL2Distance8bit(querySubVectors[i], svDB.bytes, svDB.growthRate, svDB.midpoint, svDB.minValue, svDB.maxValue);
                    }
                    return 1.0f / (1.0f + dist);
                };
            }
        }
        throw new IllegalArgumentException("Unsupported bits per dimension " + String.valueOf((Object)this.nvq.bitsPerDimension));
    }

    private NVQScoreFunction cosineScoreFunctionFor(VectorFloat<?> query) {
        float queryNorm = (float)Math.sqrt(VectorUtil.dotProduct(query, query));
        VectorFloat[] querySubVectors = this.nvq.getSubVectors(query);
        VectorFloat[] meanSubVectors = this.nvq.getSubVectors(this.nvq.globalMean);
        switch (this.nvq.bitsPerDimension) {
            case EIGHT: {
                for (int i = 0; i < querySubVectors.length; ++i) {
                    VectorUtil.nvqShuffleQueryInPlace8bit(querySubVectors[i]);
                    VectorUtil.nvqShuffleQueryInPlace8bit(meanSubVectors[i]);
                }
                return vector2 -> {
                    float cos = 0.0f;
                    float squaredNormalization = 0.0f;
                    for (int i = 0; i < querySubVectors.length; ++i) {
                        NVQuantization.QuantizedSubVector svDB = vector2.subVectors[i];
                        float[] partialCosSim = VectorUtil.nvqCosine8bit(querySubVectors[i], svDB.bytes, svDB.growthRate, svDB.midpoint, svDB.minValue, svDB.maxValue, meanSubVectors[i]);
                        cos += partialCosSim[0];
                        squaredNormalization += partialCosSim[1];
                    }
                    float cosine = cos / queryNorm / (float)Math.sqrt(squaredNormalization);
                    return (1.0f + cosine) / 2.0f;
                };
            }
        }
        throw new IllegalArgumentException("Unsupported bits per dimension " + String.valueOf((Object)this.nvq.bitsPerDimension));
    }

    public static interface NVQScoreFunction {
        public float similarityTo(NVQuantization.QuantizedVector var1);
    }
}

