/*
 * Licensed to the University Corporation for Advanced Internet Development,
 * Inc. (UCAID) under one or more contributor license agreements.  See the
 * NOTICE file distributed with this work for additional information regarding
 * copyright ownership. The UCAID licenses this file to You under the Apache
 * License, Version 2.0 (the "License"); you may not use this file except in
 * compliance with the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.shibboleth.idp.plugin.oidc.op.profile.impl;

import java.text.ParseException;
import java.time.Duration;
import java.util.function.Function;
import java.util.function.Predicate;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.opensaml.messaging.context.navigate.ChildContextLookup;
import org.opensaml.profile.action.ActionSupport;
import org.opensaml.profile.context.ProfileRequestContext;
import org.opensaml.storage.ReplayCache;
import org.opensaml.storage.RevocationCache;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.nimbusds.oauth2.sdk.AuthorizationCodeGrant;
import com.nimbusds.oauth2.sdk.AuthorizationGrant;
import com.nimbusds.oauth2.sdk.GrantType;
import com.nimbusds.oauth2.sdk.RefreshTokenGrant;

import net.shibboleth.idp.plugin.oidc.op.messaging.context.OIDCAuthenticationResponseContext;
import net.shibboleth.idp.plugin.oidc.op.profile.logic.DefaultChainRevocationLifetimeLookupStrategy;
import net.shibboleth.idp.plugin.oidc.op.storage.RevocationCacheContexts;
import net.shibboleth.idp.plugin.oidc.op.token.support.AuthorizeCodeClaimsSet;
import net.shibboleth.idp.plugin.oidc.op.token.support.RefreshTokenClaimsSet;
import net.shibboleth.idp.plugin.oidc.op.token.support.TokenClaimsSet;
import net.shibboleth.idp.profile.IdPEventIds;
import net.shibboleth.idp.profile.context.RelyingPartyContext;
import net.shibboleth.oidc.profile.config.logic.RefreshTokensEnabledPredicate;
import net.shibboleth.oidc.profile.config.navigate.RevocationLifetimeLookupFunction;
import net.shibboleth.oidc.profile.core.OidcEventIds;
import net.shibboleth.utilities.java.support.annotation.ParameterName;
import net.shibboleth.utilities.java.support.annotation.constraint.NonnullAfterInit;
import net.shibboleth.utilities.java.support.component.ComponentInitializationException;
import net.shibboleth.utilities.java.support.component.ComponentSupport;
import net.shibboleth.utilities.java.support.logic.Constraint;
import net.shibboleth.utilities.java.support.primitive.StringSupport;
import net.shibboleth.utilities.java.support.security.DataSealer;
import net.shibboleth.utilities.java.support.security.DataSealerException;

/**
 * Action that validates an authorization grant.
 * 
 * <p>A grant is valid if it is successfully unwrapped, parsed as a code or refresh token, is unexpired, was issued
 * to the expected client and has not been used before (authz code) or the authz code used to produce it has not been
 * revoked (refresh token).</p>
 * 
 * <p> The validated claims from the grant are stored to response context via
 * {@link OIDCAuthenticationResponseContext#getAuthorizationGrantClaimsSet()}.</p>
 * 
 * <p>Note that the addition of support for the "client_credentials" grant type means that there may not in fact be a
 * grant, or resulting claims set.</p>
 */
public class ValidateGrant extends AbstractOIDCTokenResponseAction {

    /** Class logger. */
    @Nonnull private Logger log = LoggerFactory.getLogger(ValidateGrant.class);

    /** Data sealer for unwrapping authorization code. */
    @Nonnull private final DataSealer dataSealer;

    /** Message replay cache instance to use. */
    @NonnullAfterInit private ReplayCache replayCache;

    /** Message revocation cache instance to use. */
    @NonnullAfterInit private RevocationCache revocationCache;

    /**
     * Strategy used to locate the {@link RelyingPartyContext} associated with a given {@link ProfileRequestContext}.
     */
    @Nonnull private Function<ProfileRequestContext, RelyingPartyContext> relyingPartyContextLookupStrategy;

    /** Predicate used to indicate whether refresh tokens are enabled. */
    @Nonnull private Predicate<ProfileRequestContext> refreshTokensEnabledPredicate;

    /** Lookup function to supply chain revocation lifetime. */
    @Nonnull private Function<ProfileRequestContext,Duration> chainRevocationLifetimeLookupStrategy;

    /** The RelyingPartyContext to operate on. */
    @Nullable private RelyingPartyContext rpCtx;

    /**
     * Constructor.
     * 
     * @param sealer sealer to decrypt/hmac authorize code.
     */
    public ValidateGrant(@Nonnull @ParameterName(name = "sealer") final DataSealer sealer) {
        dataSealer = Constraint.isNotNull(sealer, "DataSealer cannot be null");
        relyingPartyContextLookupStrategy = new ChildContextLookup<>(RelyingPartyContext.class);
        refreshTokensEnabledPredicate = new RefreshTokensEnabledPredicate();
        chainRevocationLifetimeLookupStrategy = new DefaultChainRevocationLifetimeLookupStrategy();
        ((RevocationLifetimeLookupFunction) chainRevocationLifetimeLookupStrategy).setUseActiveProfileOnly(false);
    }

