/*
 * Licensed to the University Corporation for Advanced Internet Development,
 * Inc. (UCAID) under one or more contributor license agreements.  See the
 * NOTICE file distributed with this work for additional information regarding
 * copyright ownership. The UCAID licenses this file to You under the Apache
 * License, Version 2.0 (the "License"); you may not use this file except in
 * compliance with the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.shibboleth.idp.plugin.oidc.op.profile.impl;

import java.util.Collection;
import java.util.Collections;
import java.util.Set;
import java.util.function.Function;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import net.minidev.json.JSONObject;
import net.shibboleth.idp.attribute.AttributeEncodingException;
import net.shibboleth.idp.attribute.AttributesMapContainer;
import net.shibboleth.idp.attribute.IdPAttribute;
import net.shibboleth.idp.attribute.context.AttributeContext;
import net.shibboleth.idp.attribute.transcoding.AttributeTranscoder;
import net.shibboleth.idp.attribute.transcoding.AttributeTranscoderRegistry;
import net.shibboleth.idp.attribute.transcoding.TranscoderSupport;
import net.shibboleth.idp.attribute.transcoding.TranscodingRule;
import net.shibboleth.idp.plugin.oidc.op.messaging.context.OIDCAuthenticationResponseContext;
import net.shibboleth.idp.plugin.oidc.op.messaging.context.OIDCAuthenticationResponseTokenClaimsContext;
import net.shibboleth.idp.profile.IdPEventIds;
import net.shibboleth.idp.profile.context.RelyingPartyContext;
import net.shibboleth.oidc.profile.config.navigate.AlwaysIncludedAttributesLookupFunction;
import net.shibboleth.oidc.profile.config.navigate.DeniedUserInfoAttributesLookupFunction;
import net.shibboleth.oidc.profile.config.navigate.EncodedAttributesLookupFunction;
import net.shibboleth.utilities.java.support.annotation.constraint.NonnullAfterInit;
import net.shibboleth.utilities.java.support.annotation.constraint.NonnullElements;
import net.shibboleth.utilities.java.support.component.ComponentInitializationException;
import net.shibboleth.utilities.java.support.component.ComponentSupport;
import net.shibboleth.utilities.java.support.logic.Constraint;
import net.shibboleth.utilities.java.support.service.ReloadableService;
import net.shibboleth.utilities.java.support.service.ServiceableComponent;

import org.opensaml.messaging.context.navigate.ChildContextLookup;
import org.opensaml.profile.action.ActionSupport;
import org.opensaml.profile.context.ProfileRequestContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Action that checks for any released attributes marked for token delivery. For such attributes
 * {@link OIDCAuthenticationResponseTokenClaimsContext} is created under {@link OIDCAuthenticationResponseContext} and
 * the marked attributes are placed there.
 **/
public class SetTokenDeliveryAttributesToResponseContext extends AbstractOIDCResponseAction {

    /** Class logger. */
    @Nonnull private Logger log = LoggerFactory.getLogger(SetTokenDeliveryAttributesToResponseContext.class);

    /**
     * Strategy used to locate the {@link AttributeContext} associated with a given {@link ProfileRequestContext}.
     */
    @Nonnull private Function<ProfileRequestContext,AttributeContext> attributeContextLookupStrategy;

    /** Strategy used to obtain the set of attribute IDs to encode for back-channel recovery. */
    @Nonnull private Function<ProfileRequestContext,Set<String>> encodedAttributesLookupStrategy;

    /** Strategy used to obtain the set of attribute IDs to include in the ID token in all cases. */
    @Nonnull private Function<ProfileRequestContext,Set<String>> alwaysIncludedAttributesLookupStrategy;

    /** Strategy used to obtain the set of attribute IDs to omit from the UserInfo token. */
    @Nonnull private Function<ProfileRequestContext,Set<String>> deniedUserInfoAttributesLookupStrategy;

    /** Transcoder registry service object. */
    @NonnullAfterInit private ReloadableService<AttributeTranscoderRegistry> transcoderRegistry;

    /**
     * Whether attributes that result in an {@link net.shibboleth.idp.attribute.AttributeEncodingException}
     * when being encoded should be ignored or
     * result in an {@link net.shibboleth.idp.profile.IdPEventIds#UNABLE_ENCODE_ATTRIBUTE} transition.
     */
    private boolean ignoringUnencodableAttributes;
    
    /** AttributeContext to use. */
    @Nullable private AttributeContext attributeCtx;
    
    /** Attributes to encode for recovery. */
    @Nullable @NonnullElements private Set<String> encodedAttributes;

    /** Attributes to include in ID token no matter what. */
    @Nullable @NonnullElements private Set<String> alwaysIncludedAttributes;
    
    /** Attributes to omit from UserInfo token. */
    @Nullable @NonnullElements private Set<String> deniedUserInfoAttributes;

