/*
 * Decompiled with CFR 0.152.
 */
package org.eclipse.californium.scandium.dtls;

import java.net.InetSocketAddress;
import java.security.Principal;
import java.security.PublicKey;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.eclipse.californium.scandium.dtls.CompressionMethod;
import org.eclipse.californium.scandium.dtls.DTLSConnectionState;
import org.eclipse.californium.scandium.dtls.SessionId;
import org.eclipse.californium.scandium.dtls.cipher.CipherSuite;

public class DTLSSession {
    private static final Logger LOGGER = Logger.getLogger(DTLSSession.class.getName());
    private static final int RECEIVE_WINDOW_SIZE = 64;
    private static final long MAX_SEQUENCE_NO = 0xFFFFFFFFFFFFL;
    private InetSocketAddress peer = null;
    private SessionId sessionIdentifier = null;
    private Principal peerIdentity;
    private CompressionMethod compressionMethod;
    private CipherSuite cipherSuite;
    private byte[] masterSecret = null;
    private String pskIdentity;
    private PublicKey peerRawPublicKey;
    private boolean active = false;
    private boolean isClient;
    private DTLSConnectionState readState = new DTLSConnectionState();
    private DTLSConnectionState writeState = new DTLSConnectionState();
    private int readEpoch = 0;
    private int writeEpoch = 0;
    private Map<Integer, Long> sequenceNumbers = new HashMap<Integer, Long>();
    private CipherSuite.KeyExchangeAlgorithm keyExchange;
    private boolean sendRawPublicKey = false;
    private boolean receiveRawPublicKey = false;
    private volatile long receiveWindowUpperBoundary = 63L;
    private volatile long receiveWindowLowerBoundary = 0L;
    private volatile long receivedRecordsVector = 0L;

    public DTLSSession(InetSocketAddress peerAddress, boolean isClient) {
        this(peerAddress, isClient, 0L);
    }

    public DTLSSession(InetSocketAddress peerAddress, boolean isClient, long initialSequenceNo) {
        if (peerAddress == null) {
            throw new NullPointerException("Peer address must not be null");
        }
        if (initialSequenceNo < 0L || initialSequenceNo > 0xFFFFFFFFFFFFL) {
            throw new IllegalArgumentException("Initial sequence number must be greater than 0 and less than 2^48");
        }
        this.peer = peerAddress;
        this.isClient = isClient;
        this.cipherSuite = CipherSuite.TLS_NULL_WITH_NULL_NULL;
        this.compressionMethod = CompressionMethod.NULL;
        this.sequenceNumbers.put(0, initialSequenceNo);
        this.readState = new DTLSConnectionState();
        this.writeState = new DTLSConnectionState();
    }

    public SessionId getSessionIdentifier() {
        return this.sessionIdentifier;
    }

    final synchronized void setSessionIdentifier(SessionId sessionIdentifier) {
        this.sessionIdentifier = sessionIdentifier;
    }

    public final PublicKey getPeerRawPublicKey() {
        return this.peerRawPublicKey;
    }

    final synchronized void setPeerRawPublicKey(PublicKey key) {
        this.peerRawPublicKey = key;
    }

    public CompressionMethod getCompressionMethod() {
        return this.compressionMethod;
    }

    public void setCompressionMethod(CompressionMethod compressionMethod) {
        this.compressionMethod = compressionMethod;
    }

    final CipherSuite getCipherSuite() {
        return this.cipherSuite;
    }

    final synchronized void setCipherSuite(CipherSuite cipherSuite) {
        this.cipherSuite = cipherSuite;
        this.keyExchange = cipherSuite.getKeyExchange();
    }

    public final synchronized boolean isActive() {
        return this.active;
    }

    public final synchronized void setActive(boolean isActive) {
        this.active = isActive;
    }

    public final boolean isClient() {
        return this.isClient;
    }

    public final int getWriteEpoch() {
        return this.writeEpoch;
    }

    final synchronized void setWriteEpoch(int epoch) {
        if (epoch < 0) {
            throw new IllegalArgumentException("Write epoch must not be negative");
        }
        this.writeEpoch = epoch;
    }

    public final int getReadEpoch() {
        return this.readEpoch;
    }

    final synchronized void setReadEpoch(int epoch) {
        if (epoch < 0) {
            throw new IllegalArgumentException("Read epoch must not be negative");
        }
        this.resetReceiveWindow();
        this.readEpoch = epoch;
    }

    private synchronized void incrementReadEpoch() {
        this.resetReceiveWindow();
        ++this.readEpoch;
    }

    private synchronized void incrementWriteEpoch() {
        ++this.writeEpoch;
        this.sequenceNumbers.put(this.writeEpoch, 0L);
    }

    public final synchronized long getSequenceNumber() {
        return this.getSequenceNumber(this.writeEpoch);
    }

    public final synchronized long getSequenceNumber(int epoch) {
        long sequenceNumber = this.sequenceNumbers.get(epoch);
        if (sequenceNumber < 0xFFFFFFFFFFFFL) {
            this.sequenceNumbers.put(epoch, sequenceNumber + 1L);
            return sequenceNumber;
        }
        throw new IllegalStateException("Maximum sequence number for epoch has been reached");
    }

