/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kafka.common.security.oauthbearer;

import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import javax.security.auth.login.AppConfigurationEntry;
import org.apache.kafka.common.security.oauthbearer.JwtValidator;
import org.apache.kafka.common.security.oauthbearer.JwtValidatorException;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
import org.apache.kafka.common.security.oauthbearer.internals.secured.BasicOAuthBearerToken;
import org.apache.kafka.common.security.oauthbearer.internals.secured.ClaimValidationUtils;
import org.apache.kafka.common.security.oauthbearer.internals.secured.CloseableVerificationKeyResolver;
import org.apache.kafka.common.security.oauthbearer.internals.secured.ConfigurationUtils;
import org.apache.kafka.common.security.oauthbearer.internals.secured.SerializedJwt;
import org.apache.kafka.common.security.oauthbearer.internals.secured.VerificationKeyResolverFactory;
import org.jose4j.jwa.AlgorithmConstraints;
import org.jose4j.jwt.JwtClaims;
import org.jose4j.jwt.MalformedClaimException;
import org.jose4j.jwt.NumericDate;
import org.jose4j.jwt.consumer.InvalidJwtException;
import org.jose4j.jwt.consumer.JwtConsumer;
import org.jose4j.jwt.consumer.JwtConsumerBuilder;
import org.jose4j.jwt.consumer.JwtContext;
import org.jose4j.keys.resolvers.VerificationKeyResolver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BrokerJwtValidator
implements JwtValidator {
    private static final Logger log = LoggerFactory.getLogger(BrokerJwtValidator.class);
    private final Optional<CloseableVerificationKeyResolver> verificationKeyResolverOpt;
    private JwtConsumer jwtConsumer;
    private String scopeClaimName;
    private String subClaimName;

    public BrokerJwtValidator() {
        this.verificationKeyResolverOpt = Optional.empty();
    }

    BrokerJwtValidator(CloseableVerificationKeyResolver verificationKeyResolver) {
        this.verificationKeyResolverOpt = Optional.of(verificationKeyResolver);
    }

    @Override
    public void configure(Map<String, ?> configs, String saslMechanism, List<AppConfigurationEntry> jaasConfigEntries) {
        ConfigurationUtils cu = new ConfigurationUtils(configs, saslMechanism);
        List expectedAudiencesList = (List)cu.get("sasl.oauthbearer.expected.audience");
        Set expectedAudiences = expectedAudiencesList != null ? Set.copyOf(expectedAudiencesList) : null;
        Integer clockSkew = cu.validateInteger("sasl.oauthbearer.clock.skew.seconds", false);
        String expectedIssuer = cu.validateString("sasl.oauthbearer.expected.issuer", false);
        String scopeClaimName = cu.validateString("sasl.oauthbearer.scope.claim.name");
        String subClaimName = cu.validateString("sasl.oauthbearer.sub.claim.name");
        CloseableVerificationKeyResolver verificationKeyResolver = this.verificationKeyResolverOpt.orElseGet(() -> VerificationKeyResolverFactory.get(configs, saslMechanism, jaasConfigEntries));
        JwtConsumerBuilder jwtConsumerBuilder = new JwtConsumerBuilder();
        if (clockSkew != null) {
            jwtConsumerBuilder.setAllowedClockSkewInSeconds(clockSkew.intValue());
        }
        if (expectedAudiences != null && !expectedAudiences.isEmpty()) {
            jwtConsumerBuilder.setExpectedAudience(expectedAudiences.toArray(new String[0]));
        }
        if (expectedIssuer != null) {
            jwtConsumerBuilder.setExpectedIssuer(expectedIssuer);
        }
        this.jwtConsumer = jwtConsumerBuilder.setJwsAlgorithmConstraints(AlgorithmConstraints.DISALLOW_NONE).setRequireExpirationTime().setRequireIssuedAt().setVerificationKeyResolver((VerificationKeyResolver)verificationKeyResolver).build();
        this.scopeClaimName = scopeClaimName;
        this.subClaimName = subClaimName;
    }

    @Override
    public OAuthBearerToken validate(String accessToken) throws JwtValidatorException {
        JwtContext jwt;
        SerializedJwt serializedJwt = new SerializedJwt(accessToken);
        try {
            jwt = this.jwtConsumer.process(serializedJwt.getToken());
        }
        catch (InvalidJwtException e) {
            throw new JwtValidatorException(String.format("Could not validate the access token: %s", e.getMessage()), e);
        }
        JwtClaims claims = jwt.getJwtClaims();
        Object scopeRaw = this.getClaim(() -> claims.getClaimValue(this.scopeClaimName), this.scopeClaimName);
        Collection<Object> scopeRawCollection = scopeRaw instanceof String ? Collections.singletonList((String)scopeRaw) : (scopeRaw instanceof Collection ? (Collection)scopeRaw : Collections.emptySet());
        NumericDate expirationRaw = this.getClaim(() -> ((JwtClaims)claims).getExpirationTime(), "exp");
        String subRaw = this.getClaim(() -> claims.getStringClaimValue(this.subClaimName), this.subClaimName);
        NumericDate issuedAtRaw = this.getClaim(() -> ((JwtClaims)claims).getIssuedAt(), "iat");
        Set<String> scopes = ClaimValidationUtils.validateScopes(this.scopeClaimName, scopeRawCollection);
        long expiration = ClaimValidationUtils.validateExpiration("exp", expirationRaw != null ? Long.valueOf(expirationRaw.getValueInMillis()) : null);
        String sub = ClaimValidationUtils.validateSubject(this.subClaimName, subRaw);
        Long issuedAt = ClaimValidationUtils.validateIssuedAt("iat", issuedAtRaw != null ? Long.valueOf(issuedAtRaw.getValueInMillis()) : null);
        return new BasicOAuthBearerToken(accessToken, scopes, expiration, sub, issuedAt);
    }

    private <T> T getClaim(ClaimSupplier<T> supplier, String claimName) throws JwtValidatorException {
        try {
            T value = supplier.get();
            log.debug("getClaim - {}: {}", (Object)claimName, (Object)value);
            return value;
        }
        catch (MalformedClaimException e) {
            throw new JwtValidatorException(String.format("Could not extract the '%s' claim from the access token", claimName), e);
        }
    }

    public static interface ClaimSupplier<T> {
        public T get() throws MalformedClaimException;
    }
}

