/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.model.bedrock;

import dev.langchain4j.Internal;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.data.message.Content;
import dev.langchain4j.data.message.ImageContent;
import dev.langchain4j.data.message.PdfFileContent;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.exception.UnsupportedFeatureException;
import dev.langchain4j.model.bedrock.AwsDocumentConverter;
import dev.langchain4j.model.bedrock.BedrockCachePointPlacement;
import dev.langchain4j.model.bedrock.BedrockChatRequestParameters;
import dev.langchain4j.model.bedrock.BedrockTokenUsage;
import dev.langchain4j.model.bedrock.Utils;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.request.ChatRequestParameters;
import dev.langchain4j.model.chat.request.DefaultChatRequestParameters;
import dev.langchain4j.model.chat.request.ResponseFormatType;
import dev.langchain4j.model.chat.request.ToolChoice;
import dev.langchain4j.model.output.FinishReason;
import java.net.URI;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.core.document.Document;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrockruntime.model.AnyToolChoice;
import software.amazon.awssdk.services.bedrockruntime.model.CachePointBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ConversationRole;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
import software.amazon.awssdk.services.bedrockruntime.model.DocumentBlock;
import software.amazon.awssdk.services.bedrockruntime.model.DocumentFormat;
import software.amazon.awssdk.services.bedrockruntime.model.DocumentSource;
import software.amazon.awssdk.services.bedrockruntime.model.ImageBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ImageSource;
import software.amazon.awssdk.services.bedrockruntime.model.InferenceConfiguration;
import software.amazon.awssdk.services.bedrockruntime.model.Message;
import software.amazon.awssdk.services.bedrockruntime.model.ReasoningContentBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ReasoningTextBlock;
import software.amazon.awssdk.services.bedrockruntime.model.StopReason;
import software.amazon.awssdk.services.bedrockruntime.model.SystemContentBlock;
import software.amazon.awssdk.services.bedrockruntime.model.TokenUsage;
import software.amazon.awssdk.services.bedrockruntime.model.Tool;
import software.amazon.awssdk.services.bedrockruntime.model.ToolConfiguration;
import software.amazon.awssdk.services.bedrockruntime.model.ToolInputSchema;
import software.amazon.awssdk.services.bedrockruntime.model.ToolResultBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ToolResultContentBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ToolSpecification;
import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlock;

@Internal
abstract class AbstractBedrockChatModel {
    private static final String THINKING_SIGNATURE_KEY = "thinking_signature";
    protected final Region region;
    protected final Duration timeout;
    protected final boolean returnThinking;
    protected final boolean sendThinking;
    protected final BedrockChatRequestParameters defaultRequestParameters;
    protected final List<ChatModelListener> listeners;

    protected AbstractBedrockChatModel(AbstractBuilder<?> builder) {
        BedrockChatRequestParameters bedrockChatRequestParameters;
        ChatRequestParameters commonParameters;
        this.region = (Region)dev.langchain4j.internal.Utils.getOrDefault((Object)builder.region, (Object)Region.US_EAST_1);
        this.timeout = (Duration)dev.langchain4j.internal.Utils.getOrDefault((Object)builder.timeout, (Object)Duration.ofMinutes(1L));
        this.returnThinking = (Boolean)dev.langchain4j.internal.Utils.getOrDefault((Object)builder.returnThinking, (Object)false);
        this.sendThinking = (Boolean)dev.langchain4j.internal.Utils.getOrDefault((Object)builder.sendThinking, (Object)true);
        this.listeners = dev.langchain4j.internal.Utils.copy(builder.listeners);
        if (builder.defaultRequestParameters != null) {
            AbstractBedrockChatModel.validate(builder.defaultRequestParameters);
            commonParameters = builder.defaultRequestParameters;
        } else {
            commonParameters = DefaultChatRequestParameters.EMPTY;
        }
        ChatRequestParameters chatRequestParameters = builder.defaultRequestParameters;
        BedrockChatRequestParameters bedrockParameters = chatRequestParameters instanceof BedrockChatRequestParameters ? (bedrockChatRequestParameters = (BedrockChatRequestParameters)chatRequestParameters) : BedrockChatRequestParameters.EMPTY;
        this.defaultRequestParameters = ((BedrockChatRequestParameters.Builder)((BedrockChatRequestParameters.Builder)((BedrockChatRequestParameters.Builder)((BedrockChatRequestParameters.Builder)((BedrockChatRequestParameters.Builder)((BedrockChatRequestParameters.Builder)((BedrockChatRequestParameters.Builder)BedrockChatRequestParameters.builder().modelName((String)dev.langchain4j.internal.Utils.getOrDefault((Object)builder.modelId, (Object)commonParameters.modelName()))).temperature(commonParameters.temperature())).topP(commonParameters.topP())).maxOutputTokens(commonParameters.maxOutputTokens())).stopSequences(commonParameters.stopSequences())).toolSpecifications(commonParameters.toolSpecifications())).toolChoice(commonParameters.toolChoice())).additionalModelRequestFields(bedrockParameters.additionalModelRequestFields()).promptCaching(bedrockParameters.cachePointPlacement()).build();
    }

