package org.jfrog.security.crypto.signing.rsa;

import lombok.experimental.UtilityClass;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.lang.StringUtils;
import org.bouncycastle.asn1.pkcs.PrivateKeyInfo;
import org.bouncycastle.openssl.PEMDecryptorProvider;
import org.bouncycastle.openssl.PEMEncryptedKeyPair;
import org.bouncycastle.openssl.PEMKeyPair;
import org.bouncycastle.openssl.PEMParser;
import org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter;
import org.bouncycastle.openssl.jcajce.JcePEMDecryptorProviderBuilder;
import org.jfrog.security.util.BCProviderFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nullable;
import java.io.IOException;
import java.io.StringReader;
import java.security.KeyFactory;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.spec.X509EncodedKeySpec;

/**
 * @author BarakH
 */
@UtilityClass
public class RSAKeyParser {
    private final Logger log = LoggerFactory.getLogger(RSAKeyParser.class);
    private final byte[] DATA_TO_SIGN = "ACBDE".getBytes();

    public void verifyRSAKeyPair(String privateKeyString, String publicKeyString, @Nullable String pass)
            throws Exception {
        try {
            PrivateKey privateKey = parsePrivateKey(privateKeyString, pass);
            PublicKey publicKey = parsePublicKey(publicKeyString);
            byte[] signature = RSASigner.signDataWithSha1(DATA_TO_SIGN, privateKey);
            RSASigner.verifySha1Signature(signature, DATA_TO_SIGN, publicKey);
        } catch (Exception e) {
            log.error("Failed to verify key pair with an error {}", e.getMessage());
            log.debug("Failed to verify key pair with an error", e);
            throw new IllegalArgumentException("Failed to verify key pair", e);
        }
    }

    public PrivateKey parsePrivateKey(String privateKey, @Nullable String password) throws IOException {
        log.debug("Parsing RSA private key");
        PEMParser pemParser = new PEMParser(new StringReader(privateKey));
        JcaPEMKeyConverter converter = new JcaPEMKeyConverter().setProvider(BCProviderFactory.getProvider());
        if (StringUtils.isNotEmpty(password)) {
            PEMEncryptedKeyPair encryptedKeyPair = (PEMEncryptedKeyPair) pemParser.readObject();
            PEMDecryptorProvider decryptorProvider = new JcePEMDecryptorProviderBuilder().build(password.toCharArray());
            PEMKeyPair pemKeyPair = encryptedKeyPair.decryptKeyPair(decryptorProvider);
            return converter.getPrivateKey(pemKeyPair.getPrivateKeyInfo());
        } else {
            Object key = pemParser.readObject();
            if (key instanceof PEMKeyPair) {
                PEMKeyPair pemKeyPair = (PEMKeyPair) key;
                return converter.getPrivateKey(pemKeyPair.getPrivateKeyInfo());
            } else if (key instanceof PrivateKeyInfo) {
                return converter.getPrivateKey((PrivateKeyInfo) key);
            } else {
                throw new IllegalArgumentException("Unknown RSA key type");
            }
        }
    }

    public PublicKey parsePublicKey(String publicKey) throws NoSuchAlgorithmException {
        log.debug("Parsing RSA public key");
        byte[] decodedPub = decodePublic(publicKey);
        X509EncodedKeySpec spec = new X509EncodedKeySpec(decodedPub);
        KeyFactory kf = KeyFactory.getInstance("RSA");
        try {
            return kf.generatePublic(spec);
        } catch (Exception e) {
            log.error("Failed to parse public key with an error {}", e.getMessage());
            log.debug("Failed to parse public key with an error", e);
            throw new IllegalArgumentException("Failed to parse RSA public key");
        }
    }

    private byte[] decodePublic(String publicKey) {
        int headerEndOffset = publicKey.indexOf(10);
        int footerStartOffset = publicKey.indexOf("-----END");
        if (headerEndOffset != -1 || footerStartOffset != -1) {
            publicKey = publicKey.substring(headerEndOffset + 1, footerStartOffset - 1);
        }
        return Base64.decodeBase64(publicKey);
    }
}
