/*
 * Decompiled with CFR 0.152.
 */
package org.apache.arrow.driver.jdbc.client;

import java.io.IOException;
import java.net.URI;
import java.security.GeneralSecurityException;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.apache.arrow.driver.jdbc.client.CloseableEndpointStreamPair;
import org.apache.arrow.driver.jdbc.client.utils.ClientAuthenticationUtils;
import org.apache.arrow.driver.jdbc.shaded.com.google.common.collect.ImmutableMap;
import org.apache.arrow.driver.jdbc.shaded.org.apache.arrow.flight.CallOption;
import org.apache.arrow.driver.jdbc.shaded.org.apache.arrow.flight.CallStatus;
import org.apache.arrow.driver.jdbc.shaded.org.apache.arrow.flight.CloseSessionRequest;
import org.apache.arrow.driver.jdbc.shaded.org.apache.arrow.flight.FlightClient;
import org.apache.arrow.driver.jdbc.shaded.org.apache.arrow.flight.FlightClientMiddleware;
import org.apache.arrow.driver.jdbc.shaded.org.apache.arrow.flight.FlightEndpoint;
import org.apache.arrow.driver.jdbc.shaded.org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.driver.jdbc.shaded.org.apache.arrow.flight.FlightRuntimeException;
import org.apache.arrow.driver.jdbc.shaded.org.apache.arrow.flight.FlightStatusCode;
import org.apache.arrow.driver.jdbc.shaded.org.apache.arrow.flight.Location;
import org.apache.arrow.driver.jdbc.shaded.org.apache.arrow.flight.SessionOptionValueFactory;
import org.apache.arrow.driver.jdbc.shaded.org.apache.arrow.flight.SetSessionOptionsRequest;
import org.apache.arrow.driver.jdbc.shaded.org.apache.arrow.flight.SetSessionOptionsResult;
import org.apache.arrow.driver.jdbc.shaded.org.apache.arrow.flight.auth2.BearerCredentialWriter;
import org.apache.arrow.driver.jdbc.shaded.org.apache.arrow.flight.auth2.ClientBearerHeaderHandler;
import org.apache.arrow.driver.jdbc.shaded.org.apache.arrow.flight.auth2.ClientIncomingAuthHeaderMiddleware;
import org.apache.arrow.driver.jdbc.shaded.org.apache.arrow.flight.client.ClientCookieMiddleware;
import org.apache.arrow.driver.jdbc.shaded.org.apache.arrow.flight.grpc.CredentialCallOption;
import org.apache.arrow.driver.jdbc.shaded.org.apache.arrow.flight.sql.FlightSqlClient;
import org.apache.arrow.driver.jdbc.shaded.org.apache.arrow.flight.sql.impl.FlightSql;
import org.apache.arrow.driver.jdbc.shaded.org.apache.arrow.flight.sql.util.TableRef;
import org.apache.arrow.driver.jdbc.shaded.org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.driver.jdbc.shaded.org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.driver.jdbc.shaded.org.apache.arrow.util.Preconditions;
import org.apache.arrow.driver.jdbc.shaded.org.apache.arrow.util.VisibleForTesting;
import org.apache.arrow.driver.jdbc.shaded.org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.driver.jdbc.shaded.org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.driver.jdbc.shaded.org.apache.calcite.avatica.Meta;
import org.apache.arrow.driver.jdbc.shaded.org.checkerframework.checker.nullness.qual.Nullable;
import org.apache.arrow.driver.jdbc.shaded.org.slf4j.Logger;
import org.apache.arrow.driver.jdbc.shaded.org.slf4j.LoggerFactory;

