/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.web.socket.messaging;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.security.Principal;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompConversionException;
import org.springframework.messaging.simp.stomp.StompDecoder;
import org.springframework.messaging.simp.stomp.StompEncoder;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.simp.user.DestinationUserNameProvider;
import org.springframework.messaging.simp.user.UserSessionRegistry;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.util.Assert;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.messaging.SubProtocolHandler;

public class StompSubProtocolHandler
implements SubProtocolHandler {
    public static final String CONNECTED_USER_HEADER = "user-name";
    private static final Charset UTF8_CHARSET = Charset.forName("UTF-8");
    private static final Log logger = LogFactory.getLog(StompSubProtocolHandler.class);
    private final StompDecoder stompDecoder = new StompDecoder();
    private final StompEncoder stompEncoder = new StompEncoder();
    private UserSessionRegistry userSessionRegistry;

    public void setUserSessionRegistry(UserSessionRegistry registry) {
        this.userSessionRegistry = registry;
    }

    public UserSessionRegistry getUserSessionRegistry() {
        return this.userSessionRegistry;
    }

    @Override
    public List<String> getSupportedProtocols() {
        return Arrays.asList("v10.stomp", "v11.stomp", "v12.stomp");
    }

    @Override
    public void handleMessageFromClient(WebSocketSession session, WebSocketMessage<?> webSocketMessage, MessageChannel outputChannel) {
        Message message = null;
        Throwable decodeFailure = null;
        try {
            Assert.isInstanceOf(TextMessage.class, webSocketMessage);
            String payload = (String)((TextMessage)webSocketMessage).getPayload();
            ByteBuffer byteBuffer = ByteBuffer.wrap(payload.getBytes(UTF8_CHARSET));
            message = this.stompDecoder.decode(byteBuffer);
            if (message == null) {
                decodeFailure = new IllegalStateException("Not a valid STOMP frame: " + payload);
            }
        }
        catch (Throwable ex) {
            decodeFailure = ex;
        }
        if (decodeFailure != null) {
            logger.error((Object)"Failed to parse WebSocket message as STOMP frame", decodeFailure);
            this.sendErrorMessage(session, decodeFailure);
            return;
        }
        try {
            StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
            if (logger.isTraceEnabled()) {
                if (SimpMessageType.HEARTBEAT.equals((Object)headers.getMessageType())) {
                    logger.trace((Object)("Received heartbeat from client session=" + session.getId()));
                } else {
                    logger.trace((Object)("Received message from client session=" + session.getId()));
                }
            }
            headers.setSessionId(session.getId());
            headers.setUser(session.getPrincipal());
            message = MessageBuilder.withPayload((Object)message.getPayload()).setHeaders((MessageHeaderAccessor)headers).build();
            outputChannel.send(message);
        }
        catch (Throwable ex) {
            logger.error((Object)"Terminating STOMP session due to failure to send message", ex);
            this.sendErrorMessage(session, ex);
        }
    }

    protected void sendErrorMessage(WebSocketSession session, Throwable error) {
        StompHeaderAccessor headers = StompHeaderAccessor.create((StompCommand)StompCommand.ERROR);
        headers.setMessage(error.getMessage());
        Message message = MessageBuilder.withPayload((Object)new byte[0]).setHeaders((MessageHeaderAccessor)headers).build();
        String payload = new String(this.stompEncoder.encode(message), UTF8_CHARSET);
        try {
            session.sendMessage(new TextMessage(payload));
        }
        catch (Throwable ex) {
            // empty catch block
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void handleMessageToClient(WebSocketSession session, Message<?> message) {
        StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
        if (headers.getMessageType() == SimpMessageType.CONNECT_ACK) {
            StompHeaderAccessor connectedHeaders = StompHeaderAccessor.create((StompCommand)StompCommand.CONNECTED);
            connectedHeaders.setVersion(this.getVersion(headers));
            connectedHeaders.setHeartbeat(0L, 0L);
            headers = connectedHeaders;
        } else if (SimpMessageType.MESSAGE.equals((Object)headers.getMessageType())) {
            headers.updateStompCommandAsServerMessage();
        }
        if (headers.getCommand() == StompCommand.CONNECTED) {
            this.afterStompSessionConnected(headers, session);
        }
        if (StompCommand.MESSAGE.equals((Object)headers.getCommand())) {
            if (headers.getSubscriptionId() == null) {
                logger.error((Object)("Ignoring message, no subscriptionId header: " + message));
                return;
            }
            String header = "subscribeDestination";
            if (message.getHeaders().containsKey((Object)header)) {
                headers.setDestination((String)message.getHeaders().get((Object)header));
            }
        }
        if (!(message.getPayload() instanceof byte[])) {
            logger.error((Object)("Ignoring message, expected byte[] content: " + message));
            return;
        }
        try {
            message = MessageBuilder.withPayload((Object)message.getPayload()).setHeaders((MessageHeaderAccessor)headers).build();
            byte[] bytes = this.stompEncoder.encode(message);
            WebSocketSession webSocketSession = session;
            synchronized (webSocketSession) {
                session.sendMessage(new TextMessage(new String(bytes, UTF8_CHARSET)));
            }
        }
        catch (Throwable ex) {
            this.sendErrorMessage(session, ex);
        }
        finally {
            if (StompCommand.ERROR.equals((Object)headers.getCommand())) {
                try {
                    session.close(CloseStatus.PROTOCOL_ERROR);
                }
                catch (IOException ex) {}
            }
        }
    }

    private String getVersion(StompHeaderAccessor connectAckHeaders) {
        String name = "simpConnectMessage";
        Message connectMessage = (Message)connectAckHeaders.getHeader(name);
        StompHeaderAccessor connectHeaders = StompHeaderAccessor.wrap((Message)connectMessage);
        Assert.notNull((Object)connectMessage, (String)("CONNECT_ACK does not contain original CONNECT " + connectAckHeaders));
        Set acceptVersions = connectHeaders.getAcceptVersion();
        if (acceptVersions.contains("1.2")) {
            return "1.2";
        }
        if (acceptVersions.contains("1.1")) {
            return "1.1";
        }
        if (acceptVersions.isEmpty()) {
            return null;
        }
        throw new StompConversionException("Unsupported version '" + acceptVersions + "'");
    }

    private void afterStompSessionConnected(StompHeaderAccessor headers, WebSocketSession session) {
        Principal principal = session.getPrincipal();
        if (principal != null) {
            headers.setNativeHeader(CONNECTED_USER_HEADER, principal.getName());
            if (this.userSessionRegistry != null) {
                String userName = this.resolveNameForUserSessionRegistry(principal);
                this.userSessionRegistry.registerSessionId(userName, session.getId());
            }
        }
    }

    private String resolveNameForUserSessionRegistry(Principal principal) {
        String userName = principal.getName();
        if (principal instanceof DestinationUserNameProvider) {
            userName = ((DestinationUserNameProvider)principal).getDestinationUserName();
        }
        return userName;
    }

    @Override
    public String resolveSessionId(Message<?> message) {
        StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
        return headers.getSessionId();
    }

    @Override
    public void afterSessionStarted(WebSocketSession session, MessageChannel outputChannel) {
    }

    @Override
    public void afterSessionEnded(WebSocketSession session, CloseStatus closeStatus, MessageChannel outputChannel) {
        Principal principal = session.getPrincipal();
        if (this.userSessionRegistry != null && principal != null) {
            String userName = this.resolveNameForUserSessionRegistry(principal);
            this.userSessionRegistry.unregisterSessionId(userName, session.getId());
        }
        StompHeaderAccessor headers = StompHeaderAccessor.create((StompCommand)StompCommand.DISCONNECT);
        headers.setSessionId(session.getId());
        Message message = MessageBuilder.withPayload((Object)new byte[0]).setHeaders((MessageHeaderAccessor)headers).build();
        outputChannel.send(message);
    }
}

