/*
 * Licensed to the University Corporation for Advanced Internet Development,
 * Inc. (UCAID) under one or more contributor license agreements.  See the
 * NOTICE file distributed with this work for additional information regarding
 * copyright ownership. The UCAID licenses this file to You under the Apache
 * License, Version 2.0 (the "License"); you may not use this file except in
 * compliance with the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.shibboleth.idp.plugin.authn.duo.impl;

import java.text.ParseException;
import java.util.List;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.crypto.spec.SecretKeySpec;

import org.opensaml.profile.action.ActionSupport;
import org.opensaml.profile.context.ProfileRequestContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.nimbusds.jose.Algorithm;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jwt.JWT;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.PlainJWT;
import com.nimbusds.jwt.SignedJWT;


import net.shibboleth.idp.authn.AuthnEventIds;
import net.shibboleth.idp.authn.context.AuthenticationContext;
import net.shibboleth.idp.plugin.authn.duo.AbstractDuoAuthenticationAction;
import net.shibboleth.idp.plugin.authn.duo.DuoOIDCIntegration;
import net.shibboleth.idp.plugin.authn.duo.context.DuoOIDCAuthenticationContext;
import net.shibboleth.oidc.security.credential.BasicJWKCredential;
import net.shibboleth.oidc.security.impl.JWSAssemblyUtils;
import net.shibboleth.oidc.security.impl.JWTSignatureValidationUtil;
import net.shibboleth.utilities.java.support.component.ComponentSupport;
import net.shibboleth.utilities.java.support.logic.Constraint;
import net.shibboleth.utilities.java.support.logic.ConstraintViolationException;


/**
 * Action to validate the JWT signature. The JWT **must** be signed using the HMAC_SHA family,
 * any other type, including 'none', emits an error back to the flow.
 * 
 * 
 * @event {@link net.shibboleth.idp.authn.AuthnEventIds#NO_CREDENTIALS}
 * @event {@link net.shibboleth.idp.authn.AuthnEventIds#INVALID_AUTHN_CTX}
 * @event {@link net.shibboleth.idp.authn.AuthnEventIds#AUTHN_EXCEPTION}
 * @pre <pre>
 *      ProfileRequestContext.getSubcontext(AuthenticationContext.class, false) != null
 *      </pre>
 * 
 * @pre <pre>
 *      AuthenticationContext.getSubcontext(DuoOIDCAuthenticationContext.class, false) != null
 *      </pre>
 */
public class ValidateTokenSignature extends AbstractDuoAuthenticationAction {
    
    /** 
     * The HMAC 'family' of signature algorithms is the only supported, based on the
     * shared secret in the client integration.
     */
    @Nonnull private static final JWSAlgorithm.Family SUPPORTED_SIGNATURE_FAMILY = 
            JWSAlgorithm.Family.HMAC_SHA;

    /** Class logger. */
    @Nonnull private final Logger log = LoggerFactory.getLogger(ValidateTokenSignature.class);
    
    /** 
     * The signature algorithm used. This is fixed and not taken from the JWS.
     * There is no reason, in the Duo case, to determine the algorithm from the JWS
     * as HS512 is the only required algorithm.
     */
    @Nonnull private Algorithm signatureAlgorithm;

    /** The Duo authentication token. */
    @Nullable private JWT token;
    
    /** The parsed claimset. */
    @Nullable private JWTClaimsSet claimSet;
    
    /** The Duo integration appropriate for this request.*/
    @Nullable private DuoOIDCIntegration integration;
    
    /** Constructor.*/
    public ValidateTokenSignature() {
        //this is the default HMAC algorithm Duo support.
        signatureAlgorithm = JWSAlgorithm.HS512;
    }
    
    /**
     * Set the signature algorithm to use. Only supports one of the HMAC_SHA family.
     * 
     * @param algo the JWS signature algorithm.
     */
    public void setSignatureAlgorithm(@Nonnull final JWSAlgorithm algo) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
        ComponentSupport.ifDestroyedThrowDestroyedComponentException(this);
        
