/*
 * Decompiled with CFR 0.152.
 */
package org.eclipse.californium.scandium.dtls;

import java.net.InetSocketAddress;
import java.security.GeneralSecurityException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.cert.Certificate;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.crypto.Mac;
import javax.crypto.SecretKey;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import org.eclipse.californium.scandium.dtls.AlertMessage;
import org.eclipse.californium.scandium.dtls.CompressionMethod;
import org.eclipse.californium.scandium.dtls.ContentType;
import org.eclipse.californium.scandium.dtls.DTLSConnectionState;
import org.eclipse.californium.scandium.dtls.DTLSFlight;
import org.eclipse.californium.scandium.dtls.DTLSMessage;
import org.eclipse.californium.scandium.dtls.DTLSSession;
import org.eclipse.californium.scandium.dtls.FragmentedHandshakeMessage;
import org.eclipse.californium.scandium.dtls.HandshakeException;
import org.eclipse.californium.scandium.dtls.HandshakeMessage;
import org.eclipse.californium.scandium.dtls.HandshakeType;
import org.eclipse.californium.scandium.dtls.ProtocolVersion;
import org.eclipse.californium.scandium.dtls.Random;
import org.eclipse.californium.scandium.dtls.Record;
import org.eclipse.californium.scandium.dtls.SessionListener;
import org.eclipse.californium.scandium.dtls.cipher.CipherSuite;
import org.eclipse.californium.scandium.dtls.cipher.ECDHECryptography;
import org.eclipse.californium.scandium.util.ByteArrayUtils;

public abstract class Handshaker {
    private static final String MESSAGE_DIGEST_ALGORITHM_NAME = "SHA-256";
    protected static final Logger LOGGER = Logger.getLogger(Handshaker.class.getCanonicalName());
    public static final int MASTER_SECRET_LABEL = 1;
    public static final int KEY_EXPANSION_LABEL = 2;
    public static final int CLIENT_FINISHED_LABEL = 3;
    public static final int SERVER_FINISHED_LABEL = 4;
    protected boolean isClient;
    protected int state = -1;
    protected ProtocolVersion usedProtocol;
    protected Random clientRandom;
    protected Random serverRandom;
    private CipherSuite cipherSuite;
    private CompressionMethod compressionMethod;
    protected CipherSuite.KeyExchangeAlgorithm keyExchange;
    protected ECDHECryptography ecdhe;
    private byte[] masterSecret;
    private SecretKey clientWriteMACKey;
    private SecretKey serverWriteMACKey;
    private IvParameterSpec clientWriteIV;
    private IvParameterSpec serverWriteIV;
    private SecretKey clientWriteKey;
    private SecretKey serverWriteKey;
    protected DTLSSession session = null;
    private int sequenceNumber = 0;
    private int nextReceiveSeq = 0;
    protected Collection<Record> queuedMessages;
    protected Map<Integer, List<FragmentedHandshakeMessage>> fragmentedMessages = new HashMap<Integer, List<FragmentedHandshakeMessage>>();
    protected MessageDigest md;
    protected byte[] handshakeMessages = new byte[0];
    protected DTLSFlight lastFlight = null;
    protected PrivateKey privateKey;
    protected PublicKey publicKey;
    protected Certificate[] certificates;
    protected final Certificate[] rootCertificates;
    private int maxFragmentLength = 4096;
    private SessionListener sessionListener;

    protected Handshaker(boolean isClient, DTLSSession session, SessionListener sessionListener, Certificate[] rootCertificates, int maxFragmentLength) throws HandshakeException {
        this(isClient, 0, session, sessionListener, rootCertificates, maxFragmentLength);
    }

    protected Handshaker(boolean isClient, int initialMessageSeq, DTLSSession session, SessionListener sessionListener, Certificate[] rootCertificates, int maxFragmentLength) throws HandshakeException {
        if (session == null) {
            throw new NullPointerException("DTLS Session must not be null");
        }
        if (initialMessageSeq < 0) {
            throw new IllegalArgumentException("Initial message sequence number must not be negative");
        }
        this.sessionListener = sessionListener;
        this.nextReceiveSeq = initialMessageSeq;
        this.sequenceNumber = initialMessageSeq;
        this.isClient = isClient;
        this.session = session;
        this.queuedMessages = new HashSet<Record>();
        this.rootCertificates = rootCertificates == null ? new Certificate[]{} : rootCertificates;
        this.maxFragmentLength = maxFragmentLength;
        try {
            this.md = MessageDigest.getInstance(MESSAGE_DIGEST_ALGORITHM_NAME);
        }
        catch (NoSuchAlgorithmException e) {
            LOGGER.log(Level.SEVERE, "Could not initialize message digest algorithm for Handshaker.", e);
            throw new HandshakeException("Could not initialize handshake", new AlertMessage(AlertMessage.AlertLevel.FATAL, AlertMessage.AlertDescription.INTERNAL_ERROR));
        }
    }

