/*
 * Decompiled with CFR 0.152.
 */
package org.bouncycastle.mls;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.security.InvalidParameterException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.bouncycastle.mls.KeyGeneration;
import org.bouncycastle.mls.TreeKEM.LeafIndex;
import org.bouncycastle.mls.TreeKEM.NodeIndex;
import org.bouncycastle.mls.TreeSize;
import org.bouncycastle.mls.codec.ContentType;
import org.bouncycastle.mls.crypto.MlsCipherSuite;
import org.bouncycastle.mls.crypto.Secret;

public class GroupKeySet {
    final MlsCipherSuite suite;
    final int secretSize;
    final Secret encryptionSecretCommit;
    public SecretTree secretTree;
    Map<LeafIndex, HashRatchet> handshakeRatchets;
    Map<LeafIndex, HashRatchet> applicationRatchets;

    public GroupKeySet(MlsCipherSuite suite, TreeSize treeSize, Secret encryptionSecret) throws IOException, IllegalAccessException {
        this.suite = suite;
        this.secretSize = suite.getKDF().getHashLength();
        this.encryptionSecretCommit = encryptionSecret.deriveSecret(suite, "commitment");
        this.secretTree = new SecretTree(treeSize, encryptionSecret);
        this.handshakeRatchets = new HashMap<LeafIndex, HashRatchet>();
        this.applicationRatchets = new HashMap<LeafIndex, HashRatchet>();
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        GroupKeySet that = (GroupKeySet)o;
        return this.secretSize == that.secretSize && this.suite.equals(that.suite) && this.encryptionSecretCommit.equals(that.encryptionSecretCommit);
    }

    void initRatchets(LeafIndex sender) throws IOException, IllegalAccessException {
        Secret leafSecret = this.secretTree.get(sender);
        Secret handshakeRatchetSecret = leafSecret.expandWithLabel(this.suite, "handshake", new byte[0], this.secretSize);
        Secret applicationRatchetSecret = leafSecret.expandWithLabel(this.suite, "application", new byte[0], this.secretSize);
        HashRatchet handshakeRatchet = new HashRatchet(handshakeRatchetSecret);
        HashRatchet applicationRatchet = new HashRatchet(applicationRatchetSecret);
        this.handshakeRatchets.put(sender, handshakeRatchet);
        this.applicationRatchets.put(sender, applicationRatchet);
    }

    public KeyGeneration get(ContentType contentType, LeafIndex sender, int generation, byte[] reuseGuard) throws IOException, IllegalAccessException {
        HashRatchet chain;
        switch (contentType) {
            case APPLICATION: {
                chain = this.applicationRatchet(sender);
                break;
            }
            case PROPOSAL: 
            case COMMIT: {
                chain = this.handshakeRatchet(sender);
                break;
            }
            default: {
                return null;
            }
        }
        KeyGeneration keys = chain.get(generation);
        this.ApplyReuseGuard(reuseGuard, keys.nonce);
        return keys;
    }

    public KeyGeneration get(ContentType contentType, LeafIndex sender, byte[] reuseGuard) throws IOException, IllegalAccessException {
        HashRatchet chain;
        switch (contentType) {
            case APPLICATION: {
                chain = this.applicationRatchet(sender);
                break;
            }
            case PROPOSAL: 
            case COMMIT: {
                chain = this.handshakeRatchet(sender);
                break;
            }
            default: {
                return null;
            }
        }
        KeyGeneration keys = chain.next();
        this.ApplyReuseGuard(reuseGuard, keys.nonce);
        return keys;
    }

    private void ApplyReuseGuard(byte[] guard, byte[] nonce) {
        for (int i = 0; i < guard.length; ++i) {
            int n = i;
            nonce[n] = (byte)(nonce[n] ^ guard[i]);
        }
    }

    public void erase(ContentType contentType, LeafIndex sender, int generation) throws IOException, IllegalAccessException {
        switch (contentType) {
            case APPLICATION: {
                this.applicationRatchet(sender).erase(generation);
                break;
            }
            case PROPOSAL: 
            case COMMIT: {
                this.handshakeRatchet(sender).erase(generation);
            }
        }
    }

    public HashRatchet handshakeRatchet(LeafIndex sender) throws IOException, IllegalAccessException {
        if (!this.handshakeRatchets.containsKey(sender)) {
            this.initRatchets(sender);
        }
        return this.handshakeRatchets.get(sender);
    }

