/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.azure.openai;

import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.models.ChatChoice;
import com.azure.ai.openai.models.ChatCompletions;
import com.azure.ai.openai.models.ChatCompletionsFunctionToolCall;
import com.azure.ai.openai.models.ChatCompletionsFunctionToolDefinition;
import com.azure.ai.openai.models.ChatCompletionsOptions;
import com.azure.ai.openai.models.ChatCompletionsToolCall;
import com.azure.ai.openai.models.ChatCompletionsToolDefinition;
import com.azure.ai.openai.models.ChatRequestAssistantMessage;
import com.azure.ai.openai.models.ChatRequestMessage;
import com.azure.ai.openai.models.ChatRequestSystemMessage;
import com.azure.ai.openai.models.ChatRequestToolMessage;
import com.azure.ai.openai.models.ChatRequestUserMessage;
import com.azure.ai.openai.models.ChatResponseMessage;
import com.azure.ai.openai.models.CompletionsFinishReason;
import com.azure.ai.openai.models.FunctionDefinition;
import com.azure.core.util.BinaryData;
import com.azure.core.util.IterableStream;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.azure.openai.AzureOpenAiChatOptions;
import org.springframework.ai.azure.openai.metadata.AzureOpenAiChatResponseMetadata;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptions;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import reactor.core.publisher.Flux;

