/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.web.socket.messaging;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.security.Principal;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.context.ApplicationEvent;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.ApplicationEventPublisherAware;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.simp.SimpAttributes;
import org.springframework.messaging.simp.SimpAttributesContextHolder;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.stomp.BufferingStompDecoder;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompDecoder;
import org.springframework.messaging.simp.stomp.StompEncoder;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.simp.user.DestinationUserNameProvider;
import org.springframework.messaging.simp.user.UserSessionRegistry;
import org.springframework.messaging.support.AbstractMessageChannel;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.messaging.support.ImmutableMessageChannelInterceptor;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.messaging.support.MessageHeaderInitializer;
import org.springframework.util.Assert;
import org.springframework.util.MimeTypeUtils;
import org.springframework.web.socket.BinaryMessage;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.SessionLimitExceededException;
import org.springframework.web.socket.handler.WebSocketSessionDecorator;
import org.springframework.web.socket.messaging.SessionConnectEvent;
import org.springframework.web.socket.messaging.SessionConnectedEvent;
import org.springframework.web.socket.messaging.SessionDisconnectEvent;
import org.springframework.web.socket.messaging.SessionSubscribeEvent;
import org.springframework.web.socket.messaging.SessionUnsubscribeEvent;
import org.springframework.web.socket.messaging.SubProtocolHandler;
import org.springframework.web.socket.sockjs.transport.SockJsSession;

