/*
 * Decompiled with CFR 0.152.
 */
package com.databricks.jdbc.auth;

import com.databricks.internal.apache.http.client.entity.UrlEncodedFormEntity;
import com.databricks.internal.apache.http.client.methods.CloseableHttpResponse;
import com.databricks.internal.apache.http.client.methods.HttpPost;
import com.databricks.internal.apache.http.client.utils.URIBuilder;
import com.databricks.internal.apache.http.message.BasicNameValuePair;
import com.databricks.internal.google.common.annotations.VisibleForTesting;
import com.databricks.internal.nimbusds.jwt.JWTClaimsSet;
import com.databricks.internal.nimbusds.jwt.SignedJWT;
import com.databricks.internal.sdk.core.CredentialsProvider;
import com.databricks.internal.sdk.core.DatabricksConfig;
import com.databricks.internal.sdk.core.HeaderFactory;
import com.databricks.internal.sdk.core.oauth.OAuthResponse;
import com.databricks.internal.sdk.core.oauth.Token;
import com.databricks.internal.sdk.core.oauth.TokenSource;
import com.databricks.jdbc.api.internal.IDatabricksConnectionContext;
import com.databricks.jdbc.common.util.DriverUtil;
import com.databricks.jdbc.common.util.JsonUtil;
import com.databricks.jdbc.dbclient.IDatabricksHttpClient;
import com.databricks.jdbc.dbclient.impl.http.DatabricksHttpClientFactory;
import com.databricks.jdbc.exception.DatabricksDriverException;
import com.databricks.jdbc.log.JdbcLogger;
import com.databricks.jdbc.log.JdbcLoggerFactory;
import com.databricks.jdbc.model.telemetry.enums.DatabricksDriverErrorCode;
import java.net.MalformedURLException;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.text.ParseException;
import java.time.Duration;
import java.time.Instant;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

