/*
 * Decompiled with CFR 0.152.
 */
package io.modelcontextprotocol.server.transport;

import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.json.McpJsonDefaults;
import io.modelcontextprotocol.json.McpJsonMapper;
import io.modelcontextprotocol.server.McpStatelessServerHandler;
import io.modelcontextprotocol.server.McpTransportContextExtractor;
import io.modelcontextprotocol.server.transport.ServerTransportSecurityException;
import io.modelcontextprotocol.server.transport.ServerTransportSecurityValidator;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpStatelessServerTransport;
import io.modelcontextprotocol.util.Assert;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.MediaType;
import org.springframework.web.reactive.function.server.RouterFunction;
import org.springframework.web.reactive.function.server.RouterFunctions;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.reactive.function.server.ServerResponse;
import reactor.core.publisher.Mono;

public class WebFluxStatelessServerTransport
implements McpStatelessServerTransport {
    private static final Logger logger = LoggerFactory.getLogger(WebFluxStatelessServerTransport.class);
    private final McpJsonMapper jsonMapper;
    private final String mcpEndpoint;
    private final RouterFunction<?> routerFunction;
    private McpStatelessServerHandler mcpHandler;
    private McpTransportContextExtractor<ServerRequest> contextExtractor;
    private volatile boolean isClosing = false;
    private final ServerTransportSecurityValidator securityValidator;

    private WebFluxStatelessServerTransport(McpJsonMapper jsonMapper, String mcpEndpoint, McpTransportContextExtractor<ServerRequest> contextExtractor, ServerTransportSecurityValidator securityValidator) {
        Assert.notNull((Object)jsonMapper, (String)"jsonMapper must not be null");
        Assert.notNull((Object)mcpEndpoint, (String)"mcpEndpoint must not be null");
        Assert.notNull(contextExtractor, (String)"contextExtractor must not be null");
        Assert.notNull((Object)securityValidator, (String)"Security validator must not be null");
        this.jsonMapper = jsonMapper;
        this.mcpEndpoint = mcpEndpoint;
        this.contextExtractor = contextExtractor;
        this.securityValidator = securityValidator;
        this.routerFunction = RouterFunctions.route().GET(this.mcpEndpoint, this::handleGet).POST(this.mcpEndpoint, this::handlePost).build();
    }

    public void setMcpHandler(McpStatelessServerHandler mcpHandler) {
        this.mcpHandler = mcpHandler;
    }

    public Mono<Void> closeGracefully() {
        return Mono.fromRunnable(() -> {
            this.isClosing = true;
        });
    }

    public RouterFunction<?> getRouterFunction() {
        return this.routerFunction;
    }

    private Mono<ServerResponse> handleGet(ServerRequest request) {
        return ServerResponse.status((HttpStatusCode)HttpStatus.METHOD_NOT_ALLOWED).build();
    }

    private Mono<ServerResponse> handlePost(ServerRequest request) {
        if (this.isClosing) {
            return ServerResponse.status((HttpStatusCode)HttpStatus.SERVICE_UNAVAILABLE).bodyValue((Object)"Server is shutting down");
        }
        try {
            HttpHeaders headers = request.headers().asHttpHeaders();
            this.securityValidator.validateHeaders((Map)headers);
        }
        catch (ServerTransportSecurityException e) {
            return ServerResponse.status((int)e.getStatusCode()).bodyValue((Object)e.getMessage());
        }
        McpTransportContext transportContext = this.contextExtractor.extract((Object)request);
        List acceptHeaders = request.headers().asHttpHeaders().getAccept();
        if (!acceptHeaders.contains(MediaType.APPLICATION_JSON) || !acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM)) {
            return ServerResponse.badRequest().build();
        }
        return request.bodyToMono(String.class).flatMap(body -> {
            try {
                McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage((McpJsonMapper)this.jsonMapper, (String)body);
                if (message instanceof McpSchema.JSONRPCRequest) {
                    McpSchema.JSONRPCRequest jsonrpcRequest = (McpSchema.JSONRPCRequest)message;
                    return this.mcpHandler.handleRequest(transportContext, jsonrpcRequest).flatMap(jsonrpcResponse -> {
                        try {
                            String json = this.jsonMapper.writeValueAsString(jsonrpcResponse);
                            return ServerResponse.ok().contentType(MediaType.APPLICATION_JSON).bodyValue((Object)json);
                        }
                        catch (IOException e) {
                            logger.error("Failed to serialize response: {}", (Object)e.getMessage());
                            return ServerResponse.status((HttpStatusCode)HttpStatus.INTERNAL_SERVER_ERROR).bodyValue((Object)new McpError((Object)"Failed to serialize response"));
                        }
                    });
                }
                if (message instanceof McpSchema.JSONRPCNotification) {
                    McpSchema.JSONRPCNotification jsonrpcNotification = (McpSchema.JSONRPCNotification)message;
                    return this.mcpHandler.handleNotification(transportContext, jsonrpcNotification).then(ServerResponse.accepted().build());
                }
                return ServerResponse.badRequest().bodyValue((Object)new McpError((Object)"The server accepts either requests or notifications"));
            }
            catch (IOException | IllegalArgumentException e) {
                logger.error("Failed to deserialize message: {}", (Object)e.getMessage());
                return ServerResponse.badRequest().bodyValue((Object)new McpError((Object)"Invalid message format"));
            }
        }).contextWrite(ctx -> ctx.put((Object)"MCP_TRANSPORT_CONTEXT", (Object)transportContext));
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {
        private McpJsonMapper jsonMapper;
        private String mcpEndpoint = "/mcp";
        private McpTransportContextExtractor<ServerRequest> contextExtractor = serverRequest -> McpTransportContext.EMPTY;
        private ServerTransportSecurityValidator securityValidator = ServerTransportSecurityValidator.NOOP;

        private Builder() {
        }

        public Builder jsonMapper(McpJsonMapper jsonMapper) {
            Assert.notNull((Object)jsonMapper, (String)"JsonMapper must not be null");
            this.jsonMapper = jsonMapper;
            return this;
        }

        public Builder messageEndpoint(String messageEndpoint) {
            Assert.notNull((Object)messageEndpoint, (String)"Message endpoint must not be null");
            this.mcpEndpoint = messageEndpoint;
            return this;
        }

        public Builder contextExtractor(McpTransportContextExtractor<ServerRequest> contextExtractor) {
            Assert.notNull(contextExtractor, (String)"Context extractor must not be null");
            this.contextExtractor = contextExtractor;
            return this;
        }

        public Builder securityValidator(ServerTransportSecurityValidator securityValidator) {
            Assert.notNull((Object)securityValidator, (String)"Security validator must not be null");
            this.securityValidator = securityValidator;
            return this;
        }

        public WebFluxStatelessServerTransport build() {
            Assert.notNull((Object)this.mcpEndpoint, (String)"Message endpoint must be set");
            return new WebFluxStatelessServerTransport(this.jsonMapper == null ? McpJsonDefaults.getMapper() : this.jsonMapper, this.mcpEndpoint, this.contextExtractor, this.securityValidator);
        }
    }
}

