/*
 * Decompiled with CFR 0.152.
 */
package com.vesoft.nebula.driver.graph.net;

import com.alibaba.fastjson.JSON;
import com.google.common.base.Charsets;
import com.google.protobuf.ByteString;
import com.vesoft.nebula.driver.graph.ErrorCode;
import com.vesoft.nebula.driver.graph.data.HostAddress;
import com.vesoft.nebula.driver.graph.exception.AuthFailedException;
import com.vesoft.nebula.driver.graph.exception.IOErrorException;
import com.vesoft.nebula.driver.graph.net.AuthResult;
import com.vesoft.nebula.driver.graph.net.Connection;
import com.vesoft.nebula.proto.common.ClientInfo;
import com.vesoft.nebula.proto.common.Common;
import com.vesoft.nebula.proto.graph.AuthRequest;
import com.vesoft.nebula.proto.graph.AuthResponse;
import com.vesoft.nebula.proto.graph.ExecuteRequest;
import com.vesoft.nebula.proto.graph.ExecuteResponse;
import com.vesoft.nebula.proto.graph.GraphServiceGrpc;
import io.grpc.Deadline;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import java.nio.charset.Charset;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GrpcConnection
extends Connection {
    private static final Logger LOGGER = LoggerFactory.getLogger(GrpcConnection.class);
    private static final ConcurrentHashMap<HostAddress, ManagedChannel> channels = new ConcurrentHashMap();
    private GraphServiceGrpc.GraphServiceBlockingStub stub;
    private long connectTimeout = 0L;
    private long requestTimeout = 0L;
    private final Charset charset = Charsets.UTF_8;
    private static final ReadWriteLock lock = new ReentrantReadWriteLock();

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void open(HostAddress address, long connectTimeout, long requestTimeout) {
        this.serverAddr = address;
        this.connectTimeout = connectTimeout;
        this.requestTimeout = requestTimeout;
        lock.readLock().lock();
        try {
            channels.computeIfAbsent(this.serverAddr, key -> this.createChannel());
        }
        finally {
            lock.readLock().unlock();
        }
        this.stub = GraphServiceGrpc.newBlockingStub(channels.get(this.serverAddr));
    }

    @Override
    public void close() {
        if (!channels.isEmpty()) {
            GrpcConnection.closeChannel();
        }
        this.stub = null;
    }

    @Override
    public boolean ping(long sessionID, long timeoutMs) throws IOErrorException {
        ExecuteResponse response = this.execute(sessionID, "RETURN 1", timeoutMs);
        return ErrorCode.SUCCESSFUL_COMPLETION.code.equals(response.getStatus().getCode().toString(this.charset));
    }

    public AuthResult authenticate(String user, Map<String, Object> authOptions) throws AuthFailedException {
        ClientInfo clientInfo = ClientInfo.newBuilder().setLang(ClientInfo.Language.JAVA).setProtocolVersion(Common.getDescriptor().getOptions().getExtension(Common.protocolVersion)).setVersion(ByteString.copyFrom("5.0.0", this.charset)).build();
        String authInfoString = JSON.toJSONString(authOptions);
        AuthRequest authReq = AuthRequest.newBuilder().setUsername(ByteString.copyFrom(user, this.charset)).setAuthInfo(ByteString.copyFrom(authInfoString, this.charset)).setClientInfo(clientInfo).build();
        this.getChannel();
        AuthResponse resp = ((GraphServiceGrpc.GraphServiceBlockingStub)this.stub.withDeadlineAfter(this.connectTimeout, TimeUnit.MILLISECONDS)).authenticate(authReq);
        String code = resp.getStatus().getCode().toString(this.charset);
        if (!ErrorCode.SUCCESSFUL_COMPLETION.code.equals(code)) {
            throw new AuthFailedException(resp.getStatus().getMessage().toString(this.charset));
        }
        return new AuthResult(resp.getSessionId());
    }

    public ExecuteResponse execute(long sessionID, String stmt, long timeout) throws IOErrorException {
        this.getChannel();
        ExecuteRequest request = ExecuteRequest.newBuilder().setSessionId(sessionID).setStmt(ByteString.copyFrom(stmt, this.charset)).build();
        return ((GraphServiceGrpc.GraphServiceBlockingStub)this.stub.withDeadlineAfter(timeout, TimeUnit.MILLISECONDS)).execute(request);
    }

    public ExecuteResponse execute(long sessionID, String stmt) throws IOErrorException {
        return this.execute(sessionID, stmt, this.requestTimeout);
    }

    private void getChannel() {
        lock.readLock().lock();
        try {
            channels.computeIfAbsent(this.serverAddr, key -> {
                ManagedChannel channel = this.createChannel();
                this.stub = (GraphServiceGrpc.GraphServiceBlockingStub)GraphServiceGrpc.newBlockingStub(channel).withDeadline(Deadline.after(this.requestTimeout, TimeUnit.MILLISECONDS));
                return channel;
            });
        }
        finally {
            lock.readLock().unlock();
        }
    }

    private ManagedChannel createChannel() {
        return ((ManagedChannelBuilder)ManagedChannelBuilder.forAddress(this.serverAddr.getHost(), this.serverAddr.getPort()).usePlaintext()).build();
    }

    private static void closeChannel() {
        lock.writeLock().lock();
        try {
            for (ManagedChannel channel : channels.values()) {
                if (channel == null || channel.isShutdown()) continue;
                try {
                    channel.shutdownNow().awaitTermination(5L, TimeUnit.SECONDS);
                }
                catch (InterruptedException e) {
                    LOGGER.warn("close grpc connection is interrupted.", e);
                }
            }
            channels.clear();
        }
        finally {
            lock.writeLock().unlock();
        }
    }
}