public class DatabricksTokenFederationProvider
implements CredentialsProvider,
TokenSource {
    private static final JdbcLogger LOGGER = JdbcLoggerFactory.getLogger(DatabricksTokenFederationProvider.class);
    private Token token;
    private HeaderFactory externalHeaderFactory;
    private static final Map<String, String> TOKEN_EXCHANGE_PARAMS = Map.of("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange", "scope", "sql", "subject_token_type", "urn:ietf:params:oauth:token-type:jwt", "return_original_token_if_authenticated", "true");
    private static final String TOKEN_EXCHANGE_ENDPOINT = "/oidc/v1/token";
    private final IDatabricksConnectionContext connectionContext;
    private final CredentialsProvider credentialsProvider;
    private DatabricksConfig config;
    private Map<String, String> externalProviderHeaders;
    private IDatabricksHttpClient hc;

    public DatabricksTokenFederationProvider(IDatabricksConnectionContext connectionContext, CredentialsProvider credentialsProvider) {
        this.connectionContext = connectionContext;
        this.credentialsProvider = credentialsProvider;
        this.externalProviderHeaders = new HashMap<String, String>();
        this.hc = DatabricksHttpClientFactory.getInstance().getClient(connectionContext);
        this.config = null;
        this.externalHeaderFactory = null;
        this.token = new Token("", "", "", Instant.now().minus(Duration.ofMinutes(1L)));
    }

    @VisibleForTesting
    DatabricksTokenFederationProvider(IDatabricksConnectionContext connectionContext, CredentialsProvider credentialsProvider, DatabricksConfig config) {
        this.connectionContext = connectionContext;
        this.credentialsProvider = credentialsProvider;
        this.config = config;
        this.externalHeaderFactory = this.credentialsProvider.configure(this.config);
        this.externalProviderHeaders = new HashMap<String, String>();
        this.token = new Token("", "", "", Instant.now().minus(Duration.ofMinutes(1L)));
    }

    @Override
    public String authType() {
        return this.credentialsProvider.authType();
    }

    public CredentialsProvider getCredentialsProvider() {
        return this.credentialsProvider;
    }

    @Override
    public HeaderFactory configure(DatabricksConfig databricksConfig) {
        LOGGER.debug("DatabricksTokenFederation configure");
        if (DriverUtil.isRunningAgainstFake()) {
            return this.credentialsProvider.configure(databricksConfig);
        }
        this.config = databricksConfig;
        this.externalHeaderFactory = this.credentialsProvider.configure(this.config);
        return () -> {
            Token exchangedToken = this.getToken();
            HashMap<String, String> headers = new HashMap<String, String>(this.externalProviderHeaders);
            headers.put("Authorization", exchangedToken.getTokenType() + " " + exchangedToken.getAccessToken());
            return headers;
        };
    }

    @Override
    public Token getToken() {
        if (this.externalHeaderFactory == null) {
            this.externalHeaderFactory = this.credentialsProvider.configure(this.config);
        }
        this.externalProviderHeaders = this.externalHeaderFactory.headers();
        String[] tokenInfo = this.extractTokenInfoFromHeader(this.externalProviderHeaders);
        String accessTokenType = tokenInfo[0];
        String accessToken = tokenInfo[1];
        try {
            SignedJWT signedJWT = SignedJWT.parse(accessToken);
            JWTClaimsSet claims = signedJWT.getJWTClaimsSet();
            Optional<Object> optionalToken = Optional.empty();
            if (!this.isSameHost(claims.getIssuer(), this.config.getHost())) {
                optionalToken = this.tryTokenExchange(accessToken, accessTokenType);
            }
            if (optionalToken.isEmpty()) {
                optionalToken = Optional.of(this.createToken(accessToken, accessTokenType));
            }
            return (Token)optionalToken.get();
        }
        catch (Exception e) {
            LOGGER.error(e, "Failed to refresh access token");
            throw new DatabricksDriverException("Failed to refresh access token", (Throwable)e, DatabricksDriverErrorCode.AUTH_ERROR);
        }
    }

    @VisibleForTesting
    Optional<Token> tryTokenExchange(String accessToken, String accessTokenType) {
        LOGGER.debug("Token tryTokenExchange(String accessToken, String accessTokenType = {})", accessTokenType);
        try {
            return Optional.of(this.exchangeToken(accessToken));
        }
        catch (Exception e) {
            LOGGER.error(e, "Token exchange failed, falling back to using external token");
            return Optional.empty();
        }
    }

    @VisibleForTesting
    Token createToken(String accessToken, String accessTokenType) throws ParseException {
        SignedJWT signedJWT = SignedJWT.parse(accessToken);
        JWTClaimsSet claims = signedJWT.getJWTClaimsSet();
        Instant expiry = Instant.ofEpochMilli(claims.getExpirationTime().getTime());
        return new Token(accessToken, accessTokenType, "", expiry);
    }

    @VisibleForTesting
    Token exchangeToken(String accessToken) {
        LOGGER.debug("Token exchangeToken( String accessToken )");
        String tokenUrl = this.config.getHost() + TOKEN_EXCHANGE_ENDPOINT;
        HashMap<String, String> params = new HashMap<String, String>(TOKEN_EXCHANGE_PARAMS);
        params.put("subject_token", accessToken);
        if (this.connectionContext.getIdentityFederationClientId() != null) {
            params.put("client_id", this.connectionContext.getIdentityFederationClientId());
        }
        HashMap<String, String> headers = new HashMap<String, String>();
        headers.put("Accept", "*/*");
        headers.put("Content-Type", "application/x-www-form-urlencoded");
        return this.retrieveToken(this.hc, tokenUrl, params, headers);
    }

    @VisibleForTesting
    Token retrieveToken(IDatabricksHttpClient hc, String tokenUrl, Map<String, String> params, Map<String, String> headers) {
        try {
            URIBuilder uriBuilder = new URIBuilder(tokenUrl);
            HttpPost postRequest = new HttpPost(uriBuilder.build());
            postRequest.setEntity(new UrlEncodedFormEntity(params.entrySet().stream().map(e -> new BasicNameValuePair((String)e.getKey(), (String)e.getValue())).collect(Collectors.toList()), StandardCharsets.UTF_8));
            headers.forEach(postRequest::setHeader);
            CloseableHttpResponse response = hc.execute(postRequest);
            OAuthResponse resp = JsonUtil.getMapper().readValue(response.getEntity().getContent(), OAuthResponse.class);
            return this.createToken(resp.getAccessToken(), resp.getTokenType());
        }
        catch (Exception e2) {
            LOGGER.error(e2, "Failed to retrieve the exchanged token");
            throw new DatabricksDriverException("Failed to retrieve the exchanged token", (Throwable)e2, DatabricksDriverErrorCode.AUTH_ERROR);
        }
    }

    private boolean isSameHost(String url1, String url2) {
        try {
            String host1 = new URL(url1).getHost();
            String host2 = new URL(url2).getHost();
            return host1.equals(host2);
        }
        catch (MalformedURLException e) {
            LOGGER.error(e, "Unable to parse URL String");
            return false;
        }
    }

    private String[] extractTokenInfoFromHeader(Map<String, String> headers) {
        String authHeader = headers.get("Authorization");
        try {
            return authHeader.split(" ", 2);
        }
        catch (NullPointerException e) {
            LOGGER.error(e, "Failed to extract token info from header");
            throw new DatabricksDriverException("Failed to extract token info from header", (Throwable)e, DatabricksDriverErrorCode.AUTH_ERROR);
        }
    }
}