    protected List<SystemContentBlock> extractSystemMessages(List<ChatMessage> messages) {
        return this.extractSystemMessages(messages, null);
    }

    protected List<SystemContentBlock> extractSystemMessages(List<ChatMessage> messages, BedrockCachePointPlacement cachePointPlacement) {
        ArrayList<SystemContentBlock> systemBlocks = new ArrayList<SystemContentBlock>();
        for (ChatMessage message : messages) {
            if (message.type() != ChatMessageType.SYSTEM) continue;
            systemBlocks.add((SystemContentBlock)SystemContentBlock.builder().text(((SystemMessage)message).text()).build());
        }
        if (cachePointPlacement == BedrockCachePointPlacement.AFTER_SYSTEM && !systemBlocks.isEmpty()) {
            systemBlocks.add((SystemContentBlock)SystemContentBlock.builder().cachePoint((CachePointBlock)CachePointBlock.builder().type("default").build()).build());
        }
        return systemBlocks;
    }

    protected List<Message> extractRegularMessages(List<ChatMessage> messages) {
        return this.extractRegularMessages(messages, null);
    }

    protected List<Message> extractRegularMessages(List<ChatMessage> messages, BedrockCachePointPlacement cachePointPlacement) {
        ArrayList<Message> bedrockMessages = new ArrayList<Message>();
        ArrayList<ContentBlock> currentBlocks = new ArrayList<ContentBlock>();
        boolean firstUserMessageProcessed = false;
        for (int i = 0; i < messages.size(); ++i) {
            ChatMessage msg = messages.get(i);
            if (msg instanceof ToolExecutionResultMessage) {
                ToolExecutionResultMessage toolResult = (ToolExecutionResultMessage)msg;
                this.handleToolResult(toolResult, currentBlocks, bedrockMessages, i, messages);
                continue;
            }
            if (msg instanceof SystemMessage) continue;
            Message bedrockMessage = this.convertToBedRockMessage(msg);
            if (cachePointPlacement == BedrockCachePointPlacement.AFTER_USER_MESSAGE && msg instanceof UserMessage && !firstUserMessageProcessed) {
                ArrayList<ContentBlock> contentWithCachePoint = new ArrayList<ContentBlock>(bedrockMessage.content());
                contentWithCachePoint.add((ContentBlock)ContentBlock.builder().cachePoint((CachePointBlock)CachePointBlock.builder().type("default").build()).build());
                bedrockMessage = (Message)Message.builder().role(bedrockMessage.role()).content(contentWithCachePoint).build();
                firstUserMessageProcessed = true;
            }
            bedrockMessages.add(bedrockMessage);
        }
        return bedrockMessages;
    }

    protected void handleToolResult(ToolExecutionResultMessage toolResult, List<ContentBlock> blocks, List<Message> bedrockMessages, int currentIndex, List<ChatMessage> allMessages) {
        boolean isLastOrNextIsNotToolResult;
        blocks.add(this.createToolResultBlock(toolResult));
        boolean bl = isLastOrNextIsNotToolResult = currentIndex + 1 >= allMessages.size() || !(allMessages.get(currentIndex + 1) instanceof ToolExecutionResultMessage);
        if (isLastOrNextIsNotToolResult) {
            bedrockMessages.add((Message)Message.builder().role(ConversationRole.USER).content(blocks).build());
            blocks.clear();
        }
    }

    protected ContentBlock createToolResultBlock(ToolExecutionResultMessage toolResult) {
        return (ContentBlock)ContentBlock.builder().toolResult((ToolResultBlock)ToolResultBlock.builder().toolUseId(toolResult.id()).content(new ToolResultContentBlock[]{(ToolResultContentBlock)ToolResultContentBlock.builder().text(toolResult.text()).build()}).build()).build();
    }

    protected Message convertToBedRockMessage(ChatMessage message) {
        if (message instanceof UserMessage) {
            UserMessage userMsg = (UserMessage)message;
            return this.createUserMessage(userMsg);
        }
        if (message instanceof AiMessage) {
            AiMessage aiMsg = (AiMessage)message;
            return this.createAiMessage(aiMsg);
        }
        throw new IllegalArgumentException("Unsupported message type: " + String.valueOf(message.getClass()));
    }

