/*
 * Licensed to the University Corporation for Advanced Internet Development,
 * Inc. (UCAID) under one or more contributor license agreements.  See the
 * NOTICE file distributed with this work for additional information regarding
 * copyright ownership. The UCAID licenses this file to You under the Apache
 * License, Version 2.0 (the "License"); you may not use this file except in
 * compliance with the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.shibboleth.idp.plugin.oidc.op.oauth2.profile.impl;

import java.security.interfaces.ECPublicKey;
import java.security.interfaces.RSAPublicKey;
import java.text.ParseException;
import java.util.function.Function;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.opensaml.messaging.context.navigate.ChildContextLookup;
import org.opensaml.profile.action.ActionSupport;
import org.opensaml.profile.action.EventIds;
import org.opensaml.profile.context.ProfileRequestContext;
import org.opensaml.saml.saml2.profile.context.EncryptionContext;
import org.opensaml.security.credential.Credential;
import org.opensaml.xmlsec.EncryptionParameters;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.nimbusds.jose.EncryptionMethod;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWEAlgorithm;
import com.nimbusds.jose.JWEHeader;
import com.nimbusds.jose.JWEObject;
import com.nimbusds.jose.Payload;
import com.nimbusds.jose.crypto.AESEncrypter;
import com.nimbusds.jose.crypto.ECDHEncrypter;
import com.nimbusds.jose.crypto.RSAEncrypter;
import com.nimbusds.jwt.EncryptedJWT;

import net.shibboleth.idp.plugin.oidc.op.profile.impl.AbstractOIDCResponseAction;
import net.shibboleth.idp.profile.context.RelyingPartyContext;
import net.shibboleth.oidc.security.impl.CredentialConversionUtil;
import net.shibboleth.utilities.java.support.component.ComponentSupport;
import net.shibboleth.utilities.java.support.logic.Constraint;

/**
 * Action that encrypts a source object into an {@link EncryptedJWT}.
 * 
 * <p> The existence of encryption parameters is used to decide whether the encryption should take place.
 * 
 * @since 3.1.0
 */
public abstract class AbstractEncryptTokenAction extends AbstractOIDCResponseAction {

    /** Class logger. */
    @Nonnull private Logger log = LoggerFactory.getLogger(AbstractEncryptTokenAction.class);

    /** Strategy used to look up the {@link EncryptionContext} to store parameters in. */
    @Nonnull private Function<ProfileRequestContext,EncryptionContext> encryptionContextLookupStrategy;

    /** Encryption parameters for encrypting payload. */
    @Nullable private EncryptionParameters params;

    /**
     * Constructor.
     */
    public AbstractEncryptTokenAction() {
        encryptionContextLookupStrategy = new ChildContextLookup<>(EncryptionContext.class).compose(
                new ChildContextLookup<>(RelyingPartyContext.class));
    }

    /**
     * Set the strategy used to look up the {@link EncryptionContext} to set the flags for.
     * 
     * @param strategy lookup strategy
     */
    public void setEncryptionContextLookupStrategy(
            @Nonnull final Function<ProfileRequestContext,EncryptionContext> strategy) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);

        encryptionContextLookupStrategy =
                Constraint.isNotNull(strategy, "EncryptionContext lookup strategy cannot be null");
    }

    /** {@inheritDoc} */
    @Override
    protected boolean doPreExecute(@Nonnull final ProfileRequestContext profileRequestContext) {
        if (!super.doPreExecute(profileRequestContext)) {
            return false;
        }

        final EncryptionContext encryptCtx = encryptionContextLookupStrategy.apply(profileRequestContext);
        if (encryptCtx == null) {
            log.error("{} No EncryptionContext returned by lookup strategy", getLogPrefix());
            ActionSupport.buildEvent(profileRequestContext, EventIds.INVALID_PROFILE_CTX);
            return false;
        }
        params = encryptCtx.getAssertionEncryptionParameters();
        if (params == null) {
            log.debug("{} No Encryption parameters, nothing to do", getLogPrefix());
            return false;
        }
        
        return true;
    }

    /** {@inheritDoc} */
    @Override
    protected void doExecute(@Nonnull final ProfileRequestContext profileRequestContext) {

        final Payload payload = getPayload(profileRequestContext);
        if (payload == null) {
            log.debug("{} No plain text source provided to encrypt", getLogPrefix());
            return;
        }
        
        final JWEAlgorithm encAlg = JWEAlgorithm.parse(params.getKeyTransportEncryptionAlgorithm());
        final Credential credential = params.getKeyTransportEncryptionCredential();
        final EncryptionMethod encEnc = EncryptionMethod.parse(params.getDataEncryptionAlgorithm());
        final String kid = CredentialConversionUtil.resolveKid(credential);

        log.debug("{} Encrypting with kid {} and params alg: {} enc: {}", getLogPrefix(), kid, encAlg.getName(),
                encEnc.getName());

        final JWEObject jweObject =
                new JWEObject(new JWEHeader.Builder(encAlg, encEnc).contentType("JWT").keyID(kid).build(), payload);
        try {
            if (JWEAlgorithm.Family.RSA.contains(encAlg)) {
                jweObject.encrypt(new RSAEncrypter((RSAPublicKey) credential.getPublicKey()));
            } else if (JWEAlgorithm.Family.ECDH_ES.contains(encAlg)) {
                jweObject.encrypt(new ECDHEncrypter((ECPublicKey) credential.getPublicKey()));
            } else if (JWEAlgorithm.Family.SYMMETRIC.contains(encAlg)) {
                jweObject.encrypt(new AESEncrypter(credential.getSecretKey()));
            } else {
                log.error("{} Unsupported algorithm {}", getLogPrefix(), encAlg.getName());
                ActionSupport.buildEvent(profileRequestContext, EventIds.UNABLE_TO_ENCRYPT);
            }
            setProcessedToken(profileRequestContext, EncryptedJWT.parse(jweObject.serialize()));
        } catch (final JOSEException | ParseException e) {
            log.error("{} Encryption failed {}", getLogPrefix(), e);
            ActionSupport.buildEvent(profileRequestContext, EventIds.UNABLE_TO_ENCRYPT);
        }
    }

    /**
     * Get the payload to encrypt.
     * 
     * @param profileRequestContext profile request context
     * 
     * @return payload to encrypt
     */
    @Nonnull protected abstract Payload getPayload(@Nonnull final ProfileRequestContext profileRequestContext);
    
    /**
     * Store the resulting token.
     * 
     * @param profileRequestContext profile request context
     * @param token encrypted token
     */
    protected abstract void setProcessedToken(@Nonnull final ProfileRequestContext profileRequestContext,
            @Nonnull final EncryptedJWT token);
            
}