public class AzureOpenAiChatClient
extends AbstractFunctionCallSupport<ChatRequestMessage, ChatCompletionsOptions, ChatCompletions>
implements ChatClient,
StreamingChatClient {
    private static final String DEFAULT_DEPLOYMENT_NAME = "gpt-35-turbo";
    private static final Float DEFAULT_TEMPERATURE = Float.valueOf(0.7f);
    private final Logger logger = LoggerFactory.getLogger(((Object)((Object)this)).getClass());
    private AzureOpenAiChatOptions defaultOptions;
    private final OpenAIClient openAIClient;

    public AzureOpenAiChatClient(OpenAIClient microsoftOpenAiClient) {
        this(microsoftOpenAiClient, AzureOpenAiChatOptions.builder().withDeploymentName(DEFAULT_DEPLOYMENT_NAME).withTemperature(DEFAULT_TEMPERATURE).build());
    }

    public AzureOpenAiChatClient(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options) {
        this(microsoftOpenAiClient, options, null);
    }

    public AzureOpenAiChatClient(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options, FunctionCallbackContext functionCallbackContext) {
        super(functionCallbackContext);
        Assert.notNull((Object)microsoftOpenAiClient, (String)"com.azure.ai.openai.OpenAIClient must not be null");
        Assert.notNull((Object)options, (String)"AzureOpenAiChatOptions must not be null");
        this.openAIClient = microsoftOpenAiClient;
        this.defaultOptions = options;
    }

    @Deprecated(forRemoval=true, since="0.8.0")
    public AzureOpenAiChatClient withDefaultOptions(AzureOpenAiChatOptions defaultOptions) {
        Assert.notNull((Object)defaultOptions, (String)"DefaultOptions must not be null");
        this.defaultOptions = defaultOptions;
        return this;
    }

    public AzureOpenAiChatOptions getDefaultOptions() {
        return this.defaultOptions;
    }

    public ChatResponse call(Prompt prompt) {
        ChatCompletionsOptions options = this.toAzureChatCompletionsOptions(prompt);
        options.setStream(Boolean.valueOf(false));
        this.logger.trace("Azure ChatCompletionsOptions: {}", (Object)options);
        ChatCompletions chatCompletions = (ChatCompletions)this.callWithFunctionSupport(options);
        this.logger.trace("Azure ChatCompletions: {}", (Object)chatCompletions);
        List<Generation> generations = chatCompletions.getChoices().stream().map(choice -> new Generation(choice.getMessage().getContent()).withGenerationMetadata(this.generateChoiceMetadata((ChatChoice)choice))).toList();
        PromptMetadata promptFilterMetadata = this.generatePromptMetadata(chatCompletions);
        return new ChatResponse(generations, (ChatResponseMetadata)AzureOpenAiChatResponseMetadata.from(chatCompletions, promptFilterMetadata));
    }

    public Flux<ChatResponse> stream(Prompt prompt) {
        ChatCompletionsOptions options = this.toAzureChatCompletionsOptions(prompt);
        options.setStream(Boolean.valueOf(true));
        IterableStream chatCompletionsStream = this.openAIClient.getChatCompletionsStream(options.getModel(), options);
        return Flux.fromStream(chatCompletionsStream.stream().skip(1L).map(ChatCompletions::getChoices).flatMap(Collection::stream).map(choice -> {
            String content = choice.getDelta() != null ? choice.getDelta().getContent() : null;
            Generation generation = new Generation(content).withGenerationMetadata(this.generateChoiceMetadata((ChatChoice)choice));
            return new ChatResponse(List.of(generation));
        }));
    }

    ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) {
        HashSet<String> functionsForThisRequest = new HashSet<String>();
        List<ChatRequestMessage> azureMessages = prompt.getInstructions().stream().map(this::fromSpringAiMessage).toList();
        ChatCompletionsOptions options = new ChatCompletionsOptions(azureMessages);
        if (this.defaultOptions != null) {
            options = this.merge(options, this.defaultOptions);
            Set defaultEnabledFunctions = this.handleFunctionCallbackConfigurations(this.defaultOptions, false);
            functionsForThisRequest.addAll(defaultEnabledFunctions);
        }
        if (prompt.getOptions() != null) {
            ModelOptions modelOptions = prompt.getOptions();
            if (modelOptions instanceof ChatOptions) {
                ChatOptions runtimeOptions = (ChatOptions)modelOptions;
                AzureOpenAiChatOptions updatedRuntimeOptions = (AzureOpenAiChatOptions)ModelOptionsUtils.copyToTarget((Object)runtimeOptions, ChatOptions.class, AzureOpenAiChatOptions.class);
                options = this.merge(updatedRuntimeOptions, options);
                Set promptEnabledFunctions = this.handleFunctionCallbackConfigurations(updatedRuntimeOptions, true);
                functionsForThisRequest.addAll(promptEnabledFunctions);
            } else {
                throw new IllegalArgumentException("Prompt options are not of type ChatCompletionsOptions:" + prompt.getOptions().getClass().getSimpleName());
            }
        }
        if (!CollectionUtils.isEmpty(functionsForThisRequest)) {
            List<ChatCompletionsFunctionToolDefinition> tools = this.getFunctionTools(functionsForThisRequest);
            List<ChatCompletionsToolDefinition> tools2 = tools.stream().map(t -> t).toList();
            options.setTools(tools2);
        }
        return options;
    }

    private List<ChatCompletionsFunctionToolDefinition> getFunctionTools(Set<String> functionNames) {
        return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> {
            FunctionDefinition functionDefinition = new FunctionDefinition(functionCallback.getName());
            functionDefinition.setDescription(functionCallback.getDescription());
            BinaryData parameters = BinaryData.fromObject((Object)ModelOptionsUtils.jsonToMap((String)functionCallback.getInputTypeSchema()));
            functionDefinition.setParameters(parameters);
            return new ChatCompletionsFunctionToolDefinition(functionDefinition);
        }).toList();
    }

    private ChatRequestMessage fromSpringAiMessage(Message message) {
        switch (message.getMessageType()) {
            case USER: {
                return new ChatRequestUserMessage(message.getContent());
            }
            case SYSTEM: {
                return new ChatRequestSystemMessage(message.getContent());
            }
            case ASSISTANT: {
                return new ChatRequestAssistantMessage(message.getContent());
            }
        }
        throw new IllegalArgumentException("Unknown message type " + String.valueOf(message.getMessageType()));
    }

    private ChatGenerationMetadata generateChoiceMetadata(ChatChoice choice) {
        return ChatGenerationMetadata.from((String)String.valueOf(choice.getFinishReason()), (Object)choice.getContentFilterResults());
    }

    private PromptMetadata generatePromptMetadata(ChatCompletions chatCompletions) {
        List promptFilterResults = this.nullSafeList(chatCompletions.getPromptFilterResults());
        return PromptMetadata.of(promptFilterResults.stream().map(promptFilterResult -> PromptMetadata.PromptFilterMetadata.from((int)promptFilterResult.getPromptIndex(), (Object)promptFilterResult.getContentFilterResults())).toList());
    }

    private <T> List<T> nullSafeList(List<T> list) {
        return list != null ? list : Collections.emptyList();
    }

    private ChatCompletionsOptions merge(ChatCompletionsOptions azureOptions, AzureOpenAiChatOptions springAiOptions) {
        if (springAiOptions == null) {
            return azureOptions;
        }
        ChatCompletionsOptions mergedAzureOptions = new ChatCompletionsOptions(azureOptions.getMessages());
        mergedAzureOptions.setStream(azureOptions.isStream());
        mergedAzureOptions.setMaxTokens(azureOptions.getMaxTokens() != null ? azureOptions.getMaxTokens() : springAiOptions.getMaxTokens());
        mergedAzureOptions.setLogitBias(azureOptions.getLogitBias() != null ? azureOptions.getLogitBias() : springAiOptions.getLogitBias());
        mergedAzureOptions.setStop(azureOptions.getStop() != null ? azureOptions.getStop() : springAiOptions.getStop());
        mergedAzureOptions.setTemperature(azureOptions.getTemperature());
        if (mergedAzureOptions.getTemperature() == null && springAiOptions.getTemperature() != null) {
            mergedAzureOptions.setTemperature(Double.valueOf(springAiOptions.getTemperature().doubleValue()));
        }
        mergedAzureOptions.setTopP(azureOptions.getTopP());
        if (mergedAzureOptions.getTopP() == null && springAiOptions.getTopP() != null) {
            mergedAzureOptions.setTopP(Double.valueOf(springAiOptions.getTopP().doubleValue()));
        }
        mergedAzureOptions.setFrequencyPenalty(azureOptions.getFrequencyPenalty());
        if (mergedAzureOptions.getFrequencyPenalty() == null && springAiOptions.getFrequencyPenalty() != null) {
            mergedAzureOptions.setFrequencyPenalty(Double.valueOf(springAiOptions.getFrequencyPenalty()));
        }
        mergedAzureOptions.setPresencePenalty(azureOptions.getPresencePenalty());
        if (mergedAzureOptions.getPresencePenalty() == null && springAiOptions.getPresencePenalty() != null) {
            mergedAzureOptions.setPresencePenalty(Double.valueOf(springAiOptions.getPresencePenalty()));
        }
        mergedAzureOptions.setN(azureOptions.getN() != null ? azureOptions.getN() : springAiOptions.getN());
        mergedAzureOptions.setUser(azureOptions.getUser() != null ? azureOptions.getUser() : springAiOptions.getUser());
        mergedAzureOptions.setModel(azureOptions.getModel() != null ? azureOptions.getModel() : springAiOptions.getDeploymentName());
        return mergedAzureOptions;
    }

    private ChatCompletionsOptions merge(AzureOpenAiChatOptions springAiOptions, ChatCompletionsOptions azureOptions) {
        if (springAiOptions == null) {
            return azureOptions;
        }
        ChatCompletionsOptions mergedAzureOptions = new ChatCompletionsOptions(azureOptions.getMessages());
        mergedAzureOptions = this.merge(azureOptions, mergedAzureOptions);
        mergedAzureOptions.setStream(azureOptions.isStream());
        if (springAiOptions.getMaxTokens() != null) {
            mergedAzureOptions.setMaxTokens(springAiOptions.getMaxTokens());
        }
        if (springAiOptions.getLogitBias() != null) {
            mergedAzureOptions.setLogitBias(springAiOptions.getLogitBias());
        }
        if (springAiOptions.getStop() != null) {
            mergedAzureOptions.setStop(springAiOptions.getStop());
        }
        if (springAiOptions.getTemperature() != null && springAiOptions.getTemperature() != null) {
            mergedAzureOptions.setTemperature(Double.valueOf(springAiOptions.getTemperature().doubleValue()));
        }
        if (springAiOptions.getTopP() != null && springAiOptions.getTopP() != null) {
            mergedAzureOptions.setTopP(Double.valueOf(springAiOptions.getTopP().doubleValue()));
        }
        if (springAiOptions.getFrequencyPenalty() != null && springAiOptions.getFrequencyPenalty() != null) {
            mergedAzureOptions.setFrequencyPenalty(Double.valueOf(springAiOptions.getFrequencyPenalty()));
        }
        if (springAiOptions.getPresencePenalty() != null && springAiOptions.getPresencePenalty() != null) {
            mergedAzureOptions.setPresencePenalty(Double.valueOf(springAiOptions.getPresencePenalty()));
        }
        if (springAiOptions.getN() != null) {
            mergedAzureOptions.setN(springAiOptions.getN());
        }
        if (springAiOptions.getUser() != null) {
            mergedAzureOptions.setUser(springAiOptions.getUser());
        }
        if (springAiOptions.getDeploymentName() != null) {
            mergedAzureOptions.setModel(springAiOptions.getDeploymentName());
        }
        return mergedAzureOptions;
    }

    private ChatCompletionsOptions merge(ChatCompletionsOptions fromOptions, ChatCompletionsOptions toOptions) {
        if (fromOptions == null) {
            return toOptions;
        }
        ChatCompletionsOptions mergedOptions = new ChatCompletionsOptions(toOptions.getMessages());
        mergedOptions.setStream(toOptions.isStream());
        if (fromOptions.getMaxTokens() != null) {
            mergedOptions.setMaxTokens(fromOptions.getMaxTokens());
        }
        if (fromOptions.getLogitBias() != null) {
            mergedOptions.setLogitBias(fromOptions.getLogitBias());
        }
        if (fromOptions.getStop() != null) {
            mergedOptions.setStop(fromOptions.getStop());
        }
        if (fromOptions.getTemperature() != null) {
            mergedOptions.setTemperature(fromOptions.getTemperature());
        }
        if (fromOptions.getTopP() != null) {
            mergedOptions.setTopP(fromOptions.getTopP());
        }
        if (fromOptions.getFrequencyPenalty() != null) {
            mergedOptions.setFrequencyPenalty(fromOptions.getFrequencyPenalty());
        }
        if (fromOptions.getPresencePenalty() != null) {
            mergedOptions.setPresencePenalty(fromOptions.getPresencePenalty());
        }
        if (fromOptions.getN() != null) {
            mergedOptions.setN(fromOptions.getN());
        }
        if (fromOptions.getUser() != null) {
            mergedOptions.setUser(fromOptions.getUser());
        }
        if (fromOptions.getModel() != null) {
            mergedOptions.setModel(fromOptions.getModel());
        }
        return mergedOptions;
    }

    protected ChatCompletionsOptions doCreateToolResponseRequest(ChatCompletionsOptions previousRequest, ChatRequestMessage responseMessage, List<ChatRequestMessage> conversationHistory) {
        for (ChatCompletionsToolCall toolCall : ((ChatRequestAssistantMessage)responseMessage).getToolCalls()) {
            String functionName = ((ChatCompletionsFunctionToolCall)toolCall).getFunction().getName();
            String functionArguments = ((ChatCompletionsFunctionToolCall)toolCall).getFunction().getArguments();
            if (!this.functionCallbackRegister.containsKey(functionName)) {
                throw new IllegalStateException("No function callback found for function name: " + functionName);
            }
            String functionResponse = ((FunctionCallback)this.functionCallbackRegister.get(functionName)).call(functionArguments);
            conversationHistory.add((ChatRequestMessage)new ChatRequestToolMessage(functionResponse, toolCall.getId()));
        }
        ChatCompletionsOptions newRequest = new ChatCompletionsOptions(conversationHistory);
        newRequest = this.merge(previousRequest, newRequest);
        return newRequest;
    }

    protected List<ChatRequestMessage> doGetUserMessages(ChatCompletionsOptions request) {
        return request.getMessages();
    }

    protected ChatRequestMessage doGetToolResponseMessage(ChatCompletions response) {
        ChatResponseMessage responseMessage = ((ChatChoice)response.getChoices().get(0)).getMessage();
        ChatRequestAssistantMessage assistantMessage = new ChatRequestAssistantMessage("");
        assistantMessage.setToolCalls(responseMessage.getToolCalls());
        return assistantMessage;
    }

    protected ChatCompletions doChatCompletion(ChatCompletionsOptions request) {
        return this.openAIClient.getChatCompletions(request.getModel(), request);
    }

    protected boolean isToolFunctionCall(ChatCompletions chatCompletions) {
        if (chatCompletions == null || CollectionUtils.isEmpty((Collection)chatCompletions.getChoices())) {
            return false;
        }
        ChatChoice choice = (ChatChoice)chatCompletions.getChoices().get(0);
        if (choice == null || choice.getFinishReason() == null) {
            return false;
        }
        return choice.getFinishReason() == CompletionsFinishReason.TOOL_CALLS;
    }
}