    public HashRatchet applicationRatchet(LeafIndex sender) throws IOException, IllegalAccessException {
        if (!this.applicationRatchets.containsKey(sender)) {
            this.initRatchets(sender);
        }
        return this.applicationRatchets.get(sender);
    }

    public boolean hasLeaf(LeafIndex sender) {
        return this.secretTree.hasLeaf(sender);
    }

    public class HashRatchet {
        final int keySize;
        final int nonceSize;
        Secret nextSecret;
        int nextGeneration;
        Map<Integer, KeyGeneration> cache;

        HashRatchet(Secret baseSecret) {
            this.keySize = GroupKeySet.this.suite.getAEAD().getKeySize();
            this.nonceSize = GroupKeySet.this.suite.getAEAD().getNonceSize();
            this.nextGeneration = 0;
            this.nextSecret = baseSecret;
            this.cache = new HashMap<Integer, KeyGeneration>();
        }

        public KeyGeneration next() throws IOException, IllegalAccessException {
            Secret key = this.nextSecret.deriveTreeSecret(GroupKeySet.this.suite, "key", this.nextGeneration, this.keySize);
            Secret nonce = this.nextSecret.deriveTreeSecret(GroupKeySet.this.suite, "nonce", this.nextGeneration, this.nonceSize);
            Secret secret = this.nextSecret.deriveTreeSecret(GroupKeySet.this.suite, "secret", this.nextGeneration, GroupKeySet.this.secretSize);
            KeyGeneration generation = new KeyGeneration(this.nextGeneration, key, nonce);
            ++this.nextGeneration;
            this.nextSecret.consume();
            this.nextSecret = secret;
            this.cache.put(generation.generation, generation);
            return generation;
        }

        public KeyGeneration get(int generation) throws IOException, IllegalAccessException {
            if (this.cache.containsKey(generation)) {
                return this.cache.get(generation);
            }
            if (this.nextGeneration > generation) {
                throw new InvalidParameterException("Request for expired key");
            }
            while (this.nextGeneration < generation) {
                this.next();
            }
            return this.next();
        }

        public void erase(int generation) {
            if (this.cache.containsKey(generation)) {
                this.cache.get(generation).consume();
                this.cache.remove(generation);
            }
        }
    }

    public class SecretTree {
        final TreeSize treeSize;
        public Map<NodeIndex, Secret> secrets;

        public SecretTree(TreeSize treeSizeIn, Secret encryptionSecret) {
            this.treeSize = treeSizeIn;
            this.secrets = new HashMap<NodeIndex, Secret>();
            this.secrets.put(NodeIndex.root(this.treeSize), encryptionSecret);
        }

        protected boolean hasLeaf(LeafIndex sender) {
            return (long)sender.value() < this.treeSize.leafCount();
        }

        public Secret get(LeafIndex leaf) throws IOException, IllegalAccessException {
            int curr;
            byte[] leftLabel = "left".getBytes(StandardCharsets.UTF_8);
            byte[] rightLabel = "right".getBytes(StandardCharsets.UTF_8);
            NodeIndex rootNode = NodeIndex.root(this.treeSize);
            NodeIndex leafNode = new NodeIndex(leaf);
            List<NodeIndex> dirpath = leaf.directPath(this.treeSize);
            dirpath.add(0, leafNode);
            dirpath.add(rootNode);
            for (curr = 0; curr < dirpath.size() && !this.secrets.containsKey(dirpath.get(curr)); ++curr) {
            }
            if (curr > dirpath.size()) {
                throw new InvalidParameterException("No secret found to derive leaf key");
            }
            while (curr > 0) {
                NodeIndex currNode = dirpath.get(curr);
                NodeIndex left = currNode.left();
                NodeIndex right = currNode.right();
                Secret secret = this.secrets.get(currNode);
                this.secrets.put(left, secret.expandWithLabel(GroupKeySet.this.suite, "tree", leftLabel, GroupKeySet.this.secretSize));
                this.secrets.put(right, secret.expandWithLabel(GroupKeySet.this.suite, "tree", rightLabel, GroupKeySet.this.secretSize));
                --curr;
            }
            Secret leafSecret = this.secrets.get(leafNode);
            for (NodeIndex i : dirpath) {
                if (i.equals(leafNode) || !this.secrets.containsKey(i)) continue;
                this.secrets.get(i).consume();
                this.secrets.remove(i);
            }
            return leafSecret;
        }
    }
}