    /**
     * Set the strategy used to locate the {@link RelyingPartyContext} associated with a given
     * {@link ProfileRequestContext}.
     * 
     * @param strategy strategy used to locate the {@link RelyingPartyContext} associated with a given
     *            {@link ProfileRequestContext}
     */
    public void setRelyingPartyContextLookupStrategy(
            @Nonnull final Function<ProfileRequestContext, RelyingPartyContext> strategy) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);

        relyingPartyContextLookupStrategy =
                Constraint.isNotNull(strategy, "RelyingPartyContext lookup strategy cannot be null");
    }

    /**
     * Set the predicate used to indicate whether refresh tokens are enabled.
     *
     * @param predicate predicate used to indicate whether refresh tokens are enabled.
     */
    public void setRefreshTokensEnabledPredicate(@Nonnull final Predicate<ProfileRequestContext> predicate) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);

        refreshTokensEnabledPredicate =
                Constraint.isNotNull(predicate, "Refresh tokens enabled predicate cannot be null");
    }
    /**
     * Set the replay cache instance to use.
     * 
     * @param cache The replayCache to set.
     */
    public void setReplayCache(@Nonnull final ReplayCache cache) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
        replayCache = Constraint.isNotNull(cache, "ReplayCache cannot be null");
    }

    /**
     * Set the revocation cache instance to use.
     * 
     * @param cache The revocationCache to set.
     */
    public void setRevocationCache(@Nonnull final RevocationCache cache) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
        revocationCache = Constraint.isNotNull(cache, "RevocationCache cannot be null");
    }

    /**
     * Set a lookup strategy for the chain revocation lifetime.
     *
     * @param strategy What to set.
     */
    public void setChainRevocationLifetimeLookupStrategy(
            @Nullable final Function<ProfileRequestContext,Duration> strategy) {
        chainRevocationLifetimeLookupStrategy = Constraint.isNotNull(strategy, "Lookup strategy cannot be null");
    }

    /** {@inheritDoc} */
    @Override
    protected void doInitialize() throws ComponentInitializationException {
        super.doInitialize();
        
        if (replayCache == null || revocationCache == null) {
            throw new ComponentInitializationException("ReplayCache and RevocationCache cannot be null");
        }
    }

    /** {@inheritDoc} */
    @Override
    protected boolean doPreExecute(@Nonnull final ProfileRequestContext profileRequestContext) {

        if (!super.doPreExecute(profileRequestContext)) {
            return false;
        }
        
        rpCtx = relyingPartyContextLookupStrategy.apply(profileRequestContext);
        if (rpCtx == null) {
            log.error("{} No relying party context associated with this profile request", getLogPrefix());
            ActionSupport.buildEvent(profileRequestContext, IdPEventIds.INVALID_RELYING_PARTY_CTX);
            return false;
        }
        
        return true;
    }
    