    public Handshaker(InetSocketAddress peerAddress, boolean isClient, DTLSSession session, Certificate[] rootCertificates) throws HandshakeException {
        if (session == null) {
            throw new NullPointerException("DTLS Session must not be null");
        }
        if (!session.getPeer().equals(peerAddress)) {
            throw new IllegalArgumentException("Peer address must be the same as in session");
        }
        this.isClient = isClient;
        this.session = session;
        this.queuedMessages = new HashSet<Record>();
        this.rootCertificates = rootCertificates == null ? new Certificate[]{} : rootCertificates;
        try {
            this.md = MessageDigest.getInstance(MESSAGE_DIGEST_ALGORITHM_NAME);
        }
        catch (NoSuchAlgorithmException e) {
            LOGGER.log(Level.SEVERE, "Could not initialize message digest algorithm for Handshaker.", e);
            throw new HandshakeException("Could not initialize handshake", new AlertMessage(AlertMessage.AlertLevel.FATAL, AlertMessage.AlertDescription.INTERNAL_ERROR));
        }
    }

    public final DTLSFlight processMessage(Record message) throws HandshakeException {
        DTLSFlight nextFlight = null;
        if (!this.session.isDuplicate(message.getSequenceNumber())) {
            try {
                message.setSession(this.session);
                nextFlight = this.doProcessMessage(message);
                this.session.markRecordAsRead(message.getEpoch(), message.getSequenceNumber());
            }
            catch (GeneralSecurityException e) {
                LOGGER.log(Level.WARNING, String.format("Cannot process handshake message from peer [%s] due to [%s]", this.getSession().getPeer(), e.getMessage()), e);
                AlertMessage alert = new AlertMessage(AlertMessage.AlertLevel.FATAL, AlertMessage.AlertDescription.INTERNAL_ERROR);
                throw new HandshakeException("Cannot process handshake message", alert);
            }
        } else {
            LOGGER.log(Level.FINER, "Discarding duplicate HANDSHAKE message received from peer [{0}]:\n{1}", new Object[]{this.getPeerAddress(), message});
        }
        return nextFlight;
    }

    protected DTLSFlight doProcessMessage(Record record) throws HandshakeException, GeneralSecurityException {
        return null;
    }

    public abstract DTLSFlight getStartHandshakeMessage() throws HandshakeException;

    protected final void generateKeys(byte[] premasterSecret) {
        this.masterSecret = this.generateMasterSecret(premasterSecret);
        this.session.setMasterSecret(this.masterSecret);
        this.calculateKeys(this.masterSecret);
    }

    private void calculateKeys(byte[] masterSecret) {
        byte[] data = Handshaker.doPRF(masterSecret, 2, ByteArrayUtils.concatenate(this.serverRandom.getRandomBytes(), this.clientRandom.getRandomBytes()));
        if (this.cipherSuite == null) {
            this.cipherSuite = this.session.getCipherSuite();
        }
        int macKeyLength = this.cipherSuite.getMacKeyLength();
        int encKeyLength = this.cipherSuite.getEncKeyLength();
        int fixedIvLength = this.cipherSuite.getFixedIvLength();
        this.clientWriteMACKey = new SecretKeySpec(data, 0, macKeyLength, "Mac");
        this.serverWriteMACKey = new SecretKeySpec(data, macKeyLength, macKeyLength, "Mac");
        this.clientWriteKey = new SecretKeySpec(data, 2 * macKeyLength, encKeyLength, "AES");
        this.serverWriteKey = new SecretKeySpec(data, 2 * macKeyLength + encKeyLength, encKeyLength, "AES");
        this.clientWriteIV = new IvParameterSpec(data, 2 * macKeyLength + 2 * encKeyLength, fixedIvLength);
        this.serverWriteIV = new IvParameterSpec(data, 2 * macKeyLength + 2 * encKeyLength + fixedIvLength, fixedIvLength);
    }

    private byte[] generateMasterSecret(byte[] premasterSecret) {
        byte[] randomSeed = ByteArrayUtils.concatenate(this.clientRandom.getRandomBytes(), this.serverRandom.getRandomBytes());
        return Handshaker.doPRF(premasterSecret, 1, randomSeed);
    }

