/*
 * Licensed to the University Corporation for Advanced Internet Development,
 * Inc. (UCAID) under one or more contributor license agreements.  See the
 * NOTICE file distributed with this work for additional information regarding
 * copyright ownership. The UCAID licenses this file to You under the Apache
 * License, Version 2.0 (the "License"); you may not use this file except in
 * compliance with the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.shibboleth.idp.plugin.oidc.op.security.impl;

import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.interfaces.ECPrivateKey;
import java.security.interfaces.RSAPrivateKey;
import java.time.Duration;
import java.time.Instant;
import java.util.List;
import java.util.function.Predicate;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;

import net.shibboleth.idp.plugin.oidc.op.criterion.ClientInformationCriterion;
import net.shibboleth.oidc.jwk.RemoteJwkSetCache;
import net.shibboleth.oidc.security.credential.BasicJWKCredential;
import net.shibboleth.oidc.security.impl.OIDCDecryptionParameters;
import net.shibboleth.utilities.java.support.annotation.constraint.Positive;
import net.shibboleth.utilities.java.support.logic.Constraint;
import net.shibboleth.utilities.java.support.resolver.CriteriaSet;
import net.shibboleth.utilities.java.support.resolver.ResolverException;

import org.opensaml.security.credential.Credential;
import org.opensaml.xmlsec.EncryptionConfiguration;
import org.opensaml.xmlsec.EncryptionParameters;
import org.opensaml.xmlsec.criterion.EncryptionConfigurationCriterion;
import org.opensaml.xmlsec.criterion.EncryptionOptionalCriterion;
import org.opensaml.xmlsec.impl.BasicEncryptionParametersResolver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.nimbusds.jose.EncryptionMethod;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWEAlgorithm;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.KeyType;
import com.nimbusds.jose.jwk.KeyUse;
import com.nimbusds.jose.jwk.RSAKey;
import com.nimbusds.oauth2.sdk.auth.Secret;
import com.nimbusds.jose.jwk.ECKey;
import com.nimbusds.openid.connect.sdk.rp.OIDCClientInformation;

/**
 * A specialization of {@link BasicEncryptionParametersResolver} which resolves both encryption and decryption
 * credentials and algorithm preferences using client registration data of OIDC client. The credentials and algorithm
 * preferences are resolved for request object decryption, id token encryption and userinfo response encryption.
 * 
 * <p>
 * In addition to the {@link net.shibboleth.utilities.java.support.resolver.Criterion} inputs documented in
 * {@link BasicEncryptionParametersResolver}, the inputs and associated modes of operation documented for
 * {@link ClientInformationCriterion} are also supported.
 * </p>
 * 
 */
public class OIDCClientInformationEncryptionParametersResolver extends BasicEncryptionParametersResolver {

    /** Logger. */
    private Logger log = LoggerFactory.getLogger(OIDCClientInformationEncryptionParametersResolver.class);

    /**
     * Whether to create parameters for request object decryption, id token encryption or userinfo response encryption.
     */
    public enum ParameterType {
        
        /** Type for request object decryption. */
        REQUEST_OBJECT_DECRYPTION,
        
        /** Type for id_token encryption. */
        IDTOKEN_ENCRYPTION,
        
        /** Type for user info encryption. */
        USERINFO_ENCRYPTION
    }

    /**
     * Whether to create parameters for request object decryption, id token encryption or userinfo response encryption.
     */
    private ParameterType target = ParameterType.IDTOKEN_ENCRYPTION;
    
    /** The cache for remote JWK key sets. */
    private RemoteJwkSetCache remoteJwkSetCache;

    /** The remote key refresh interval. Default value: 30 minutes. */
    @Positive
    private Duration keyFetchInterval = Duration.ofMinutes(30);
    
    /**
     * Constructor.
     */
    public OIDCClientInformationEncryptionParametersResolver() {
        super();
    }

    /**
     * Whether to create parameters for request object decryption, id token encryption or userinfo response encryption.
     * 
     * @param value Whether to create parameters for request object decryption, id token encryption or userinfo
     *            response encryption.
     */
    public void setParameterType(final ParameterType value) {
        target = value;
    }