// Checkstyle: CyclomaticComplexity|MethodLength|ReturnCount OFF
    /** {@inheritDoc} */
    @Override
    protected void doExecute(@Nonnull final ProfileRequestContext profileRequestContext) {
        final AuthorizationGrant grant = getTokenRequest().getAuthorizationGrant();
        
        log.debug("{} Validating grant type: {}", getLogPrefix(),grant.getType());

        TokenClaimsSet tokenClaimsSet = null;
        if (GrantType.AUTHORIZATION_CODE.equals(grant.getType())) {
            final AuthorizationCodeGrant codeGrant = (AuthorizationCodeGrant) grant;
            if (codeGrant.getAuthorizationCode() != null && codeGrant.getAuthorizationCode().getValue() != null) {
                try {
                    final AuthorizeCodeClaimsSet authzCodeClaimsSet =
                            AuthorizeCodeClaimsSet.parse(codeGrant.getAuthorizationCode().getValue(), dataSealer);
                    log.debug("{} Authz code unwrapped {}", getLogPrefix(), authzCodeClaimsSet.serialize());
                    if (!replayCache.check(getClass().getName(), authzCodeClaimsSet.getID(),
                            authzCodeClaimsSet.getExp())) {
                        log.error("{} Replay detected of authz code {}", getLogPrefix(), authzCodeClaimsSet.getID());
                        if (!revokeChain(authzCodeClaimsSet.getID(),
                                chainRevocationLifetimeLookupStrategy.apply(profileRequestContext))) {
                            log.warn("{} Fatal error, unable to save replayed code to revocation cache",
                                    getLogPrefix());
                        }
                        ActionSupport.buildEvent(profileRequestContext, OidcEventIds.INVALID_GRANT);
                        return;
                    }
                    tokenClaimsSet = authzCodeClaimsSet;
                } catch (final DataSealerException | ParseException e) {
                    log.warn("{} Unwrapping authz code failed: {}", getLogPrefix(), e.getMessage());
                    ActionSupport.buildEvent(profileRequestContext, OidcEventIds.INVALID_GRANT);
                    return;
                }
            }
        } else if (GrantType.REFRESH_TOKEN.equals(grant.getType())) {
            if (!refreshTokensEnabledPredicate.test(profileRequestContext)) {
                log.warn("{} Refresh token grant detected, but not enabled", getLogPrefix());
                ActionSupport.buildEvent(profileRequestContext, OidcEventIds.INVALID_GRANT);
                return;
            }
            final RefreshTokenGrant refreshTokentokenGrant = (RefreshTokenGrant) grant;
            if (refreshTokentokenGrant.getRefreshToken() != null
                    && refreshTokentokenGrant.getRefreshToken().getValue() != null) {
                try {
                    final RefreshTokenClaimsSet refreshTokenClaimsSet = RefreshTokenClaimsSet
                            .parse(refreshTokentokenGrant.getRefreshToken().getValue(), dataSealer);
                    final String rootJti = refreshTokenClaimsSet.getRootTokenIdentifier();
                    final String rootJtiToUse;
                    if (StringSupport.trimOrNull(rootJti) == null) {
                        log.warn("{} No root token identifier returned, using JWT id for checking revocation status",
                                getLogPrefix());
                        rootJtiToUse = refreshTokenClaimsSet.getID();
                    } else {
                        rootJtiToUse = rootJti;
                    }
                    if (revocationCache.isRevoked(RevocationCacheContexts.AUTHORIZATION_CODE, rootJtiToUse)) {
                        log.error("{} Authz code {} and all derived tokens have been revoked", getLogPrefix(),
                                rootJtiToUse);
                        ActionSupport.buildEvent(profileRequestContext, OidcEventIds.INVALID_GRANT);
                        return;
                    } else if (revocationCache.isRevoked(RevocationCacheContexts.SINGLE_ACCESS_OR_REFRESH_TOKENS,
                            refreshTokenClaimsSet.getID())) {
                        log.error("{} The refresh token {} has been revoked. Revoking the full chain now.",
                                getLogPrefix(), refreshTokenClaimsSet.getID());
                        if (!revokeChain(rootJtiToUse,
                                chainRevocationLifetimeLookupStrategy.apply(profileRequestContext))) {
                            log.error("{} Fatal error, unable to store revocation into the revocation cache",
                                    getLogPrefix());
                            ActionSupport.buildEvent(profileRequestContext, IdPEventIds.INVALID_PROFILE_CONFIG);
                            return;
                        }
                        ActionSupport.buildEvent(profileRequestContext, OidcEventIds.INVALID_GRANT);
                        return;
                    }
                    tokenClaimsSet = refreshTokenClaimsSet;
                } catch (final ParseException | DataSealerException e) {
                    log.warn("{} Unwrapping refresh token failed {}", getLogPrefix(), e.getMessage());
                    ActionSupport.buildEvent(profileRequestContext, OidcEventIds.INVALID_GRANT);
                    return;
                }
            }
        } else if (GrantType.CLIENT_CREDENTIALS.equals(grant.getType())) {
            return;
        }
        
        if (tokenClaimsSet == null) {
            log.warn("{} Grant type not supported", getLogPrefix());
            ActionSupport.buildEvent(profileRequestContext, OidcEventIds.INVALID_GRANT);
            return;
        }
        if (!tokenClaimsSet.isTimeValid()) {
            log.warn("{} Token is expired or not net valid", getLogPrefix());
            ActionSupport.buildEvent(profileRequestContext, OidcEventIds.INVALID_GRANT);
            return;
        }
        if (!tokenClaimsSet.getClientID().getValue().equals(rpCtx.getRelyingPartyId())) {
            log.warn("{} Token issued to client {}, invalid for {}", getLogPrefix(),
                    tokenClaimsSet.getClientID().getValue(), rpCtx.getRelyingPartyId());
            ActionSupport.buildEvent(profileRequestContext, OidcEventIds.INVALID_GRANT);
            return;
        }
        getOidcResponseContext().setAuthorizationGrantClaimsSet(tokenClaimsSet);
    }
// Checkstyle: CyclomaticComplexity|MethodLength|ReturnCount ON

    /**
     * Revokes the token chain with the given id, optionally with a given lifetime. If the given lifetime is null,
     * the default lifetime set to the {@link RevocationCache} is used.
     * 
     * @param id The identifier to be revoked in {@link RevocationCacheContexts#AUTHORIZATION_CODE} context.
     * @param lifetime The lifetime for the revocation
     * @return The result returned by the {@link RevocationCache}
     */
    protected boolean revokeChain(@Nonnull final String id, @Nullable final Duration lifetime) {
        if (lifetime == null) {
            log.warn("{} No profile-specific revocation lifetime could be resolved, using default value",
                    getLogPrefix());
            return revocationCache.revoke(RevocationCacheContexts.AUTHORIZATION_CODE, id);
        }
        return revocationCache.revoke(RevocationCacheContexts.AUTHORIZATION_CODE, id, lifetime);
    }
    
}