    /** Constructor. */
    SetTokenDeliveryAttributesToResponseContext() {
        attributeContextLookupStrategy = new ChildContextLookup<>(AttributeContext.class).compose(
                new ChildContextLookup<>(RelyingPartyContext.class));
        encodedAttributesLookupStrategy = new EncodedAttributesLookupFunction();
        alwaysIncludedAttributesLookupStrategy = new AlwaysIncludedAttributesLookupFunction();
        deniedUserInfoAttributesLookupStrategy = new DeniedUserInfoAttributesLookupFunction();
        
        ignoringUnencodableAttributes = true;
    }

    /**
     * Sets the registry of transcoding rules to apply to encode attributes.
     * 
     * @param registry registry service interface
     */
    public void setTranscoderRegistry(@Nonnull final ReloadableService<AttributeTranscoderRegistry> registry) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
        
        transcoderRegistry = Constraint.isNotNull(registry, "AttributeTranscoderRegistry cannot be null");
    }
    
    /**
     * Set whether the attributes that result in an {@link AttributeEncodingException} when being encoded
     * should be ignored or result in an {@link IdPEventIds#UNABLE_ENCODE_ATTRIBUTE} transition.
     * 
     * @param flag flag to set
     */
    public void setIgnoringUnencodableAttributes(final boolean flag) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);

        ignoringUnencodableAttributes = flag;
    }
    
    /**
     * Set the strategy used to locate the {@link AttributeContext} associated with a given
     * {@link ProfileRequestContext}.
     * 
     * @param strategy strategy used to locate the {@link AttributeContext} associated with a given
     *            {@link ProfileRequestContext}
     */
    public void setAttributeContextLookupStrategy(
            @Nonnull final Function<ProfileRequestContext, AttributeContext> strategy) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);

        attributeContextLookupStrategy =
                Constraint.isNotNull(strategy, "AttributeContext lookup strategy cannot be null");
    }
    
    /**
     * Set the strategy used to obtain the set of attribute IDs to encode for back-channel recovery.
     * 
     * @param strategy lookup strategy
     */
    public void setEncodedAttributesLookupStrategy(
            @Nonnull final Function<ProfileRequestContext,Set<String>> strategy) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
        
        encodedAttributesLookupStrategy = Constraint.isNotNull(strategy,
                "Encoded attributes lookup strategy cannot be null");
    }

    /**
     * Set the strategy used to obtain the set of attribute IDs always included in ID tokens.
     * 
     * @param strategy lookup strategy
     */
    public void setAlwaysIncludedAttributesLookupStrategy(
            @Nonnull final Function<ProfileRequestContext,Set<String>> strategy) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
        
        alwaysIncludedAttributesLookupStrategy = Constraint.isNotNull(strategy,
                "Always included ID token attributes lookup strategy cannot be null");
    }

    /**
     * Set the strategy used to obtain the set of attribute IDs to omit from UserInfo tokens.
     * 
     * @param strategy lookup strategy
     */
    public void setDeniedUserInfoAttributesLookupStrategy(
            @Nonnull final Function<ProfileRequestContext,Set<String>> strategy) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
        
        deniedUserInfoAttributesLookupStrategy = Constraint.isNotNull(strategy,
                "Denied UserInfo attributes lookup strategy cannot be null");
    }
    
    /** {@inheritDoc} */
    @Override
    protected void doInitialize() throws ComponentInitializationException {
        super.doInitialize();
        
        if (transcoderRegistry == null) {
            throw new ComponentInitializationException("AttributeTranscoderRegistry cannot be null");
        }
    }
    
    /** {@inheritDoc} */
    @Override
    protected boolean doPreExecute(@Nonnull final ProfileRequestContext profileRequestContext) {
        if (!super.doPreExecute(profileRequestContext)) {
            return false;
        }
        
        attributeCtx = attributeContextLookupStrategy.apply(profileRequestContext);
        if (attributeCtx == null) {
            log.debug("{} No AttributeSubcontext available, nothing to do", getLogPrefix());
            return false;
        
        }
        
        encodedAttributes = encodedAttributesLookupStrategy.apply(profileRequestContext);
        if (encodedAttributes == null) {
            encodedAttributes = Collections.emptySet();
        }

        alwaysIncludedAttributes = alwaysIncludedAttributesLookupStrategy.apply(profileRequestContext);
        if (alwaysIncludedAttributes == null) {
            alwaysIncludedAttributes = Collections.emptySet();
        }
        
        deniedUserInfoAttributes = deniedUserInfoAttributesLookupStrategy.apply(profileRequestContext);
        if (deniedUserInfoAttributes == null) {
            deniedUserInfoAttributes = Collections.emptySet();
        }

        return true;
    }

    /** {@inheritDoc} */
    @Override
    protected void doExecute(@Nonnull final ProfileRequestContext profileRequestContext) {
        
        ServiceableComponent<AttributeTranscoderRegistry> component = null;
        try {
            component = transcoderRegistry.getServiceableComponent();
            if (component == null) {
                log.error("Attribute transoding service unavailable");
                throw new AttributeEncodingException("Attribute transoding service unavailable");
            }
            
            for (final IdPAttribute attribute : attributeCtx.getIdPAttributes().values()) {
                if (attribute != null && !attribute.getValues().isEmpty() &&
                        encodedAttributes.contains(attribute.getId())) {
                    // This will generate the claims and add them to the appropriate claims sets.
                    encodeAttribute(component.getComponent(), profileRequestContext, attribute);
                }
            }
        } catch (final AttributeEncodingException e) {
            ActionSupport.buildEvent(profileRequestContext, IdPEventIds.UNABLE_ENCODE_ATTRIBUTE);
            return;
        } finally {
            if (null != component) {
                component.unpinComponent();
            }
        }
    }

    // Checkstyle: CyclomaticComplexity OFF

    /**
     * Access the registry of transcoding rules to transform the input attribute into claims.
     * 
     * @param registry  registry of transcoding rules
     * @param profileRequestContext current profile request context
     * @param attribute input attribute
     * 
     * @throws AttributeEncodingException if a non-ignorable error occurs
     */
    private void encodeAttribute(@Nonnull final AttributeTranscoderRegistry registry,
            @Nonnull final ProfileRequestContext profileRequestContext, @Nonnull final IdPAttribute attribute)
                    throws AttributeEncodingException {
        
        final Collection<TranscodingRule> transcodingRules = registry.getTranscodingRules(attribute, JSONObject.class);
        if (transcodingRules.isEmpty()) {
            log.debug("{} Attribute {} does not have any transcoding rules, nothing to do", getLogPrefix(),
                    attribute.getId());
            return;
        }
        
        final OIDCAuthenticationResponseTokenClaimsContext tokenClaimsCtx =
                getOidcResponseContext().getSubcontext(OIDCAuthenticationResponseTokenClaimsContext.class, true);
        final AttributesMapContainer requestedToIdTokenContainer =
                getOidcResponseContext().getMappedIdTokenRequestedClaims();
        
        for (final TranscodingRule rule : transcodingRules) {
            try {
                final AttributeTranscoder<JSONObject> transcoder = TranscoderSupport.<JSONObject>getTranscoder(rule);
                final boolean requestedToIdToken = requestedToIdTokenContainer != null
                        && requestedToIdTokenContainer.get().containsKey(attribute.getId());

                if (alwaysIncludedAttributes.contains(attribute.getId()) &&
                        !deniedUserInfoAttributes.contains(attribute.getId())) {
                    // Deliver for UserInfo and ID token
                    final JSONObject encodedAttribute =
                            transcoder.encode(profileRequestContext, attribute, JSONObject.class, rule);
                    if (encodedAttribute != null) {
                        encodedAttribute.keySet().forEach(
                                k -> tokenClaimsCtx.getClaims().setClaim(k, encodedAttribute.get(k)));
                    }
                } else if (alwaysIncludedAttributes.contains(attribute.getId())) {
                    // Deliver only for ID token
                    final JSONObject encodedAttribute =
                            transcoder.encode(profileRequestContext, attribute, JSONObject.class, rule);
                    if (encodedAttribute != null) {
                        encodedAttribute.keySet().forEach(
                                k -> tokenClaimsCtx.getIdtokenClaims().setClaim(k, encodedAttribute.get(k)));
                    }
                } else if (!deniedUserInfoAttributes.contains(attribute.getId())) {
                    // Deliver only for UserInfo token, unless requested in ID token too
                    final JSONObject encodedAttribute =
                            transcoder.encode(profileRequestContext, attribute, JSONObject.class, rule);
                    if (encodedAttribute != null) {
                        if (requestedToIdToken) {
                            encodedAttribute.keySet().forEach(
                                    k -> tokenClaimsCtx.getClaims().setClaim(k, encodedAttribute.get(k)));
                        } else {
                            encodedAttribute.keySet().forEach(
                                    k -> tokenClaimsCtx.getUserinfoClaims().setClaim(k, encodedAttribute.get(k)));
                        }
                    }
                } else if (requestedToIdToken) {
                    // Deliver only in ID token, if requested
                    final JSONObject encodedAttribute =
                            transcoder.encode(profileRequestContext, attribute, JSONObject.class, rule);
                    if (encodedAttribute != null) {
                        encodedAttribute.keySet().forEach(
                                k -> tokenClaimsCtx.getIdtokenClaims().setClaim(k, encodedAttribute.get(k)));
                    }
                }
                
            } catch (final AttributeEncodingException e) {
                log.warn("{} Unable to encode attribute {}", getLogPrefix(), attribute.getId(), e);
                if (!ignoringUnencodableAttributes) {
                    throw e;
                }
            }
        }
    }

    // Checkstyle: CyclomaticComplexity ON

}