/*
 * Licensed to the University Corporation for Advanced Internet Development,
 * Inc. (UCAID) under one or more contributor license agreements.  See the
 * NOTICE file distributed with this work for additional information regarding
 * copyright ownership. The UCAID licenses this file to You under the Apache
 * License, Version 2.0 (the "License"); you may not use this file except in
 * compliance with the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.shibboleth.idp.plugin.oidc.op.oauth2.messaging.impl;

import java.util.function.Function;

import javax.annotation.Nonnull;

import org.opensaml.messaging.context.MessageContext;
import org.opensaml.messaging.handler.AbstractMessageHandler;
import org.opensaml.messaging.handler.MessageHandler;
import org.opensaml.messaging.handler.MessageHandlerException;
import org.opensaml.saml.common.messaging.context.AbstractSAMLEntityContext;
import org.opensaml.saml.common.messaging.context.SAMLPeerEntityContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.nimbusds.oauth2.sdk.id.ClientID;

import net.shibboleth.idp.plugin.oidc.op.profile.context.navigate.DefaultClientIDLookupFunction;
import net.shibboleth.utilities.java.support.component.ComponentSupport;
import net.shibboleth.utilities.java.support.logic.Constraint;

/**
 * {@link MessageHandler} that sets the entityID to the given {@link SAMLPeerEntityContext} class.
 * The value to be set is resolved via the given lookup strategy for client ID.
 */
public class SetEntityIdToSAMLPeerEntityContext extends AbstractMessageHandler {

    /** Class logger. */
    @Nonnull
    private Logger log = LoggerFactory.getLogger(SetEntityIdToSAMLPeerEntityContext.class);
    
    /** Strategy used to obtain the client id value for authorize/token request. */
    @Nonnull private Function<MessageContext, ClientID> clientIDLookupStrategy;

    /** The context class representing the SAML entity for whom data is to be attached. 
     * Defaults to: {@link SAMLPeerEntityContext}. */
    @Nonnull private Class<? extends AbstractSAMLEntityContext> entityContextClass = SAMLPeerEntityContext.class;

    /**
     * Constructor.
     */
    public SetEntityIdToSAMLPeerEntityContext() {
        clientIDLookupStrategy = new DefaultClientIDLookupFunction();
    }
    
    /**
     * Set the strategy used to locate the client id of the request.
     * 
     * @param strategy lookup strategy
     */
    public void setClientIDLookupStrategy(@Nonnull final Function<MessageContext, ClientID> strategy) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
        clientIDLookupStrategy =
                Constraint.isNotNull(strategy, "ClientIDLookupStrategy lookup strategy cannot be null");
    }
    
    /**
     * Set the class type holding the SAML entity data.
     * 
     * <p>Defaults to: {@link SAMLPeerEntityContext}.</p>
     * 
     * @param clazz the entity context class type
     */
    public void setEntityContextClass(@Nonnull final Class<? extends AbstractSAMLEntityContext> clazz) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
        
        entityContextClass = Constraint.isNotNull(clazz, "SAML entity context class may not be null");
    }
    
    /** {@inheritDoc} */
    @Override
    protected void doInvoke(@Nonnull final MessageContext messageContext) throws MessageHandlerException {
        
        final AbstractSAMLEntityContext entityCtx = messageContext.getSubcontext(entityContextClass);
        if (entityCtx == null) {
            throw new MessageHandlerException("Unable to locate subcontext of type " + entityContextClass);
        }
        
        final ClientID clientID = clientIDLookupStrategy.apply(messageContext);
        if (clientID != null) {
            log.debug("{} Set clientID '{}' to the peer entity context", getLogPrefix(), clientID.getValue());
            entityCtx.setEntityId(clientID.getValue());            
        } else {
            log.debug("{} No clientID could be resolved, nothing to do", getLogPrefix());
        }
    }

}