/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.model.bedrock;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.internal.RetryUtils;
import dev.langchain4j.model.bedrock.BedrockMistralAiChatModelResponse;
import dev.langchain4j.model.bedrock.internal.AbstractBedrockChatModel;
import dev.langchain4j.model.bedrock.internal.Json;
import dev.langchain4j.model.chat.listener.ChatModelRequest;
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
import dev.langchain4j.model.chat.listener.ChatModelResponse;
import dev.langchain4j.model.chat.listener.ChatModelResponseContext;
import dev.langchain4j.model.output.Response;
import java.nio.charset.Charset;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse;

public class BedrockMistralAiChatModel
extends AbstractBedrockChatModel<BedrockMistralAiChatModelResponse> {
    private static final Logger log = LoggerFactory.getLogger(BedrockMistralAiChatModel.class);
    private final int topK;
    private final String model;

    @Override
    protected String getModelId() {
        return this.model;
    }

    @Override
    protected Map<String, Object> getRequestParameters(String prompt) {
        HashMap<String, Object> parameters = new HashMap<String, Object>(7);
        parameters.put("prompt", prompt);
        parameters.put("max_tokens", this.getMaxTokens());
        parameters.put("temperature", this.getTemperature());
        parameters.put("top_p", Float.valueOf(this.getTopP()));
        parameters.put("top_k", this.topK);
        parameters.put("stop", this.getStopSequences());
        return parameters;
    }

    @Override
    public Response<AiMessage> generate(List<ChatMessage> messages) {
        String prompt = this.buildPrompt(messages);
        Map<String, Object> requestParameters = this.getRequestParameters(prompt);
        String body = Json.toJson(requestParameters);
        InvokeModelRequest invokeModelRequest = (InvokeModelRequest)InvokeModelRequest.builder().modelId(this.getModelId()).body(SdkBytes.fromString((String)body, (Charset)Charset.defaultCharset())).build();
        ChatModelRequest modelListenerRequest = this.createModelListenerRequest(invokeModelRequest, messages, Collections.emptyList());
        ConcurrentHashMap<Object, Object> attributes = new ConcurrentHashMap<Object, Object>();
        ChatModelRequestContext requestContext = new ChatModelRequestContext(modelListenerRequest, attributes);
        InvokeModelResponse invokeModelResponse = (InvokeModelResponse)RetryUtils.withRetry(() -> this.invoke(invokeModelRequest, requestContext), (int)this.getMaxRetries());
        String response = invokeModelResponse.body().asUtf8String().trim();
        BedrockMistralAiChatModelResponse result = Json.fromJson(response, this.getResponseClassType());
        try {
            Response<AiMessage> responseMessage = this.toAiMessage(result);
            ChatModelResponse modelListenerResponse = this.createModelListenerResponse(null, null, responseMessage);
            ChatModelResponseContext responseContext = new ChatModelResponseContext(modelListenerResponse, modelListenerRequest, attributes);
            this.listeners.forEach(listener -> {
                try {
                    listener.onResponse(responseContext);
                }
                catch (Exception e) {
                    log.warn("Exception while calling model listener", (Throwable)e);
                }
            });
            return responseMessage;
        }
        catch (RuntimeException e) {
            this.listenerErrorResponse(e, modelListenerRequest, attributes);
            throw e;
        }
    }

    private String buildPrompt(List<ChatMessage> messages) {
        StringBuilder promptBuilder = new StringBuilder();
        promptBuilder.append("<s>");
        block4: for (ChatMessage message : messages) {
            switch (message.type()) {
                case USER: {
                    promptBuilder.append("[INST] ").append(message.text()).append(" [/INST]");
                    continue block4;
                }
                case AI: {
                    promptBuilder.append(" ").append(message.text()).append(" ");
                    continue block4;
                }
            }
            throw new IllegalArgumentException("Bedrock Mistral AI does not support the message type: " + String.valueOf(message.type()));
        }
        promptBuilder.append("</s>");
        return promptBuilder.toString();
    }

    @Override
    public Class<BedrockMistralAiChatModelResponse> getResponseClassType() {
        return BedrockMistralAiChatModelResponse.class;
    }

    private static int $default$topK() {
        return 200;
    }

    private static String $default$model() {
        return Types.Mistral7bInstructV0_2.getValue();
    }

    protected BedrockMistralAiChatModel(BedrockMistralAiChatModelBuilder<?, ?> b) {
        super(b);
        this.topK = b.topK$set ? b.topK$value : BedrockMistralAiChatModel.$default$topK();
        this.model = b.model$set ? b.model$value : BedrockMistralAiChatModel.$default$model();
    }

    public static BedrockMistralAiChatModelBuilder<?, ?> builder() {
        return new BedrockMistralAiChatModelBuilderImpl();
    }

    @Override
    public int getTopK() {
        return this.topK;
    }

    public String getModel() {
        return this.model;
    }

    public static enum Types {
        Mistral7bInstructV0_2("mistral.mistral-7b-instruct-v0:2"),
        MistralMixtral8x7bInstructV0_1("mistral.mixtral-8x7b-instruct-v0:1");

        private final String value;

        private Types(String modelID) {
            this.value = modelID;
        }

        public String getValue() {
            return this.value;
        }
    }

    public static abstract class BedrockMistralAiChatModelBuilder<C extends BedrockMistralAiChatModel, B extends BedrockMistralAiChatModelBuilder<C, B>>
    extends AbstractBedrockChatModel.AbstractBedrockChatModelBuilder<BedrockMistralAiChatModelResponse, C, B> {
        private boolean topK$set;
        private int topK$value;
        private boolean model$set;
        private String model$value;

        @Override
        public B topK(int topK) {
            this.topK$value = topK;
            this.topK$set = true;
            return (B)this.self();
        }

        public B model(String model) {
            this.model$value = model;
            this.model$set = true;
            return (B)this.self();
        }

        @Override
        protected abstract B self();

        @Override
        public abstract C build();

        @Override
        public String toString() {
            return "BedrockMistralAiChatModel.BedrockMistralAiChatModelBuilder(super=" + super.toString() + ", topK$value=" + this.topK$value + ", model$value=" + this.model$value + ")";
        }
    }

    private static final class BedrockMistralAiChatModelBuilderImpl
    extends BedrockMistralAiChatModelBuilder<BedrockMistralAiChatModel, BedrockMistralAiChatModelBuilderImpl> {
        private BedrockMistralAiChatModelBuilderImpl() {
        }

        @Override
        protected BedrockMistralAiChatModelBuilderImpl self() {
            return this;
        }

        @Override
        public BedrockMistralAiChatModel build() {
            return new BedrockMistralAiChatModel(this);
        }
    }
}

