/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.quantization;

import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.quantization.PQVectors;
import io.github.jbellis.jvector.quantization.ProductQuantization;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorUtil;
import io.github.jbellis.jvector.vector.VectorizationProvider;
import io.github.jbellis.jvector.vector.types.ByteSequence;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;

abstract class PQDecoder
implements ScoreFunction.ApproximateScoreFunction {
    private static final VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport();
    protected final PQVectors cv;

    protected PQDecoder(PQVectors cv) {
        this.cv = cv;
    }

    static class CosineDecoder
    extends PQDecoder {
        protected final VectorFloat<?> partialSums;
        protected final VectorFloat<?> aMagnitude;
        protected final float bMagnitude;

        public CosineDecoder(PQVectors cv, VectorFloat<?> query) {
            super(cv);
            ProductQuantization pq = this.cv.pq;
            this.aMagnitude = cv.partialSquaredMagnitudes().updateAndGet(current -> {
                if (current != null) {
                    return current;
                }
                VectorFloat<?> partialMagnitudes = vts.createFloatVector(pq.getSubspaceCount() * pq.getClusterCount());
                for (int m = 0; m < pq.getSubspaceCount(); ++m) {
                    int size = pq.subvectorSizesAndOffsets[m][0];
                    VectorFloat<?> codebook = pq.codebooks[m];
                    for (int j = 0; j < pq.getClusterCount(); ++j) {
                        partialMagnitudes.set(m * pq.getClusterCount() + j, VectorUtil.dotProduct(codebook, j * size, codebook, j * size, size));
                    }
                }
                return partialMagnitudes;
            });
            this.partialSums = cv.reusablePartialSums();
            VectorFloat<?> center = pq.globalCentroid;
            VectorFloat<?> centeredQuery = center == null ? query : VectorUtil.sub(query, center);
            for (int m = 0; m < pq.getSubspaceCount(); ++m) {
                int offset = pq.subvectorSizesAndOffsets[m][1];
                int size = pq.subvectorSizesAndOffsets[m][0];
                VectorFloat<?> codebook = pq.codebooks[m];
                for (int j = 0; j < pq.getClusterCount(); ++j) {
                    this.partialSums.set(m * pq.getClusterCount() + j, VectorUtil.dotProduct(codebook, j * size, centeredQuery, offset, size));
                }
            }
            this.bMagnitude = VectorUtil.dotProduct(centeredQuery, centeredQuery);
        }

        @Override
        public float similarityTo(int node2) {
            return (1.0f + this.decodedCosine(node2)) / 2.0f;
        }

        protected float decodedCosine(int node2) {
            ByteSequence<?> encoded = this.cv.get(node2);
            return VectorUtil.pqDecodedCosineSimilarity(encoded, this.cv.pq.getClusterCount(), this.partialSums, this.aMagnitude, this.bMagnitude);
        }
    }

    static class EuclideanDecoder
    extends CachingDecoder {
        public EuclideanDecoder(PQVectors cv, VectorFloat<?> query) {
            super(cv, query, VectorSimilarityFunction.EUCLIDEAN);
        }

        @Override
        public float similarityTo(int node2) {
            return 1.0f / (1.0f + this.decodedSimilarity(this.cv.get(node2)));
        }
    }

    static class DotProductDecoder
    extends CachingDecoder {
        public DotProductDecoder(PQVectors cv, VectorFloat<?> query) {
            super(cv, query, VectorSimilarityFunction.DOT_PRODUCT);
        }

        @Override
        public float similarityTo(int node2) {
            return (1.0f + this.decodedSimilarity(this.cv.get(node2))) / 2.0f;
        }
    }

    protected static abstract class CachingDecoder
    extends PQDecoder {
        protected final VectorFloat<?> partialSums;

        protected CachingDecoder(PQVectors cv, VectorFloat<?> query, VectorSimilarityFunction vsf) {
            super(cv);
            ProductQuantization pq = this.cv.pq;
            this.partialSums = cv.reusablePartialSums();
            VectorFloat<?> center = pq.globalCentroid;
            VectorFloat<?> centeredQuery = center == null ? query : VectorUtil.sub(query, center);
            for (int i = 0; i < pq.getSubspaceCount(); ++i) {
                int offset = pq.subvectorSizesAndOffsets[i][1];
                int size = pq.subvectorSizesAndOffsets[i][0];
                VectorFloat<?> codebook = pq.codebooks[i];
                VectorUtil.calculatePartialSums(codebook, i, size, pq.getClusterCount(), centeredQuery, offset, vsf, this.partialSums);
            }
        }

        protected float decodedSimilarity(ByteSequence<?> encoded) {
            return VectorUtil.assembleAndSum(this.partialSums, this.cv.pq.getClusterCount(), encoded);
        }
    }
}

