/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.embedding.huggingface;

import ai.vespa.embedding.PoolingStrategy;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
import com.yahoo.language.huggingface.Encoding;
import com.yahoo.language.huggingface.HuggingFaceTokenizer;
import com.yahoo.language.huggingface.ModelInfo;
import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.BitSet;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;

@Beta
public class HuggingFaceEmbedder
extends AbstractComponent
implements Embedder {
    private static final Logger log = Logger.getLogger(HuggingFaceEmbedder.class.getName());
    private final Embedder.Runtime runtime;
    private final String inputIdsName;
    private final String attentionMaskName;
    private final String tokenTypeIdsName;
    private final String outputName;
    private final boolean normalize;
    private final HuggingFaceTokenizer tokenizer;
    private final OnnxEvaluator evaluator;
    private final PoolingStrategy poolingStrategy;

    @Inject
    public HuggingFaceEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, HuggingFaceEmbedderConfig config) {
        this.runtime = runtime;
        this.inputIdsName = config.transformerInputIds();
        this.attentionMaskName = config.transformerAttentionMask();
        this.tokenTypeIdsName = config.transformerTokenTypeIds();
        this.outputName = config.transformerOutput();
        this.normalize = config.normalize();
        Path tokenizerPath = Paths.get(config.tokenizerPath().toString(), new String[0]);
        HuggingFaceTokenizer.Builder builder = new HuggingFaceTokenizer.Builder().addSpecialTokens(true).addDefaultModel(tokenizerPath).setPadding(false);
        ModelInfo info = HuggingFaceTokenizer.getModelInfo((Path)tokenizerPath);
        log.fine(() -> "'%s' has info '%s'".formatted(tokenizerPath, info));
        if (info.maxLength() == -1 || info.truncation() != ModelInfo.TruncationStrategy.LONGEST_FIRST) {
            int maxLength = info.maxLength() > 0 && info.maxLength() <= config.transformerMaxTokens() ? info.maxLength() : config.transformerMaxTokens();
            builder.setTruncation(true).setMaxLength(maxLength);
        }
        this.tokenizer = builder.build();
        this.poolingStrategy = PoolingStrategy.fromString(config.poolingStrategy().toString());
        OnnxEvaluatorOptions onnxOpts = new OnnxEvaluatorOptions();
        if (config.transformerGpuDevice() >= 0) {
            onnxOpts.setGpuDevice(config.transformerGpuDevice());
        }
        onnxOpts.setExecutionMode(config.transformerExecutionMode().toString());
        onnxOpts.setThreads(config.transformerInterOpThreads(), config.transformerIntraOpThreads());
        this.evaluator = onnx.evaluatorOf(config.transformerModel().toString(), onnxOpts);
        this.validateModel();
    }

    public void validateModel() {
        Map<String, TensorType> inputs = this.evaluator.getInputInfo();
        this.validateName(inputs, this.inputIdsName, "input");
        this.validateName(inputs, this.attentionMaskName, "input");
        if (!this.tokenTypeIdsName.isEmpty()) {
            this.validateName(inputs, this.tokenTypeIdsName, "input");
        }
        Map<String, TensorType> outputs = this.evaluator.getOutputInfo();
        this.validateName(outputs, this.outputName, "output");
    }

    private void validateName(Map<String, TensorType> types, String name, String type) {
        if (!types.containsKey(name)) {
            throw new IllegalArgumentException("Model does not contain required " + type + ": '" + name + "'. Model contains: " + String.join((CharSequence)",", types.keySet()));
        }
    }

    public List<Integer> embed(String s, Embedder.Context context) {
        long start = System.nanoTime();
        List tokens = this.tokenizer.embed(s, context);
        this.runtime.sampleSequenceLength((long)tokens.size(), context);
        this.runtime.sampleEmbeddingLatency((double)(System.nanoTime() - start) / 1000000.0, context);
        return tokens;
    }

    public void deconstruct() {
        this.evaluator.close();
        this.tokenizer.close();
    }

    public Tensor embed(String s, Embedder.Context context, TensorType tensorType) {
        Tensor result;
        long start = System.nanoTime();
        Encoding encoding = this.tokenizer.encode(s, context.getLanguage());
        this.runtime.sampleSequenceLength((long)encoding.ids().size(), context);
        IndexedTensor inputSequence = this.createTensorRepresentation(encoding.ids(), "d1");
        IndexedTensor attentionMask = this.createTensorRepresentation(encoding.attentionMask(), "d1");
        IndexedTensor tokenTypeIds = this.tokenTypeIdsName.isEmpty() ? null : this.createTensorRepresentation(encoding.typeIds(), "d1");
        Map<String, Tensor> inputs = this.tokenTypeIdsName.isEmpty() || tokenTypeIds.isEmpty() ? Map.of(this.inputIdsName, inputSequence.expand("d0"), this.attentionMaskName, attentionMask.expand("d0")) : Map.of(this.inputIdsName, inputSequence.expand("d0"), this.attentionMaskName, attentionMask.expand("d0"), this.tokenTypeIdsName, tokenTypeIds.expand("d0"));
        Map<String, Tensor> outputs = this.evaluator.evaluate(inputs);
        IndexedTensor tokenEmbeddings = (IndexedTensor)outputs.get(this.outputName);
        long[] resultShape = tokenEmbeddings.shape();
        if (resultShape.length != 3) {
            throw new IllegalArgumentException("Expected 3 output dimensions for output name '" + this.outputName + "': [batch, sequence, embedding], got " + resultShape.length);
        }
        if (tensorType.valueType() == TensorType.Value.INT8) {
            long outputDimensions = resultShape[2];
            long targetDim = (Long)((TensorType.Dimension)tensorType.dimensions().get(0)).size().get();
            if (targetDim * 8L > outputDimensions) {
                throw new IllegalArgumentException("Cannot pack " + outputDimensions + " into " + targetDim + " int8s");
            }
            long firstDimensions = 8L * targetDim;
            String name = ((TensorType.Dimension)tensorType.indexedSubtype().dimensions().get(0)).name();
            TensorType poolingType = new TensorType.Builder(TensorType.Value.FLOAT).indexed(name, firstDimensions).build();
            result = this.poolingStrategy.toSentenceEmbedding(poolingType, (Tensor)tokenEmbeddings, (Tensor)attentionMask);
            result = this.normalize ? this.normalize(result, poolingType) : result;
            result = HuggingFaceEmbedder.binarize((IndexedTensor)result, tensorType);
        } else {
            result = this.poolingStrategy.toSentenceEmbedding(tensorType, (Tensor)tokenEmbeddings, (Tensor)attentionMask);
            result = this.normalize ? this.normalize(result, tensorType) : result;
        }
        this.runtime.sampleEmbeddingLatency((double)(System.nanoTime() - start) / 1000000.0, context);
        return result;
    }

    Tensor normalize(Tensor embedding, TensorType tensorType) {
        double sumOfSquares = 0.0;
        Tensor.Builder builder = Tensor.Builder.of((TensorType)tensorType);
        int i = 0;
        while ((long)i < (Long)((TensorType.Dimension)tensorType.dimensions().get(0)).size().get()) {
            double item = embedding.get(TensorAddress.of((int[])new int[]{i++}));
            sumOfSquares += item * item;
        }
        double magnitude = Math.sqrt(sumOfSquares);
        int i2 = 0;
        while ((long)i2 < (Long)((TensorType.Dimension)tensorType.dimensions().get(0)).size().get()) {
            double value = embedding.get(TensorAddress.of((int[])new int[]{i2}));
            builder.cell(value / magnitude, new long[]{i2});
            ++i2;
        }
        return builder.build();
    }

    public static Tensor binarize(IndexedTensor embedding, TensorType tensorType) {
        Tensor.Builder builder = Tensor.Builder.of((TensorType)tensorType);
        BitSet bitSet = new BitSet(8);
        int index = 0;
        for (int d = 0; d < embedding.sizeAsInt(); ++d) {
            double value = embedding.get((long)d);
            int bitIndex = 7 - d % 8;
            if (value > 0.0) {
                bitSet.set(bitIndex);
            } else {
                bitSet.clear(bitIndex);
            }
            if ((d + 1) % 8 != 0) continue;
            byte[] bytes = bitSet.toByteArray();
            byte packed = bytes.length == 0 ? (byte)0 : bytes[0];
            builder.cell(TensorAddress.of((int[])new int[]{index++}), (float)packed);
            bitSet = new BitSet(8);
        }
        return builder.build();
    }

    private IndexedTensor createTensorRepresentation(List<Long> input, String dimension) {
        int size = input.size();
        TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, (long)size).build();
        IndexedTensor.Builder builder = IndexedTensor.Builder.of((TensorType)type);
        for (int i = 0; i < size; ++i) {
            builder.cell((float)input.get(i).longValue(), new long[]{i});
        }
        return builder.build();
    }
}