public final class ArrowFlightSqlClientHandler
implements AutoCloseable {
    private static final Logger LOGGER = LoggerFactory.getLogger(ArrowFlightSqlClientHandler.class);
    private static final String CATALOG = "catalog";
    private final FlightSqlClient sqlClient;
    private final Set<CallOption> options = new HashSet<CallOption>();
    private final Builder builder;
    private final Optional<String> catalog;

    ArrowFlightSqlClientHandler(FlightSqlClient sqlClient, Builder builder, Collection<CallOption> credentialOptions, Optional<String> catalog) {
        this.options.addAll(builder.options);
        this.options.addAll(credentialOptions);
        this.sqlClient = Preconditions.checkNotNull(sqlClient);
        this.builder = builder;
        this.catalog = catalog;
    }

    static ArrowFlightSqlClientHandler createNewHandler(FlightClient client, Builder builder, Collection<CallOption> options, Optional<String> catalog) {
        ArrowFlightSqlClientHandler handler = new ArrowFlightSqlClientHandler(new FlightSqlClient(client), builder, options, catalog);
        handler.setSetCatalogInSessionIfPresent();
        return handler;
    }

    private CallOption[] getOptions() {
        return this.options.toArray(new CallOption[0]);
    }

    public List<CloseableEndpointStreamPair> getStreams(FlightInfo flightInfo) throws SQLException {
        ArrayList<CloseableEndpointStreamPair> endpoints = new ArrayList<CloseableEndpointStreamPair>(flightInfo.getEndpoints().size());
        try {
            for (FlightEndpoint endpoint : flightInfo.getEndpoints()) {
                if (endpoint.getLocations().isEmpty()) {
                    endpoints.add(new CloseableEndpointStreamPair(this.sqlClient.getStream(endpoint.getTicket(), this.getOptions()), null));
                    continue;
                }
                ArrayList<Exception> exceptions = new ArrayList<Exception>();
                CloseableEndpointStreamPair stream = null;
                for (Location location : endpoint.getLocations()) {
                    URI endpointUri = location.getUri();
                    if (endpointUri.getScheme().equals("arrow-flight-reuse-connection")) {
                        stream = new CloseableEndpointStreamPair(this.sqlClient.getStream(endpoint.getTicket(), this.getOptions()), null);
                        break;
                    }
                    Builder builderForEndpoint = new Builder(this.builder).withHost(endpointUri.getHost()).withPort(endpointUri.getPort()).withEncryption(endpointUri.getScheme().equals("grpc+tls"));
                    ArrowFlightSqlClientHandler endpointHandler = null;
                    try {
                        endpointHandler = builderForEndpoint.build();
                        stream = new CloseableEndpointStreamPair(endpointHandler.sqlClient.getStream(endpoint.getTicket(), endpointHandler.getOptions()), endpointHandler.sqlClient);
                        stream.getStream().getSchema();
                        break;
                    }
                    catch (Exception ex) {
                        if (endpointHandler != null) {
                            AutoCloseables.close(endpointHandler);
                        }
                        exceptions.add(ex);
                    }
                }
                if (stream != null) {
                    endpoints.add(stream);
                    continue;
                }
                if (exceptions.isEmpty()) {
                    throw new IllegalStateException("Could not connect to endpoint and no errors occurred");
                }
                Exception ex = (Exception)exceptions.remove(0);
                while (!exceptions.isEmpty()) {
                    ex.addSuppressed((Throwable)exceptions.remove(exceptions.size() - 1));
                }
                throw ex;
            }
        }
        catch (Exception outerException) {
            try {
                AutoCloseables.close(endpoints);
            }
            catch (Exception innerEx) {
                outerException.addSuppressed(innerEx);
            }
            if (outerException instanceof SQLException) {
                throw (SQLException)outerException;
            }
            throw new SQLException(outerException);
        }
        return endpoints;
    }

    public FlightInfo getInfo(String query) {
        return this.sqlClient.execute(query, this.getOptions());
    }

    @Override
    public void close() throws SQLException {
        if (this.catalog.isPresent()) {
            this.sqlClient.closeSession(new CloseSessionRequest(), this.getOptions());
        }
        try {
            AutoCloseables.close(this.sqlClient);
        }
        catch (Exception e) {
            throw new SQLException("Failed to clean up client resources.", e);
        }
    }

    private void setSetCatalogInSessionIfPresent() {
        SetSessionOptionsRequest setSessionOptionRequest;
        SetSessionOptionsResult result;
        if (this.catalog.isPresent() && (result = this.sqlClient.setSessionOptions(setSessionOptionRequest = new SetSessionOptionsRequest(ImmutableMap.builder().put(CATALOG, SessionOptionValueFactory.makeSessionOptionValue(this.catalog.get())).build()), this.getOptions())).hasErrors()) {
            Map<String, SetSessionOptionsResult.Error> errors = result.getErrors();
            for (Map.Entry<String, SetSessionOptionsResult.Error> error : errors.entrySet()) {
                LOGGER.warn(error.toString());
            }
            throw CallStatus.INVALID_ARGUMENT.withDescription(String.format("Cannot set session option for catalog = %s. Check log for details.", this.catalog)).toRuntimeException();
        }
    }

    public PreparedStatement prepare(String query) {
        final FlightSqlClient.PreparedStatement preparedStatement = this.sqlClient.prepare(query, this.getOptions());
        return new PreparedStatement(){

            @Override
            public FlightInfo executeQuery() throws SQLException {
                return preparedStatement.execute(ArrowFlightSqlClientHandler.this.getOptions());
            }

            @Override
            public long executeUpdate() {
                return preparedStatement.executeUpdate(ArrowFlightSqlClientHandler.this.getOptions());
            }

            @Override
            public Meta.StatementType getType() {
                Schema schema = preparedStatement.getResultSetSchema();
                return schema.getFields().isEmpty() ? Meta.StatementType.UPDATE : Meta.StatementType.SELECT;
            }

            @Override
            public Schema getDataSetSchema() {
                return preparedStatement.getResultSetSchema();
            }

            @Override
            public Schema getParameterSchema() {
                return preparedStatement.getParameterSchema();
            }

            @Override
            public void setParameters(VectorSchemaRoot parameters) {
                preparedStatement.setParameters(parameters);
            }

            @Override
            public void close() {
                try {
                    preparedStatement.close(ArrowFlightSqlClientHandler.this.getOptions());
                }
                catch (FlightRuntimeException fre) {
                    if (fre.status().code().equals((Object)FlightStatusCode.UNAVAILABLE) || fre.status().code().equals((Object)FlightStatusCode.INTERNAL) && fre.getMessage().contains("Connection closed after GOAWAY")) {
                        LOGGER.warn("Supressed error closing PreparedStatement", fre);
                        return;
                    }
                    throw fre;
                }
            }
        };
    }

    public FlightInfo getCatalogs() {
        return this.sqlClient.getCatalogs(this.getOptions());
    }

    public FlightInfo getImportedKeys(String catalog, String schema, String table) {
        return this.sqlClient.getImportedKeys(TableRef.of(catalog, schema, table), this.getOptions());
    }

    public FlightInfo getExportedKeys(String catalog, String schema, String table) {
        return this.sqlClient.getExportedKeys(TableRef.of(catalog, schema, table), this.getOptions());
    }

    public FlightInfo getSchemas(String catalog, String schemaPattern) {
        return this.sqlClient.getSchemas(catalog, schemaPattern, this.getOptions());
    }

    public FlightInfo getTableTypes() {
        return this.sqlClient.getTableTypes(this.getOptions());
    }

    public FlightInfo getTables(String catalog, String schemaPattern, String tableNamePattern, List<String> types, boolean includeSchema) {
        return this.sqlClient.getTables(catalog, schemaPattern, tableNamePattern, types, includeSchema, this.getOptions());
    }

    public FlightInfo getSqlInfo(FlightSql.SqlInfo ... info) {
        return this.sqlClient.getSqlInfo(info, this.getOptions());
    }

    public FlightInfo getPrimaryKeys(String catalog, String schema, String table) {
        return this.sqlClient.getPrimaryKeys(TableRef.of(catalog, schema, table), this.getOptions());
    }

    public FlightInfo getCrossReference(String pkCatalog, String pkSchema, String pkTable, String fkCatalog, String fkSchema, String fkTable) {
        return this.sqlClient.getCrossReference(TableRef.of(pkCatalog, pkSchema, pkTable), TableRef.of(fkCatalog, fkSchema, fkTable), this.getOptions());
    }

    public static final class Builder {
        private final Set<FlightClientMiddleware.Factory> middlewareFactories = new HashSet<FlightClientMiddleware.Factory>();
        private final Set<CallOption> options = new HashSet<CallOption>();
        private String host;
        private int port;
        @VisibleForTesting
        String username;
        @VisibleForTesting
        String password;
        @VisibleForTesting
        String trustStorePath;
        @VisibleForTesting
        String trustStorePassword;
        @VisibleForTesting
        String token;
        @VisibleForTesting
        boolean useEncryption = true;
        @VisibleForTesting
        boolean disableCertificateVerification;
        @VisibleForTesting
        boolean useSystemTrustStore = true;
        @VisibleForTesting
        String tlsRootCertificatesPath;
        @VisibleForTesting
        String clientCertificatePath;
        @VisibleForTesting
        String clientKeyPath;
        @VisibleForTesting
        private BufferAllocator allocator;
        @VisibleForTesting
        boolean retainCookies = true;
        @VisibleForTesting
        boolean retainAuth = true;
        @VisibleForTesting
        Optional<String> catalog = Optional.empty();
        @VisibleForTesting
        ClientIncomingAuthHeaderMiddleware.Factory authFactory = new ClientIncomingAuthHeaderMiddleware.Factory(new ClientBearerHeaderHandler());
        @VisibleForTesting
        ClientCookieMiddleware.Factory cookieFactory = new ClientCookieMiddleware.Factory();

        public Builder() {
        }

        @VisibleForTesting
        Builder(Builder original) {
            this.middlewareFactories.addAll(original.middlewareFactories);
            this.options.addAll(original.options);
            this.host = original.host;
            this.port = original.port;
            this.username = original.username;
            this.password = original.password;
            this.trustStorePath = original.trustStorePath;
            this.trustStorePassword = original.trustStorePassword;
            this.token = original.token;
            this.useEncryption = original.useEncryption;
            this.disableCertificateVerification = original.disableCertificateVerification;
            this.useSystemTrustStore = original.useSystemTrustStore;
            this.tlsRootCertificatesPath = original.tlsRootCertificatesPath;
            this.clientCertificatePath = original.clientCertificatePath;
            this.clientKeyPath = original.clientKeyPath;
            this.allocator = original.allocator;
            this.catalog = original.catalog;
            if (original.retainCookies) {
                this.cookieFactory = original.cookieFactory;
            }
            if (original.retainAuth) {
                this.authFactory = original.authFactory;
            }
        }

        public Builder withHost(String host) {
            this.host = host;
            return this;
        }

        public Builder withPort(int port) {
            this.port = port;
            return this;
        }

        public Builder withUsername(String username) {
            this.username = username;
            return this;
        }

        public Builder withPassword(String password) {
            this.password = password;
            return this;
        }

        public Builder withTrustStorePath(String trustStorePath) {
            this.trustStorePath = trustStorePath;
            return this;
        }

        public Builder withTrustStorePassword(String trustStorePassword) {
            this.trustStorePassword = trustStorePassword;
            return this;
        }

        public Builder withEncryption(boolean useEncryption) {
            this.useEncryption = useEncryption;
            return this;
        }

        public Builder withDisableCertificateVerification(boolean disableCertificateVerification) {
            this.disableCertificateVerification = disableCertificateVerification;
            return this;
        }

        public Builder withSystemTrustStore(boolean useSystemTrustStore) {
            this.useSystemTrustStore = useSystemTrustStore;
            return this;
        }

        public Builder withTlsRootCertificates(String tlsRootCertificatesPath) {
            this.tlsRootCertificatesPath = tlsRootCertificatesPath;
            return this;
        }

        public Builder withClientCertificate(String clientCertificatePath) {
            this.clientCertificatePath = clientCertificatePath;
            return this;
        }

        public Builder withClientKey(String clientKeyPath) {
            this.clientKeyPath = clientKeyPath;
            return this;
        }

        public Builder withToken(String token) {
            this.token = token;
            return this;
        }

        public Builder withBufferAllocator(BufferAllocator allocator) {
            this.allocator = allocator.newChildAllocator("ArrowFlightSqlClientHandler", 0L, allocator.getLimit());
            return this;
        }

        public Builder withRetainCookies(boolean retainCookies) {
            this.retainCookies = retainCookies;
            return this;
        }

        public Builder withRetainAuth(boolean retainAuth) {
            this.retainAuth = retainAuth;
            return this;
        }

        public Builder withMiddlewareFactories(FlightClientMiddleware.Factory ... factories) {
            return this.withMiddlewareFactories(Arrays.asList(factories));
        }

        public Builder withMiddlewareFactories(Collection<FlightClientMiddleware.Factory> factories) {
            this.middlewareFactories.addAll(factories);
            return this;
        }

        public Builder withCallOptions(CallOption ... options) {
            return this.withCallOptions(Arrays.asList(options));
        }

        public Builder withCallOptions(Collection<CallOption> options) {
            this.options.addAll(options);
            return this;
        }

        public Builder withCatalog(@Nullable String catalog) {
            this.catalog = Optional.ofNullable(catalog);
            return this;
        }

        public ArrowFlightSqlClientHandler build() throws SQLException {
            HashSet<FlightClientMiddleware.Factory> buildTimeMiddlewareFactories = new HashSet<FlightClientMiddleware.Factory>(this.middlewareFactories);
            FlightClient client = null;
            boolean isUsingUserPasswordAuth = this.username != null && this.token == null;
            try {
                Location location;
                if (isUsingUserPasswordAuth) {
                    buildTimeMiddlewareFactories.add(this.authFactory);
                }
                FlightClient.Builder clientBuilder = FlightClient.builder().allocator(this.allocator);
                buildTimeMiddlewareFactories.add(new ClientCookieMiddleware.Factory());
                buildTimeMiddlewareFactories.forEach(clientBuilder::intercept);
                if (this.useEncryption) {
                    location = Location.forGrpcTls(this.host, this.port);
                    clientBuilder.useTls();
                } else {
                    location = Location.forGrpcInsecure(this.host, this.port);
                }
                clientBuilder.location(location);
                if (this.useEncryption) {
                    if (this.disableCertificateVerification) {
                        clientBuilder.verifyServer(false);
                    } else if (this.tlsRootCertificatesPath != null) {
                        clientBuilder.trustedCertificates(ClientAuthenticationUtils.getTlsRootCertificatesStream(this.tlsRootCertificatesPath));
                    } else if (this.useSystemTrustStore) {
                        clientBuilder.trustedCertificates(ClientAuthenticationUtils.getCertificateInputStreamFromSystem(this.trustStorePassword));
                    } else if (this.trustStorePath != null) {
                        clientBuilder.trustedCertificates(ClientAuthenticationUtils.getCertificateStream(this.trustStorePath, this.trustStorePassword));
                    }
                    if (this.clientCertificatePath != null && this.clientKeyPath != null) {
                        clientBuilder.clientCertificate(ClientAuthenticationUtils.getClientCertificateStream(this.clientCertificatePath), ClientAuthenticationUtils.getClientKeyStream(this.clientKeyPath));
                    }
                }
                client = clientBuilder.build();
                ArrayList<CallOption> credentialOptions = new ArrayList<CallOption>();
                if (isUsingUserPasswordAuth) {
                    if (this.authFactory.getCredentialCallOption() != null) {
                        credentialOptions.add(this.authFactory.getCredentialCallOption());
                    } else {
                        credentialOptions.add(ClientAuthenticationUtils.getAuthenticate(client, this.username, this.password, this.authFactory, this.options.toArray(new CallOption[0])));
                    }
                } else if (this.token != null) {
                    credentialOptions.add(ClientAuthenticationUtils.getAuthenticate(client, new CredentialCallOption(new BearerCredentialWriter(this.token)), this.options.toArray(new CallOption[0])));
                }
                return ArrowFlightSqlClientHandler.createNewHandler(client, this, credentialOptions, this.catalog);
            }
            catch (IOException | IllegalArgumentException | GeneralSecurityException | FlightRuntimeException e) {
                SQLException originalException = new SQLException(e);
                if (client != null) {
                    try {
                        client.close();
                    }
                    catch (InterruptedException interruptedException) {
                        originalException.addSuppressed(interruptedException);
                    }
                }
                throw originalException;
            }
        }
    }

    public static interface PreparedStatement
    extends AutoCloseable {
        public FlightInfo executeQuery() throws SQLException;

        public long executeUpdate();

        public Meta.StatementType getType();

        public Schema getDataSetSchema();

        public Schema getParameterSchema();

        public void setParameters(VectorSchemaRoot var1);

        @Override
        public void close();
    }
}

