/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.mcp.client.transport.http;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.mcp.client.protocol.McpClientMessage;
import dev.langchain4j.mcp.client.protocol.McpInitializationNotification;
import dev.langchain4j.mcp.client.protocol.McpInitializeRequest;
import dev.langchain4j.mcp.client.transport.McpOperationHandler;
import dev.langchain4j.mcp.client.transport.McpTransport;
import dev.langchain4j.mcp.client.transport.http.SseSubscriber;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class StreamableHttpMcpTransport
implements McpTransport {
    private static final Logger DEFAULT_TRAFFIC_LOG = LoggerFactory.getLogger((String)"MCP");
    private static final Logger LOG = LoggerFactory.getLogger(StreamableHttpMcpTransport.class);
    private final String url;
    private final Supplier<Map<String, String>> customHeadersSupplier;
    private final boolean logResponses;
    private final boolean logRequests;
    private final Logger trafficLog;
    static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
    private final AtomicReference<CompletableFuture<JsonNode>> initializeInProgress = new AtomicReference<Object>(null);
    private volatile McpOperationHandler operationHandler;
    private final HttpClient httpClient;
    private McpInitializeRequest initializeRequest;
    private final AtomicReference<String> mcpSessionId = new AtomicReference();

    public StreamableHttpMcpTransport(Builder builder) {
        this.url = (String)ValidationUtils.ensureNotNull((Object)builder.url, (String)"Missing server endpoint URL");
        this.logRequests = builder.logRequests;
        this.logResponses = builder.logResponses;
        this.trafficLog = (Logger)Utils.getOrDefault((Object)builder.logger, (Object)DEFAULT_TRAFFIC_LOG);
        Duration timeout = (Duration)Utils.getOrDefault((Object)builder.timeout, (Object)Duration.ofSeconds(60L));
        this.customHeadersSupplier = (Supplier)Utils.getOrDefault(builder.customHeadersSupplier, () -> Map::of);
        HttpClient.Builder clientBuilder = HttpClient.newBuilder();
        if (builder.executor != null) {
            clientBuilder.executor(builder.executor);
        }
        this.httpClient = clientBuilder.connectTimeout(timeout).build();
    }

    @Override
    public void start(McpOperationHandler operationHandler) {
        this.operationHandler = operationHandler;
    }

    @Override
    public CompletableFuture<JsonNode> initialize(McpInitializeRequest operation) {
        this.initializeRequest = operation;
        CompletableFuture<JsonNode> completableFuture = this.execute(operation, operation.getId());
        this.initializeInProgress.set(completableFuture);
        return ((CompletableFuture)completableFuture.thenCompose(originalResponse -> {
            this.initializeInProgress.set(null);
            return CompletableFuture.completedFuture(originalResponse);
        })).thenCompose(originalResponse -> this.execute(new McpInitializationNotification(), null).thenCompose(nullNode -> CompletableFuture.completedFuture(originalResponse)));
    }

    private HttpRequest createRequest(McpClientMessage message) throws JsonProcessingException {
        Map<String, String> headers;
        String body = OBJECT_MAPPER.writeValueAsString((Object)message);
        HttpRequest.BodyPublisher bodyPublisher = HttpRequest.BodyPublishers.ofString(body);
        if (this.logRequests) {
            this.trafficLog.info("Request: {}", (Object)body);
        }
        HttpRequest.Builder builder = HttpRequest.newBuilder();
        String sessionId = this.mcpSessionId.get();
        if (sessionId != null && !(message instanceof McpInitializeRequest)) {
            builder.header("Mcp-Session-Id", sessionId);
        }
        if ((headers = this.customHeadersSupplier.get()) != null) {
            headers.forEach(builder::header);
        }
        return builder.uri(URI.create(this.url)).header("Content-Type", "application/json").header("Accept", "application/json,text/event-stream").POST(bodyPublisher).build();
    }

    @Override
    public CompletableFuture<JsonNode> executeOperationWithResponse(McpClientMessage operation) {
        return this.execute(operation, operation.getId());
    }

    @Override
    public void executeOperationWithoutResponse(McpClientMessage operation) {
        this.execute(operation, null);
    }

    @Override
    public void checkHealth() {
    }

    @Override
    public void onFailure(Runnable actionOnFailure) {
    }

    private CompletableFuture<JsonNode> execute(McpClientMessage message, Long id) {
        return this.execute(message, id, false);
    }

    private CompletableFuture<JsonNode> execute(McpClientMessage message, Long id, boolean isRetry) {
        CompletableFuture<JsonNode> reinitializeInProgress = this.initializeInProgress.get();
        if (reinitializeInProgress != null) {
            reinitializeInProgress.join();
        }
        HttpRequest request = null;
        try {
            request = this.createRequest(message);
        }
        catch (JsonProcessingException e) {
            return CompletableFuture.failedFuture(e);
        }
        CompletableFuture<JsonNode> future = new CompletableFuture<JsonNode>();
        if (id != null) {
            this.operationHandler.startOperation(id, future);
        }
        this.httpClient.sendAsync(request, responseInfo -> {
            if (!this.isExpectedStatusCode(responseInfo.statusCode())) {
                if (!(message instanceof McpInitializeRequest) && responseInfo.statusCode() == 404) {
                    if (!isRetry) {
                        ((CompletableFuture)this.initialize(this.initializeRequest).thenAccept(node -> ((CompletableFuture)this.execute(message, id, true).thenAccept(future::complete)).exceptionally(t -> {
                            future.completeExceptionally((Throwable)t);
                            return null;
                        }))).exceptionally(t -> {
                            future.completeExceptionally((Throwable)t);
                            return null;
                        });
                    }
                } else {
                    future.completeExceptionally(new RuntimeException("Unexpected status code: " + responseInfo.statusCode()));
                }
                return HttpResponse.BodySubscribers.discarding();
            }
            Optional<String> contentType = responseInfo.headers().firstValue("Content-Type");
            Optional<String> mcpSessionId = responseInfo.headers().firstValue("Mcp-Session-Id");
            if (mcpSessionId.isPresent()) {
                LOG.debug("Assigned MCP session ID: {}", mcpSessionId);
                this.mcpSessionId.set(mcpSessionId.get());
            }
            if (id != null && contentType.isPresent() && contentType.get().contains("text/event-stream")) {
                return HttpResponse.BodySubscribers.fromLineSubscriber(new SseSubscriber(future, this.logResponses, this.operationHandler, this.trafficLog));
            }
            return HttpResponse.BodySubscribers.mapping(HttpResponse.BodySubscribers.ofString(StandardCharsets.UTF_8), responseBody -> {
                if (this.logResponses) {
                    this.trafficLog.info("Response: {}", responseBody);
                }
                if (id == null) {
                    future.complete(null);
                }
                try {
                    JsonNode node = OBJECT_MAPPER.readTree(responseBody);
                    this.operationHandler.handle(node);
                    return null;
                }
                catch (IOException e) {
                    future.completeExceptionally(e);
                    return null;
                }
            });
        }).exceptionally(t -> {
            future.completeExceptionally((Throwable)t);
            return null;
        });
        return future;
    }

    private boolean isExpectedStatusCode(int statusCode) {
        return statusCode >= 200 && statusCode < 300;
    }

    @Override
    public void close() throws IOException {
        try {
            this.httpClient.getClass().getMethod("close", new Class[0]).invoke((Object)this.httpClient, new Object[0]);
        }
        catch (IllegalAccessException | NoSuchMethodException | InvocationTargetException reflectiveOperationException) {
            // empty catch block
        }
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {
        private Executor executor;
        private String url;
        private Supplier<Map<String, String>> customHeadersSupplier;
        private Duration timeout;
        private boolean logRequests = false;
        private boolean logResponses = false;
        private Logger logger;

        public Builder url(String url) {
            this.url = url;
            return this;
        }

        public Builder customHeaders(Map<String, String> customHeaders) {
            this.customHeadersSupplier = () -> customHeaders;
            return this;
        }

        public Builder customHeaders(Supplier<Map<String, String>> customHeadersSupplier) {
            this.customHeadersSupplier = customHeadersSupplier;
            return this;
        }

        public Builder timeout(Duration timeout) {
            this.timeout = timeout;
            return this;
        }

        public Builder logRequests(boolean logRequests) {
            this.logRequests = logRequests;
            return this;
        }

        public Builder logResponses(boolean logResponses) {
            this.logResponses = logResponses;
            return this;
        }

        public Builder logger(Logger logger) {
            this.logger = logger;
            return this;
        }

        public Builder executor(Executor executor) {
            this.executor = executor;
            return this;
        }

        public StreamableHttpMcpTransport build() {
            return new StreamableHttpMcpTransport(this);
        }
    }
}

