/*
 * Decompiled with CFR 0.152.
 */
package net.luminis.tls.handshake;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import net.luminis.tls.TlsConstants;
import net.luminis.tls.TlsProtocolException;
import net.luminis.tls.alert.DecodeErrorException;
import net.luminis.tls.alert.IllegalParameterAlert;
import net.luminis.tls.extension.ApplicationLayerProtocolNegotiationExtension;
import net.luminis.tls.extension.CertificateAuthoritiesExtension;
import net.luminis.tls.extension.ClientHelloPreSharedKeyExtension;
import net.luminis.tls.extension.EarlyDataExtension;
import net.luminis.tls.extension.Extension;
import net.luminis.tls.extension.ExtensionParser;
import net.luminis.tls.extension.KeyShareExtension;
import net.luminis.tls.extension.PskKeyExchangeModesExtension;
import net.luminis.tls.extension.ServerNameExtension;
import net.luminis.tls.extension.ServerPreSharedKeyExtension;
import net.luminis.tls.extension.SignatureAlgorithmsExtension;
import net.luminis.tls.extension.SupportedGroupsExtension;
import net.luminis.tls.extension.SupportedVersionsExtension;
import net.luminis.tls.extension.UnknownExtension;
import net.luminis.tls.log.Logger;

public abstract class HandshakeMessage {
    public abstract TlsConstants.HandshakeType getType();

    protected int parseHandshakeHeader(ByteBuffer buffer, TlsConstants.HandshakeType expectedType, int minimumMessageSize) throws DecodeErrorException {
        if (buffer.remaining() < 4) {
            throw new DecodeErrorException("handshake message underflow");
        }
        int handshakeType = buffer.get() & 0xFF;
        if (handshakeType != expectedType.value) {
            throw new IllegalStateException();
        }
        int messageDataLength = (buffer.get() & 0xFF) << 16 | (buffer.get() & 0xFF) << 8 | buffer.get() & 0xFF;
        if (4 + messageDataLength < minimumMessageSize) {
            throw new DecodeErrorException(this.getClass().getSimpleName() + " can't be less than " + minimumMessageSize + " bytes");
        }
        if (buffer.remaining() < messageDataLength) {
            throw new DecodeErrorException("handshake message underflow");
        }
        return messageDataLength;
    }

    public abstract byte[] getBytes();

    static List<Extension> parseExtensions(ByteBuffer buffer, TlsConstants.HandshakeType context) throws TlsProtocolException {
        return HandshakeMessage.parseExtensions(buffer, context, null);
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    static List<Extension> parseExtensions(ByteBuffer buffer, TlsConstants.HandshakeType context, ExtensionParser customExtensionParser) throws TlsProtocolException {
        int extensionLength;
        int remainingExtensionsLength;
        if (buffer.remaining() < 2) {
            throw new DecodeErrorException("Extension field must be at least 2 bytes long");
        }
        ArrayList<Extension> extensions = new ArrayList<Extension>();
        if (buffer.remaining() < remainingExtensionsLength) {
            throw new DecodeErrorException("Extensions too short");
        }
        for (remainingExtensionsLength = buffer.getShort() & 0xFFFF; remainingExtensionsLength >= 4; remainingExtensionsLength -= extensionLength) {
            buffer.mark();
            int extensionType = buffer.getShort() & 0xFFFF;
            extensionLength = buffer.getShort() & 0xFFFF;
            buffer.reset();
            if (extensionLength > (remainingExtensionsLength -= 4)) {
                throw new DecodeErrorException("Extension length exceeds extensions length");
            }
            int extensionStartPosition = buffer.position();
            if (extensionType == TlsConstants.ExtensionType.server_name.value) {
                extensions.add(new ServerNameExtension(buffer));
            } else if (extensionType == TlsConstants.ExtensionType.supported_groups.value) {
                extensions.add(new SupportedGroupsExtension(buffer));
            } else if (extensionType == TlsConstants.ExtensionType.signature_algorithms.value) {
                extensions.add(new SignatureAlgorithmsExtension(buffer));
            } else if (extensionType == TlsConstants.ExtensionType.application_layer_protocol_negotiation.value) {
                extensions.add(new ApplicationLayerProtocolNegotiationExtension(buffer));
            } else if (extensionType == TlsConstants.ExtensionType.pre_shared_key.value) {
                if (context == TlsConstants.HandshakeType.server_hello) {
                    extensions.add(new ServerPreSharedKeyExtension().parse(buffer));
                } else {
                    if (context != TlsConstants.HandshakeType.client_hello) throw new IllegalParameterAlert("Extension not allowed in " + (Object)((Object)Arrays.stream(TlsConstants.HandshakeType.values()).filter(it -> it.value == context.value).findFirst().get()));
                    extensions.add(new ClientHelloPreSharedKeyExtension().parse(buffer));
                }
            } else if (extensionType == TlsConstants.ExtensionType.early_data.value) {
                extensions.add(new EarlyDataExtension(buffer, context));
            } else if (extensionType == TlsConstants.ExtensionType.supported_versions.value) {
                extensions.add(new SupportedVersionsExtension(buffer, context));
            } else if (extensionType == TlsConstants.ExtensionType.psk_key_exchange_modes.value) {
                extensions.add(new PskKeyExchangeModesExtension(buffer));
            } else if (extensionType == TlsConstants.ExtensionType.certificate_authorities.value) {
                extensions.add(new CertificateAuthoritiesExtension(buffer));
            } else if (extensionType == TlsConstants.ExtensionType.key_share.value) {
                extensions.add(new KeyShareExtension(buffer, context));
            } else {
                Extension extension = null;
                if (customExtensionParser != null) {
                    extension = customExtensionParser.apply(buffer, context);
                }
                if (extension != null) {
                    extensions.add(extension);
                } else {
                    Logger.debug("Unsupported extension, type is: " + extensionType);
                    extensions.add(new UnknownExtension().parse(buffer));
                }
            }
            if (buffer.position() - extensionStartPosition == 4 + extensionLength) continue;
            throw new DecodeErrorException("Incorrect extension length");
        }
        return extensions;
    }

    public static int findPositionLastExtension(ByteBuffer buffer) {
        int extensionsLength;
        int length;
        int lastExtensionStart = 0;
        for (int remaining = extensionsLength = buffer.getShort() & 0xFFFF; remaining > 4; remaining -= 4 + length) {
            lastExtensionStart = buffer.position();
            short type = buffer.getShort();
            length = buffer.getShort() & 0xFFFF;
            buffer.get(new byte[length]);
        }
        return lastExtensionStart;
    }
}