    /**
     * Set the cache for remote JWK key sets.
     * 
     * @param jwkSetCache What to set.
     */
    public void setRemoteJwkSetCache(final RemoteJwkSetCache jwkSetCache) {
        remoteJwkSetCache = Constraint.isNotNull(jwkSetCache, "The remote JWK set cache cannot be null");
    }

    /**
     * Set the remote key refresh interval.
     * 
     * @param interval What to set.
     */
    public void setKeyFetchInterval(@Positive final Duration interval) {
        Constraint.isFalse(interval == null || interval.isNegative(), "Remote key refresh must be greater than 0");
        keyFetchInterval = interval;
    }

    /** {@inheritDoc} */
    @Override
    @Nullable
    public EncryptionParameters resolveSingle(@Nonnull final CriteriaSet criteria) throws ResolverException {
        Constraint.isNotNull(criteria, "CriteriaSet was null");
        Constraint.isNotNull(criteria.get(EncryptionConfigurationCriterion.class),
                "Resolver requires an instance of EncryptionConfigurationCriterion");

        final Predicate<String> includeExcludePredicate = getIncludeExcludePredicate(criteria);

        // For decryption we need to list all the located keys and need the extended EncryptionParameters
        final EncryptionParameters params = (target == ParameterType.REQUEST_OBJECT_DECRYPTION)
                ? new OIDCDecryptionParameters() : new EncryptionParameters();

        resolveAndPopulateCredentialsAndAlgorithms(params, criteria, includeExcludePredicate);
        
        boolean encryptionOptional = false;
        final EncryptionOptionalCriterion encryptionOptionalCrit = criteria.get(EncryptionOptionalCriterion.class);
        if (encryptionOptionalCrit != null) {
            encryptionOptional = encryptionOptionalCrit.isEncryptionOptional();
        }

        if (validate(params, encryptionOptional)) {
            logResult(params);
            return params;
        } else {
            return null;
        }

    }
    
    // Checkstyle: CyclomaticComplexity OFF
    // Checkstyle: MethodLength OFF
    // Checkstyle: ReturnCount OFF