    protected Message createUserMessage(UserMessage message) {
        return (Message)Message.builder().role(ConversationRole.USER).content(this.convertContents(message.contents())).build();
    }

    protected Message createAiMessage(AiMessage message) {
        ArrayList<ContentBlock> blocks = new ArrayList<ContentBlock>();
        if (this.sendThinking && message.thinking() != null) {
            ReasoningContentBlock reasoningContentBlock = (ReasoningContentBlock)ReasoningContentBlock.builder().reasoningText((ReasoningTextBlock)ReasoningTextBlock.builder().text(message.thinking()).signature((String)message.attribute(THINKING_SIGNATURE_KEY, String.class)).build()).build();
            blocks.add((ContentBlock)ContentBlock.builder().reasoningContent(reasoningContentBlock).build());
        }
        if (message.text() != null) {
            blocks.add((ContentBlock)ContentBlock.builder().text(message.text()).build());
        }
        if (message.hasToolExecutionRequests()) {
            blocks.addAll(this.convertToolRequests(message.toolExecutionRequests()));
        }
        return (Message)Message.builder().role(ConversationRole.ASSISTANT).content(blocks).build();
    }

    protected List<ContentBlock> convertToolRequests(List<ToolExecutionRequest> requests) {
        return requests.stream().map(req -> (ContentBlock)ContentBlock.builder().toolUse((ToolUseBlock)ToolUseBlock.builder().name(req.name()).toolUseId(req.id()).input(AwsDocumentConverter.documentFromJson(req.arguments())).build()).build()).toList();
    }

    protected List<ContentBlock> convertContents(List<Content> contents) {
        if (dev.langchain4j.internal.Utils.isNullOrEmpty(contents)) {
            return Collections.emptyList();
        }
        return contents.stream().map(this::convertContent).toList();
    }

    protected ContentBlock convertContent(Content content) {
        if (content instanceof TextContent) {
            TextContent text = (TextContent)content;
            return (ContentBlock)ContentBlock.builder().text(text.text()).build();
        }
        if (content instanceof PdfFileContent) {
            PdfFileContent pdfFileContent = (PdfFileContent)content;
            SdkBytes bytes = SdkBytes.fromByteArray((byte[])(Objects.nonNull(pdfFileContent.pdfFile().base64Data()) ? Base64.getDecoder().decode(pdfFileContent.pdfFile().base64Data()) : dev.langchain4j.internal.Utils.readBytes((String)String.valueOf(pdfFileContent.pdfFile().url()))));
            return (ContentBlock)ContentBlock.builder().document((DocumentBlock)DocumentBlock.builder().format(DocumentFormat.PDF).source((DocumentSource)DocumentSource.builder().bytes(bytes).build()).name(AbstractBedrockChatModel.extractFilenameWithoutExtensionFromUri(pdfFileContent.pdfFile().url())).build()).build();
        }
        if (content instanceof ImageContent) {
            ImageContent image = (ImageContent)content;
            return this.createImageBlock(image);
        }
        throw new IllegalArgumentException("Unsupported content type: " + String.valueOf(content.getClass()));
    }

    protected ContentBlock createImageBlock(ImageContent imageContent) {
        SdkBytes bytes = SdkBytes.fromByteArray((byte[])(Objects.nonNull(imageContent.image().base64Data()) ? Base64.getDecoder().decode(imageContent.image().base64Data()) : dev.langchain4j.internal.Utils.readBytes((String)String.valueOf(imageContent.image().url()))));
        String imgFormat = Utils.extractAndValidateFormat(imageContent.image());
        return (ContentBlock)ContentBlock.builder().image((ImageBlock)ImageBlock.builder().format(imgFormat).source((ImageSource)ImageSource.builder().bytes(bytes).build()).build()).build();
    }

    protected ToolConfiguration extractToolConfigurationFrom(ChatRequest chatRequest) {
        return this.extractToolConfigurationFrom(chatRequest, null);
    }

