package org.jfrog.security.crypto.signing.gpg;

import org.apache.commons.codec.digest.DigestUtils;
import org.apache.commons.lang3.StringUtils;
import org.bouncycastle.asn1.nist.NISTNamedCurves;
import org.bouncycastle.asn1.x9.ECNamedCurveTable;
import org.bouncycastle.bcpg.ECDSAPublicBCPGKey;
import org.bouncycastle.bcpg.PublicKeyAlgorithmTags;
import org.bouncycastle.bcpg.PublicKeyPacket;
import org.bouncycastle.openpgp.PGPException;
import org.bouncycastle.openpgp.PGPObjectFactory;
import org.bouncycastle.openpgp.PGPPrivateKey;
import org.bouncycastle.openpgp.PGPPublicKey;
import org.bouncycastle.openpgp.PGPPublicKeyRing;
import org.bouncycastle.openpgp.PGPSecretKey;
import org.bouncycastle.openpgp.PGPSecretKeyRing;
import org.bouncycastle.openpgp.PGPSecretKeyRingCollection;
import org.bouncycastle.openpgp.PGPUtil;
import org.bouncycastle.openpgp.operator.KeyFingerPrintCalculator;
import org.bouncycastle.openpgp.operator.PBESecretKeyDecryptor;
import org.bouncycastle.openpgp.operator.bc.BcKeyFingerprintCalculator;
import org.bouncycastle.openpgp.operator.jcajce.JcaPGPKeyConverter;
import org.bouncycastle.openpgp.operator.jcajce.JcePBESecretKeyDecryptorBuilder;
import org.jfrog.security.crypto.exception.CryptoRuntimeException;
import org.jfrog.security.util.BCProviderFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.security.PrivateKey;
import java.security.Provider;
import java.security.PublicKey;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

import static org.bouncycastle.openpgp.PGPUtil.getDecoderStream;

/**
 * Created by tomerm on 4.7.17.
 */
public class PGPKeyParser {

    private static final Logger log = LoggerFactory.getLogger(PGPKeyParser.class);
    private static Map<String, String> curveToJwsAlgoNameConverter = new HashMap<>();

    static {
        curveToJwsAlgoNameConverter.put("P-256", "ES256");
        curveToJwsAlgoNameConverter.put("P-384", "ES384");
        curveToJwsAlgoNameConverter.put("P-521", "ES512");
    }

    public static WrappedPrivateKey privateKeyParse(byte[] privateKey, String passphrase) {
        try {
            PGPSecretKey pgpSec = findSecretGPGKey(privateKey);
            PGPPublicKey pgpPublicKey = pgpSec.getPublicKey();
            String kid = calculateKeyId(pgpPublicKey);
            PBESecretKeyDecryptor decryptor = new JcePBESecretKeyDecryptorBuilder()
                    .setProvider(BCProviderFactory.getProvider())
                    .build(passphrase.toCharArray());
            PGPPrivateKey pgpPriv = pgpSec.extractPrivateKey(decryptor);
            String algName = getAlgorithmFromKey(pgpPriv);
            JcaPGPKeyConverter converter = new JcaPGPKeyConverter();
            Provider bcProvider = BCProviderFactory.getProvider();
            converter.setProvider(bcProvider);
            PrivateKey privateKeyObj = converter.getPrivateKey(pgpPriv);
            return new WrappedPrivateKey(privateKeyObj, kid, algName);
        } catch (Exception e) {
            log.error("Error Parsing GPG Private Key " + e.getMessage());
            throw new CryptoRuntimeException(e);
        }
    }

    public static String calculateKeyId(PGPPublicKey pgpPublicKey) throws IOException {
        return DigestUtils.sha256Hex(pgpPublicKey.getEncoded()).substring(0, 6);
    }

    public static PGPPublicKey pgpPublicKeyParse(byte[] publicKeyBytes) {
        try {
            InputStream pgpIn = PGPUtil.getDecoderStream(new ByteArrayInputStream(publicKeyBytes));
            PGPObjectFactory pgpFact = new PGPObjectFactory(pgpIn, new BcKeyFingerprintCalculator());
            PGPPublicKeyRing pgpSecRing = (PGPPublicKeyRing) pgpFact.nextObject();
            return pgpSecRing.getPublicKey();
        } catch (IOException e) {
            log.error("Error Parsing GPG Public Key " + e.getMessage());
            throw new CryptoRuntimeException(e);
        }
    }

    public static PublicKey publicKeyParse(byte[] publicKeyBytes) {
        try {
            PGPPublicKey publicKey = pgpPublicKeyParse(publicKeyBytes);
            JcaPGPKeyConverter converter = new JcaPGPKeyConverter();
            Provider bcProvider = BCProviderFactory.getProvider();
            converter.setProvider(bcProvider);
            return converter.getPublicKey(publicKey);
        } catch (Exception e) {
            log.error("Error Parsing GPG Public Key " + e.getMessage());
            throw new CryptoRuntimeException(e);
        }
    }

    public static PGPSecretKey findSecretGPGKey(byte[] privateKey) {
        try {
            InputStream decoderStream = getDecoderStream(new ByteArrayInputStream(privateKey));
            KeyFingerPrintCalculator keyFingerPrintCalculator = new BcKeyFingerprintCalculator();
            PGPSecretKeyRingCollection secretKeyRings = new PGPSecretKeyRingCollection(decoderStream,
                    keyFingerPrintCalculator);
            Iterator privateKeys = secretKeyRings.getKeyRings();
            if (privateKeys.hasNext()) {
                return ((PGPSecretKeyRing) privateKeys.next()).getSecretKey();
            } else {
                throw new PGPException("No private key found!");
            }
        } catch (Exception e) {
            log.error(e.getMessage());
            throw new CryptoRuntimeException(e);
        }
    }

    public static String getAlgorithmFromKey(PGPPrivateKey privKey) {
        PublicKeyPacket pubPk = privKey.getPublicKeyPacket();
        switch (pubPk.getAlgorithm()) {
            case PGPPublicKey.RSA_ENCRYPT:
            case PGPPublicKey.RSA_GENERAL:
            case PGPPublicKey.RSA_SIGN:
                return "RSA";

            case PublicKeyAlgorithmTags.ECDSA:
                ECDSAPublicBCPGKey ecdsaPub = (ECDSAPublicBCPGKey) pubPk.getKey();

                // Bouncycastle changed the order of the ECNamedCurveTable parsing, using the old first algorithm
                // resolver
                String curveName = NISTNamedCurves.getName(ecdsaPub.getCurveOID());
                if (StringUtils.isBlank(curveName)) {
                    curveName = ECNamedCurveTable.getName(ecdsaPub.getCurveOID());
                }

                String convertedVal = curveToJwsAlgoNameConverter.get(curveName);
                return convertedVal == null ? curveName : convertedVal;
        }

        return null;
    }

}