    /** {@inheritDoc} */
    @Override
    protected void resolveAndPopulateCredentialsAndAlgorithms(@Nonnull final EncryptionParameters params,
            @Nonnull final CriteriaSet criteria, @Nonnull final Predicate<String> whitelistBlacklistPredicate) {

        if (!criteria.contains(ClientInformationCriterion.class)) {
            log.debug("No client criterion, falling back to local configuration");
            super.resolveAndPopulateCredentialsAndAlgorithms(params, criteria, whitelistBlacklistPredicate);
            return;
        }
        
        if (!criteria.contains(EncryptionConfigurationCriterion.class)) {
            log.debug("No encryption configuration criterion, falling back to default configuration");
            super.resolveAndPopulateCredentialsAndAlgorithms(params, criteria, whitelistBlacklistPredicate);
            return;
        }
        
        final List<EncryptionConfiguration> encryptionConfigurations =
                criteria.get(EncryptionConfigurationCriterion.class).getConfigurations();
        if (encryptionConfigurations == null || encryptionConfigurations.isEmpty()) {
            log.debug("No encryption configuration, falling back to default configuration");
            super.resolveAndPopulateCredentialsAndAlgorithms(params, criteria, whitelistBlacklistPredicate);
            return;
        }
        
        final OIDCClientInformation clientInformation =
                criteria.get(ClientInformationCriterion.class).getOidcClientInformation();
        
        // We populate the parameters only for the algorithm the client has registered
        JWEAlgorithm keyTransportAlgorithm = null;
        EncryptionMethod encryptionMethod = null;
        switch (target) {
            case REQUEST_OBJECT_DECRYPTION:
                keyTransportAlgorithm = clientInformation.getOIDCMetadata().getRequestObjectJWEAlg();
                encryptionMethod = clientInformation.getOIDCMetadata().getRequestObjectJWEEnc();
                break;

            case USERINFO_ENCRYPTION:
                keyTransportAlgorithm = clientInformation.getOIDCMetadata().getUserInfoJWEAlg();
                encryptionMethod = clientInformation.getOIDCMetadata().getUserInfoJWEEnc();
                break;

            default:
                keyTransportAlgorithm = clientInformation.getOIDCMetadata().getIDTokenJWEAlg();
                encryptionMethod = clientInformation.getOIDCMetadata().getIDTokenJWEEnc();
        }
        if (keyTransportAlgorithm == null) {
            log.debug("No algorithm information in client information, falling back to default configuration");
            super.resolveAndPopulateCredentialsAndAlgorithms(params, criteria, whitelistBlacklistPredicate);
            return;
        }
        // Default encEnc value
        if (encryptionMethod == null) {
            encryptionMethod = EncryptionMethod.A128CBC_HS256;
        }
        final List<String> keyTransportAlgorithms =
                getEffectiveKeyTransportAlgorithms(criteria, whitelistBlacklistPredicate);
        log.trace("Resolved effective key transport algorithms: {}", keyTransportAlgorithms);
        if (!keyTransportAlgorithms.contains(keyTransportAlgorithm.getName())) {
            log.warn("Client requests key transport algorithm {} that is not available",
                    keyTransportAlgorithm.getName());
            super.resolveAndPopulateCredentialsAndAlgorithms(params, criteria, whitelistBlacklistPredicate);
            return;
        }
        final List<String> dataEncryptionAlgorithms =
                getEffectiveDataEncryptionAlgorithms(criteria, whitelistBlacklistPredicate);
        log.trace("Resolved effective data encryption algorithms: {}", dataEncryptionAlgorithms);
        if (!dataEncryptionAlgorithms.contains(encryptionMethod.getName())) {
            log.warn("Client requests encryption algorithm {} that is not available", encryptionMethod.getName());
            super.resolveAndPopulateCredentialsAndAlgorithms(params, criteria, whitelistBlacklistPredicate);
            return;
        }
        // for AES + client secret based key transports we generate secret key from client_secret
        if (JWEAlgorithm.Family.SYMMETRIC.contains(keyTransportAlgorithm)) {
            final Secret secret = clientInformation.getSecret();
            if (secret == null) {
                log.warn("No client secret available");
                super.resolveAndPopulateCredentialsAndAlgorithms(params, criteria, whitelistBlacklistPredicate);
                return;
            }
            final BasicJWKCredential jwkCredential = new BasicJWKCredential();
            jwkCredential.setAlgorithm(keyTransportAlgorithm);
            try {
                jwkCredential.setSecretKey(generateSymmetricKey(secret.getValueBytes(), keyTransportAlgorithm));
            } catch (final NoSuchAlgorithmException e) {
                log.warn("Unable to generate secret key: " + e.getMessage());
                super.resolveAndPopulateCredentialsAndAlgorithms(params, criteria, whitelistBlacklistPredicate);
                return;
            }
            if (params instanceof OIDCDecryptionParameters) {
                ((OIDCDecryptionParameters) params).getKeyTransportDecryptionCredentials().add(jwkCredential);
            }
            params.setKeyTransportEncryptionCredential(jwkCredential);
            params.setKeyTransportEncryptionAlgorithm(keyTransportAlgorithm.getName());
            params.setDataEncryptionAlgorithm(encryptionMethod.getName());
            return;
        }
        // For RSA & EC based encryption we pick one encryption key from client's registration data
        if (target != ParameterType.REQUEST_OBJECT_DECRYPTION) {
            final JWKSet keySet;
            if (clientInformation.getOIDCMetadata().getJWKSetURI() != null) {
                keySet = remoteJwkSetCache.fetch(clientInformation.getOIDCMetadata().getJWKSetURI(),
                        Instant.now().plus(keyFetchInterval));
            } else {
                keySet = clientInformation.getOIDCMetadata().getJWKSet();
            }
            if (keySet == null) {
                log.warn("No keyset available");
                super.resolveAndPopulateCredentialsAndAlgorithms(params, criteria, whitelistBlacklistPredicate);
                return;
            }
            for (final JWK key : keySet.getKeys()) {
                if (KeyUse.SIGNATURE.equals(key.getKeyUse())) {
                    continue;
                }
                if ((JWEAlgorithm.Family.RSA.contains(keyTransportAlgorithm) && key.getKeyType().equals(KeyType.RSA))
                        || (JWEAlgorithm.Family.ECDH_ES.contains(keyTransportAlgorithm)
                                && key.getKeyType().equals(KeyType.EC))) {
                    final BasicJWKCredential jwkCredential = new BasicJWKCredential();
                    jwkCredential.setAlgorithm(keyTransportAlgorithm);
                    jwkCredential.setKid(key.getKeyID());
                    try {
                        if (key.getKeyType().equals(KeyType.RSA)) {
                            jwkCredential.setPublicKey(((RSAKey) key).toPublicKey());
                        } else {
                            jwkCredential.setPublicKey(((ECKey) key).toPublicKey());
                        }
                    } catch (final JOSEException e) {
                        log.warn("Unable to parse keyset");
                        super.resolveAndPopulateCredentialsAndAlgorithms(params, criteria, whitelistBlacklistPredicate);
                        return;
                    }
                    log.debug("Selected key {} for alg {} and enc {}", key.getKeyID(), keyTransportAlgorithm.getName(),
                            encryptionMethod.getName());
                    params.setKeyTransportEncryptionCredential(jwkCredential);
                    params.setKeyTransportEncryptionAlgorithm(keyTransportAlgorithm.getName());
                    params.setDataEncryptionAlgorithm(encryptionMethod.getName());
                    return;
                }
            }
        } else {
            // For RSA & EC based decryption we pick all the possible decryption keys from security configuration
            for (final EncryptionConfiguration encryptionConfiguration : encryptionConfigurations) {
                for (final Credential credential : encryptionConfiguration.getKeyTransportEncryptionCredentials()) {
                    if ((JWEAlgorithm.Family.RSA.contains(keyTransportAlgorithm)
                            && credential.getPrivateKey() instanceof RSAPrivateKey)
                            || (JWEAlgorithm.Family.ECDH_ES.contains(keyTransportAlgorithm)
                                    && credential.getPrivateKey() instanceof ECPrivateKey)) {
                        log.debug("Picked key for alg {} and enc {}", keyTransportAlgorithm.getName(),
                                encryptionMethod.getName());
                        params.setKeyTransportEncryptionCredential(credential);
                        params.setKeyTransportEncryptionAlgorithm(keyTransportAlgorithm.getName());
                        params.setDataEncryptionAlgorithm(encryptionMethod.getName());
                        if (params instanceof OIDCDecryptionParameters) {
                            ((OIDCDecryptionParameters) params).getKeyTransportDecryptionCredentials().add(credential);
                            continue;
                        }
                        return;
                    }

                }
            }
        }
        if (params.getKeyTransportEncryptionCredential() == null) {
            log.debug("Not able to credentials based on provided client information");
            super.resolveAndPopulateCredentialsAndAlgorithms(params, criteria, whitelistBlacklistPredicate);
        }
    }
    