    protected ToolConfiguration extractToolConfigurationFrom(ChatRequest chatRequest, BedrockCachePointPlacement cachePointPlacement) {
        List toolSpecifications = chatRequest.toolSpecifications();
        ChatRequestParameters parameters = chatRequest.parameters();
        ArrayList<Tool> allTools = new ArrayList<Tool>();
        ToolConfiguration.Builder toolConfigurationBuilder = ToolConfiguration.builder();
        if (Objects.nonNull(toolSpecifications) && !toolSpecifications.isEmpty()) {
            List<Tool> tools = toolSpecifications.stream().map(toolSpecification -> {
                ToolInputSchema toolInputSchema = (ToolInputSchema)ToolInputSchema.builder().json(AwsDocumentConverter.convertJsonObjectSchemaToDocument(toolSpecification)).build();
                return (ToolSpecification)ToolSpecification.builder().name(toolSpecification.name()).description(toolSpecification.description()).inputSchema(toolInputSchema).build();
            }).map(toolSpecification -> (Tool)Tool.builder().toolSpec(toolSpecification).build()).toList();
            allTools.addAll(tools);
            if (cachePointPlacement == BedrockCachePointPlacement.AFTER_TOOLS) {
                allTools.add((Tool)Tool.builder().cachePoint((CachePointBlock)CachePointBlock.builder().type("default").build()).build());
            }
        }
        if (allTools.isEmpty()) {
            return null;
        }
        toolConfigurationBuilder.tools(allTools);
        if (Objects.nonNull(parameters) && ToolChoice.REQUIRED.equals((Object)parameters.toolChoice())) {
            toolConfigurationBuilder.toolChoice(software.amazon.awssdk.services.bedrockruntime.model.ToolChoice.fromAny((AnyToolChoice)((AnyToolChoice)AnyToolChoice.builder().build())));
        }
        return (ToolConfiguration)toolConfigurationBuilder.build();
    }

    protected AiMessage aiMessageFrom(ConverseResponse converseResponse) {
        ArrayList<String> texts = new ArrayList<String>();
        String thinking = null;
        Map<String, String> attributes = null;
        ArrayList<ToolExecutionRequest> toolExecutionRequests = new ArrayList<ToolExecutionRequest>();
        for (ContentBlock cBlock : converseResponse.output().message().content()) {
            if (cBlock.type() == ContentBlock.Type.TOOL_USE) {
                toolExecutionRequests.add(ToolExecutionRequest.builder().name(cBlock.toolUse().name()).id(cBlock.toolUse().toolUseId()).arguments(AwsDocumentConverter.documentToJson(cBlock.toolUse().input())).build());
                continue;
            }
            if (cBlock.type() == ContentBlock.Type.TEXT) {
                if (!dev.langchain4j.internal.Utils.isNotNullOrEmpty((String)cBlock.text())) continue;
                texts.add(cBlock.text());
                continue;
            }
            if (cBlock.type() == ContentBlock.Type.REASONING_CONTENT) {
                ReasoningTextBlock reasoningTextBlock;
                ReasoningContentBlock reasoningContentBlock;
                if (!this.returnThinking || (reasoningContentBlock = cBlock.reasoningContent()) == null || (reasoningTextBlock = reasoningContentBlock.reasoningText()) == null) continue;
                if (dev.langchain4j.internal.Utils.isNotNullOrEmpty((String)reasoningTextBlock.text())) {
                    thinking = reasoningTextBlock.text();
                }
                if (!dev.langchain4j.internal.Utils.isNotNullOrEmpty((String)reasoningTextBlock.signature())) continue;
                attributes = Map.of(THINKING_SIGNATURE_KEY, reasoningTextBlock.signature());
                continue;
            }
            throw new IllegalArgumentException("Unsupported content in LLM response. Content type: " + String.valueOf(cBlock.type()));
        }
        String text = texts.stream().collect(Collectors.joining("\n\n"));
        return AiMessage.builder().text(dev.langchain4j.internal.Utils.isNullOrEmpty((String)text) ? null : text).thinking(thinking).attributes(attributes).toolExecutionRequests(toolExecutionRequests).build();
    }

    protected BedrockTokenUsage tokenUsageFrom(TokenUsage tokenUsage) {
        return Optional.ofNullable(tokenUsage).map(usage -> BedrockTokenUsage.builder().inputTokenCount(tokenUsage.inputTokens()).outputTokenCount(tokenUsage.outputTokens()).cacheWriteInputTokens(tokenUsage.cacheWriteInputTokens()).cacheReadInputTokens(tokenUsage.cacheReadInputTokens()).build()).orElseGet(BedrockTokenUsage.builder()::build);
    }

    protected FinishReason finishReasonFrom(StopReason stopReason) {
        if (stopReason == StopReason.END_TURN || stopReason == StopReason.STOP_SEQUENCE) {
            return FinishReason.STOP;
        }
        if (stopReason == StopReason.MAX_TOKENS) {
            return FinishReason.LENGTH;
        }
        if (stopReason == StopReason.TOOL_USE) {
            return FinishReason.TOOL_EXECUTION;
        }
        if (stopReason == StopReason.CONTENT_FILTERED) {
            return FinishReason.CONTENT_FILTER;
        }
        throw new IllegalArgumentException("Unknown stop reason: " + String.valueOf(stopReason));
    }

