/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.model.vertexai;

import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
import com.google.protobuf.Message;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Json;
import dev.langchain4j.internal.RetryUtils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.vertexai.VertexAiEmbeddingInstance;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

public class VertexAiEmbeddingModel
implements EmbeddingModel {
    private final PredictionServiceSettings settings;
    private final EndpointName endpointName;
    private final Integer maxRetries;

    public VertexAiEmbeddingModel(String endpoint, String project, String location, String publisher, String modelName, Integer maxRetries) {
        try {
            this.settings = ((PredictionServiceSettings.Builder)PredictionServiceSettings.newBuilder().setEndpoint(ValidationUtils.ensureNotBlank((String)endpoint, (String)"endpoint"))).build();
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        this.endpointName = EndpointName.ofProjectLocationPublisherModelName((String)ValidationUtils.ensureNotBlank((String)project, (String)"project"), (String)ValidationUtils.ensureNotBlank((String)location, (String)"location"), (String)ValidationUtils.ensureNotBlank((String)publisher, (String)"publisher"), (String)ValidationUtils.ensureNotBlank((String)modelName, (String)"modelName"));
        this.maxRetries = maxRetries == null ? 3 : maxRetries;
    }

    public List<Embedding> embedAll(List<TextSegment> textSegments) {
        List<String> texts = textSegments.stream().map(TextSegment::text).collect(Collectors.toList());
        return this.embedTexts(texts);
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    private List<Embedding> embedTexts(List<String> texts) {
        try (PredictionServiceClient client = PredictionServiceClient.create((PredictionServiceSettings)this.settings);){
            ArrayList<Value> instances = new ArrayList<Value>();
            for (String text : texts) {
                Value.Builder instanceBuilder = Value.newBuilder();
                JsonFormat.parser().merge(Json.toJson((Object)new VertexAiEmbeddingInstance(text)), (Message.Builder)instanceBuilder);
                instances.add(instanceBuilder.build());
            }
            PredictResponse response = (PredictResponse)RetryUtils.withRetry(() -> client.predict(this.endpointName, instances, ValueConverter.EMPTY_VALUE), (int)this.maxRetries);
            List<Embedding> list = response.getPredictionsList().stream().map(VertexAiEmbeddingModel::toVector).map(Embedding::from).collect(Collectors.toList());
            return list;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private static List<Float> toVector(Value prediction) {
        return ((Value)prediction.getStructValue().getFieldsMap().get("embeddings")).getStructValue().getFieldsOrThrow("values").getListValue().getValuesList().stream().map(v -> Float.valueOf((float)v.getNumberValue())).collect(Collectors.toList());
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {
        private String endpoint;
        private String project;
        private String location;
        private String publisher;
        private String modelName;
        private Integer maxRetries;

        public Builder endpoint(String endpoint) {
            this.endpoint = endpoint;
            return this;
        }

        public Builder project(String project) {
            this.project = project;
            return this;
        }

        public Builder location(String location) {
            this.location = location;
            return this;
        }

        public Builder publisher(String publisher) {
            this.publisher = publisher;
            return this;
        }

        public Builder modelName(String modelName) {
            this.modelName = modelName;
            return this;
        }

        public Builder maxRetries(Integer maxRetries) {
            this.maxRetries = maxRetries;
            return this;
        }

        public VertexAiEmbeddingModel build() {
            return new VertexAiEmbeddingModel(this.endpoint, this.project, this.location, this.publisher, this.modelName, this.maxRetries);
        }
    }
}