    // Checkstyle: CyclomaticComplexity ON
    // Checkstyle: MethodLength ON
    // Checkstyle: ReturnCount ON

    /**
     * Generate symmetric key from client secret.
     * 
     * @param clientSecret client secret that is the basis of key
     * @param keyTransportAlgorithm algorithm the key is generated for
     * @return key derived from client secret.
     * @throws NoSuchAlgorithmException if algorithm or digest method is unsupported
     */
    private SecretKey generateSymmetricKey(final byte[] clientSecret, final JWEAlgorithm keyTransportAlgorithm)
            throws NoSuchAlgorithmException {
        final MessageDigest md = MessageDigest.getInstance("SHA-256");
        switch (keyTransportAlgorithm.getName()) {
            case "A128KW":
            case "A128GCMKW":
                return new SecretKeySpec(md.digest(clientSecret), 0, 16, "AES");
            case "A192KW":
            case "A192GCMKW":
                return new SecretKeySpec(md.digest(clientSecret), 0, 24, "AES");
            case "A256KW":
            case "A256GCMKW":
                return new SecretKeySpec(md.digest(clientSecret), 0, 32, "AES");
                
            default:
        }
        throw new NoSuchAlgorithmException(
                "Implementation does not support generating key for " + keyTransportAlgorithm.getName());
    }

}