    protected InferenceConfiguration inferenceConfigFrom(ChatRequestParameters parameters) {
        return (InferenceConfiguration)InferenceConfiguration.builder().maxTokens(parameters.maxOutputTokens()).temperature(AbstractBedrockChatModel.dblToFloat(parameters.temperature())).topP(AbstractBedrockChatModel.dblToFloat(parameters.topP())).stopSequences((Collection)(dev.langchain4j.internal.Utils.isNullOrEmpty((Collection)parameters.stopSequences()) ? null : parameters.stopSequences())).build();
    }

    protected Document additionalRequestModelFieldsFrom(ChatRequestParameters chatRequestParameters) {
        BedrockChatRequestParameters bedrockChatRequestParameters;
        HashMap<String, Object> additionalModelRequestFieldsMap = new HashMap<String, Object>(this.defaultRequestParameters.additionalModelRequestFields());
        if (chatRequestParameters instanceof BedrockChatRequestParameters && Objects.nonNull((bedrockChatRequestParameters = (BedrockChatRequestParameters)chatRequestParameters).additionalModelRequestFields())) {
            additionalModelRequestFieldsMap.putAll(bedrockChatRequestParameters.additionalModelRequestFields());
        }
        if (dev.langchain4j.internal.Utils.isNullOrEmpty(additionalModelRequestFieldsMap)) {
            return null;
        }
        return AwsDocumentConverter.convertAdditionalModelRequestFields(additionalModelRequestFieldsMap);
    }

    protected static void validate(ChatRequestParameters parameters) {
        String errorTemplate = "%s is not supported yet by this model provider";
        if (parameters.topK() != null) {
            throw new UnsupportedFeatureException(String.format(errorTemplate, "'topK' parameter"));
        }
        if (parameters.frequencyPenalty() != null) {
            throw new UnsupportedFeatureException(String.format(errorTemplate, "'frequencyPenalty' parameter"));
        }
        if (parameters.presencePenalty() != null) {
            throw new UnsupportedFeatureException(String.format(errorTemplate, "'presencePenalty' parameter"));
        }
        if (Objects.nonNull(parameters.responseFormat()) && parameters.responseFormat().type().equals((Object)ResponseFormatType.JSON)) {
            throw new UnsupportedFeatureException(String.format(errorTemplate, "JSON response format"));
        }
    }

    protected static Float dblToFloat(Double d) {
        if (Objects.isNull(d)) {
            return null;
        }
        return Float.valueOf(d.floatValue());
    }

    protected static String extractFilenameWithoutExtensionFromUri(URI uri) {
        String extractedCleanFileName = Utils.extractCleanFileName(uri);
        if (dev.langchain4j.internal.Utils.isNullOrEmpty((String)extractedCleanFileName)) {
            extractedCleanFileName = UUID.randomUUID().toString();
        }
        return extractedCleanFileName;
    }

    public static abstract class AbstractBuilder<T extends AbstractBuilder<T>> {
        protected Region region;
        protected String modelId;
        protected Duration timeout;
        protected Boolean returnThinking;
        protected Boolean sendThinking;
        protected ChatRequestParameters defaultRequestParameters;
        protected Boolean logRequests;
        protected Boolean logResponses;
        protected Logger logger;
        protected List<ChatModelListener> listeners;

        public T self() {
            return (T)this;
        }

        public T defaultRequestParameters(ChatRequestParameters defaultRequestParameters) {
            this.defaultRequestParameters = defaultRequestParameters;
            return this.self();
        }

        public T region(Region region) {
            this.region = region;
            return this.self();
        }

        public T modelId(String modelId) {
            this.modelId = modelId;
            return this.self();
        }

        public T returnThinking(Boolean returnThinking) {
            this.returnThinking = returnThinking;
            return this.self();
        }

        public T sendThinking(Boolean sendThinking) {
            this.sendThinking = sendThinking;
            return this.self();
        }

        public T timeout(Duration timeout) {
            this.timeout = timeout;
            return this.self();
        }

        public T logRequests(Boolean logRequests) {
            this.logRequests = logRequests;
            return this.self();
        }

        public T logResponses(Boolean logResponses) {
            this.logResponses = logResponses;
            return this.self();
        }

        public T logger(Logger logger) {
            this.logger = logger;
            return this.self();
        }

        public T listeners(List<ChatModelListener> listeners) {
            this.listeners = listeners;
            return this.self();
        }
    }
}

