/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.driver.internal;

import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.neo4j.driver.internal.BaseDriver;
import org.neo4j.driver.internal.RoutingErrorHandler;
import org.neo4j.driver.internal.RoutingNetworkSession;
import org.neo4j.driver.internal.net.BoltServerAddress;
import org.neo4j.driver.internal.security.SecurityPlan;
import org.neo4j.driver.internal.spi.Connection;
import org.neo4j.driver.internal.spi.ConnectionPool;
import org.neo4j.driver.internal.util.Clock;
import org.neo4j.driver.internal.util.ConcurrentRoundRobinSet;
import org.neo4j.driver.v1.AccessMode;
import org.neo4j.driver.v1.Logger;
import org.neo4j.driver.v1.Logging;
import org.neo4j.driver.v1.Record;
import org.neo4j.driver.v1.Session;
import org.neo4j.driver.v1.StatementResult;
import org.neo4j.driver.v1.Value;
import org.neo4j.driver.v1.exceptions.ClientException;
import org.neo4j.driver.v1.exceptions.ConnectionFailureException;
import org.neo4j.driver.v1.exceptions.ServiceUnavailableException;
import org.neo4j.driver.v1.util.Function;

public class RoutingDriver
extends BaseDriver {
    private static final String GET_SERVERS = "dbms.cluster.routing.getServers";
    private static final long MAX_TTL = 9223372036854775L;
    private static final Comparator<BoltServerAddress> COMPARATOR = new Comparator<BoltServerAddress>(){

        @Override
        public int compare(BoltServerAddress o1, BoltServerAddress o2) {
            int compare = o1.host().compareTo(o2.host());
            if (compare == 0) {
                compare = Integer.compare(o1.port(), o2.port());
            }
            return compare;
        }
    };
    private final ConnectionPool connections;
    private final Function<Connection, Session> sessionProvider;
    private final Clock clock;
    private ClusterView clusterView;

    public RoutingDriver(BoltServerAddress seedAddress, ConnectionPool connections, SecurityPlan securityPlan, Function<Connection, Session> sessionProvider, Clock clock, Logging logging) {
        super(securityPlan, logging);
        this.connections = connections;
        this.sessionProvider = sessionProvider;
        this.clock = clock;
        this.clusterView = new ClusterView(0L, clock, this.log);
        this.clusterView.addRouter(seedAddress);
        this.checkServers();
    }

    private synchronized void checkServers() {
        if (this.clusterView.isStale()) {
            Set<BoltServerAddress> oldAddresses = this.clusterView.all();
            ClusterView newView = this.newClusterView();
            Set<BoltServerAddress> newAddresses = newView.all();
            oldAddresses.removeAll(newAddresses);
            for (BoltServerAddress boltServerAddress : oldAddresses) {
                this.connections.purge(boltServerAddress);
            }
            this.clusterView = newView;
        }
    }

    private long calculateNewExpiry(Record record) {
        long ttl = record.get("ttl").asLong();
        long nextExpiry = this.clock.millis() + 1000L * ttl;
        if (ttl < 0L || ttl >= 9223372036854775L || nextExpiry < 0L) {
            return Long.MAX_VALUE;
        }
        return nextExpiry;
    }

    private ClusterView newClusterView() {
        BoltServerAddress address = null;
        for (int i = 0; i < this.clusterView.numberOfRouters(); ++i) {
            ClusterView newClusterView;
            address = this.clusterView.nextRouter();
            try {
                newClusterView = this.call(address, GET_SERVERS, new Function<Record, ClusterView>(){

                    @Override
                    public ClusterView apply(Record record) {
                        long expire = RoutingDriver.this.calculateNewExpiry(record);
                        ClusterView newClusterView = new ClusterView(expire, RoutingDriver.this.clock, RoutingDriver.this.log);
                        List servers = RoutingDriver.this.servers(record);
                        for (ServerInfo server : servers) {
                            switch (server.role()) {
                                case "READ": {
                                    newClusterView.addReaders(server.addresses());
                                    break;
                                }
                                case "WRITE": {
                                    newClusterView.addWriters(server.addresses());
                                    break;
                                }
                                case "ROUTE": {
                                    newClusterView.addRouters(server.addresses());
                                }
                            }
                        }
                        return newClusterView;
                    }
                });
            }
            catch (Throwable t) {
                this.forget(address);
                continue;
            }
            if (newClusterView.numberOfRouters() == 0) continue;
            return newClusterView;
        }
        this.close();
        throw new ServiceUnavailableException(String.format("Server %s couldn't perform discovery", address == null ? "`UNKNOWN`" : address.toString()));
    }

    private List<ServerInfo> servers(Record record) {
        return record.get("servers").asList(new Function<Value, ServerInfo>(){

            @Override
            public ServerInfo apply(Value value) {
                return new ServerInfo(value.get("addresses").asList(new Function<Value, BoltServerAddress>(){

                    @Override
                    public BoltServerAddress apply(Value value) {
                        return new BoltServerAddress(value.asString());
                    }
                }), value.get("role").asString());
            }
        });
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private <T> T call(BoltServerAddress address, String procedureName, Function<Record, T> recorder) {
        try (Session session = null;){
            Connection acquire = this.connections.acquire(address);
            session = this.sessionProvider.apply(acquire);
            StatementResult records = session.run(String.format("CALL %s", procedureName));
            if (!records.hasNext()) {
                this.forget(address);
                throw new IllegalStateException("Server responded with empty result");
            }
            T t = recorder.apply(records.single());
            return t;
        }
    }

    private synchronized void forget(BoltServerAddress address) {
        this.connections.purge(address);
        this.clusterView.remove(address);
    }

    @Override
    public Session session() {
        return this.session(AccessMode.WRITE);
    }

    @Override
    public Session session(AccessMode mode) {
        return new RoutingNetworkSession(mode, this.acquireConnection(mode), new RoutingErrorHandler(){

            @Override
            public void onConnectionFailure(BoltServerAddress address) {
                RoutingDriver.this.forget(address);
            }

            @Override
            public void onWriteFailure(BoltServerAddress address) {
                RoutingDriver.this.clusterView.removeWriter(address);
            }
        });
    }

    private Connection acquireConnection(AccessMode role) {
        this.checkServers();
        switch (role) {
            case READ: {
                return this.acquireReadConnection();
            }
            case WRITE: {
                return this.acquireWriteConnection();
            }
        }
        throw new ClientException((Object)((Object)role) + " is not supported for creating new sessions");
    }

    private Connection acquireReadConnection() {
        int numberOfServers = this.clusterView.numberOfReaders();
        for (int i = 0; i < numberOfServers; ++i) {
            BoltServerAddress address = this.clusterView.nextReader();
            try {
                return this.connections.acquire(address);
            }
            catch (ConnectionFailureException e) {
                this.forget(address);
                continue;
            }
        }
        throw new ConnectionFailureException("Failed to connect to any read server");
    }

    private Connection acquireWriteConnection() {
        int numberOfServers = this.clusterView.numberOfWriters();
        for (int i = 0; i < numberOfServers; ++i) {
            BoltServerAddress address = this.clusterView.nextWriter();
            try {
                return this.connections.acquire(address);
            }
            catch (ConnectionFailureException e) {
                this.forget(address);
                continue;
            }
        }
        throw new ConnectionFailureException("Failed to connect to any write server");
    }

    @Override
    public void close() {
        try {
            this.connections.close();
        }
        catch (Exception ex) {
            this.log.error(String.format("~~ [ERROR] %s", ex.getMessage()), ex);
        }
    }

    public Set<BoltServerAddress> routingServers() {
        return Collections.unmodifiableSet(this.clusterView.routingServers);
    }

    public Set<BoltServerAddress> readServers() {
        return Collections.unmodifiableSet(this.clusterView.readServers);
    }

    public Set<BoltServerAddress> writeServers() {
        return Collections.unmodifiableSet(this.clusterView.writeServers);
    }

    public ConnectionPool connectionPool() {
        return this.connections;
    }

    static /* synthetic */ Comparator access$400() {
        return COMPARATOR;
    }

    private static class ClusterView {
        private static final int MIN_ROUTERS = 1;
        private final ConcurrentRoundRobinSet<BoltServerAddress> routingServers = new ConcurrentRoundRobinSet(RoutingDriver.access$400());
        private final ConcurrentRoundRobinSet<BoltServerAddress> readServers = new ConcurrentRoundRobinSet(RoutingDriver.access$400());
        private final ConcurrentRoundRobinSet<BoltServerAddress> writeServers = new ConcurrentRoundRobinSet(RoutingDriver.access$400());
        private final Clock clock;
        private final long expires;
        private final Logger log;

        private ClusterView(long expires, Clock clock, Logger log) {
            this.expires = expires;
            this.clock = clock;
            this.log = log;
        }

        public void addRouter(BoltServerAddress router) {
            this.routingServers.add(router);
        }

        public boolean isStale() {
            return this.expires < this.clock.millis() || this.routingServers.size() <= 1 || this.readServers.isEmpty() || this.writeServers.isEmpty();
        }

        Set<BoltServerAddress> all() {
            HashSet<BoltServerAddress> all = new HashSet<BoltServerAddress>(this.routingServers.size() + this.readServers.size() + this.writeServers.size());
            all.addAll(this.routingServers);
            all.addAll(this.readServers);
            all.addAll(this.writeServers);
            return all;
        }

        public int numberOfRouters() {
            return this.routingServers.size();
        }

        public BoltServerAddress nextRouter() {
            return this.routingServers.hop();
        }

        public BoltServerAddress nextReader() {
            return this.readServers.hop();
        }

        public BoltServerAddress nextWriter() {
            return this.writeServers.hop();
        }

        public void addReaders(List<BoltServerAddress> addresses) {
            this.readServers.addAll((Collection<BoltServerAddress>)addresses);
        }

        public void addWriters(List<BoltServerAddress> addresses) {
            this.writeServers.addAll((Collection<BoltServerAddress>)addresses);
        }

        public void addRouters(List<BoltServerAddress> addresses) {
            this.routingServers.addAll((Collection<BoltServerAddress>)addresses);
        }

        public void remove(BoltServerAddress address) {
            if (this.routingServers.remove(address)) {
                this.log.debug("Removing %s from routers", address.toString());
            }
            if (this.readServers.remove(address)) {
                this.log.debug("Removing %s from readers", address.toString());
            }
            if (this.writeServers.remove(address)) {
                this.log.debug("Removing %s from writers", address.toString());
            }
        }

        public boolean removeWriter(BoltServerAddress address) {
            return this.writeServers.remove(address);
        }

        public int numberOfReaders() {
            return this.readServers.size();
        }

        public int numberOfWriters() {
            return this.writeServers.size();
        }
    }

    private static class ServerInfo {
        private final List<BoltServerAddress> addresses;
        private final String role;

        public ServerInfo(List<BoltServerAddress> addresses, String role) {
            this.addresses = addresses;
            this.role = role;
        }

        public String role() {
            return this.role;
        }

        List<BoltServerAddress> addresses() {
            return this.addresses;
        }
    }
}