    final DTLSConnectionState getReadState() {
        return this.readState;
    }

    final synchronized void setReadState(DTLSConnectionState readState) {
        if (readState == null) {
            throw new NullPointerException("Read state must not be null");
        }
        this.readState = readState;
        this.incrementReadEpoch();
        LOGGER.log(Level.FINEST, "Setting current read state to\n{0}", this.writeState);
    }

    final DTLSConnectionState getWriteState() {
        return this.writeState;
    }

    final synchronized void setWriteState(DTLSConnectionState writeState) {
        if (writeState == null) {
            throw new NullPointerException("Write state must not be null");
        }
        this.writeState = writeState;
        this.incrementWriteEpoch();
        LOGGER.log(Level.FINEST, "Setting current write state to\n{0}", writeState);
    }

    final CipherSuite.KeyExchangeAlgorithm getKeyExchange() {
        return this.keyExchange;
    }

    final byte[] getMasterSecret() {
        return this.masterSecret;
    }

    void setMasterSecret(byte[] masterSecret) {
        if (this.masterSecret == null) {
            if (masterSecret == null) {
                throw new NullPointerException("Master secret must not be null");
            }
            if (masterSecret.length != 48) {
                throw new IllegalArgumentException(String.format("Master secret must consist of of exactly 48 bytes but has [%d] bytes", masterSecret.length));
            }
            this.masterSecret = masterSecret;
        }
    }

    final boolean sendRawPublicKey() {
        return this.sendRawPublicKey;
    }

    final synchronized void setSendRawPublicKey(boolean sendRawPublicKey) {
        this.sendRawPublicKey = sendRawPublicKey;
    }

    final boolean receiveRawPublicKey() {
        return this.receiveRawPublicKey;
    }

    final synchronized void setReceiveRawPublicKey(boolean receiveRawPublicKey) {
        this.receiveRawPublicKey = receiveRawPublicKey;
    }

    public InetSocketAddress getPeer() {
        return this.peer;
    }

    public final Principal getPeerIdentity() {
        return this.peerIdentity;
    }

    final synchronized void setPeerIdentity(Principal peerIdentity) {
        if (peerIdentity == null) {
            throw new NullPointerException("Peer identity must not be null");
        }
        this.peerIdentity = peerIdentity;
    }

    public final String getPskIdentity() {
        return this.pskIdentity;
    }

    final synchronized void setPskIdentity(String pskIdentity) {
        this.pskIdentity = pskIdentity;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public final boolean isRecordProcessable(long epoch, long sequenceNo) {
        if (epoch < (long)this.getReadEpoch()) {
            return false;
        }
        if (epoch > (long)this.getReadEpoch()) {
            return false;
        }
        DTLSSession dTLSSession = this;
        synchronized (dTLSSession) {
            if (sequenceNo < this.receiveWindowLowerBoundary) {
                return false;
            }
            return !this.isDuplicate(sequenceNo);
        }
    }

    synchronized boolean isDuplicate(long sequenceNo) {
        if (sequenceNo > this.receiveWindowUpperBoundary) {
            return false;
        }
        long idx = sequenceNo - this.receiveWindowLowerBoundary;
        long bitMask = 1L << (int)idx;
        if (LOGGER.isLoggable(Level.FINER)) {
            LOGGER.log(Level.FINER, "Checking sequence no [{0}] using bit mask [{1}] against received records [{2}] with lower boundary [{3}]", new Object[]{sequenceNo, Long.toBinaryString(bitMask), Long.toBinaryString(this.receivedRecordsVector), this.receiveWindowLowerBoundary});
        }
        return (this.receivedRecordsVector & bitMask) == bitMask;
    }

    public final synchronized void markRecordAsRead(long epoch, long sequenceNo) {
        if (epoch == (long)this.getReadEpoch()) {
            if (sequenceNo > this.receiveWindowUpperBoundary) {
                long incr = sequenceNo - this.receiveWindowUpperBoundary;
                this.receiveWindowUpperBoundary = sequenceNo;
                this.receivedRecordsVector >>>= (int)incr;
                this.receiveWindowLowerBoundary = Math.max(0L, this.receiveWindowUpperBoundary - 64L + 1L);
            }
            long bitMask = 1L << (int)(sequenceNo - this.receiveWindowLowerBoundary);
            this.receivedRecordsVector |= bitMask;
            LOGGER.log(Level.FINER, "Updated receive window with sequence number [{0}]: new upper boundary [{1}], new bit vector [{2}]", new Object[]{sequenceNo, this.receiveWindowUpperBoundary, Long.toBinaryString(this.receivedRecordsVector)});
        }
    }

    private synchronized void resetReceiveWindow() {
        this.receivedRecordsVector = 0L;
        this.receiveWindowUpperBoundary = 63L;
        this.receiveWindowLowerBoundary = 0L;
    }
}

