/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.vector;

import io.github.jbellis.jvector.pq.LocallyAdaptiveVectorQuantization;
import io.github.jbellis.jvector.vector.ArrayVectorFloat;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorUtilSupport;
import io.github.jbellis.jvector.vector.types.ByteSequence;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import java.util.List;

final class DefaultVectorUtilSupport
implements VectorUtilSupport {
    DefaultVectorUtilSupport() {
    }

    @Override
    public float dotProduct(VectorFloat<?> av, VectorFloat<?> bv) {
        int i;
        float[] a = ((ArrayVectorFloat)av).get();
        float[] b = ((ArrayVectorFloat)bv).get();
        float res = 0.0f;
        for (i = 0; i < a.length % 8; ++i) {
            res += b[i] * a[i];
        }
        if (a.length < 8) {
            return res;
        }
        while (i + 31 < a.length) {
            res += b[i + 0] * a[i + 0] + b[i + 1] * a[i + 1] + b[i + 2] * a[i + 2] + b[i + 3] * a[i + 3] + b[i + 4] * a[i + 4] + b[i + 5] * a[i + 5] + b[i + 6] * a[i + 6] + b[i + 7] * a[i + 7];
            res += b[i + 8] * a[i + 8] + b[i + 9] * a[i + 9] + b[i + 10] * a[i + 10] + b[i + 11] * a[i + 11] + b[i + 12] * a[i + 12] + b[i + 13] * a[i + 13] + b[i + 14] * a[i + 14] + b[i + 15] * a[i + 15];
            res += b[i + 16] * a[i + 16] + b[i + 17] * a[i + 17] + b[i + 18] * a[i + 18] + b[i + 19] * a[i + 19] + b[i + 20] * a[i + 20] + b[i + 21] * a[i + 21] + b[i + 22] * a[i + 22] + b[i + 23] * a[i + 23];
            res += b[i + 24] * a[i + 24] + b[i + 25] * a[i + 25] + b[i + 26] * a[i + 26] + b[i + 27] * a[i + 27] + b[i + 28] * a[i + 28] + b[i + 29] * a[i + 29] + b[i + 30] * a[i + 30] + b[i + 31] * a[i + 31];
            i += 32;
        }
        while (i + 7 < a.length) {
            res += b[i + 0] * a[i + 0] + b[i + 1] * a[i + 1] + b[i + 2] * a[i + 2] + b[i + 3] * a[i + 3] + b[i + 4] * a[i + 4] + b[i + 5] * a[i + 5] + b[i + 6] * a[i + 6] + b[i + 7] * a[i + 7];
            i += 8;
        }
        return res;
    }

    @Override
    public float dotProduct(VectorFloat<?> av, int aoffset, VectorFloat<?> bv, int boffset, int length) {
        float[] b = ((ArrayVectorFloat)bv).get();
        float[] a = ((ArrayVectorFloat)av).get();
        float sum = 0.0f;
        for (int i = 0; i < length; ++i) {
            sum += a[aoffset + i] * b[boffset + i];
        }
        return sum;
    }

    @Override
    public float cosine(VectorFloat<?> av, VectorFloat<?> bv) {
        float[] a = ((ArrayVectorFloat)av).get();
        float[] b = ((ArrayVectorFloat)bv).get();
        float sum = 0.0f;
        float norm1 = 0.0f;
        float norm2 = 0.0f;
        int dim = a.length;
        for (int i = 0; i < dim; ++i) {
            float elem1 = a[i];
            float elem2 = b[i];
            sum += elem1 * elem2;
            norm1 += elem1 * elem1;
            norm2 += elem2 * elem2;
        }
        return (float)((double)sum / Math.sqrt((double)norm1 * (double)norm2));
    }

    @Override
    public float cosine(VectorFloat<?> av, int aoffset, VectorFloat<?> bv, int boffset, int length) {
        float[] a = ((ArrayVectorFloat)av).get();
        float[] b = ((ArrayVectorFloat)bv).get();
        float sum = 0.0f;
        float norm1 = 0.0f;
        float norm2 = 0.0f;
        for (int i = 0; i < length; ++i) {
            float elem1 = a[aoffset + i];
            float elem2 = b[boffset + i];
            sum += elem1 * elem2;
            norm1 += elem1 * elem1;
            norm2 += elem2 * elem2;
        }
        return (float)((double)sum / Math.sqrt((double)norm1 * (double)norm2));
    }

    @Override
    public float squareDistance(VectorFloat<?> av, VectorFloat<?> bv) {
        float[] a = ((ArrayVectorFloat)av).get();
        float[] b = ((ArrayVectorFloat)bv).get();
        float squareSum = 0.0f;
        int dim = a.length;
        int i = 0;
        while (i + 8 <= dim) {
            squareSum += DefaultVectorUtilSupport.squareDistanceUnrolled(a, b, i);
            i += 8;
        }
        while (i < dim) {
            float diff = a[i] - b[i];
            squareSum += diff * diff;
            ++i;
        }
        return squareSum;
    }

    private static float squareDistanceUnrolled(float[] v1, float[] v2, int index) {
        float diff0 = v1[index + 0] - v2[index + 0];
        float diff1 = v1[index + 1] - v2[index + 1];
        float diff2 = v1[index + 2] - v2[index + 2];
        float diff3 = v1[index + 3] - v2[index + 3];
        float diff4 = v1[index + 4] - v2[index + 4];
        float diff5 = v1[index + 5] - v2[index + 5];
        float diff6 = v1[index + 6] - v2[index + 6];
        float diff7 = v1[index + 7] - v2[index + 7];
        return diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3 + diff4 * diff4 + diff5 * diff5 + diff6 * diff6 + diff7 * diff7;
    }

    @Override
    public float squareDistance(VectorFloat<?> av, int aoffset, VectorFloat<?> bv, int boffset, int length) {
        float[] a = ((ArrayVectorFloat)av).get();
        float[] b = ((ArrayVectorFloat)bv).get();
        float squareSum = 0.0f;
        for (int i = 0; i < length; ++i) {
            float diff = a[aoffset + i] - b[boffset + i];
            squareSum += diff * diff;
        }
        return squareSum;
    }

    @Override
    public VectorFloat<?> sum(List<VectorFloat<?>> vectors) {
        ArrayVectorFloat sum = new ArrayVectorFloat(vectors.get(0).length());
        for (VectorFloat<?> vector : vectors) {
            for (int i = 0; i < vector.length(); ++i) {
                sum.set(i, sum.get(i) + vector.get(i));
            }
        }
        return sum;
    }

    @Override
    public float sum(VectorFloat<?> vector) {
        float sum = 0.0f;
        for (int i = 0; i < vector.length(); ++i) {
            sum += vector.get(i);
        }
        return sum;
    }

    @Override
    public void scale(VectorFloat<?> vector, float multiplier) {
        for (int i = 0; i < vector.length(); ++i) {
            vector.set(i, vector.get(i) * multiplier);
        }
    }

    @Override
    public void addInPlace(VectorFloat<?> v1, VectorFloat<?> v2) {
        for (int i = 0; i < v1.length(); ++i) {
            v1.set(i, v1.get(i) + v2.get(i));
        }
    }

    @Override
    public void subInPlace(VectorFloat<?> v1, VectorFloat<?> v2) {
        for (int i = 0; i < v1.length(); ++i) {
            v1.set(i, v1.get(i) - v2.get(i));
        }
    }

    @Override
    public VectorFloat<?> sub(VectorFloat<?> a, VectorFloat<?> b) {
        return this.sub(a, 0, b, 0, a.length());
    }

    @Override
    public VectorFloat<?> sub(VectorFloat<?> a, int aOffset, VectorFloat<?> b, int bOffset, int length) {
        ArrayVectorFloat result = new ArrayVectorFloat(length);
        for (int i = 0; i < length; ++i) {
            result.set(i, a.get(aOffset + i) - b.get(bOffset + i));
        }
        return result;
    }

    @Override
    public float assembleAndSum(VectorFloat<?> data, int dataBase, ByteSequence<?> baseOffsets) {
        float sum = 0.0f;
        for (int i = 0; i < baseOffsets.length(); ++i) {
            sum += data.get(dataBase * i + Byte.toUnsignedInt(baseOffsets.get(i)));
        }
        return sum;
    }

    @Override
    public int hammingDistance(long[] v1, long[] v2) {
        int hd = 0;
        for (int i = 0; i < v1.length; ++i) {
            hd += Long.bitCount(v1[i] ^ v2[i]);
        }
        return hd;
    }

    @Override
    public void bulkShuffleSimilarity(ByteSequence<?> shuffles, int codebookCount, VectorFloat<?> partials, VectorSimilarityFunction vsf, VectorFloat<?> results) {
        int i;
        for (i = 0; i < codebookCount; ++i) {
            for (int j = 0; j < 32; ++j) {
                results.set(j, results.get(j) + partials.get(i * 32 + shuffles.get(i * 32 + j)));
            }
        }
        block6: for (i = 0; i < results.length(); ++i) {
            switch (vsf) {
                case EUCLIDEAN: {
                    results.set(i, 1.0f / (1.0f + results.get(i)));
                    continue block6;
                }
                case DOT_PRODUCT: {
                    results.set(i, (results.get(i) + 1.0f) / 2.0f);
                    continue block6;
                }
                default: {
                    throw new UnsupportedOperationException("Unsupported similarity function " + String.valueOf((Object)vsf));
                }
            }
        }
    }

    @Override
    public void calculatePartialSums(VectorFloat<?> codebook, int codebookBase, int size, int clusterCount, VectorFloat<?> query, int queryOffset, VectorSimilarityFunction vsf, VectorFloat<?> partialSums) {
        block4: for (int i = 0; i < clusterCount; ++i) {
            switch (vsf) {
                case DOT_PRODUCT: {
                    partialSums.set(codebookBase + i, this.dotProduct(codebook, i * size, query, queryOffset, size));
                    continue block4;
                }
                case EUCLIDEAN: {
                    partialSums.set(codebookBase + i, this.squareDistance(codebook, i * size, query, queryOffset, size));
                    continue block4;
                }
                default: {
                    throw new UnsupportedOperationException("Unsupported similarity function " + String.valueOf((Object)vsf));
                }
            }
        }
    }

    @Override
    public float max(VectorFloat<?> v) {
        float max = Float.MIN_VALUE;
        for (int i = 0; i < v.length(); ++i) {
            max = Math.max(max, v.get(i));
        }
        return max;
    }

    @Override
    public float min(VectorFloat<?> v) {
        float min = Float.MAX_VALUE;
        for (int i = 0; i < v.length(); ++i) {
            min = Math.min(min, v.get(i));
        }
        return min;
    }

    @Override
    public float lvqDotProduct(VectorFloat<?> vector, LocallyAdaptiveVectorQuantization.PackedVector packedVector, float vectorSum) {
        float sum = 0.0f;
        for (int i = 0; i < vector.length(); ++i) {
            sum += vector.get(i) * (float)packedVector.getQuantized(i);
        }
        sum *= packedVector.scale;
        return sum += packedVector.bias * vectorSum;
    }

    @Override
    public float lvqSquareL2Distance(VectorFloat<?> vector, LocallyAdaptiveVectorQuantization.PackedVector packedVector) {
        float sum = 0.0f;
        for (int i = 0; i < vector.length(); ++i) {
            float diff = vector.get(i) - packedVector.getDequantized(i);
            sum += diff * diff;
        }
        return sum;
    }

    @Override
    public float lvqCosine(VectorFloat<?> vector, LocallyAdaptiveVectorQuantization.PackedVector packedVector, VectorFloat<?> centroid) {
        float sum = 0.0f;
        float norm1 = 0.0f;
        float norm2 = 0.0f;
        for (int i = 0; i < vector.length(); ++i) {
            float elem1 = vector.get(i);
            float elem2 = packedVector.getDequantized(i) + centroid.get(i);
            sum += elem1 * elem2;
            norm1 += elem1 * elem1;
            norm2 += elem2 * elem2;
        }
        return (float)((double)sum / Math.sqrt((double)norm1 * (double)norm2));
    }
}