    protected final byte[] generatePremasterSecretFromPSK(byte[] psk) {
        int length = psk.length;
        byte[] lengthField = new byte[]{(byte)(length >> 8), (byte)length};
        byte[] zero = ByteArrayUtils.padArray(new byte[0], (byte)0, length);
        byte[] premasterSecret = ByteArrayUtils.concatenate(lengthField, ByteArrayUtils.concatenate(zero, ByteArrayUtils.concatenate(lengthField, psk)));
        return premasterSecret;
    }

    static byte[] doPRF(byte[] secret, byte[] label, byte[] seed, int length) {
        try {
            Mac hmac = Mac.getInstance("HmacSHA256");
            hmac.init(new SecretKeySpec(secret, "MAC"));
            return Handshaker.doExpansion(hmac, ByteArrayUtils.concatenate(label, seed), length);
        }
        catch (GeneralSecurityException e) {
            LOGGER.log(Level.SEVERE, "Message digest algorithm not available", e);
            return null;
        }
    }

    static final byte[] doPRF(byte[] secret, int labelId, byte[] seed) {
        int length;
        String label;
        switch (labelId) {
            case 1: {
                label = "master secret";
                length = 48;
                break;
            }
            case 2: {
                label = "key expansion";
                length = 128;
                break;
            }
            case 3: {
                label = "client finished";
                length = 12;
                break;
            }
            case 4: {
                label = "server finished";
                length = 12;
                break;
            }
            default: {
                LOGGER.log(Level.SEVERE, "Unknown label: {0}", labelId);
                return null;
            }
        }
        return Handshaker.doPRF(secret, label.getBytes(), seed, length);
    }

    static final byte[] doExpansion(Mac hmac, byte[] data, int length) {
        int iterations = (int)Math.ceil((double)length / (double)hmac.getMacLength());
        byte[] expansion = new byte[]{};
        byte[] A = data;
        for (int i = 0; i < iterations; ++i) {
            A = hmac.doFinal(A);
            expansion = ByteArrayUtils.concatenate(expansion, hmac.doFinal(ByteArrayUtils.concatenate(A, data)));
        }
        return ByteArrayUtils.truncate(expansion, length);
    }

    protected final void setCurrentReadState() {
        DTLSConnectionState connectionState = this.isClient ? new DTLSConnectionState(this.cipherSuite, this.compressionMethod, this.serverWriteKey, this.serverWriteIV, this.serverWriteMACKey) : new DTLSConnectionState(this.cipherSuite, this.compressionMethod, this.clientWriteKey, this.clientWriteIV, this.clientWriteMACKey);
        this.session.setReadState(connectionState);
    }

    protected final void setCurrentWriteState() {
        DTLSConnectionState connectionState = this.isClient ? new DTLSConnectionState(this.cipherSuite, this.compressionMethod, this.clientWriteKey, this.clientWriteIV, this.clientWriteMACKey) : new DTLSConnectionState(this.cipherSuite, this.compressionMethod, this.serverWriteKey, this.serverWriteIV, this.serverWriteMACKey);
        this.session.setWriteState(connectionState);
    }

    protected final List<Record> wrapMessage(DTLSMessage fragment) throws HandshakeException {
        try {
            switch (fragment.getContentType()) {
                case HANDSHAKE: {
                    return this.wrapHandshakeMessage((HandshakeMessage)fragment);
                }
            }
            ArrayList<Record> records = new ArrayList<Record>();
            records.add(new Record(fragment.getContentType(), this.session.getWriteEpoch(), this.session.getSequenceNumber(), fragment, this.session));
            return records;
        }
        catch (GeneralSecurityException e) {
            throw new HandshakeException("Cannot create record", new AlertMessage(AlertMessage.AlertLevel.FATAL, AlertMessage.AlertDescription.INTERNAL_ERROR));
        }
    }

