/*
 * Decompiled with CFR 0.152.
 */
package org.glassfish.soteria.mechanisms.openid.controller;

import com.nimbusds.jose.Algorithm;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWEHeader;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.jwk.source.ImmutableSecret;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.jwk.source.JWKSourceBuilder;
import com.nimbusds.jose.proc.BadJOSEException;
import com.nimbusds.jose.proc.JWSKeySelector;
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.nimbusds.jose.util.DefaultResourceRetriever;
import com.nimbusds.jwt.EncryptedJWT;
import com.nimbusds.jwt.JWT;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.PlainJWT;
import com.nimbusds.jwt.SignedJWT;
import com.nimbusds.jwt.proc.DefaultJWTProcessor;
import com.nimbusds.jwt.proc.JWTClaimsSetVerifier;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;
import java.nio.charset.StandardCharsets;
import java.text.ParseException;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import org.glassfish.soteria.mechanisms.openid.controller.CacheKey;
import org.glassfish.soteria.mechanisms.openid.domain.OpenIdConfiguration;

@ApplicationScoped
public class JWTValidator {
    @Inject
    private OpenIdConfiguration configuration;
    private ConcurrentHashMap<CacheKey, JWSKeySelector<?>> jwsCache = new ConcurrentHashMap();

    public JWTClaimsSet validateBearerToken(JWT token, JWTClaimsSetVerifier jwtVerifier) {
        JWTClaimsSet claimsSet;
        block6: {
            try {
                if (token instanceof PlainJWT) {
                    PlainJWT plainToken = (PlainJWT)token;
                    claimsSet = plainToken.getJWTClaimsSet();
                    jwtVerifier.verify(claimsSet, null);
                    break block6;
                }
                if (token instanceof SignedJWT) {
                    SignedJWT signedToken = (SignedJWT)token;
                    JWSHeader header = signedToken.getHeader();
                    String alg = header.getAlgorithm().getName();
                    if (Objects.isNull(alg)) {
                        alg = "RS256";
                    }
                    DefaultJWTProcessor jwtProcessor = new DefaultJWTProcessor();
                    jwtProcessor.setJWSKeySelector(this.getJWSKeySelector(alg));
                    jwtProcessor.setJWTClaimsSetVerifier(jwtVerifier);
                    claimsSet = jwtProcessor.process(signedToken, null);
                    break block6;
                }
                if (token instanceof EncryptedJWT) {
                    EncryptedJWT encryptedToken = (EncryptedJWT)token;
                    JWEHeader header = encryptedToken.getHeader();
                    String alg = header.getAlgorithm().getName();
                    DefaultJWTProcessor jwtProcessor = new DefaultJWTProcessor();
                    jwtProcessor.setJWSKeySelector(this.getJWSKeySelector(alg));
                    jwtProcessor.setJWTClaimsSetVerifier(jwtVerifier);
                    claimsSet = jwtProcessor.process(encryptedToken, null);
                    break block6;
                }
                throw new IllegalStateException("Unexpected JWT type : " + token.getClass());
            }
            catch (JOSEException | BadJOSEException | ParseException ex) {
                throw new IllegalStateException(ex);
            }
        }
        return claimsSet;
    }

    private JWSKeySelector<?> getJWSKeySelector(String alg) {
        return this.jwsCache.computeIfAbsent(this.createCacheKey(alg), k -> this.createJWSKeySelector(alg));
    }

    private CacheKey createCacheKey(String alg) {
        return new CacheKey(alg, this.configuration.getJwksConnectTimeout(), this.configuration.getJwksReadTimeout(), this.configuration.getProviderMetadata().getJwksURL(), this.configuration.getClientSecret());
    }

    private JWSKeySelector<?> createJWSKeySelector(String alg) {
        JWKSource jwkSource;
        JWSAlgorithm jWSAlgorithm = new JWSAlgorithm(alg);
        if (Algorithm.NONE.equals(jWSAlgorithm)) {
            throw new IllegalStateException("Unsupported JWS algorithm : " + jWSAlgorithm);
        }
        if (JWSAlgorithm.Family.RSA.contains(jWSAlgorithm) || JWSAlgorithm.Family.EC.contains(jWSAlgorithm)) {
            DefaultResourceRetriever jwkSetRetriever = new DefaultResourceRetriever(this.configuration.getJwksConnectTimeout(), this.configuration.getJwksReadTimeout(), 51200);
            jwkSource = JWKSourceBuilder.create(this.configuration.getProviderMetadata().getJwksURL(), jwkSetRetriever).build();
        } else if (JWSAlgorithm.Family.HMAC_SHA.contains(jWSAlgorithm)) {
            byte[] clientSecret = new String(this.configuration.getClientSecret()).getBytes(StandardCharsets.UTF_8);
            if (Objects.isNull(clientSecret)) {
                throw new IllegalStateException("Missing client secret");
            }
            jwkSource = new ImmutableSecret(clientSecret);
        } else {
            throw new IllegalStateException("Unsupported JWS algorithm : " + jWSAlgorithm);
        }
        return new JWSVerificationKeySelector(jWSAlgorithm, jwkSource);
    }
}

