/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.graph.disk.feature;

import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.disk.CommonHeader;
import io.github.jbellis.jvector.graph.disk.feature.Feature;
import io.github.jbellis.jvector.graph.disk.feature.FeatureId;
import io.github.jbellis.jvector.graph.disk.feature.FeatureSource;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.quantization.NVQScorer;
import io.github.jbellis.jvector.quantization.NVQuantization;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import java.io.DataOutput;
import java.io.IOException;
import java.io.UncheckedIOException;

public class NVQ
implements Feature {
    private final NVQuantization nvq;
    private final NVQScorer scorer;
    private final ThreadLocal<NVQuantization.QuantizedVector> reusableQuantizedVector;

    public NVQ(NVQuantization nvq) {
        this.nvq = nvq;
        this.scorer = new NVQScorer(this.nvq);
        this.reusableQuantizedVector = ThreadLocal.withInitial(() -> NVQuantization.QuantizedVector.createEmpty(nvq.subvectorSizesAndOffsets, nvq.bitsPerDimension));
    }

    @Override
    public FeatureId id() {
        return FeatureId.NVQ_VECTORS;
    }

    @Override
    public int headerSize() {
        return this.nvq.compressorSize();
    }

    @Override
    public int featureSize() {
        return this.nvq.compressedVectorSize();
    }

    public int dimension() {
        return this.nvq.globalMean.length();
    }

    static NVQ load(CommonHeader header, RandomAccessReader reader) {
        try {
            return new NVQ(NVQuantization.load(reader));
        }
        catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    @Override
    public void writeHeader(DataOutput out) throws IOException {
        this.nvq.write(out, 5);
    }

    @Override
    public void writeInline(DataOutput out, Feature.State state_) throws IOException {
        State state = (State)state_;
        state.vector.write(out);
    }

    public ScoreFunction.ExactScoreFunction rerankerFor(VectorFloat<?> queryVector, VectorSimilarityFunction vsf, FeatureSource source) {
        NVQScorer.NVQScoreFunction function = this.scorer.scoreFunctionFor(queryVector, vsf);
        return node2 -> {
            try {
                RandomAccessReader reader = source.featureReaderForNode(node2, FeatureId.NVQ_VECTORS);
                NVQuantization.QuantizedVector.loadInto(reader, this.reusableQuantizedVector.get());
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
            return function.similarityTo(this.reusableQuantizedVector.get());
        };
    }

    public static class State
    implements Feature.State {
        public final NVQuantization.QuantizedVector vector;

        public State(NVQuantization.QuantizedVector vector) {
            this.vector = vector;
        }
    }
}

