/*
 * Decompiled with CFR 0.152.
 */
package io.modelcontextprotocol.server.transport;

import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.json.McpJsonMapper;
import io.modelcontextprotocol.json.TypeRef;
import io.modelcontextprotocol.server.McpTransportContextExtractor;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpStreamableServerSession;
import io.modelcontextprotocol.spec.McpStreamableServerTransport;
import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider;
import io.modelcontextprotocol.util.Assert;
import io.modelcontextprotocol.util.KeepAliveScheduler;
import java.io.IOException;
import java.time.Duration;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpStatus;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.web.reactive.function.server.RouterFunction;
import org.springframework.web.reactive.function.server.RouterFunctions;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.reactive.function.server.ServerResponse;
import reactor.core.Disposable;
import reactor.core.Exceptions;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono;

public class WebFluxStreamableServerTransportProvider
implements McpStreamableServerTransportProvider {
    private static final Logger logger = LoggerFactory.getLogger(WebFluxStreamableServerTransportProvider.class);
    public static final String MESSAGE_EVENT_TYPE = "message";
    private final McpJsonMapper jsonMapper;
    private final String mcpEndpoint;
    private final boolean disallowDelete;
    private final RouterFunction<?> routerFunction;
    private McpStreamableServerSession.Factory sessionFactory;
    private final ConcurrentHashMap<String, McpStreamableServerSession> sessions = new ConcurrentHashMap();
    private McpTransportContextExtractor<ServerRequest> contextExtractor;
    private volatile boolean isClosing = false;
    private KeepAliveScheduler keepAliveScheduler;

    private WebFluxStreamableServerTransportProvider(McpJsonMapper jsonMapper, String mcpEndpoint, McpTransportContextExtractor<ServerRequest> contextExtractor, boolean disallowDelete, Duration keepAliveInterval) {
        Assert.notNull((Object)jsonMapper, (String)"JsonMapper must not be null");
        Assert.notNull((Object)mcpEndpoint, (String)"Message endpoint must not be null");
        Assert.notNull(contextExtractor, (String)"Context extractor must not be null");
        this.jsonMapper = jsonMapper;
        this.mcpEndpoint = mcpEndpoint;
        this.contextExtractor = contextExtractor;
        this.disallowDelete = disallowDelete;
        this.routerFunction = RouterFunctions.route().GET(this.mcpEndpoint, this::handleGet).POST(this.mcpEndpoint, this::handlePost).DELETE(this.mcpEndpoint, this::handleDelete).build();
        if (keepAliveInterval != null) {
            this.keepAliveScheduler = KeepAliveScheduler.builder(() -> this.isClosing ? Flux.empty() : Flux.fromIterable(this.sessions.values())).initialDelay(keepAliveInterval).interval(keepAliveInterval).build();
            this.keepAliveScheduler.start();
        }
    }

    public List<String> protocolVersions() {
        return List.of("2024-11-05", "2025-03-26", "2025-06-18");
    }

    public void setSessionFactory(McpStreamableServerSession.Factory sessionFactory) {
        this.sessionFactory = sessionFactory;
    }

    public Mono<Void> notifyClients(String method, Object params) {
        if (this.sessions.isEmpty()) {
            logger.debug("No active sessions to broadcast message to");
            return Mono.empty();
        }
        logger.debug("Attempting to broadcast message to {} active sessions", (Object)this.sessions.size());
        return Flux.fromIterable(this.sessions.values()).flatMap(session -> session.sendNotification(method, params).doOnError(e -> logger.error("Failed to send message to session {}: {}", (Object)session.getId(), (Object)e.getMessage())).onErrorComplete()).then();
    }

    public Mono<Void> closeGracefully() {
        return Mono.defer(() -> {
            this.isClosing = true;
            return Flux.fromIterable(this.sessions.values()).doFirst(() -> logger.debug("Initiating graceful shutdown with {} active sessions", (Object)this.sessions.size())).flatMap(McpStreamableServerSession::closeGracefully).then();
        }).then().doOnSuccess(v -> {
            this.sessions.clear();
            if (this.keepAliveScheduler != null) {
                this.keepAliveScheduler.shutdown();
            }
        });
    }

    public RouterFunction<?> getRouterFunction() {
        return this.routerFunction;
    }

    private Mono<ServerResponse> handleGet(ServerRequest request) {
        if (this.isClosing) {
            return ServerResponse.status((HttpStatusCode)HttpStatus.SERVICE_UNAVAILABLE).bodyValue((Object)"Server is shutting down");
        }
        McpTransportContext transportContext = this.contextExtractor.extract((Object)request);
        return Mono.defer(() -> {
            List acceptHeaders = request.headers().asHttpHeaders().getAccept();
            if (!acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM)) {
                return ServerResponse.badRequest().build();
            }
            if (request.headers().header("Mcp-Session-Id").isEmpty()) {
                return ServerResponse.badRequest().build();
            }
            String sessionId = request.headers().asHttpHeaders().getFirst("Mcp-Session-Id");
            McpStreamableServerSession session = this.sessions.get(sessionId);
            if (session == null) {
                return ServerResponse.notFound().build();
            }
            if (!request.headers().header("Last-Event-ID").isEmpty()) {
                String lastId = request.headers().asHttpHeaders().getFirst("Last-Event-ID");
                return ServerResponse.ok().contentType(MediaType.TEXT_EVENT_STREAM).body((Object)session.replay((Object)lastId).contextWrite(ctx -> ctx.put((Object)"MCP_TRANSPORT_CONTEXT", (Object)transportContext)), ServerSentEvent.class);
            }
            return ServerResponse.ok().contentType(MediaType.TEXT_EVENT_STREAM).body((Object)Flux.create(sink -> {
                WebFluxStreamableMcpSessionTransport sessionTransport = new WebFluxStreamableMcpSessionTransport((FluxSink<ServerSentEvent<?>>)sink);
                McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session.listeningStream((McpStreamableServerTransport)sessionTransport);
                sink.onDispose(() -> ((McpStreamableServerSession.McpStreamableServerSessionStream)listeningStream).close());
            }).contextWrite(ctx -> ctx.put((Object)"MCP_TRANSPORT_CONTEXT", (Object)transportContext)), ServerSentEvent.class);
        }).contextWrite(ctx -> ctx.put((Object)"MCP_TRANSPORT_CONTEXT", (Object)transportContext));
    }

    private Mono<ServerResponse> handlePost(ServerRequest request) {
        if (this.isClosing) {
            return ServerResponse.status((HttpStatusCode)HttpStatus.SERVICE_UNAVAILABLE).bodyValue((Object)"Server is shutting down");
        }
        McpTransportContext transportContext = this.contextExtractor.extract((Object)request);
        List acceptHeaders = request.headers().asHttpHeaders().getAccept();
        if (!acceptHeaders.contains(MediaType.APPLICATION_JSON) || !acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM)) {
            return ServerResponse.badRequest().build();
        }
        return request.bodyToMono(String.class).flatMap(body -> {
            try {
                McpSchema.JSONRPCRequest jsonrpcRequest;
                McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage((McpJsonMapper)this.jsonMapper, (String)body);
                if (message instanceof McpSchema.JSONRPCRequest && (jsonrpcRequest = (McpSchema.JSONRPCRequest)message).method().equals("initialize")) {
                    TypeRef<McpSchema.InitializeRequest> typeReference = new TypeRef<McpSchema.InitializeRequest>(){};
                    McpSchema.InitializeRequest initializeRequest = (McpSchema.InitializeRequest)this.jsonMapper.convertValue(jsonrpcRequest.params(), (TypeRef)typeReference);
                    McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory.startSession(initializeRequest);
                    this.sessions.put(init.session().getId(), init.session());
                    return init.initResult().map(initializeResult -> {
                        McpSchema.JSONRPCResponse jsonrpcResponse = new McpSchema.JSONRPCResponse("2.0", jsonrpcRequest.id(), initializeResult, null);
                        try {
                            return this.jsonMapper.writeValueAsString((Object)jsonrpcResponse);
                        }
                        catch (IOException e) {
                            logger.warn("Failed to serialize initResponse", (Throwable)e);
                            throw Exceptions.propagate((Throwable)e);
                        }
                    }).flatMap(initResult -> ((ServerResponse.BodyBuilder)ServerResponse.ok().contentType(MediaType.APPLICATION_JSON).header("Mcp-Session-Id", new String[]{init.session().getId()})).bodyValue(initResult));
                }
                if (request.headers().header("Mcp-Session-Id").isEmpty()) {
                    return ServerResponse.badRequest().bodyValue((Object)new McpError((Object)"Session ID missing"));
                }
                String sessionId = request.headers().asHttpHeaders().getFirst("Mcp-Session-Id");
                McpStreamableServerSession session = this.sessions.get(sessionId);
                if (session == null) {
                    return ServerResponse.status((HttpStatusCode)HttpStatus.NOT_FOUND).bodyValue((Object)new McpError((Object)("Session not found: " + sessionId)));
                }
                if (message instanceof McpSchema.JSONRPCResponse) {
                    McpSchema.JSONRPCResponse jsonrpcResponse = (McpSchema.JSONRPCResponse)message;
                    return session.accept(jsonrpcResponse).then(ServerResponse.accepted().build());
                }
                if (message instanceof McpSchema.JSONRPCNotification) {
                    McpSchema.JSONRPCNotification jsonrpcNotification = (McpSchema.JSONRPCNotification)message;
                    return session.accept(jsonrpcNotification).then(ServerResponse.accepted().build());
                }
                if (message instanceof McpSchema.JSONRPCRequest) {
                    McpSchema.JSONRPCRequest jsonrpcRequest2 = (McpSchema.JSONRPCRequest)message;
                    return ServerResponse.ok().contentType(MediaType.TEXT_EVENT_STREAM).body((Object)Flux.create(sink -> {
                        WebFluxStreamableMcpSessionTransport st = new WebFluxStreamableMcpSessionTransport((FluxSink<ServerSentEvent<?>>)sink);
                        Mono stream = session.responseStream(jsonrpcRequest2, (McpStreamableServerTransport)st);
                        Disposable streamSubscription = stream.onErrorComplete(err -> {
                            sink.error(err);
                            return true;
                        }).contextWrite(sink.contextView()).subscribe();
                        sink.onCancel(streamSubscription);
                    }).contextWrite(ctx -> ctx.put((Object)"MCP_TRANSPORT_CONTEXT", (Object)transportContext)), ServerSentEvent.class);
                }
                return ServerResponse.badRequest().bodyValue((Object)new McpError((Object)"Unknown message type"));
            }
            catch (IOException | IllegalArgumentException e) {
                logger.error("Failed to deserialize message: {}", (Object)e.getMessage());
                return ServerResponse.badRequest().bodyValue((Object)new McpError((Object)"Invalid message format"));
            }
        }).switchIfEmpty(ServerResponse.badRequest().build()).contextWrite(ctx -> ctx.put((Object)"MCP_TRANSPORT_CONTEXT", (Object)transportContext));
    }

    private Mono<ServerResponse> handleDelete(ServerRequest request) {
        if (this.isClosing) {
            return ServerResponse.status((HttpStatusCode)HttpStatus.SERVICE_UNAVAILABLE).bodyValue((Object)"Server is shutting down");
        }
        McpTransportContext transportContext = this.contextExtractor.extract((Object)request);
        return Mono.defer(() -> {
            if (request.headers().header("Mcp-Session-Id").isEmpty()) {
                return ServerResponse.badRequest().build();
            }
            if (this.disallowDelete) {
                return ServerResponse.status((HttpStatusCode)HttpStatus.METHOD_NOT_ALLOWED).build();
            }
            String sessionId = request.headers().asHttpHeaders().getFirst("Mcp-Session-Id");
            McpStreamableServerSession session = this.sessions.get(sessionId);
            if (session == null) {
                return ServerResponse.notFound().build();
            }
            return session.delete().then(ServerResponse.ok().build());
        }).contextWrite(ctx -> ctx.put((Object)"MCP_TRANSPORT_CONTEXT", (Object)transportContext));
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {
        private McpJsonMapper jsonMapper;
        private String mcpEndpoint = "/mcp";
        private McpTransportContextExtractor<ServerRequest> contextExtractor = serverRequest -> McpTransportContext.EMPTY;
        private boolean disallowDelete;
        private Duration keepAliveInterval;

        private Builder() {
        }

        public Builder jsonMapper(McpJsonMapper jsonMapper) {
            Assert.notNull((Object)jsonMapper, (String)"McpJsonMapper must not be null");
            this.jsonMapper = jsonMapper;
            return this;
        }

        public Builder messageEndpoint(String messageEndpoint) {
            Assert.notNull((Object)messageEndpoint, (String)"Message endpoint must not be null");
            this.mcpEndpoint = messageEndpoint;
            return this;
        }

        public Builder contextExtractor(McpTransportContextExtractor<ServerRequest> contextExtractor) {
            Assert.notNull(contextExtractor, (String)"contextExtractor must not be null");
            this.contextExtractor = contextExtractor;
            return this;
        }

        public Builder disallowDelete(boolean disallowDelete) {
            this.disallowDelete = disallowDelete;
            return this;
        }

        public Builder keepAliveInterval(Duration keepAliveInterval) {
            this.keepAliveInterval = keepAliveInterval;
            return this;
        }

        public WebFluxStreamableServerTransportProvider build() {
            Assert.notNull((Object)this.mcpEndpoint, (String)"Message endpoint must be set");
            return new WebFluxStreamableServerTransportProvider(this.jsonMapper == null ? McpJsonMapper.getDefault() : this.jsonMapper, this.mcpEndpoint, this.contextExtractor, this.disallowDelete, this.keepAliveInterval);
        }
    }

    private class WebFluxStreamableMcpSessionTransport
    implements McpStreamableServerTransport {
        private final FluxSink<ServerSentEvent<?>> sink;

        public WebFluxStreamableMcpSessionTransport(FluxSink<ServerSentEvent<?>> sink) {
            this.sink = sink;
        }

        public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
            return this.sendMessage(message, null);
        }

        public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message, String messageId) {
            return Mono.fromSupplier(() -> {
                try {
                    return WebFluxStreamableServerTransportProvider.this.jsonMapper.writeValueAsString((Object)message);
                }
                catch (IOException e) {
                    throw Exceptions.propagate((Throwable)e);
                }
            }).doOnNext(jsonText -> {
                ServerSentEvent event = ServerSentEvent.builder().id(messageId).event(WebFluxStreamableServerTransportProvider.MESSAGE_EVENT_TYPE).data(jsonText).build();
                this.sink.next((Object)event);
            }).doOnError(e -> {
                Throwable exception = Exceptions.unwrap((Throwable)e);
                this.sink.error(exception);
            }).then();
        }

        public <T> T unmarshalFrom(Object data, TypeRef<T> typeRef) {
            return (T)WebFluxStreamableServerTransportProvider.this.jsonMapper.convertValue(data, typeRef);
        }

        public Mono<Void> closeGracefully() {
            return Mono.fromRunnable(() -> this.sink.complete());
        }

        public void close() {
            this.sink.complete();
        }
    }
}