public class StompSubProtocolHandler
implements SubProtocolHandler,
ApplicationEventPublisherAware {
    public static final int MINIMUM_WEBSOCKET_MESSAGE_SIZE = 16640;
    public static final String CONNECTED_USER_HEADER = "user-name";
    private static final Log logger = LogFactory.getLog(StompSubProtocolHandler.class);
    private static final byte[] EMPTY_PAYLOAD = new byte[0];
    private int messageSizeLimit = 65536;
    private UserSessionRegistry userSessionRegistry;
    private final StompEncoder stompEncoder = new StompEncoder();
    private final StompDecoder stompDecoder = new StompDecoder();
    private final Map<String, BufferingStompDecoder> decoders = new ConcurrentHashMap<String, BufferingStompDecoder>();
    private MessageHeaderInitializer headerInitializer;
    private Boolean immutableMessageInterceptorPresent;
    private ApplicationEventPublisher eventPublisher;
    private final Stats stats = new Stats();

    public void setMessageSizeLimit(int messageSizeLimit) {
        this.messageSizeLimit = messageSizeLimit;
    }

    public int getMessageSizeLimit() {
        return this.messageSizeLimit;
    }

    public void setUserSessionRegistry(UserSessionRegistry registry) {
        this.userSessionRegistry = registry;
    }

    public UserSessionRegistry getUserSessionRegistry() {
        return this.userSessionRegistry;
    }

    public void setHeaderInitializer(MessageHeaderInitializer headerInitializer) {
        this.headerInitializer = headerInitializer;
        this.stompDecoder.setHeaderInitializer(headerInitializer);
    }

    public MessageHeaderInitializer getHeaderInitializer() {
        return this.headerInitializer;
    }

    @Override
    public List<String> getSupportedProtocols() {
        return Arrays.asList("v10.stomp", "v11.stomp", "v12.stomp");
    }

    public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) {
        this.eventPublisher = applicationEventPublisher;
    }

    public String getStatsInfo() {
        return this.stats.toString();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void handleMessageFromClient(WebSocketSession session, WebSocketMessage<?> webSocketMessage, MessageChannel outputChannel) {
        List messages;
        try {
            ByteBuffer byteBuffer;
            if (webSocketMessage instanceof TextMessage) {
                byteBuffer = ByteBuffer.wrap(((TextMessage)webSocketMessage).asBytes());
            } else if (webSocketMessage instanceof BinaryMessage) {
                byteBuffer = (ByteBuffer)((BinaryMessage)webSocketMessage).getPayload();
            } else {
                return;
            }
            BufferingStompDecoder decoder = this.decoders.get(session.getId());
            if (decoder == null) {
                throw new IllegalStateException("No decoder for session id '" + session.getId() + "'");
            }
            messages = decoder.decode(byteBuffer);
            if (messages.isEmpty()) {
                if (logger.isTraceEnabled()) {
                    logger.trace((Object)("Incomplete STOMP frame content received in session " + session + ", bufferSize=" + decoder.getBufferSize() + ", bufferSizeLimit=" + decoder.getBufferSizeLimit() + "."));
                }
                return;
            }
        }
        catch (Throwable ex) {
            if (logger.isErrorEnabled()) {
                logger.error((Object)("Failed to parse " + webSocketMessage + " in session " + session.getId() + ". Sending STOMP ERROR to client."), ex);
            }
            this.sendErrorMessage(session, ex);
            return;
        }
        for (Message message : messages) {
            try {
                StompHeaderAccessor headerAccessor = (StompHeaderAccessor)MessageHeaderAccessor.getAccessor((Message)message, StompHeaderAccessor.class);
                if (logger.isTraceEnabled()) {
                    logger.trace((Object)("From client: " + headerAccessor.getShortLogMessage(message.getPayload())));
                }
                headerAccessor.setSessionId(session.getId());
                headerAccessor.setSessionAttributes(session.getAttributes());
                headerAccessor.setUser(session.getPrincipal());
                if (!this.detectImmutableMessageInterceptor(outputChannel)) {
                    headerAccessor.setImmutable();
                }
                if (StompCommand.CONNECT.equals((Object)headerAccessor.getCommand())) {
                    this.stats.incrementConnectCount();
                } else if (StompCommand.DISCONNECT.equals((Object)headerAccessor.getCommand())) {
                    this.stats.incrementDisconnectCount();
                }
                try {
                    SimpAttributesContextHolder.setAttributesFromMessage((Message)message);
                    if (this.eventPublisher != null) {
                        if (StompCommand.CONNECT.equals((Object)headerAccessor.getCommand())) {
                            this.publishEvent(new SessionConnectEvent(this, (Message<byte[]>)message));
                        } else if (StompCommand.SUBSCRIBE.equals((Object)headerAccessor.getCommand())) {
                            this.publishEvent(new SessionSubscribeEvent(this, (Message<byte[]>)message));
                        } else if (StompCommand.UNSUBSCRIBE.equals((Object)headerAccessor.getCommand())) {
                            this.publishEvent(new SessionUnsubscribeEvent(this, (Message<byte[]>)message));
                        }
                    }
                    outputChannel.send(message);
                }
                finally {
                    SimpAttributesContextHolder.resetAttributes();
                }
            }
            catch (Throwable ex) {
                logger.error((Object)("Failed to send client message to application via MessageChannel in session " + session.getId() + ". Sending STOMP ERROR to client."), ex);
                this.sendErrorMessage(session, ex);
            }
        }
    }

    private boolean detectImmutableMessageInterceptor(MessageChannel channel) {
        if (this.immutableMessageInterceptorPresent != null) {
            return this.immutableMessageInterceptorPresent;
        }
        if (channel instanceof AbstractMessageChannel) {
            for (ChannelInterceptor interceptor : ((AbstractMessageChannel)channel).getInterceptors()) {
                if (!(interceptor instanceof ImmutableMessageChannelInterceptor)) continue;
                this.immutableMessageInterceptorPresent = true;
                return true;
            }
        }
        this.immutableMessageInterceptorPresent = false;
        return false;
    }

    private void publishEvent(ApplicationEvent event) {
        try {
            this.eventPublisher.publishEvent(event);
        }
        catch (Throwable ex) {
            logger.error((Object)("Error publishing " + event), ex);
        }
    }

    protected void sendErrorMessage(WebSocketSession session, Throwable error) {
        StompHeaderAccessor headerAccessor = StompHeaderAccessor.create((StompCommand)StompCommand.ERROR);
        headerAccessor.setMessage(error.getMessage());
        byte[] bytes = this.stompEncoder.encode((Map)headerAccessor.getMessageHeaders(), EMPTY_PAYLOAD);
        try {
            session.sendMessage(new TextMessage(bytes));
        }
        catch (Throwable ex) {
            logger.debug((Object)"Failed to send STOMP ERROR to client.", ex);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void handleMessageToClient(WebSocketSession session, Message<?> message) {
        if (!(message.getPayload() instanceof byte[])) {
            logger.error((Object)("Expected byte[] payload. Ignoring " + message + "."));
            return;
        }
        StompHeaderAccessor stompAccessor = this.getStompHeaderAccessor(message);
        StompCommand command = stompAccessor.getCommand();
        if (StompCommand.MESSAGE.equals((Object)command)) {
            String origDestination;
            if (stompAccessor.getSubscriptionId() == null) {
                logger.warn((Object)("No STOMP \"subscription\" header in " + message));
            }
            if ((origDestination = stompAccessor.getFirstNativeHeader("simpOrigDestination")) != null) {
                stompAccessor = this.toMutableAccessor(stompAccessor, message);
                stompAccessor.removeNativeHeader("simpOrigDestination");
                stompAccessor.setDestination(origDestination);
            }
        } else if (StompCommand.CONNECTED.equals((Object)command)) {
            this.stats.incrementConnectedCount();
            stompAccessor = this.afterStompSessionConnected(message, stompAccessor, session);
            if (this.eventPublisher != null && StompCommand.CONNECTED.equals((Object)command)) {
                try {
                    SimpAttributes simpAttributes = new SimpAttributes(session.getId(), session.getAttributes());
                    SimpAttributesContextHolder.setAttributes((SimpAttributes)simpAttributes);
                    this.publishEvent(new SessionConnectedEvent(this, message));
                }
                finally {
                    SimpAttributesContextHolder.resetAttributes();
                }
            }
        }
        try {
            boolean useBinary;
            byte[] payload = (byte[])message.getPayload();
            byte[] bytes = this.stompEncoder.encode((Map)stompAccessor.getMessageHeaders(), payload);
            boolean bl = useBinary = payload.length > 0 && !(session instanceof SockJsSession) && MimeTypeUtils.APPLICATION_OCTET_STREAM.isCompatibleWith(stompAccessor.getContentType());
            if (useBinary) {
                session.sendMessage(new BinaryMessage(bytes));
            } else {
                session.sendMessage(new TextMessage(bytes));
            }
        }
        catch (SessionLimitExceededException ex) {
            throw ex;
        }
        catch (Throwable ex) {
            logger.debug((Object)("Failed to send WebSocket message to client in session " + session.getId()), ex);
            command = StompCommand.ERROR;
        }
        finally {
            if (StompCommand.ERROR.equals((Object)command)) {
                try {
                    session.close(CloseStatus.PROTOCOL_ERROR);
                }
                catch (IOException iOException) {}
            }
        }
    }

    private StompHeaderAccessor getStompHeaderAccessor(Message<?> message) {
        StompHeaderAccessor stompAccessor;
        MessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class);
        if (accessor == null) {
            throw new IllegalStateException("No header accessor in " + message);
        }
        if (accessor instanceof StompHeaderAccessor) {
            stompAccessor = (StompHeaderAccessor)accessor;
        } else if (accessor instanceof SimpMessageHeaderAccessor) {
            stompAccessor = StompHeaderAccessor.wrap(message);
            if (SimpMessageType.CONNECT_ACK.equals((Object)stompAccessor.getMessageType())) {
                stompAccessor = this.convertConnectAcktoStompConnected(stompAccessor);
            } else if (SimpMessageType.DISCONNECT_ACK.equals((Object)stompAccessor.getMessageType())) {
                stompAccessor = StompHeaderAccessor.create((StompCommand)StompCommand.ERROR);
                stompAccessor.setMessage("Session closed.");
            } else if (stompAccessor.getCommand() == null || StompCommand.SEND.equals((Object)stompAccessor.getCommand())) {
                stompAccessor.updateStompCommandAsServerMessage();
            }
        } else {
            throw new IllegalStateException("Unexpected header accessor type: " + accessor.getClass() + " in " + message);
        }
        return stompAccessor;
    }

    private StompHeaderAccessor convertConnectAcktoStompConnected(StompHeaderAccessor connectAckHeaders) {
        String version;
        String name = "simpConnectMessage";
        Message message = (Message)connectAckHeaders.getHeader(name);
        Assert.notNull((Object)message, (String)("Original STOMP CONNECT not found in " + connectAckHeaders));
        StompHeaderAccessor connectHeaders = (StompHeaderAccessor)MessageHeaderAccessor.getAccessor((Message)message, StompHeaderAccessor.class);
        Set acceptVersions = connectHeaders.getAcceptVersion();
        if (acceptVersions.contains("1.2")) {
            version = "1.2";
        } else if (acceptVersions.contains("1.1")) {
            version = "1.1";
        } else if (acceptVersions.isEmpty()) {
            version = null;
        } else {
            throw new IllegalArgumentException("Unsupported STOMP version '" + acceptVersions + "'");
        }
        StompHeaderAccessor connectedHeaders = StompHeaderAccessor.create((StompCommand)StompCommand.CONNECTED);
        connectedHeaders.setVersion(version);
        connectedHeaders.setHeartbeat(0L, 0L);
        return connectedHeaders;
    }

    protected StompHeaderAccessor toMutableAccessor(StompHeaderAccessor headerAccessor, Message<?> message) {
        return headerAccessor.isMutable() ? headerAccessor : StompHeaderAccessor.wrap(message);
    }

    private StompHeaderAccessor afterStompSessionConnected(Message<?> message, StompHeaderAccessor accessor, WebSocketSession session) {
        long[] heartbeat;
        Principal principal = session.getPrincipal();
        if (principal != null) {
            accessor = this.toMutableAccessor(accessor, message);
            accessor.setNativeHeader(CONNECTED_USER_HEADER, principal.getName());
            if (this.userSessionRegistry != null) {
                String userName = this.getSessionRegistryUserName(principal);
                this.userSessionRegistry.registerSessionId(userName, session.getId());
            }
        }
        if ((heartbeat = accessor.getHeartbeat())[1] > 0L && (session = WebSocketSessionDecorator.unwrap(session)) instanceof SockJsSession) {
            ((SockJsSession)session).disableHeartbeat();
        }
        return accessor;
    }

    private String getSessionRegistryUserName(Principal principal) {
        String userName = principal.getName();
        if (principal instanceof DestinationUserNameProvider) {
            userName = ((DestinationUserNameProvider)principal).getDestinationUserName();
        }
        return userName;
    }

    @Override
    public String resolveSessionId(Message<?> message) {
        return SimpMessageHeaderAccessor.getSessionId((Map)message.getHeaders());
    }

    @Override
    public void afterSessionStarted(WebSocketSession session, MessageChannel outputChannel) {
        if (session.getTextMessageSizeLimit() < 16640) {
            session.setTextMessageSizeLimit(16640);
        }
        this.decoders.put(session.getId(), new BufferingStompDecoder(this.stompDecoder, this.getMessageSizeLimit()));
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void afterSessionEnded(WebSocketSession session, CloseStatus closeStatus, MessageChannel outputChannel) {
        this.decoders.remove(session.getId());
        Principal principal = session.getPrincipal();
        if (principal != null && this.userSessionRegistry != null) {
            String userName = this.getSessionRegistryUserName(principal);
            this.userSessionRegistry.unregisterSessionId(userName, session.getId());
        }
        Message<byte[]> message = this.createDisconnectMessage(session);
        SimpAttributes simpAttributes = SimpAttributes.fromMessage(message);
        try {
            SimpAttributesContextHolder.setAttributes((SimpAttributes)simpAttributes);
            if (this.eventPublisher != null) {
                this.publishEvent(new SessionDisconnectEvent(this, message, session.getId(), closeStatus));
            }
            outputChannel.send(message);
        }
        finally {
            SimpAttributesContextHolder.resetAttributes();
            simpAttributes.sessionCompleted();
        }
    }

    private Message<byte[]> createDisconnectMessage(WebSocketSession session) {
        StompHeaderAccessor headerAccessor = StompHeaderAccessor.create((StompCommand)StompCommand.DISCONNECT);
        if (this.getHeaderInitializer() != null) {
            this.getHeaderInitializer().initHeaders((MessageHeaderAccessor)headerAccessor);
        }
        headerAccessor.setSessionId(session.getId());
        headerAccessor.setSessionAttributes(session.getAttributes());
        headerAccessor.setUser(session.getPrincipal());
        return MessageBuilder.createMessage((Object)EMPTY_PAYLOAD, (MessageHeaders)headerAccessor.getMessageHeaders());
    }

    public String toString() {
        return "StompSubProtocolHandler" + this.getSupportedProtocols();
    }

    private static class Stats {
        private final AtomicInteger connect = new AtomicInteger();
        private final AtomicInteger connected = new AtomicInteger();
        private final AtomicInteger disconnect = new AtomicInteger();

        private Stats() {
        }

        public void incrementConnectCount() {
            this.connect.incrementAndGet();
        }

        public void incrementConnectedCount() {
            this.connected.incrementAndGet();
        }

        public void incrementDisconnectCount() {
            this.disconnect.incrementAndGet();
        }

        public String toString() {
            return "processed CONNECT(" + this.connect.get() + ")-CONNECTED(" + this.connected.get() + ")-DISCONNECT(" + this.disconnect.get() + ")";
        }
    }
}

