/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.arthas.nat.agent.proxy.server.handler.ws;

import com.alibaba.arthas.nat.agent.proxy.server.handler.ws.WebSocketClientHandler;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketClientHandshakerFactory;
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketVersion;
import io.netty.util.AttributeKey;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import java.io.UnsupportedEncodingException;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URLDecoder;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class WsRequestHandler {
    private static final Logger logger = LoggerFactory.getLogger(WsRequestHandler.class);
    private final ConcurrentHashMap<Channel, Channel> channelMappings = new ConcurrentHashMap();

    public void handleWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame frame) {
        if (frame instanceof CloseWebSocketFrame) {
            this.closeOutboundChannel(ctx.channel());
            ctx.close();
            return;
        }
        Channel outboundChannel = this.channelMappings.get(ctx.channel());
        if (outboundChannel == null || !outboundChannel.isActive()) {
            this.connectToDestinationServer(ctx, frame);
        } else {
            this.forwardWebSocketFrame(frame, outboundChannel);
        }
    }

    private void connectToDestinationServer(final ChannelHandlerContext ctx, WebSocketFrame frame) {
        final String nativeAgentAddress = (String)ctx.channel().attr(AttributeKey.valueOf("nativeAgentAddress")).get();
        Bootstrap b = new Bootstrap();
        ((Bootstrap)((Bootstrap)b.group(ctx.channel().eventLoop())).channel(NioSocketChannel.class)).handler(new ChannelInitializer<SocketChannel>(this){
            final /* synthetic */ WsRequestHandler this$0;
            {
                this.this$0 = this$0;
            }

            @Override
            protected void initChannel(SocketChannel ch) {
                ChannelPipeline p = ch.pipeline();
                p.addLast(new HttpClientCodec());
                p.addLast(new HttpObjectAggregator(65536));
                p.addLast(new WebSocketClientProtocolHandler(WebSocketClientHandshakerFactory.newHandshaker(URI.create("ws://" + nativeAgentAddress + "/ws"), WebSocketVersion.V13, null, false, new DefaultHttpHeaders())));
                p.addLast(new WebSocketClientHandler(ctx.channel()));
            }
        });
        String[] addressSplit = nativeAgentAddress.split(":");
        ChannelFuture f = b.connect(addressSplit[0], Integer.parseInt(addressSplit[1]));
        f.addListener(future -> {
            if (future.isSuccess()) {
                Channel outboundChannel = future.channel();
                this.channelMappings.put(ctx.channel(), outboundChannel);
                this.forwardWebSocketFrame(frame, outboundChannel);
            } else {
                logger.error("Failed to connect to destination server", future.cause());
                ctx.close();
            }
        });
    }

    private void forwardWebSocketFrame(WebSocketFrame frame, Channel outboundChannel) {
        if (outboundChannel != null && outboundChannel.isActive()) {
            outboundChannel.writeAndFlush(frame.retain()).addListener((GenericFutureListener<? extends Future<? super Void>>)((GenericFutureListener<Future>)future -> {
                if (!future.isSuccess()) {
                    logger.error("Failed to forward WebSocket frame", future.cause());
                }
            }));
        } else {
            logger.warn("Outbound channel is not active. Cannot forward frame.");
        }
    }

    private void closeOutboundChannel(Channel inboundChannel) {
        Channel outboundChannel = this.channelMappings.remove(inboundChannel);
        if (outboundChannel != null) {
            logger.info("Closing outbound channel");
            outboundChannel.close();
        }
    }

    public void channelInactive(ChannelHandlerContext ctx) {
        logger.info("Channel inactive, closing outbound channel");
        this.closeOutboundChannel(ctx.channel());
    }

    public void handleWebSocketUpgrade(ChannelHandlerContext ctx, FullHttpRequest request) {
        URI uri = null;
        try {
            uri = new URI(request.uri());
        }
        catch (URISyntaxException e) {
            return;
        }
        Map<String, String> params = this.parseQueryString(uri.getQuery());
        String nativeAgentAddress = params.get("nativeAgentAddress");
        if (nativeAgentAddress != null) {
            ctx.channel().attr(AttributeKey.valueOf("nativeAgentAddress")).set(nativeAgentAddress);
        }
        request.setUri(uri.getPath());
        ctx.fireChannelRead(request.retain());
    }

    private Map<String, String> parseQueryString(String query) {
        HashMap<String, String> params = new HashMap<String, String>();
        if (query != null) {
            String[] pairs;
            for (String pair : pairs = query.split("&")) {
                int idx = pair.indexOf("=");
                try {
                    String key = URLDecoder.decode(pair.substring(0, idx), "UTF-8");
                    String value = URLDecoder.decode(pair.substring(idx + 1), "UTF-8");
                    params.put(key, value);
                }
                catch (UnsupportedEncodingException unsupportedEncodingException) {
                    // empty catch block
                }
            }
        }
        return params;
    }
}