    private List<Record> wrapHandshakeMessage(HandshakeMessage handshakeMessage) throws GeneralSecurityException {
        this.setSequenceNumber(handshakeMessage);
        ArrayList<Record> result = new ArrayList<Record>();
        byte[] messageBytes = handshakeMessage.fragmentToByteArray();
        if (messageBytes.length <= this.maxFragmentLength) {
            result.add(new Record(ContentType.HANDSHAKE, this.session.getWriteEpoch(), this.session.getSequenceNumber(), handshakeMessage, this.session));
        } else {
            int messageSeq = handshakeMessage.getMessageSeq();
            int numFragments = messageBytes.length / this.maxFragmentLength + 1;
            int offset = 0;
            for (int i = 0; i < numFragments; ++i) {
                int fragmentLength = this.maxFragmentLength;
                if (offset + fragmentLength > messageBytes.length) {
                    fragmentLength = messageBytes.length - offset;
                }
                byte[] fragmentBytes = new byte[fragmentLength];
                System.arraycopy(messageBytes, offset, fragmentBytes, 0, fragmentLength);
                FragmentedHandshakeMessage fragmentedMessage = new FragmentedHandshakeMessage(fragmentBytes, handshakeMessage.getMessageType(), offset, messageBytes.length);
                fragmentedMessage.setMessageSeq(messageSeq);
                offset += fragmentBytes.length;
                result.add(new Record(ContentType.HANDSHAKE, this.session.getWriteEpoch(), this.session.getSequenceNumber(), fragmentedMessage, this.session));
            }
        }
        return result;
    }

    protected final boolean processMessageNext(Record record) throws HandshakeException, GeneralSecurityException {
        int epoch = record.getEpoch();
        if (epoch < this.session.getReadEpoch()) {
            LOGGER.log(Level.FINER, "Discarding message from peer [{0}] from finished epoch [{1}] < current epoch [{2}]", new Object[]{this.getPeerAddress(), epoch, this.session.getReadEpoch()});
            return false;
        }
        if (epoch == this.session.getReadEpoch()) {
            DTLSMessage fragment = record.getFragment();
            switch (fragment.getContentType()) {
                case ALERT: 
                case CHANGE_CIPHER_SPEC: {
                    return true;
                }
                case HANDSHAKE: {
                    int messageSeq = ((HandshakeMessage)fragment).getMessageSeq();
                    if (messageSeq == this.nextReceiveSeq) {
                        if (!(fragment instanceof FragmentedHandshakeMessage)) {
                            this.incrementNextReceiveSeq();
                        }
                        return true;
                    }
                    if (messageSeq > this.nextReceiveSeq) {
                        LOGGER.log(Level.FINER, "Queued newer message from same epoch, message_seq [{0}] > next_receive_seq [{1}]", new Object[]{messageSeq, this.nextReceiveSeq});
                        this.queuedMessages.add(record);
                        return false;
                    }
                    LOGGER.log(Level.FINER, "Discarding old message, message_seq [{0}] < next_receive_seq [{1}]", new Object[]{messageSeq, this.nextReceiveSeq});
                    return false;
                }
            }
            LOGGER.log(Level.FINER, "Cannot process HANDSHAKE message of unknown type");
            return false;
        }
        this.queuedMessages.add(record);
        LOGGER.log(Level.FINER, "Queueing HANDSHAKE message from future epoch [{0}] > current epoch [{1}]", new Object[]{epoch, this.getSession().getReadEpoch()});
        return false;
    }

    protected final HandshakeMessage handleFragmentation(FragmentedHandshakeMessage fragment) throws HandshakeException {
        HandshakeMessage reassembledMessage = null;
        int messageSeq = fragment.getMessageSeq();
        if (this.fragmentedMessages.get(messageSeq) == null) {
            this.fragmentedMessages.put(messageSeq, new ArrayList());
        }
        this.fragmentedMessages.get(messageSeq).add(fragment);
        reassembledMessage = this.reassembleFragments(messageSeq, fragment.getMessageLength(), fragment.getMessageType(), this.session);
        if (reassembledMessage != null) {
            this.incrementNextReceiveSeq();
            this.fragmentedMessages.remove(messageSeq);
        }
        return reassembledMessage;
    }

