/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.model.vertexai;

import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
import com.google.protobuf.Message;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.internal.Json;
import dev.langchain4j.internal.RetryUtils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.vertexai.VertexAiChatInstance;
import dev.langchain4j.model.vertexai.VertexAiParameters;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

public class VertexAiChatModel
implements ChatLanguageModel {
    private final PredictionServiceSettings settings;
    private final EndpointName endpointName;
    private final VertexAiParameters vertexAiParameters;
    private final Integer maxRetries;

    public VertexAiChatModel(String endpoint, String project, String location, String publisher, String modelName, Double temperature, Integer maxOutputTokens, Integer topK, Double topP, Integer maxRetries) {
        try {
            this.settings = ((PredictionServiceSettings.Builder)PredictionServiceSettings.newBuilder().setEndpoint(ValidationUtils.ensureNotBlank((String)endpoint, (String)"endpoint"))).build();
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        this.endpointName = EndpointName.ofProjectLocationPublisherModelName((String)ValidationUtils.ensureNotBlank((String)project, (String)"project"), (String)ValidationUtils.ensureNotBlank((String)location, (String)"location"), (String)ValidationUtils.ensureNotBlank((String)publisher, (String)"publisher"), (String)ValidationUtils.ensureNotBlank((String)modelName, (String)"modelName"));
        this.vertexAiParameters = new VertexAiParameters(temperature, maxOutputTokens, topK, topP);
        this.maxRetries = maxRetries == null ? 3 : maxRetries;
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public AiMessage sendMessages(List<ChatMessage> messages) {
        try (PredictionServiceClient client = PredictionServiceClient.create((PredictionServiceSettings)this.settings);){
            VertexAiChatInstance vertexAiChatInstance = new VertexAiChatInstance(VertexAiChatModel.toContext(messages), VertexAiChatModel.toVertexMessages(messages));
            Value.Builder instanceBuilder = Value.newBuilder();
            JsonFormat.parser().merge(Json.toJson((Object)vertexAiChatInstance), (Message.Builder)instanceBuilder);
            List<Value> instances = Collections.singletonList(instanceBuilder.build());
            Value.Builder parametersBuilder = Value.newBuilder();
            JsonFormat.parser().merge(Json.toJson((Object)this.vertexAiParameters), (Message.Builder)parametersBuilder);
            Value parameters = parametersBuilder.build();
            PredictResponse response = (PredictResponse)RetryUtils.withRetry(() -> client.predict(this.endpointName, instances, parameters), (int)this.maxRetries);
            AiMessage aiMessage = AiMessage.aiMessage((String)VertexAiChatModel.extractContent(response));
            return aiMessage;
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private static String extractContent(PredictResponse predictResponse) {
        return ((Value)((Value)predictResponse.getPredictions(0).getStructValue().getFieldsMap().get("candidates")).getListValue().getValues(0).getStructValue().getFieldsMap().get("content")).getStringValue();
    }

    private static List<VertexAiChatInstance.Message> toVertexMessages(List<ChatMessage> messages) {
        return messages.stream().filter(chatMessage -> chatMessage.type() == ChatMessageType.USER || chatMessage.type() == ChatMessageType.AI).map(chatMessage -> new VertexAiChatInstance.Message(chatMessage.type().name(), chatMessage.text())).collect(Collectors.toList());
    }

    private static String toContext(List<ChatMessage> messages) {
        return messages.stream().filter(chatMessage -> chatMessage.type() == ChatMessageType.SYSTEM).map(ChatMessage::text).collect(Collectors.joining("\n"));
    }

    public AiMessage sendMessages(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
        throw new IllegalArgumentException("Tools are currently not supported for Vertex AI models");
    }

    public AiMessage sendMessages(List<ChatMessage> messages, ToolSpecification toolSpecification) {
        throw new IllegalArgumentException("Tools are currently not supported for Vertex AI models");
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {
        private String endpoint;
        private String project;
        private String location;
        private String publisher;
        private String modelName;
        private Double temperature;
        private Integer maxOutputTokens = 200;
        private Integer topK;
        private Double topP;
        private Integer maxRetries;

        public Builder endpoint(String endpoint) {
            this.endpoint = endpoint;
            return this;
        }

        public Builder project(String project) {
            this.project = project;
            return this;
        }

        public Builder location(String location) {
            this.location = location;
            return this;
        }

        public Builder publisher(String publisher) {
            this.publisher = publisher;
            return this;
        }

        public Builder modelName(String modelName) {
            this.modelName = modelName;
            return this;
        }

        public Builder temperature(Double temperature) {
            this.temperature = temperature;
            return this;
        }

        public Builder maxOutputTokens(Integer maxOutputTokens) {
            this.maxOutputTokens = maxOutputTokens;
            return this;
        }

        public Builder topK(Integer topK) {
            this.topK = topK;
            return this;
        }

        public Builder topP(Double topP) {
            this.topP = topP;
            return this;
        }

        public Builder maxRetries(Integer maxRetries) {
            this.maxRetries = maxRetries;
            return this;
        }

        public VertexAiChatModel build() {
            return new VertexAiChatModel(this.endpoint, this.project, this.location, this.publisher, this.modelName, this.temperature, this.maxOutputTokens, this.topK, this.topP, this.maxRetries);
        }
    }
}

