/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.embedding;

import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.embedding.SpladeEmbedderConfig;
import com.yahoo.language.huggingface.Encoding;
import com.yahoo.language.huggingface.HuggingFaceTokenizer;
import com.yahoo.language.huggingface.ModelInfo;
import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.DirectIndexedAddress;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Reduce;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;

@Beta
public class SpladeEmbedder
extends AbstractComponent
implements Embedder {
    private final Embedder.Runtime runtime;
    private final String inputIdsName;
    private final String attentionMaskName;
    private final String tokenTypeIdsName;
    private final String outputName;
    private final double termScoreThreshold;
    private final boolean useCustomReduce;
    private final HuggingFaceTokenizer tokenizer;
    private final OnnxEvaluator evaluator;

    @Inject
    public SpladeEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, SpladeEmbedderConfig config) {
        this(onnx, runtime, config, true);
    }

    SpladeEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, SpladeEmbedderConfig config, boolean useCustomReduce) {
        this.runtime = runtime;
        this.inputIdsName = config.transformerInputIds();
        this.attentionMaskName = config.transformerAttentionMask();
        this.outputName = config.transformerOutput();
        this.tokenTypeIdsName = config.transformerTokenTypeIds();
        this.termScoreThreshold = config.termScoreThreshold();
        this.useCustomReduce = useCustomReduce;
        Path tokenizerPath = Paths.get(config.tokenizerPath().toString(), new String[0]);
        HuggingFaceTokenizer.Builder builder = new HuggingFaceTokenizer.Builder().addSpecialTokens(true).addDefaultModel(tokenizerPath).setPadding(false);
        ModelInfo info = HuggingFaceTokenizer.getModelInfo((Path)tokenizerPath);
        if (info.maxLength() == -1 || info.truncation() != ModelInfo.TruncationStrategy.LONGEST_FIRST) {
            int maxLength = info.maxLength() > 0 && info.maxLength() <= config.transformerMaxTokens() ? info.maxLength() : config.transformerMaxTokens();
            builder.setTruncation(true).setMaxLength(maxLength);
        }
        this.tokenizer = builder.build();
        OnnxEvaluatorOptions onnxOpts = new OnnxEvaluatorOptions();
        if (config.transformerGpuDevice() >= 0) {
            onnxOpts.setGpuDevice(config.transformerGpuDevice());
        }
        onnxOpts.setExecutionMode(config.transformerExecutionMode().toString());
        onnxOpts.setThreads(config.transformerInterOpThreads(), config.transformerIntraOpThreads());
        this.evaluator = onnx.evaluatorOf(config.transformerModel().toString(), onnxOpts);
        this.validateModel();
    }

    public void validateModel() {
        Map<String, TensorType> inputs = this.evaluator.getInputInfo();
        this.validateName(inputs, this.inputIdsName, "input");
        this.validateName(inputs, this.attentionMaskName, "input");
        Map<String, TensorType> outputs = this.evaluator.getOutputInfo();
        this.validateName(outputs, this.outputName, "output");
    }

    protected boolean verifyTensorType(TensorType target) {
        return target.dimensions().size() == 1 && ((TensorType.Dimension)target.dimensions().get(0)).isMapped();
    }

    private void validateName(Map<String, TensorType> types, String name, String type) {
        if (!types.containsKey(name)) {
            throw new IllegalArgumentException("Model does not contain required " + type + ": '" + name + "'. Model contains: " + String.join((CharSequence)",", types.keySet()));
        }
    }

    public List<Integer> embed(String text, Embedder.Context context) {
        throw new UnsupportedOperationException("This embedder only supports embed with tensor type");
    }

    public Tensor embed(String text, Embedder.Context context, TensorType tensorType) {
        if (!this.verifyTensorType(tensorType)) {
            throw new IllegalArgumentException("Invalid splade embedder tensor destination. Wanted a mapped 1-d tensor, got " + tensorType);
        }
        long start = System.nanoTime();
        Encoding encoding = this.tokenizer.encode(text, context.getLanguage());
        this.runtime.sampleSequenceLength((long)encoding.ids().size(), context);
        IndexedTensor inputSequence = this.createTensorRepresentation(encoding.ids(), "d1");
        IndexedTensor attentionMask = this.createTensorRepresentation(encoding.attentionMask(), "d1");
        IndexedTensor tokenTypeIds = this.createTensorRepresentation(encoding.typeIds(), "d1");
        Map<String, Tensor> inputs = Map.of(this.inputIdsName, inputSequence.expand("d0"), this.attentionMaskName, attentionMask.expand("d0"), this.tokenTypeIdsName, tokenTypeIds.expand("d0"));
        IndexedTensor output = (IndexedTensor)this.evaluator.evaluate(inputs).get(this.outputName);
        Tensor spladeTensor = this.useCustomReduce ? this.sparsifyCustomReduce(output, tensorType) : this.sparsifyReduce((Tensor)output, tensorType);
        this.runtime.sampleEmbeddingLatency((double)(System.nanoTime() - start) / 1000000.0, context);
        return spladeTensor;
    }

    private Tensor sparsifyReduce(Tensor modelOutput, TensorType tensorType) {
        Tensor output = modelOutput.reduce(Reduce.Aggregator.max, new String[]{"d0", "d1"});
        Tensor logOfRelu = output.map(x -> Math.log(1.0 + (x > 0.0 ? x : 0.0)));
        IndexedTensor vocab = (IndexedTensor)logOfRelu;
        Tensor.Builder builder = Tensor.Builder.of((TensorType)tensorType);
        long[] tokens = new long[1];
        int i = 0;
        while ((long)i < vocab.size()) {
            double score = vocab.get((long)i);
            if (score > this.termScoreThreshold) {
                tokens[0] = i;
                String term = this.tokenizer.decode(tokens);
                builder.cell().label(((TensorType.Dimension)tensorType.dimensions().get(0)).name(), term).value(score);
            }
            ++i;
        }
        return builder.build();
    }

    public Tensor sparsifyCustomReduce(IndexedTensor modelOutput, TensorType tensorType) {
        Tensor.Builder builder = Tensor.Builder.of((TensorType)tensorType);
        long[] shape = modelOutput.shape();
        if (shape.length != 3) {
            throw new IllegalArgumentException("The indexed tensor must be 3-dimensional");
        }
        long batch = shape[0];
        if (batch != 1L) {
            throw new IllegalArgumentException("Batch size must be 1");
        }
        if (shape[1] > Integer.MAX_VALUE) {
            throw new IllegalArgumentException("sequenceLength=" + shape[1] + " larger than an int");
        }
        if (shape[2] > Integer.MAX_VALUE) {
            throw new IllegalArgumentException("vocabSize=" + shape[2] + " larger than an int");
        }
        int sequenceLength = (int)shape[1];
        int vocabSize = (int)shape[2];
        String dimension = ((TensorType.Dimension)tensorType.dimensions().get(0)).name();
        long[] tokens = new long[1];
        DirectIndexedAddress directAddress = modelOutput.directAddress();
        directAddress.setIndex(0, 0);
        for (int v = 0; v < vocabSize; ++v) {
            double maxValue = 0.0;
            directAddress.setIndex(2, v);
            long increment = directAddress.getStride(1);
            long directIndex = directAddress.getDirectIndex();
            for (int s = 0; s < sequenceLength; ++s) {
                double value = modelOutput.get(directIndex + (long)s * increment);
                if (!(value > maxValue)) continue;
                maxValue = value;
            }
            double logOfRelu = Math.log(1.0 + maxValue);
            if (!(logOfRelu > this.termScoreThreshold)) continue;
            tokens[0] = v;
            String term = this.tokenizer.decode(tokens);
            builder.cell().label(dimension, term).value(logOfRelu);
        }
        return builder.build();
    }

    private IndexedTensor createTensorRepresentation(List<Long> input, String dimension) {
        int size = input.size();
        TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, (long)size).build();
        IndexedTensor.Builder builder = IndexedTensor.Builder.of((TensorType)type);
        for (int i = 0; i < size; ++i) {
            builder.cell((float)input.get(i).longValue(), new long[]{i});
        }
        return builder.build();
    }

    public void deconstruct() {
        this.evaluator.close();
        this.tokenizer.close();
    }
}