    protected final HandshakeMessage reassembleFragments(int messageSeq, int totalLength, HandshakeType type, DTLSSession session) throws HandshakeException {
        List<FragmentedHandshakeMessage> fragments = this.fragmentedMessages.get(messageSeq);
        HandshakeMessage message = null;
        Collections.sort(fragments, new Comparator<FragmentedHandshakeMessage>(){

            @Override
            public int compare(FragmentedHandshakeMessage o1, FragmentedHandshakeMessage o2) {
                if (o1.getFragmentOffset() == o2.getFragmentOffset()) {
                    return 0;
                }
                if (o1.getFragmentOffset() < o2.getFragmentOffset()) {
                    return -1;
                }
                return 1;
            }
        });
        byte[] reassembly = new byte[]{};
        int offset = 0;
        for (FragmentedHandshakeMessage fragmentedHandshakeMessage : fragments) {
            int fragmentOffset = fragmentedHandshakeMessage.getFragmentOffset();
            int fragmentLength = fragmentedHandshakeMessage.getFragmentLength();
            if (fragmentOffset == offset) {
                reassembly = ByteArrayUtils.concatenate(reassembly, fragmentedHandshakeMessage.fragmentToByteArray());
                offset = reassembly.length;
                continue;
            }
            if (fragmentOffset >= offset || fragmentOffset + fragmentLength <= offset) continue;
            int newOffset = offset - fragmentOffset;
            int newLength = fragmentLength - newOffset;
            byte[] newBytes = new byte[newLength];
            System.arraycopy(fragmentedHandshakeMessage.fragmentToByteArray(), newOffset, newBytes, 0, newLength);
            reassembly = ByteArrayUtils.concatenate(reassembly, newBytes);
            offset = reassembly.length;
        }
        if (reassembly.length == totalLength) {
            FragmentedHandshakeMessage wholeMessage = new FragmentedHandshakeMessage(type, totalLength, messageSeq, 0, reassembly);
            reassembly = wholeMessage.toByteArray();
            CipherSuite.KeyExchangeAlgorithm keyExchangeAlgorithm = CipherSuite.KeyExchangeAlgorithm.NULL;
            boolean receiveRawPublicKey = false;
            if (session != null) {
                keyExchangeAlgorithm = session.getKeyExchange();
                receiveRawPublicKey = session.receiveRawPublicKey();
            }
            message = HandshakeMessage.fromByteArray(reassembly, keyExchangeAlgorithm, receiveRawPublicKey);
        }
        return message;
    }

    final CipherSuite getCipherSuite() {
        return this.cipherSuite;
    }

    protected final void setCipherSuite(CipherSuite cipherSuite) throws HandshakeException {
        if (cipherSuite == null || CipherSuite.TLS_NULL_WITH_NULL_NULL == cipherSuite) {
            throw new HandshakeException("Negotiated cipher suite must not be null", new AlertMessage(AlertMessage.AlertLevel.FATAL, AlertMessage.AlertDescription.HANDSHAKE_FAILURE));
        }
        this.cipherSuite = cipherSuite;
        this.keyExchange = cipherSuite.getKeyExchange();
        this.session.setCipherSuite(cipherSuite);
    }

    final byte[] getMasterSecret() {
        return this.masterSecret;
    }

    final SecretKey getClientWriteMACKey() {
        return this.clientWriteMACKey;
    }

    final SecretKey getServerWriteMACKey() {
        return this.serverWriteMACKey;
    }

    final IvParameterSpec getClientWriteIV() {
        return this.clientWriteIV;
    }

    final IvParameterSpec getServerWriteIV() {
        return this.serverWriteIV;
    }

    final SecretKey getClientWriteKey() {
        return this.clientWriteKey;
    }

    final SecretKey getServerWriteKey() {
        return this.serverWriteKey;
    }

    final DTLSSession getSession() {
        return this.session;
    }

    public final InetSocketAddress getPeerAddress() {
        return this.session.getPeer();
    }

    private void setSequenceNumber(HandshakeMessage message) {
        message.setMessageSeq(this.sequenceNumber);
        ++this.sequenceNumber;
    }

    final int getNextReceiveSeq() {
        return this.nextReceiveSeq;
    }

    final void incrementNextReceiveSeq() {
        ++this.nextReceiveSeq;
    }

    final CompressionMethod getCompressionMethod() {
        return this.compressionMethod;
    }

    final void setCompressionMethod(CompressionMethod compressionMethod) {
        this.compressionMethod = compressionMethod;
        this.session.setCompressionMethod(compressionMethod);
    }

    final int getMaxFragmentLength() {
        return this.maxFragmentLength;
    }

    public final void setMaxFragmentLength(int maxFragmentLength) {
        this.maxFragmentLength = maxFragmentLength;
    }

    protected final void handshakeStarted() throws HandshakeException {
        if (this.sessionListener != null) {
            this.sessionListener.handshakeStarted(this);
        }
    }

    protected final void sessionEstablished() throws HandshakeException {
        if (this.sessionListener != null) {
            this.sessionListener.sessionEstablished(this, this.getSession());
        }
    }

    protected final void handshakeCompleted() {
        if (this.sessionListener != null) {
            this.sessionListener.handshakeCompleted(this.getPeerAddress());
        }
    }
}

