/*
 * Licensed to the University Corporation for Advanced Internet Development,
 * Inc. (UCAID) under one or more contributor license agreements.  See the
 * NOTICE file distributed with this work for additional information regarding
 * copyright ownership. The UCAID licenses this file to You under the Apache
 * License, Version 2.0 (the "License"); you may not use this file except in
 * compliance with the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.shibboleth.idp.plugin.oidc.op.oauth2.profile.impl;

import java.text.ParseException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.function.Function;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.opensaml.messaging.context.MessageContext;
import org.opensaml.profile.action.EventIds;
import org.opensaml.profile.context.ProfileRequestContext;
import org.opensaml.security.credential.Credential;
import org.opensaml.security.credential.CredentialResolver;
import org.opensaml.security.credential.UsageType;
import org.opensaml.security.criteria.UsageCriterion;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.nimbusds.jose.JOSEObjectType;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import com.nimbusds.oauth2.sdk.token.AccessToken;
import com.nimbusds.oauth2.sdk.token.RefreshToken;
import com.nimbusds.oauth2.sdk.token.Token;

import net.shibboleth.idp.plugin.oidc.op.oauth2.messaging.context.OAuth2TokenMgmtResponseContext;
import net.shibboleth.idp.plugin.oidc.op.profile.impl.AbstractOIDCRequestAction;
import net.shibboleth.idp.plugin.oidc.op.token.support.AccessTokenClaimsSet;
import net.shibboleth.idp.plugin.oidc.op.token.support.RefreshTokenClaimsSet;
import net.shibboleth.idp.profile.ActionSupport;
import net.shibboleth.idp.profile.IdPEventIds;
import net.shibboleth.oidc.jwt.claims.ClaimsValidator;
import net.shibboleth.oidc.jwt.claims.JWTValidationException;
import net.shibboleth.oidc.profile.config.navigate.IssuedClaimsValidatorLookupFunction;
import net.shibboleth.oidc.profile.core.OidcEventIds;
import net.shibboleth.oidc.security.impl.JWTSignatureValidationUtil;
import net.shibboleth.utilities.java.support.annotation.constraint.NotEmpty;
import net.shibboleth.utilities.java.support.component.ComponentSupport;
import net.shibboleth.utilities.java.support.logic.Constraint;
import net.shibboleth.utilities.java.support.resolver.CriteriaSet;
import net.shibboleth.utilities.java.support.resolver.ResolverException;
import net.shibboleth.utilities.java.support.security.DataSealer;
import net.shibboleth.utilities.java.support.security.DataSealerException;

/**
 * Action that processes a token by validating it and populating the resulting {@link JWTClaimsSet}
 * into an {@link OAuth2TokenMgmtResponseContext} placed beneath the outbound {@link MessageContext}.
 * 
 * <p>If the token can't be validated, the context is not populated.</p>
 * 
 * @param <T> request message type
 * 
 * @since 3.1.0
 *      
 * @post If the token is valid for use, ProfileRequestContext.getOutboundMessageContext().getSubcontext(
 *      OAuth2TokenMgmtResponseContext.class) != null and the context contains the token's {@link JWTClaimsSet}.
 * @event {@link EventIds#PROCEED_EVENT_ID}
 * @event {@link IdPEventIds#INVALID_PROFILE_CONFIG}
 */
public abstract class AbstractProcessTokenAction<T> extends AbstractOIDCRequestAction<T> {

    /** Class logger. */
    @Nonnull private Logger log = LoggerFactory.getLogger(AbstractProcessTokenAction.class);

    /** Data sealer for unwrapping token. */
    @Nullable private DataSealer dataSealer;
    
    /** Lookup strategy for claims validator. */
    @Nonnull private Function<ProfileRequestContext,ClaimsValidator> claimsValidatorLookupStrategy;
    
    /** The claims validator to use. */
    @Nullable private ClaimsValidator claimsValidator;
    
    /** Source of signing keys. */
    @Nullable private CredentialResolver credentialResolver;
    
    /** Copy of signed JWT for non-opaque access tokens. */
    @Nullable private SignedJWT signedJWT;

    /** Constructor. */
    public AbstractProcessTokenAction() {
        claimsValidatorLookupStrategy = new IssuedClaimsValidatorLookupFunction();
    }
    
    /**
     * Set the data sealer instance to use.
     * 
     * @param sealer data sealer to use
     */
    public void setDataSealer(@Nullable final DataSealer sealer) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
        dataSealer = sealer;
    }
    
    /**
     * Set the claims validator lookup strategy.
     * 
     * @param strategy lookup strategy
     */
    public void setClaimsValidatorLookupStrategy(
            @Nonnull final Function<ProfileRequestContext,ClaimsValidator> strategy) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
        claimsValidatorLookupStrategy = Constraint.isNotNull(strategy, "Lookup strategy cannot be null");
    }
    
    /**
     * Set the source of signing keys to use for JWT signature verification.
     * 
     * @param resolver signing key resolver
     */
    public void setCredentialResolver(@Nullable final CredentialResolver resolver) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
        credentialResolver = resolver;
    }
        
    /** {@inheritDoc} */
    @Override
    protected boolean doPreExecute(@Nonnull final ProfileRequestContext profileRequestContext) {
        if (!super.doPreExecute(profileRequestContext)) {
            return false;
        }
        
        claimsValidator = claimsValidatorLookupStrategy.apply(profileRequestContext);
        if (claimsValidator == null) {
            log.error("{} Unable to obtain ClaimsValidator to apply", getLogPrefix());
            ActionSupport.buildEvent(profileRequestContext, IdPEventIds.INVALID_PROFILE_CONFIG);
            return false;
        }
        
        return true;
    }

