package org.springframework.web.reactive.socket.server.upgrade;

import java.util.Collections;
import java.util.function.Supplier;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.websocket.server.ServerContainer;
import org.apache.tomcat.websocket.server.WsServerContainer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.http.server.reactive.AbstractServerHttpRequest;
import org.springframework.http.server.reactive.AbstractServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.adapter.StandardWebSocketHandlerAdapter;
import org.springframework.web.reactive.socket.adapter.TomcatWebSocketSession;
import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;

/* loaded from: input_file:org/springframework/web/reactive/socket/server/upgrade/TomcatRequestUpgradeStrategy.class */
public class TomcatRequestUpgradeStrategy implements RequestUpgradeStrategy {
    private static final String SERVER_CONTAINER_ATTR = "javax.websocket.server.ServerContainer";

    @Nullable
    private Long asyncSendTimeout;

    @Nullable
    private Long maxSessionIdleTimeout;

    @Nullable
    private Integer maxTextMessageBufferSize;

    @Nullable
    private Integer maxBinaryMessageBufferSize;

    @Nullable
    private WsServerContainer serverContainer;

    public void setAsyncSendTimeout(Long l) {
        this.asyncSendTimeout = l;
    }

    @Nullable
    public Long getAsyncSendTimeout() {
        return this.asyncSendTimeout;
    }

    public void setMaxSessionIdleTimeout(Long l) {
        this.maxSessionIdleTimeout = l;
    }

    @Nullable
    public Long getMaxSessionIdleTimeout() {
        return this.maxSessionIdleTimeout;
    }

    public void setMaxTextMessageBufferSize(Integer num) {
        this.maxTextMessageBufferSize = num;
    }

    @Nullable
    public Integer getMaxTextMessageBufferSize() {
        return this.maxTextMessageBufferSize;
    }

    public void setMaxBinaryMessageBufferSize(Integer num) {
        this.maxBinaryMessageBufferSize = num;
    }

    @Nullable
    public Integer getMaxBinaryMessageBufferSize() {
        return this.maxBinaryMessageBufferSize;
    }

    @Override // org.springframework.web.reactive.socket.server.RequestUpgradeStrategy
    public Mono<Void> upgrade(ServerWebExchange serverWebExchange, WebSocketHandler webSocketHandler, @Nullable String str, Supplier<HandshakeInfo> supplier) {
        ServerHttpRequest request = serverWebExchange.getRequest();
        ServerHttpResponse response = serverWebExchange.getResponse();
        HttpServletRequest nativeRequest = getNativeRequest(request);
        HttpServletResponse nativeResponse = getNativeResponse(response);
        HandshakeInfo handshakeInfo = supplier.get();
        DataBufferFactory bufferFactory = response.bufferFactory();
        DefaultServerEndpointConfig defaultServerEndpointConfig = new DefaultServerEndpointConfig(nativeRequest.getRequestURI(), new StandardWebSocketHandlerAdapter(webSocketHandler, session -> {
            return new TomcatWebSocketSession(session, handshakeInfo, bufferFactory);
        }));
        defaultServerEndpointConfig.setSubprotocols(str != null ? Collections.singletonList(str) : Collections.emptyList());
        return serverWebExchange.getResponse().setComplete().then(Mono.fromCallable(() -> {
            getContainer(nativeRequest).doUpgrade(nativeRequest, nativeResponse, defaultServerEndpointConfig, Collections.emptyMap());
            return null;
        }));
    }

    private static HttpServletRequest getNativeRequest(ServerHttpRequest serverHttpRequest) {
        if (serverHttpRequest instanceof AbstractServerHttpRequest) {
            return (HttpServletRequest) ((AbstractServerHttpRequest) serverHttpRequest).getNativeRequest();
        }
        if (serverHttpRequest instanceof ServerHttpRequestDecorator) {
            return getNativeRequest(((ServerHttpRequestDecorator) serverHttpRequest).getDelegate());
        }
        throw new IllegalArgumentException("Couldn't find HttpServletRequest in " + serverHttpRequest.getClass().getName());
    }

    private static HttpServletResponse getNativeResponse(ServerHttpResponse serverHttpResponse) {
        if (serverHttpResponse instanceof AbstractServerHttpResponse) {
            return (HttpServletResponse) ((AbstractServerHttpResponse) serverHttpResponse).getNativeResponse();
        }
        if (serverHttpResponse instanceof ServerHttpResponseDecorator) {
            return getNativeResponse(((ServerHttpResponseDecorator) serverHttpResponse).getDelegate());
        }
        throw new IllegalArgumentException("Couldn't find HttpServletResponse in " + serverHttpResponse.getClass().getName());
    }

    private WsServerContainer getContainer(HttpServletRequest httpServletRequest) {
        if (this.serverContainer == null) {
            Object attribute = httpServletRequest.getServletContext().getAttribute(SERVER_CONTAINER_ATTR);
            Assert.state(attribute instanceof WsServerContainer, "ServletContext attribute 'javax.websocket.server.ServerContainer' not found.");
            this.serverContainer = (WsServerContainer) attribute;
            initServerContainer(this.serverContainer);
        }
        return this.serverContainer;
    }

    private void initServerContainer(ServerContainer serverContainer) {
        if (this.asyncSendTimeout != null) {
            serverContainer.setAsyncSendTimeout(this.asyncSendTimeout.longValue());
        }
        if (this.maxSessionIdleTimeout != null) {
            serverContainer.setDefaultMaxSessionIdleTimeout(this.maxSessionIdleTimeout.longValue());
        }
        if (this.maxTextMessageBufferSize != null) {
            serverContainer.setDefaultMaxTextMessageBufferSize(this.maxTextMessageBufferSize.intValue());
        }
        if (this.maxBinaryMessageBufferSize != null) {
            serverContainer.setDefaultMaxBinaryMessageBufferSize(this.maxBinaryMessageBufferSize.intValue());
        }
    }
}
