/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.transformers;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import io.micrometer.observation.ObservationConvention;
import io.micrometer.observation.ObservationRegistry;
import java.io.InputStream;
import java.nio.Buffer;
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.AbstractEmbeddingModel;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.ai.transformers.ResourceCacheService;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.core.io.DefaultResourceLoader;
import org.springframework.core.io.Resource;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

public class TransformersEmbeddingModel
extends AbstractEmbeddingModel
implements InitializingBean {
    public static final String DEFAULT_ONNX_TOKENIZER_URI = "https://raw.githubusercontent.com/spring-projects/spring-ai/main/models/spring-ai-transformers/src/main/resources/onnx/all-MiniLM-L6-v2/tokenizer.json";
    public static final String DEFAULT_ONNX_MODEL_URI = "https://github.com/spring-projects/spring-ai/raw/main/models/spring-ai-transformers/src/main/resources/onnx/all-MiniLM-L6-v2/model.onnx";
    public static final String DEFAULT_MODEL_OUTPUT_NAME = "last_hidden_state";
    private static final Log logger = LogFactory.getLog(TransformersEmbeddingModel.class);
    private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();
    private static final int EMBEDDING_AXIS = 1;
    private final MetadataMode metadataMode;
    private final ObservationRegistry observationRegistry;
    public Map<String, String> tokenizerOptions = Map.of();
    private Resource tokenizerResource = TransformersEmbeddingModel.toResource("https://raw.githubusercontent.com/spring-projects/spring-ai/main/models/spring-ai-transformers/src/main/resources/onnx/all-MiniLM-L6-v2/tokenizer.json");
    private Resource modelResource = TransformersEmbeddingModel.toResource("https://github.com/spring-projects/spring-ai/raw/main/models/spring-ai-transformers/src/main/resources/onnx/all-MiniLM-L6-v2/model.onnx");
    private int gpuDeviceId = -1;
    private HuggingFaceTokenizer tokenizer;
    private OrtEnvironment environment;
    private OrtSession session;
    private String resourceCacheDirectory;
    private boolean disableCaching = false;
    private ResourceCacheService cacheService;
    private String modelOutputName = "last_hidden_state";
    private Set<String> onnxModelInputs;
    private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

    public TransformersEmbeddingModel() {
        this(MetadataMode.NONE);
    }

    public TransformersEmbeddingModel(MetadataMode metadataMode) {
        this(metadataMode, ObservationRegistry.NOOP);
    }

    public TransformersEmbeddingModel(MetadataMode metadataMode, ObservationRegistry observationRegistry) {
        Assert.notNull((Object)metadataMode, (String)"Metadata mode should not be null");
        Assert.notNull((Object)observationRegistry, (String)"Observation registry should not be null");
        this.metadataMode = metadataMode;
        this.observationRegistry = observationRegistry;
    }

    private static Resource toResource(String uri) {
        return new DefaultResourceLoader().getResource(uri);
    }

    public void setTokenizerOptions(Map<String, String> tokenizerOptions) {
        this.tokenizerOptions = tokenizerOptions;
    }

    public void setDisableCaching(boolean disableCaching) {
        this.disableCaching = disableCaching;
    }

    public void setResourceCacheDirectory(String resourceCacheDir) {
        this.resourceCacheDirectory = resourceCacheDir;
    }

    public void setGpuDeviceId(int gpuDeviceId) {
        this.gpuDeviceId = gpuDeviceId;
    }

    public void setTokenizerResource(Resource tokenizerResource) {
        this.tokenizerResource = tokenizerResource;
    }

    public void setModelResource(Resource modelResource) {
        this.modelResource = modelResource;
    }

    public void setTokenizerResource(String tokenizerResourceUri) {
        this.tokenizerResource = TransformersEmbeddingModel.toResource(tokenizerResourceUri);
    }

    public void setModelResource(String modelResourceUri) {
        this.modelResource = TransformersEmbeddingModel.toResource(modelResourceUri);
    }

    public void setModelOutputName(String modelOutputName) {
        this.modelOutputName = modelOutputName;
    }

    public void afterPropertiesSet() throws Exception {
        this.cacheService = StringUtils.hasText((String)this.resourceCacheDirectory) ? new ResourceCacheService(this.resourceCacheDirectory) : new ResourceCacheService();
        this.tokenizer = HuggingFaceTokenizer.newInstance((InputStream)this.getCachedResource(this.tokenizerResource).getInputStream(), this.tokenizerOptions);
        this.environment = OrtEnvironment.getEnvironment();
        try (OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();){
            if (this.gpuDeviceId >= 0) {
                sessionOptions.addCUDA(this.gpuDeviceId);
            }
            this.session = this.environment.createSession(this.getCachedResource(this.modelResource).getContentAsByteArray(), sessionOptions);
        }
        this.onnxModelInputs = this.session.getInputNames();
        Set onnxModelOutputs = this.session.getOutputNames();
        logger.info((Object)("Model input names: " + this.onnxModelInputs.stream().collect(Collectors.joining(", "))));
        logger.info((Object)("Model output names: " + onnxModelOutputs.stream().collect(Collectors.joining(", "))));
        Assert.isTrue((boolean)onnxModelOutputs.contains(this.modelOutputName), (String)("The generative output names don't contain expected: " + this.modelOutputName + ". Consider one of the available model outputs: " + onnxModelOutputs.stream().collect(Collectors.joining(", "))));
    }

    private Resource getCachedResource(Resource resource) {
        return this.disableCaching ? resource : this.cacheService.getCachedResource(resource);
    }

    public float[] embed(String text) {
        return this.embed(List.of(text)).get(0);
    }

    public float[] embed(Document document) {
        return this.embed(document.getFormattedContent(this.metadataMode));
    }

    public EmbeddingResponse embedForResponse(List<String> texts) {
        ArrayList<Embedding> data = new ArrayList<Embedding>();
        List<float[]> embed = this.embed(texts);
        for (int i = 0; i < embed.size(); ++i) {
            data.add(new Embedding(embed.get(i), Integer.valueOf(i)));
        }
        return new EmbeddingResponse(data);
    }

    public List<float[]> embed(List<String> texts) {
        return this.call(new EmbeddingRequest(texts, EmbeddingOptionsBuilder.builder().build())).getResults().stream().map(e -> e.getOutput()).toList();
    }

    public EmbeddingResponse call(EmbeddingRequest request) {
        EmbeddingModelObservationContext observationContext = EmbeddingModelObservationContext.builder().embeddingRequest(request).provider(AiProvider.ONNX.value()).requestOptions(request.getOptions()).build();
        return (EmbeddingResponse)EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION.observation((ObservationConvention)this.observationConvention, (ObservationConvention)DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry).observe(() -> {
            ArrayList<float[]> resultEmbeddings = new ArrayList<float[]>();
            try {
                Encoding[] encodings = this.tokenizer.batchEncode(request.getInstructions());
                long[][] input_ids0 = new long[encodings.length][];
                long[][] attention_mask0 = new long[encodings.length][];
                long[][] token_type_ids0 = new long[encodings.length][];
                for (int i = 0; i < encodings.length; ++i) {
                    input_ids0[i] = encodings[i].getIds();
                    attention_mask0[i] = encodings[i].getAttentionMask();
                    token_type_ids0[i] = encodings[i].getTypeIds();
                }
                try (OnnxTensor inputIds = OnnxTensor.createTensor((OrtEnvironment)this.environment, (Object)input_ids0);
                     OnnxTensor attentionMask = OnnxTensor.createTensor((OrtEnvironment)this.environment, (Object)attention_mask0);
                     OnnxTensor tokenTypeIds = OnnxTensor.createTensor((OrtEnvironment)this.environment, (Object)token_type_ids0);){
                    Map<String, OnnxTensor> modelInputs = Map.of("input_ids", inputIds, "attention_mask", attentionMask, "token_type_ids", tokenTypeIds);
                    modelInputs = this.removeUnknownModelInputs(modelInputs);
                    try (OrtSession.Result results = this.session.run(modelInputs);){
                        OnnxValue lastHiddenState = (OnnxValue)results.get(this.modelOutputName).get();
                        float[][][] tokenEmbeddings = (float[][][])lastHiddenState.getValue();
                        try (NDManager manager = NDManager.newBaseManager();){
                            NDArray ndTokenEmbeddings = this.create(tokenEmbeddings, manager);
                            NDArray ndAttentionMask = manager.create((long[][])attention_mask0);
                            NDArray embedding = this.meanPooling(ndTokenEmbeddings, ndAttentionMask);
                            int i = 0;
                            while ((long)i < embedding.size(0)) {
                                resultEmbeddings.add(embedding.get(new long[]{i}).toFloatArray());
                                ++i;
                            }
                        }
                    }
                }
            }
            catch (OrtException ex) {
                throw new RuntimeException(ex);
            }
            AtomicInteger indexCounter = new AtomicInteger(0);
            EmbeddingResponse embeddingResponse = new EmbeddingResponse(resultEmbeddings.stream().map(e -> new Embedding(e, Integer.valueOf(indexCounter.incrementAndGet()))).toList());
            observationContext.setResponse((Object)embeddingResponse);
            return embeddingResponse;
        });
    }

    private Map<String, OnnxTensor> removeUnknownModelInputs(Map<String, OnnxTensor> modelInputs) {
        return modelInputs.entrySet().stream().filter(a -> this.onnxModelInputs.contains(a.getKey())).collect(Collectors.toMap(e -> (String)e.getKey(), e -> (OnnxTensor)e.getValue()));
    }

    private NDArray create(float[][][] data3d, NDManager manager) {
        FloatBuffer buffer = FloatBuffer.allocate(data3d.length * data3d[0].length * data3d[0][0].length);
        float[][][] fArray = data3d;
        int n = fArray.length;
        for (int i = 0; i < n; ++i) {
            float[][] data2d;
            for (float[] data1d : data2d = fArray[i]) {
                buffer.put(data1d);
            }
        }
        buffer.rewind();
        return manager.create((Buffer)buffer, new Shape(new long[]{data3d.length, data3d[0].length, data3d[0][0].length}));
    }

    private NDArray meanPooling(NDArray tokenEmbeddings, NDArray attentionMask) {
        NDArray attentionMaskExpanded = attentionMask.expandDims(-1).broadcast(tokenEmbeddings.getShape()).toType(DataType.FLOAT32, false);
        NDArray weightedEmbeddings = tokenEmbeddings.mul(attentionMaskExpanded);
        NDArray sumEmbeddings = weightedEmbeddings.sum(new int[]{1});
        NDArray sumMask = attentionMaskExpanded.sum(new int[]{1}).clip((Number)Float.valueOf(1.0E-9f), (Number)Float.valueOf(Float.MAX_VALUE));
        return sumEmbeddings.div(sumMask);
    }

    public void setObservationConvention(EmbeddingModelObservationConvention observationConvention) {
        Assert.notNull((Object)observationConvention, (String)"observationConvention cannot be null");
        this.observationConvention = observationConvention;
    }
}

