/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.simdvec.internal;

import java.io.IOException;
import java.lang.foreign.MemorySegment;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.MemorySegmentAccessInput;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity;
import org.elasticsearch.simdvec.internal.Similarities;

public abstract sealed class Int7SQVectorScorerSupplier
implements RandomVectorScorerSupplier {
    static final byte BITS = 7;
    final int dims;
    final int maxOrd;
    final float scoreCorrectionConstant;
    final MemorySegmentAccessInput input;
    final QuantizedByteVectorValues values;
    final ScalarQuantizedVectorSimilarity fallbackScorer;

    protected Int7SQVectorScorerSupplier(MemorySegmentAccessInput input, QuantizedByteVectorValues values, float scoreCorrectionConstant, ScalarQuantizedVectorSimilarity fallbackScorer) {
        this.input = input;
        this.values = values;
        this.dims = values.dimension();
        this.maxOrd = values.size();
        this.scoreCorrectionConstant = scoreCorrectionConstant;
        this.fallbackScorer = fallbackScorer;
    }

    protected final void checkOrdinal(int ord) {
        if (ord < 0 || ord > this.maxOrd) {
            throw new IllegalArgumentException("illegal ordinal: " + ord);
        }
    }

    final float scoreFromOrds(int firstOrd, int secondOrd) throws IOException {
        int length = this.dims;
        long firstByteOffset = (long)firstOrd * (long)(length + 4);
        long secondByteOffset = (long)secondOrd * (long)(length + 4);
        MemorySegment firstSeg = this.input.segmentSliceOrNull(firstByteOffset, (long)length);
        if (firstSeg == null) {
            return this.fallbackScore(firstByteOffset, secondByteOffset);
        }
        float firstOffset = Float.intBitsToFloat(this.input.readInt(firstByteOffset + (long)length));
        MemorySegment secondSeg = this.input.segmentSliceOrNull(secondByteOffset, (long)length);
        if (secondSeg == null) {
            return this.fallbackScore(firstByteOffset, secondByteOffset);
        }
        float secondOffset = Float.intBitsToFloat(this.input.readInt(secondByteOffset + (long)length));
        return this.scoreFromSegments(firstSeg, firstOffset, secondSeg, secondOffset);
    }

    abstract float scoreFromSegments(MemorySegment var1, float var2, MemorySegment var3, float var4);

    protected final float fallbackScore(long firstByteOffset, long secondByteOffset) throws IOException {
        byte[] a = new byte[this.dims];
        this.input.readBytes(firstByteOffset, a, 0, a.length);
        float aOffsetValue = Float.intBitsToFloat(this.input.readInt(firstByteOffset + (long)this.dims));
        byte[] b = new byte[this.dims];
        this.input.readBytes(secondByteOffset, b, 0, a.length);
        float bOffsetValue = Float.intBitsToFloat(this.input.readInt(secondByteOffset + (long)this.dims));
        return this.fallbackScorer.score(a, aOffsetValue, b, bOffsetValue);
    }

    public UpdateableRandomVectorScorer scorer() {
        return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer((KnnVectorValues)this.values){
            private int ord;
            {
                this.ord = -1;
            }

            public float score(int node) throws IOException {
                Int7SQVectorScorerSupplier.this.checkOrdinal(node);
                return Int7SQVectorScorerSupplier.this.scoreFromOrds(this.ord, node);
            }

            public void setScoringOrdinal(int node) throws IOException {
                Int7SQVectorScorerSupplier.this.checkOrdinal(node);
                this.ord = node;
            }
        };
    }

    static boolean checkIndex(long index, long length) {
        return index >= 0L && index < length;
    }

    public static final class MaxInnerProductSupplier
    extends Int7SQVectorScorerSupplier {
        public MaxInnerProductSupplier(MemorySegmentAccessInput input, QuantizedByteVectorValues values, float scoreCorrectionConstant) {
            super(input, values, scoreCorrectionConstant, ScalarQuantizedVectorSimilarity.fromVectorSimilarity((VectorSimilarityFunction)VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT, (float)scoreCorrectionConstant, (byte)7));
        }

        @Override
        float scoreFromSegments(MemorySegment a, float aOffset, MemorySegment b, float bOffset) {
            int dotProduct = Similarities.dotProduct7u(a, b, this.dims);
            assert (dotProduct >= 0);
            float adjustedDistance = (float)dotProduct * this.scoreCorrectionConstant + aOffset + bOffset;
            if (adjustedDistance < 0.0f) {
                return 1.0f / (1.0f + -1.0f * adjustedDistance);
            }
            return adjustedDistance + 1.0f;
        }

        public MaxInnerProductSupplier copy() {
            return new MaxInnerProductSupplier(this.input.clone(), this.values, this.scoreCorrectionConstant);
        }
    }

    public static final class DotProductSupplier
    extends Int7SQVectorScorerSupplier {
        public DotProductSupplier(MemorySegmentAccessInput input, QuantizedByteVectorValues values, float scoreCorrectionConstant) {
            super(input, values, scoreCorrectionConstant, ScalarQuantizedVectorSimilarity.fromVectorSimilarity((VectorSimilarityFunction)VectorSimilarityFunction.DOT_PRODUCT, (float)scoreCorrectionConstant, (byte)7));
        }

        @Override
        float scoreFromSegments(MemorySegment a, float aOffset, MemorySegment b, float bOffset) {
            int dotProduct = Similarities.dotProduct7u(a, b, this.dims);
            assert (dotProduct >= 0);
            float adjustedDistance = (float)dotProduct * this.scoreCorrectionConstant + aOffset + bOffset;
            return Math.max((1.0f + adjustedDistance) / 2.0f, 0.0f);
        }

        public DotProductSupplier copy() {
            return new DotProductSupplier(this.input.clone(), this.values, this.scoreCorrectionConstant);
        }
    }

    public static final class EuclideanSupplier
    extends Int7SQVectorScorerSupplier {
        public EuclideanSupplier(MemorySegmentAccessInput input, QuantizedByteVectorValues values, float scoreCorrectionConstant) {
            super(input, values, scoreCorrectionConstant, ScalarQuantizedVectorSimilarity.fromVectorSimilarity((VectorSimilarityFunction)VectorSimilarityFunction.EUCLIDEAN, (float)scoreCorrectionConstant, (byte)7));
        }

        @Override
        float scoreFromSegments(MemorySegment a, float aOffset, MemorySegment b, float bOffset) {
            int squareDistance = Similarities.squareDistance7u(a, b, this.dims);
            float adjustedDistance = (float)squareDistance * this.scoreCorrectionConstant;
            return 1.0f / (1.0f + adjustedDistance);
        }

        public EuclideanSupplier copy() {
            return new EuclideanSupplier(this.input.clone(), this.values, this.scoreCorrectionConstant);
        }
    }
}