        Constraint.isNotNull(algo, "Signature algorithm can not be null");
        
        if (!SUPPORTED_SIGNATURE_FAMILY.contains(algo)) {
            throw new ConstraintViolationException("Signature algorithm must be one of "+SUPPORTED_SIGNATURE_FAMILY);
        }
        signatureAlgorithm = algo;
    }

    @Override
    protected boolean doPreExecute(@Nonnull final ProfileRequestContext profileRequestContext,
            @Nonnull final AuthenticationContext authenticationContext,
            @Nonnull final DuoOIDCAuthenticationContext duoContext) {

        token = duoContext.getAuthToken();
        if (token == null) {
            log.error("{} Duo 2FA token is not available", getLogPrefix());
            ActionSupport.buildEvent(profileRequestContext, AuthnEventIds.INVALID_AUTHN_CTX);
            return false;
        }
        try {
            //parse the claimset here, so parsing only has to happen once, and we fail fast on error (e.g. bad JSON)
            claimSet = token.getJWTClaimsSet();
        } catch (final ParseException e) {
            log.error("{} Claimset of Duo 2FA token is not available", getLogPrefix());
            ActionSupport.buildEvent(profileRequestContext, AuthnEventIds.INVALID_AUTHN_CTX);
            return false;
        }
        integration = duoContext.getIntegration();
        if (integration == null) {
            log.error("{} Duo integration is not available", getLogPrefix());
            ActionSupport.buildEvent(profileRequestContext, AuthnEventIds.INVALID_AUTHN_CTX);
            return false;
        }
        
        return true;
    }
    
    /** {@inheritDoc} */
    @Override
    protected void doExecute(@Nonnull final ProfileRequestContext profileRequestContext,
            @Nonnull final AuthenticationContext authenticationContext,
            @Nonnull final DuoOIDCAuthenticationContext duoContext) {

        log.debug("{} Validating token signature for subject '{}'",getLogPrefix(),claimSet.getSubject());
        
        //only supports HMAC signatures. Plain JWT's or those with a 'none' algorithm are not allowed.
        //So fail-fast here, before we attempt validation.       
        if (token instanceof PlainJWT || JWSAlgorithm.NONE == token.getHeader().getAlgorithm()) {
            
            log.error("{} Invalid token signature for subject '{}'. Token must be signed using one of the supported "
                    + "algorithms '{}'",getLogPrefix(),claimSet.getSubject(),SUPPORTED_SIGNATURE_FAMILY); 
            ActionSupport.buildEvent(profileRequestContext, AuthnEventIds.NO_CREDENTIALS);
            return;
            
        } else if (token instanceof SignedJWT) {
            
            final BasicJWKCredential jwkCredential = new BasicJWKCredential();
            jwkCredential.setSecretKey(new SecretKeySpec(
                    JWSAssemblyUtils.getSecretBytes(integration.getSecretKey()), "NONE"));
            jwkCredential.setAlgorithm(signatureAlgorithm);
            final String errorEventId = 
                    JWTSignatureValidationUtil.validateSignature(List.of(jwkCredential), 
                            (SignedJWT)token, AuthnEventIds.NO_CREDENTIALS);
            
            if (errorEventId != null) {
                log.error("{} Token signature is invalid for subject '{}' and client '{}'",getLogPrefix(),
                        claimSet.getSubject(),integration.getClientId());
                ActionSupport.buildEvent(profileRequestContext, errorEventId);
                return;
            } else {
                log.debug("{} Token signature is valid for subject '{}'; using algorithm "
                        + "'{}' for client '{}'", getLogPrefix(),claimSet.getSubject(),
                        signatureAlgorithm, integration.getClientId());
                //Valid token
                return;
            }
            
        }
        log.error("{} Unable to validate token signature for subject '{}' and client '{}', "
                + "unkown token type",getLogPrefix(),claimSet.getSubject(),integration.getClientId());
        ActionSupport.buildEvent(profileRequestContext, AuthnEventIds.AUTHN_EXCEPTION);
        return;
           
    }

}
