/*
 * Decompiled with CFR 0.152.
 */
package software.amazon.jdbc.plugin;

import java.sql.Connection;
import java.sql.SQLException;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Collections;
import java.util.HashSet;
import java.util.Optional;
import java.util.Properties;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.logging.Logger;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.rds.RdsUtilities;
import software.amazon.jdbc.AwsWrapperProperty;
import software.amazon.jdbc.HostSpec;
import software.amazon.jdbc.JdbcCallable;
import software.amazon.jdbc.PropertyDefinition;
import software.amazon.jdbc.plugin.AbstractConnectionPlugin;
import software.amazon.jdbc.util.Messages;
import software.amazon.jdbc.util.RdsUtils;
import software.amazon.jdbc.util.StringUtils;

public class IamAuthConnectionPlugin
extends AbstractConnectionPlugin {
    private static final Logger LOGGER = Logger.getLogger(IamAuthConnectionPlugin.class.getName());
    static final ConcurrentHashMap<String, TokenInfo> tokenCache = new ConcurrentHashMap();
    private static final int DEFAULT_TOKEN_EXPIRATION_SEC = 900;
    public static final int PG_PORT = 5432;
    public static final int MYSQL_PORT = 3306;
    protected static final AwsWrapperProperty SPECIFIED_PORT = new AwsWrapperProperty("iamDefaultPort", null, "Overrides default port that is used to generate IAM token");
    protected static final AwsWrapperProperty SPECIFIED_REGION = new AwsWrapperProperty("iamRegion", null, "Overrides AWS region that is used to generate IAM token");
    protected static final AwsWrapperProperty SPECIFIED_EXPIRATION = new AwsWrapperProperty("iamExpiration", null, "IAM token cache expiration in seconds");
    protected final RdsUtils rdsUtils = new RdsUtils();

    @Override
    public Set<String> getSubscribedMethods() {
        return new HashSet<String>(Collections.singletonList("connect"));
    }

    @Override
    public Connection connect(String driverProtocol, HostSpec hostSpec, Properties props, boolean isInitialConnection, JdbcCallable<Connection, SQLException> connectFunc) throws SQLException {
        if (StringUtils.isNullOrEmpty(PropertyDefinition.USER.getString(props))) {
            throw new SQLException(PropertyDefinition.USER.name + " is null or empty.");
        }
        String host = hostSpec.getHost();
        int port = hostSpec.getPort();
        if (!hostSpec.isPortSpecified()) {
            if (StringUtils.isNullOrEmpty(SPECIFIED_PORT.getString(props))) {
                if (!driverProtocol.startsWith("jdbc:postgresql:") && !driverProtocol.startsWith("jdbc:mysql:")) {
                    throw new RuntimeException(Messages.get("IamAuthConnectionPlugin.missingPort"));
                }
                port = driverProtocol.startsWith("jdbc:mysql:") ? 3306 : 5432;
            } else {
                port = SPECIFIED_PORT.getInteger(props);
                if (port <= 0) {
                    throw new IllegalArgumentException(Messages.get("IamAuthConnectionPlugin.invalidPort", new Object[]{port}));
                }
            }
        }
        Region region = StringUtils.isNullOrEmpty(SPECIFIED_REGION.getString(props)) ? this.getRdsRegion(host) : Region.of((String)SPECIFIED_REGION.getString(props));
        int tokenExpirationSec = StringUtils.isNullOrEmpty(SPECIFIED_EXPIRATION.getString(props)) ? 900 : SPECIFIED_EXPIRATION.getInteger(props);
        String cacheKey = this.getCacheKey(PropertyDefinition.USER.getString(props), host, port, region);
        TokenInfo tokenInfo = tokenCache.get(cacheKey);
        if (tokenInfo != null && !tokenInfo.isExpired()) {
            LOGGER.finest(() -> Messages.get("IamAuthConnectionPlugin.useCachedIamToken", new Object[]{tokenInfo.getToken()}));
            PropertyDefinition.PASSWORD.set(props, tokenInfo.getToken());
        } else {
            String token = this.generateAuthenticationToken(PropertyDefinition.USER.getString(props), hostSpec.getHost(), port, region);
            LOGGER.finest(() -> Messages.get("IamAuthConnectionPlugin.generatedNewIamToken", new Object[]{token}));
            PropertyDefinition.PASSWORD.set(props, token);
            tokenCache.put(cacheKey, new TokenInfo(token, Instant.now().plus((long)tokenExpirationSec, ChronoUnit.SECONDS)));
        }
        return connectFunc.call();
    }

    String generateAuthenticationToken(String user, String hostname, int port, Region region) {
        RdsUtilities utilities = RdsUtilities.builder().credentialsProvider((AwsCredentialsProvider)DefaultCredentialsProvider.create()).region(region).build();
        return utilities.generateAuthenticationToken(builder -> builder.hostname(hostname).port(port).username(user));
    }

    private String getCacheKey(String user, String hostname, int port, Region region) {
        return String.format("%s:%s:%d:%s", region, hostname, port, user);
    }

    static void clearCache() {
        tokenCache.clear();
    }

    private Region getRdsRegion(String hostname) throws SQLException {
        String rdsRegion = this.rdsUtils.getRdsRegion(hostname);
        if (StringUtils.isNullOrEmpty(rdsRegion)) {
            String exceptionMessage = Messages.get("IamAuthConnectionPlugin.unsupportedHostname", new Object[]{hostname});
            LOGGER.fine(() -> exceptionMessage);
            throw new SQLException(exceptionMessage);
        }
        Optional<Region> regionOptional = Region.regions().stream().filter(r -> r.id().equalsIgnoreCase(rdsRegion)).findFirst();
        if (!regionOptional.isPresent()) {
            String exceptionMessage = Messages.get("IamAuthConnectionPlugin.unsupportedRegion", new Object[]{rdsRegion});
            LOGGER.fine(() -> exceptionMessage);
            throw new SQLException(exceptionMessage);
        }
        return regionOptional.get();
    }

    static class TokenInfo {
        private final String token;
        private final Instant expiration;

        public TokenInfo(String token, Instant expiration) {
            this.token = token;
            this.expiration = expiration;
        }

        public String getToken() {
            return this.token;
        }

        public Instant getExpiration() {
            return this.expiration;
        }

        public boolean isExpired() {
            return Instant.now().isAfter(this.expiration);
        }
    }
}

