/*
 * Decompiled with CFR 0.152.
 */
package software.amazon.jdbc.plugin;

import java.sql.Connection;
import java.sql.SQLException;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Collections;
import java.util.HashSet;
import java.util.Optional;
import java.util.Properties;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.logging.Logger;
import org.checkerframework.checker.nullness.qual.NonNull;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.rds.RdsUtilities;
import software.amazon.jdbc.AwsWrapperProperty;
import software.amazon.jdbc.HostSpec;
import software.amazon.jdbc.JdbcCallable;
import software.amazon.jdbc.PluginService;
import software.amazon.jdbc.PropertyDefinition;
import software.amazon.jdbc.authentication.AwsCredentialsManager;
import software.amazon.jdbc.plugin.AbstractConnectionPlugin;
import software.amazon.jdbc.util.Messages;
import software.amazon.jdbc.util.RdsUtils;
import software.amazon.jdbc.util.StringUtils;
import software.amazon.jdbc.util.telemetry.TelemetryContext;
import software.amazon.jdbc.util.telemetry.TelemetryCounter;
import software.amazon.jdbc.util.telemetry.TelemetryFactory;
import software.amazon.jdbc.util.telemetry.TelemetryGauge;
import software.amazon.jdbc.util.telemetry.TelemetryTraceLevel;

