/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.vertexai.embedding.text;

import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictRequest;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
import com.google.protobuf.Value;
import io.micrometer.observation.ObservationConvention;
import io.micrometer.observation.ObservationRegistry;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.springframework.ai.chat.metadata.DefaultUsage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.AbstractEmbeddingModel;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingOptions;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils;
import org.springframework.ai.vertexai.embedding.text.VertexAiTextEmbeddingModelName;
import org.springframework.ai.vertexai.embedding.text.VertexAiTextEmbeddingOptions;
import org.springframework.core.retry.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

public class VertexAiTextEmbeddingModel
extends AbstractEmbeddingModel {
    private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();
    private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = Stream.of(VertexAiTextEmbeddingModelName.values()).collect(Collectors.toMap(VertexAiTextEmbeddingModelName::getName, VertexAiTextEmbeddingModelName::getDimensions));
    public final VertexAiTextEmbeddingOptions defaultOptions;
    private final VertexAiEmbeddingConnectionDetails connectionDetails;
    private final RetryTemplate retryTemplate;
    private final ObservationRegistry observationRegistry;
    private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

    public VertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails, VertexAiTextEmbeddingOptions defaultEmbeddingOptions) {
        this(connectionDetails, defaultEmbeddingOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE);
    }

    public VertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails, VertexAiTextEmbeddingOptions defaultEmbeddingOptions, RetryTemplate retryTemplate) {
        this(connectionDetails, defaultEmbeddingOptions, retryTemplate, ObservationRegistry.NOOP);
    }

    public VertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails, VertexAiTextEmbeddingOptions defaultEmbeddingOptions, RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {
        Assert.notNull((Object)defaultEmbeddingOptions, (String)"VertexAiTextEmbeddingOptions must not be null");
        Assert.notNull((Object)retryTemplate, (String)"retryTemplate must not be null");
        Assert.notNull((Object)observationRegistry, (String)"observationRegistry must not be null");
        this.defaultOptions = defaultEmbeddingOptions.initializeDefaults();
        this.connectionDetails = connectionDetails;
        this.retryTemplate = retryTemplate;
        this.observationRegistry = observationRegistry;
    }

    public float[] embed(Document document) {
        Assert.notNull((Object)document, (String)"Document must not be null");
        return this.embed(document.getFormattedContent());
    }

    public EmbeddingResponse call(EmbeddingRequest request) {
        EmbeddingRequest embeddingRequest = this.buildEmbeddingRequest(request);
        EmbeddingModelObservationContext observationContext = EmbeddingModelObservationContext.builder().embeddingRequest(embeddingRequest).provider(AiProvider.VERTEX_AI.value()).build();
        return (EmbeddingResponse)EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION.observation((ObservationConvention)this.observationConvention, (ObservationConvention)DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry).observe(() -> {
            try (PredictionServiceClient client = this.createPredictionServiceClient();){
                EmbeddingOptions options = embeddingRequest.getOptions();
                EndpointName endpointName = this.connectionDetails.getEndpointName(options.getModel());
                PredictRequest.Builder predictRequestBuilder = this.getPredictRequestBuilder(request, endpointName, (VertexAiTextEmbeddingOptions)options);
                PredictResponse embeddingResponse = (PredictResponse)RetryUtils.execute((RetryTemplate)this.retryTemplate, () -> this.getPredictResponse(client, predictRequestBuilder));
                int index = 0;
                int totalTokenCount = 0;
                ArrayList<Embedding> embeddingList = new ArrayList<Embedding>();
                for (Value prediction : embeddingResponse.getPredictionsList()) {
                    Value embeddings = prediction.getStructValue().getFieldsOrThrow("embeddings");
                    Value statistics = embeddings.getStructValue().getFieldsOrThrow("statistics");
                    Value tokenCount = statistics.getStructValue().getFieldsOrThrow("token_count");
                    totalTokenCount += (int)tokenCount.getNumberValue();
                    Value values = embeddings.getStructValue().getFieldsOrThrow("values");
                    float[] vectorValues = VertexAiEmbeddingUtils.toVector(values);
                    embeddingList.add(new Embedding(vectorValues, Integer.valueOf(index++)));
                }
                EmbeddingResponse response = new EmbeddingResponse(embeddingList, this.generateResponseMetadata(options.getModel(), totalTokenCount));
                observationContext.setResponse((Object)response);
                EmbeddingResponse embeddingResponse2 = response;
                return embeddingResponse2;
            }
        });
    }

    EmbeddingRequest buildEmbeddingRequest(EmbeddingRequest embeddingRequest) {
        VertexAiTextEmbeddingOptions requestOptions;
        VertexAiTextEmbeddingOptions runtimeOptions = null;
        if (embeddingRequest.getOptions() != null) {
            runtimeOptions = (VertexAiTextEmbeddingOptions)ModelOptionsUtils.copyToTarget((Object)embeddingRequest.getOptions(), EmbeddingOptions.class, VertexAiTextEmbeddingOptions.class);
        }
        if (!StringUtils.hasText((String)(requestOptions = (VertexAiTextEmbeddingOptions)ModelOptionsUtils.merge(runtimeOptions, (Object)this.defaultOptions, VertexAiTextEmbeddingOptions.class)).getModel())) {
            throw new IllegalArgumentException("model cannot be null or empty");
        }
        return new EmbeddingRequest(embeddingRequest.getInstructions(), (EmbeddingOptions)requestOptions);
    }

    protected PredictRequest.Builder getPredictRequestBuilder(EmbeddingRequest request, EndpointName endpointName, VertexAiTextEmbeddingOptions finalOptions) {
        PredictRequest.Builder predictRequestBuilder = PredictRequest.newBuilder().setEndpoint(endpointName.toString());
        VertexAiEmbeddingUtils.TextParametersBuilder parametersBuilder = VertexAiEmbeddingUtils.TextParametersBuilder.of();
        if (finalOptions.getAutoTruncate() != null) {
            parametersBuilder.autoTruncate(finalOptions.getAutoTruncate());
        }
        if (finalOptions.getDimensions() != null) {
            parametersBuilder.outputDimensionality(finalOptions.getDimensions());
        }
        predictRequestBuilder.setParameters(VertexAiEmbeddingUtils.valueOf(parametersBuilder.build()));
        for (int i = 0; i < request.getInstructions().size(); ++i) {
            VertexAiEmbeddingUtils.TextInstanceBuilder instanceBuilder = VertexAiEmbeddingUtils.TextInstanceBuilder.of((String)request.getInstructions().get(i)).taskType(finalOptions.getTaskType().name());
            if (StringUtils.hasText((String)finalOptions.getTitle())) {
                instanceBuilder.title(finalOptions.getTitle());
            }
            predictRequestBuilder.addInstances(VertexAiEmbeddingUtils.valueOf(instanceBuilder.build()));
        }
        return predictRequestBuilder;
    }

    PredictionServiceClient createPredictionServiceClient() {
        try {
            return PredictionServiceClient.create((PredictionServiceSettings)this.connectionDetails.getPredictionServiceSettings());
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    PredictResponse getPredictResponse(PredictionServiceClient client, PredictRequest.Builder predictRequestBuilder) {
        PredictResponse embeddingResponse = client.predict(predictRequestBuilder.build());
        return embeddingResponse;
    }

    private EmbeddingResponseMetadata generateResponseMetadata(String model, Integer totalTokens) {
        EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata();
        metadata.setModel(model);
        DefaultUsage usage = this.getDefaultUsage(totalTokens);
        metadata.setUsage((Usage)usage);
        return metadata;
    }

    private DefaultUsage getDefaultUsage(Integer totalTokens) {
        return new DefaultUsage(Integer.valueOf(0), Integer.valueOf(0), totalTokens);
    }

    public int dimensions() {
        return KNOWN_EMBEDDING_DIMENSIONS.getOrDefault(this.defaultOptions.getModel(), super.dimensions());
    }

    public void setObservationConvention(EmbeddingModelObservationConvention observationConvention) {
        Assert.notNull((Object)observationConvention, (String)"observationConvention cannot be null");
        this.observationConvention = observationConvention;
    }
}