// Checkstyle: CyclomaticComplexity|MethodLength OFF
    /** {@inheritDoc} */
    @Override
    protected void doExecute(@Nonnull final ProfileRequestContext profileRequestContext) {
        
        final Token token = getToken(profileRequestContext);
        if (token == null) {
            log.error("{} Token missing from request", getLogPrefix());
            return;
        }
        
        log.debug("{} Token to introspect: {}", getLogPrefix(), token.getValue());

        JWTClaimsSet tokenClaimsSet;
        if (token instanceof AccessToken) {
            tokenClaimsSet = parseAccessToken(token);
        } else if (token instanceof RefreshToken) {
            tokenClaimsSet = parseRefreshToken(token);
        } else {
            // No token hint, have to try both.
            tokenClaimsSet = parseAccessToken(token);
            if (tokenClaimsSet == null) {
                tokenClaimsSet = parseRefreshToken(token);
            }
        }
        
        if (tokenClaimsSet == null) {
            log.warn("{} Unable to parse/decode token for introspection", getLogPrefix());
            return;
        }

        if (signedJWT != null) {
            // Check typ header.
            final JOSEObjectType typ = signedJWT.getHeader().getType();
            if (typ == null || !"at+jwt".equals(typ.getType())) {
                log.warn("{} Missing or invalid token type: {}", getLogPrefix(), typ != null ? typ.getType() : "null");
                return;
            }
            
            if (credentialResolver == null) {
                log.error("{} No CredentialResolver available, can't verify JWT signature", getLogPrefix());
                return;
            }
            
            log.debug("{} Checking JWT signature", getLogPrefix());
            final Collection<Credential> credList = new ArrayList<>();
            final CriteriaSet criteriaSet = new CriteriaSet(new UsageCriterion(UsageType.SIGNING));
            try {
                final Iterable<Credential> creds = credentialResolver.resolve(criteriaSet);
                if (creds != null) {
                    creds.forEach(credList::add);
                }
            } catch (final ResolverException e) {
                log.error("{} Failure resolving signing credentials, can't verify JWT signature", getLogPrefix(), e);
                return;
            }
            final String errorEventId = JWTSignatureValidationUtil.validateSignatureEx(credList, signedJWT,
                    OidcEventIds.INVALID_GRANT);
            if (errorEventId != null) {
                log.warn("{} Signature on token ID '{}' invalid", getLogPrefix(), tokenClaimsSet.getJWTID());
                return;
            }
        }

        log.debug("{} Validating parsed/decoded claims set: {}", getLogPrefix(), tokenClaimsSet.toString());
        try {
            claimsValidator.validate(tokenClaimsSet, profileRequestContext);
        } catch (final JWTValidationException e) {
            log.warn("{} Claims validation failed, token is invalid: {}", getLogPrefix(), e.getMessage());
            return;
        }
        
        // Populate outbound tree.
        profileRequestContext.getOutboundMessageContext().getSubcontext(
                OAuth2TokenMgmtResponseContext.class).setTokenClaimsSet(tokenClaimsSet);
    }
// Checkstyle: CyclomaticComplexity ON

    /**
     * Attempt to parse token.
     * 
     * @param token the token
     * 
     * @return parsed claim set or null
     */
    @Nullable protected JWTClaimsSet parseAccessToken(@Nonnull @NotEmpty final Token token) {
        
        // Try parsing as a JWT.
        try {
            signedJWT = SignedJWT.parse(token.getValue());
            return signedJWT.getJWTClaimsSet();
        } catch (final ParseException e1) {
            
        }

        // Fall back to opaque.
        try {
            return AccessTokenClaimsSet.parse(token.getValue(), dataSealer).getClaimsSet();
        } catch (final DataSealerException | ParseException e) {
            
        }
        
        return null;
    }

    /**
     * Attempt to parse refresh token.
     * 
     * @param token the token
     * 
     * @return parsed claim set or null
     */
    @Nullable protected JWTClaimsSet parseRefreshToken(@Nonnull @NotEmpty final Token token) {

        // All refresh tokens are opaque.
        try {
            return RefreshTokenClaimsSet.parse(token.getValue(), dataSealer).getClaimsSet();
        } catch (final DataSealerException | ParseException e) {
            
        }
        
        return null;
    }

    /**
     * Get the token to process.
     * 
     * @param profileRequestContext current profile request context
     * 
     * @return the token to process
     */
    @Nullable protected abstract Token getToken(@Nonnull final ProfileRequestContext profileRequestContext);
    
}