public class IamAuthConnectionPlugin
extends AbstractConnectionPlugin {
    private static final Logger LOGGER = Logger.getLogger(IamAuthConnectionPlugin.class.getName());
    private static final String TELEMETRY_FETCH_TOKEN = "fetch IAM token";
    private static final Set<String> subscribedMethods = Collections.unmodifiableSet(new HashSet<String>(){
        {
            this.add("connect");
            this.add("forceConnect");
        }
    });
    static final ConcurrentHashMap<String, TokenInfo> tokenCache = new ConcurrentHashMap();
    private static final int DEFAULT_TOKEN_EXPIRATION_SEC = 870;
    public static final AwsWrapperProperty IAM_HOST = new AwsWrapperProperty("iamHost", null, "Overrides the host that is used to generate the IAM token");
    public static final AwsWrapperProperty IAM_DEFAULT_PORT = new AwsWrapperProperty("iamDefaultPort", null, "Overrides default port that is used to generate the IAM token");
    public static final AwsWrapperProperty IAM_REGION = new AwsWrapperProperty("iamRegion", null, "Overrides AWS region that is used to generate the IAM token");
    public static final AwsWrapperProperty IAM_EXPIRATION = new AwsWrapperProperty("iamExpiration", String.valueOf(870), "IAM token cache expiration in seconds");
    protected final PluginService pluginService;
    protected final RdsUtils rdsUtils = new RdsUtils();
    private final TelemetryFactory telemetryFactory;
    private final TelemetryGauge cacheSizeGauge;
    private final TelemetryCounter fetchTokenCounter;

    public IamAuthConnectionPlugin(@NonNull PluginService pluginService) {
        this.pluginService = pluginService;
        this.telemetryFactory = pluginService.getTelemetryFactory();
        this.cacheSizeGauge = this.telemetryFactory.createGauge("iam.tokenCache.size", () -> tokenCache.size());
        this.fetchTokenCounter = this.telemetryFactory.createCounter("iam.fetchToken.count");
    }

    @Override
    public Set<String> getSubscribedMethods() {
        return subscribedMethods;
    }

    @Override
    public Connection connect(String driverProtocol, HostSpec hostSpec, Properties props, boolean isInitialConnection, JdbcCallable<Connection, SQLException> connectFunc) throws SQLException {
        return this.connectInternal(driverProtocol, hostSpec, props, connectFunc);
    }

    private Connection connectInternal(String driverProtocol, HostSpec hostSpec, Properties props, JdbcCallable<Connection, SQLException> connectFunc) throws SQLException {
        boolean isCachedToken;
        if (StringUtils.isNullOrEmpty(PropertyDefinition.USER.getString(props))) {
            throw new SQLException(PropertyDefinition.USER.name + " is null or empty.");
        }
        String host = hostSpec.getHost();
        if (!StringUtils.isNullOrEmpty(IAM_HOST.getString(props))) {
            host = IAM_HOST.getString(props);
        }
        int port = this.getPort(props, hostSpec);
        String iamRegion = IAM_REGION.getString(props);
        Region region = StringUtils.isNullOrEmpty(iamRegion) ? this.getRdsRegion(host) : Region.of((String)iamRegion);
        int tokenExpirationSec = IAM_EXPIRATION.getInteger(props);
        String cacheKey = this.getCacheKey(PropertyDefinition.USER.getString(props), host, port, region);
        TokenInfo tokenInfo = tokenCache.get(cacheKey);
        boolean bl = isCachedToken = tokenInfo != null && !tokenInfo.isExpired();
        if (isCachedToken) {
            LOGGER.finest(() -> Messages.get("IamAuthConnectionPlugin.useCachedIamToken", new Object[]{tokenInfo.getToken()}));
            PropertyDefinition.PASSWORD.set(props, tokenInfo.getToken());
        } else {
            Instant tokenExpiry = Instant.now().plus((long)tokenExpirationSec, ChronoUnit.SECONDS);
            String token = this.generateAuthenticationToken(hostSpec, props, host, port, region);
            LOGGER.finest(() -> Messages.get("IamAuthConnectionPlugin.generatedNewIamToken", new Object[]{token}));
            PropertyDefinition.PASSWORD.set(props, token);
            tokenCache.put(cacheKey, new TokenInfo(token, tokenExpiry));
        }
        try {
            return connectFunc.call();
        }
        catch (SQLException exception) {
            LOGGER.finest(() -> Messages.get("IamAuthConnectionPlugin.connectException", new Object[]{exception}));
            if (!this.pluginService.isLoginException(exception) || !isCachedToken) {
                throw exception;
            }
            Instant tokenExpiry = Instant.now().plus((long)tokenExpirationSec, ChronoUnit.SECONDS);
            String token = this.generateAuthenticationToken(hostSpec, props, host, port, region);
            LOGGER.finest(() -> Messages.get("IamAuthConnectionPlugin.generatedNewIamToken", new Object[]{token}));
            PropertyDefinition.PASSWORD.set(props, token);
            tokenCache.put(cacheKey, new TokenInfo(token, tokenExpiry));
            return connectFunc.call();
        }
        catch (Exception exception) {
            LOGGER.warning(() -> Messages.get("IamAuthConnectionPlugin.unhandledException", new Object[]{exception}));
            throw new SQLException(exception);
        }
    }

    @Override
    public Connection forceConnect(@NonNull String driverProtocol, @NonNull HostSpec hostSpec, @NonNull Properties props, boolean isInitialConnection, @NonNull JdbcCallable<Connection, SQLException> forceConnectFunc) throws SQLException {
        return this.connectInternal(driverProtocol, hostSpec, props, forceConnectFunc);
    }

    String generateAuthenticationToken(HostSpec originalHostSpec, Properties props, String hostname, int port, Region region) {
        TelemetryFactory telemetryFactory = this.pluginService.getTelemetryFactory();
        TelemetryContext telemetryContext = telemetryFactory.openTelemetryContext(TELEMETRY_FETCH_TOKEN, TelemetryTraceLevel.NESTED);
        this.fetchTokenCounter.inc();
        try {
            String user = PropertyDefinition.USER.getString(props);
            RdsUtilities utilities = RdsUtilities.builder().credentialsProvider(AwsCredentialsManager.getProvider(originalHostSpec, props)).region(region).build();
            String string = utilities.generateAuthenticationToken(builder -> builder.hostname(hostname).port(port).username(user));
            return string;
        }
        catch (Exception ex) {
            telemetryContext.setSuccess(false);
            telemetryContext.setException(ex);
            throw ex;
        }
        finally {
            telemetryContext.closeContext();
        }
    }

    private String getCacheKey(String user, String hostname, int port, Region region) {
        return String.format("%s:%s:%d:%s", region, hostname, port, user);
    }

    public static void clearCache() {
        tokenCache.clear();
    }

    private int getPort(Properties props, HostSpec hostSpec) {
        if (!StringUtils.isNullOrEmpty(IAM_DEFAULT_PORT.getString(props))) {
            int defaultPort = IAM_DEFAULT_PORT.getInteger(props);
            if (defaultPort > 0) {
                return defaultPort;
            }
            LOGGER.finest(() -> Messages.get("IamAuthConnectionPlugin.invalidPort", new Object[]{defaultPort}));
        }
        if (hostSpec.isPortSpecified()) {
            return hostSpec.getPort();
        }
        return this.pluginService.getDialect().getDefaultPort();
    }

    private Region getRdsRegion(String hostname) throws SQLException {
        String rdsRegion = this.rdsUtils.getRdsRegion(hostname);
        if (StringUtils.isNullOrEmpty(rdsRegion)) {
            String exceptionMessage = Messages.get("IamAuthConnectionPlugin.unsupportedHostname", new Object[]{hostname});
            LOGGER.fine(() -> exceptionMessage);
            throw new SQLException(exceptionMessage);
        }
        Optional<Region> regionOptional = Region.regions().stream().filter(r -> r.id().equalsIgnoreCase(rdsRegion)).findFirst();
        if (!regionOptional.isPresent()) {
            String exceptionMessage = Messages.get("AwsSdk.unsupportedRegion", new Object[]{rdsRegion});
            LOGGER.fine(() -> exceptionMessage);
            throw new SQLException(exceptionMessage);
        }
        return regionOptional.get();
    }

    static {
        PropertyDefinition.registerPluginProperties(IamAuthConnectionPlugin.class);
    }

    static class TokenInfo {
        private final String token;
        private final Instant expiration;

        public TokenInfo(String token, Instant expiration) {
            this.token = token;
            this.expiration = expiration;
        }

        public String getToken() {
            return this.token;
        }

        public Instant getExpiration() {
            return this.expiration;
        }

        public boolean isExpired() {
            return Instant.now().isAfter(this.expiration);
        }
    }
}

