/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.pq;

import io.github.jbellis.jvector.graph.disk.FusedADCNeighbors;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.pq.ProductQuantization;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorUtil;
import io.github.jbellis.jvector.vector.types.ByteSequence;
import io.github.jbellis.jvector.vector.types.VectorFloat;

public abstract class QuickADCPQDecoder
implements ScoreFunction.ApproximateScoreFunction {
    protected final ProductQuantization pq;
    protected final VectorFloat<?> query;
    protected final ScoreFunction.ExactScoreFunction esf;

    protected QuickADCPQDecoder(ProductQuantization pq, VectorFloat<?> query, ScoreFunction.ExactScoreFunction esf) {
        this.pq = pq;
        this.query = query;
        this.esf = esf;
    }

    public static QuickADCPQDecoder newDecoder(FusedADCNeighbors neighbors, ProductQuantization pq, VectorFloat<?> query, VectorFloat<?> results, VectorSimilarityFunction similarityFunction, ScoreFunction.ExactScoreFunction esf) {
        switch (similarityFunction) {
            case DOT_PRODUCT: {
                return new DotProductDecoder(neighbors, pq, query, results, esf);
            }
            case EUCLIDEAN: {
                return new EuclideanDecoder(neighbors, pq, query, results, esf);
            }
        }
        throw new IllegalArgumentException("Unsupported similarity function " + String.valueOf((Object)similarityFunction));
    }

    static class DotProductDecoder
    extends CachingDecoder {
        private final VectorFloat<?> results;
        private final FusedADCNeighbors neighbors;

        public DotProductDecoder(FusedADCNeighbors neighbors, ProductQuantization pq, VectorFloat<?> query, VectorFloat<?> results, ScoreFunction.ExactScoreFunction esf) {
            super(pq, query, VectorSimilarityFunction.DOT_PRODUCT, esf);
            this.neighbors = neighbors;
            this.results = results;
        }

        @Override
        public float similarityTo(int node2) {
            return this.esf.similarityTo(node2);
        }

        @Override
        public VectorFloat<?> edgeLoadingSimilarityTo(int origin) {
            ByteSequence<?> permutedNodes = this.neighbors.getPackedNeighbors(origin);
            this.results.zero();
            VectorUtil.bulkShuffleSimilarity(permutedNodes, this.pq.compressedVectorSize(), this.partialSums, this.results, VectorSimilarityFunction.DOT_PRODUCT);
            return this.results;
        }

        @Override
        public boolean supportsEdgeLoadingSimilarity() {
            return true;
        }
    }

    static class EuclideanDecoder
    extends CachingDecoder {
        private final FusedADCNeighbors neighbors;
        private final VectorFloat<?> results;

        public EuclideanDecoder(FusedADCNeighbors neighbors, ProductQuantization pq, VectorFloat<?> query, VectorFloat<?> results, ScoreFunction.ExactScoreFunction esf) {
            super(pq, query, VectorSimilarityFunction.EUCLIDEAN, esf);
            this.neighbors = neighbors;
            this.results = results;
        }

        @Override
        public float similarityTo(int node2) {
            return this.esf.similarityTo(node2);
        }

        @Override
        public VectorFloat<?> edgeLoadingSimilarityTo(int origin) {
            ByteSequence<?> permutedNodes = this.neighbors.getPackedNeighbors(origin);
            this.results.zero();
            VectorUtil.bulkShuffleSimilarity(permutedNodes, this.pq.compressedVectorSize(), this.partialSums, this.results, VectorSimilarityFunction.EUCLIDEAN);
            return this.results;
        }

        @Override
        public boolean supportsEdgeLoadingSimilarity() {
            return true;
        }
    }

    protected static abstract class CachingDecoder
    extends QuickADCPQDecoder {
        protected final VectorFloat<?> partialSums;

        protected CachingDecoder(ProductQuantization pq, VectorFloat<?> query, VectorSimilarityFunction vsf, ScoreFunction.ExactScoreFunction esf) {
            super(pq, query, esf);
            this.partialSums = pq.reusablePartialSums();
            VectorFloat<?> center = pq.globalCentroid;
            VectorFloat<?> centeredQuery = center == null ? query : VectorUtil.sub(query, center);
            for (int i = 0; i < pq.getSubspaceCount(); ++i) {
                int offset = pq.subvectorSizesAndOffsets[i][1];
                int size = pq.subvectorSizesAndOffsets[i][0];
                int baseOffset = i * pq.getClusterCount();
                VectorFloat<?> codebook = pq.codebooks[i];
                VectorUtil.calculatePartialSums(codebook, baseOffset, size, pq.getClusterCount(), centeredQuery, offset, vsf, this.partialSums);
            }
        }
    }
}

