/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.ollama;

import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationConvention;
import io.micrometer.observation.ObservationRegistry;
import java.util.Base64;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.MessageAggregator;
import org.springframework.ai.chat.observation.ChatModelObservationContext;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaModel;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.ollama.management.ModelManagementOptions;
import org.springframework.ai.ollama.management.OllamaModelManager;
import org.springframework.ai.ollama.management.PullModelStrategy;
import org.springframework.ai.ollama.metadata.OllamaChatUsage;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import reactor.core.publisher.Flux;

public class OllamaChatModel
extends AbstractToolCallSupport
implements ChatModel {
    private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();
    private final OllamaApi chatApi;
    private final OllamaOptions defaultOptions;
    private final ObservationRegistry observationRegistry;
    private final OllamaModelManager modelManager;
    private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

    public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, FunctionCallbackContext functionCallbackContext, List<FunctionCallback> toolFunctionCallbacks, ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
        super(functionCallbackContext, (FunctionCallingOptions)defaultOptions, toolFunctionCallbacks);
        Assert.notNull((Object)ollamaApi, (String)"ollamaApi must not be null");
        Assert.notNull((Object)defaultOptions, (String)"defaultOptions must not be null");
        Assert.notNull((Object)observationRegistry, (String)"observationRegistry must not be null");
        Assert.notNull((Object)modelManagementOptions, (String)"modelManagementOptions must not be null");
        this.chatApi = ollamaApi;
        this.defaultOptions = defaultOptions;
        this.observationRegistry = observationRegistry;
        this.modelManager = new OllamaModelManager(this.chatApi, modelManagementOptions);
        this.initializeModel(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
    }

    public static Builder builder() {
        return new Builder();
    }

    public static ChatResponseMetadata from(OllamaApi.ChatResponse response) {
        Assert.notNull((Object)response, (String)"OllamaApi.ChatResponse must not be null");
        return ChatResponseMetadata.builder().withUsage((Usage)OllamaChatUsage.from(response)).withModel(response.model()).withKeyValue("created-at", (Object)response.createdAt()).withKeyValue("eval-duration", (Object)response.evalDuration()).withKeyValue("eval-count", (Object)response.evalCount()).withKeyValue("load-duration", (Object)response.loadDuration()).withKeyValue("prompt-eval-duration", (Object)response.promptEvalDuration()).withKeyValue("prompt-eval-count", (Object)response.promptEvalCount()).withKeyValue("total-duration", (Object)response.totalDuration()).withKeyValue("done", (Object)response.done()).build();
    }

    public ChatResponse call(Prompt prompt) {
        OllamaApi.ChatRequest request = this.ollamaChatRequest(prompt, false);
        ChatModelObservationContext observationContext = ChatModelObservationContext.builder().prompt(prompt).provider(OllamaApi.PROVIDER_NAME).requestOptions(this.buildRequestOptions(request)).build();
        ChatResponse response = (ChatResponse)ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation((ObservationConvention)this.observationConvention, (ObservationConvention)DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry).observe(() -> {
            OllamaApi.ChatResponse ollamaResponse = this.chatApi.chat(request);
            List toolCalls = ollamaResponse.message().toolCalls() == null ? List.of() : ollamaResponse.message().toolCalls().stream().map(toolCall -> new AssistantMessage.ToolCall("", "function", toolCall.function().name(), ModelOptionsUtils.toJsonString(toolCall.function().arguments()))).toList();
            AssistantMessage assistantMessage = new AssistantMessage(ollamaResponse.message().content(), Map.of(), toolCalls);
            ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL;
            if (ollamaResponse.promptEvalCount() != null && ollamaResponse.evalCount() != null) {
                generationMetadata = ChatGenerationMetadata.from((String)ollamaResponse.doneReason(), null);
            }
            Generation generator = new Generation(assistantMessage, generationMetadata);
            ChatResponse chatResponse = new ChatResponse(List.of(generator), OllamaChatModel.from(ollamaResponse));
            observationContext.setResponse((Object)chatResponse);
            return chatResponse;
        });
        if (!this.isProxyToolCalls(prompt, this.defaultOptions) && response != null && this.isToolCall(response, Set.of("stop"))) {
            List toolCallConversation = this.handleToolCalls(prompt, response);
            return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
        }
        return response;
    }

    public Flux<ChatResponse> stream(Prompt prompt) {
        return Flux.deferContextual(contextView -> {
            OllamaApi.ChatRequest request = this.ollamaChatRequest(prompt, true);
            ChatModelObservationContext observationContext = ChatModelObservationContext.builder().prompt(prompt).provider(OllamaApi.PROVIDER_NAME).requestOptions(this.buildRequestOptions(request)).build();
            Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation((ObservationConvention)this.observationConvention, (ObservationConvention)DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry);
            observation.parentObservation((Observation)contextView.getOrDefault((Object)"micrometer.observation", null)).start();
            Flux<OllamaApi.ChatResponse> ollamaResponse = this.chatApi.streamingChat(request);
            Flux chatResponse = ollamaResponse.map(chunk -> {
                String content = chunk.message() != null ? chunk.message().content() : "";
                List<Object> toolCalls = List.of();
                if (chunk.message() != null && chunk.message().toolCalls() != null) {
                    toolCalls = chunk.message().toolCalls().stream().map(toolCall -> new AssistantMessage.ToolCall("", "function", toolCall.function().name(), ModelOptionsUtils.toJsonString(toolCall.function().arguments()))).toList();
                }
                AssistantMessage assistantMessage = new AssistantMessage(content, Map.of(), toolCalls);
                ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL;
                if (chunk.promptEvalCount() != null && chunk.evalCount() != null) {
                    generationMetadata = ChatGenerationMetadata.from((String)chunk.doneReason(), null);
                }
                Generation generator = new Generation(assistantMessage, generationMetadata);
                return new ChatResponse(List.of(generator), OllamaChatModel.from(chunk));
            });
            Flux chatResponseFlux = chatResponse.flatMap(response -> {
                if (this.isToolCall((ChatResponse)response, Set.of("stop"))) {
                    List toolCallConversation = this.handleToolCalls(prompt, (ChatResponse)response);
                    return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
                }
                return Flux.just((Object)response);
            }).doOnError(arg_0 -> ((Observation)observation).error(arg_0)).doFinally(s -> observation.stop()).contextWrite(ctx -> ctx.put((Object)"micrometer.observation", (Object)observation));
            return new MessageAggregator().aggregate(chatResponseFlux, arg_0 -> ((ChatModelObservationContext)observationContext).setResponse(arg_0));
        });
    }

    OllamaApi.ChatRequest ollamaChatRequest(Prompt prompt, boolean stream) {
        OllamaOptions mergedOptions;
        List<OllamaApi.Message> ollamaMessages = prompt.getInstructions().stream().map(message -> {
            if (message instanceof UserMessage) {
                UserMessage userMessage = (UserMessage)message;
                OllamaApi.Message.Builder messageBuilder = OllamaApi.Message.builder(OllamaApi.Message.Role.USER).withContent(message.getContent());
                if (!CollectionUtils.isEmpty((Collection)userMessage.getMedia())) {
                    messageBuilder.withImages(userMessage.getMedia().stream().map(media -> this.fromMediaData(media.getData())).toList());
                }
                return List.of(messageBuilder.build());
            }
            if (message instanceof SystemMessage) {
                SystemMessage systemMessage = (SystemMessage)message;
                return List.of(OllamaApi.Message.builder(OllamaApi.Message.Role.SYSTEM).withContent(systemMessage.getContent()).build());
            }
            if (message instanceof AssistantMessage) {
                AssistantMessage assistantMessage = (AssistantMessage)message;
                List<OllamaApi.Message.ToolCall> toolCalls = null;
                if (!CollectionUtils.isEmpty((Collection)assistantMessage.getToolCalls())) {
                    toolCalls = assistantMessage.getToolCalls().stream().map(toolCall -> {
                        OllamaApi.Message.ToolCallFunction function = new OllamaApi.Message.ToolCallFunction(toolCall.name(), ModelOptionsUtils.jsonToMap((String)toolCall.arguments()));
                        return new OllamaApi.Message.ToolCall(function);
                    }).toList();
                }
                return List.of(OllamaApi.Message.builder(OllamaApi.Message.Role.ASSISTANT).withContent(assistantMessage.getContent()).withToolCalls(toolCalls).build());
            }
            if (message instanceof ToolResponseMessage) {
                ToolResponseMessage toolMessage = (ToolResponseMessage)message;
                return toolMessage.getResponses().stream().map(tr -> OllamaApi.Message.builder(OllamaApi.Message.Role.TOOL).withContent(tr.responseData()).build()).toList();
            }
            throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());
        }).flatMap(Collection::stream).toList();
        HashSet<String> functionsForThisRequest = new HashSet<String>();
        OllamaOptions runtimeOptions = null;
        if (prompt.getOptions() != null) {
            ChatOptions chatOptions = prompt.getOptions();
            if (chatOptions instanceof FunctionCallingOptions) {
                FunctionCallingOptions functionCallingOptions = (FunctionCallingOptions)chatOptions;
                runtimeOptions = (OllamaOptions)ModelOptionsUtils.copyToTarget((Object)functionCallingOptions, FunctionCallingOptions.class, OllamaOptions.class);
            } else {
                runtimeOptions = (OllamaOptions)ModelOptionsUtils.copyToTarget((Object)prompt.getOptions(), ChatOptions.class, OllamaOptions.class);
            }
            functionsForThisRequest.addAll(this.runtimeFunctionCallbackConfigurations(runtimeOptions));
        }
        if (!CollectionUtils.isEmpty(this.defaultOptions.getFunctions())) {
            functionsForThisRequest.addAll(this.defaultOptions.getFunctions());
        }
        if (!StringUtils.hasText((String)(mergedOptions = (OllamaOptions)ModelOptionsUtils.merge(runtimeOptions, (Object)this.defaultOptions, OllamaOptions.class)).getModel())) {
            throw new IllegalArgumentException("Model is not set!");
        }
        String model = mergedOptions.getModel();
        OllamaApi.ChatRequest.Builder requestBuilder = OllamaApi.ChatRequest.builder(model).withStream(stream).withMessages(ollamaMessages).withOptions(mergedOptions);
        if (mergedOptions.getFormat() != null) {
            requestBuilder.withFormat(mergedOptions.getFormat());
        }
        if (mergedOptions.getKeepAlive() != null) {
            requestBuilder.withKeepAlive(mergedOptions.getKeepAlive());
        }
        if (!CollectionUtils.isEmpty(functionsForThisRequest)) {
            requestBuilder.withTools(this.getFunctionTools(functionsForThisRequest));
        }
        return requestBuilder.build();
    }

    private String fromMediaData(Object mediaData) {
        if (mediaData instanceof byte[]) {
            byte[] bytes = (byte[])mediaData;
            return Base64.getEncoder().encodeToString(bytes);
        }
        if (mediaData instanceof String) {
            String text = (String)mediaData;
            return text;
        }
        throw new IllegalArgumentException("Unsupported media data type: " + mediaData.getClass().getSimpleName());
    }

    private List<OllamaApi.ChatRequest.Tool> getFunctionTools(Set<String> functionNames) {
        return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> {
            OllamaApi.ChatRequest.Tool.Function function = new OllamaApi.ChatRequest.Tool.Function(functionCallback.getName(), functionCallback.getDescription(), functionCallback.getInputTypeSchema());
            return new OllamaApi.ChatRequest.Tool(function);
        }).toList();
    }

    private ChatOptions buildRequestOptions(OllamaApi.ChatRequest request) {
        OllamaOptions options = (OllamaOptions)ModelOptionsUtils.mapToClass(request.options(), OllamaOptions.class);
        return ChatOptionsBuilder.builder().withModel(request.model()).withFrequencyPenalty(options.getFrequencyPenalty()).withMaxTokens(options.getMaxTokens()).withPresencePenalty(options.getPresencePenalty()).withStopSequences(options.getStopSequences()).withTemperature(options.getTemperature()).withTopK(options.getTopK()).withTopP(options.getTopP()).build();
    }

    public ChatOptions getDefaultOptions() {
        return OllamaOptions.fromOptions(this.defaultOptions);
    }

    private void initializeModel(String model, PullModelStrategy pullModelStrategy) {
        if (pullModelStrategy != null && !PullModelStrategy.NEVER.equals((Object)pullModelStrategy)) {
            this.modelManager.pullModel(model, pullModelStrategy);
        }
    }

    public void setObservationConvention(ChatModelObservationConvention observationConvention) {
        Assert.notNull((Object)observationConvention, (String)"observationConvention cannot be null");
        this.observationConvention = observationConvention;
    }

    public static final class Builder {
        private OllamaApi ollamaApi;
        private OllamaOptions defaultOptions = OllamaOptions.create().withModel(OllamaModel.MISTRAL.id());
        private FunctionCallbackContext functionCallbackContext;
        private List<FunctionCallback> toolFunctionCallbacks = List.of();
        private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
        private ModelManagementOptions modelManagementOptions = ModelManagementOptions.defaults();

        private Builder() {
        }

        public Builder withOllamaApi(OllamaApi ollamaApi) {
            this.ollamaApi = ollamaApi;
            return this;
        }

        public Builder withDefaultOptions(OllamaOptions defaultOptions) {
            this.defaultOptions = defaultOptions;
            return this;
        }

        public Builder withFunctionCallbackContext(FunctionCallbackContext functionCallbackContext) {
            this.functionCallbackContext = functionCallbackContext;
            return this;
        }

        public Builder withToolFunctionCallbacks(List<FunctionCallback> toolFunctionCallbacks) {
            this.toolFunctionCallbacks = toolFunctionCallbacks;
            return this;
        }

        public Builder withObservationRegistry(ObservationRegistry observationRegistry) {
            this.observationRegistry = observationRegistry;
            return this;
        }

        public Builder withModelManagementOptions(ModelManagementOptions modelManagementOptions) {
            this.modelManagementOptions = modelManagementOptions;
            return this;
        }

        public OllamaChatModel build() {
            return new OllamaChatModel(this.ollamaApi, this.defaultOptions, this.functionCallbackContext, this.toolFunctionCallbacks, this.observationRegistry, this.modelManagementOptions);
        }
    }
}

