/*
 * Decompiled with CFR 0.152.
 */
package convex.core.data;

import convex.core.data.ACell;
import convex.core.data.AHashMap;
import convex.core.data.AHashSet;
import convex.core.data.ASet;
import convex.core.data.Format;
import convex.core.data.Hash;
import convex.core.data.IRefFunction;
import convex.core.data.Ref;
import convex.core.data.SetLeaf;
import convex.core.data.Sets;
import convex.core.exceptions.BadFormatException;
import convex.core.exceptions.InvalidDataException;
import convex.core.util.Bits;
import convex.core.util.Utils;
import java.nio.ByteBuffer;

public class SetTree<T extends ACell>
extends AHashSet<T> {
    private final Ref<AHashSet<T>>[] children;
    private final int shift;
    private final short mask;
    public static int MAX_ENCODING_LENGTH = 2244;

    private SetTree(Ref<AHashSet<T>>[] blocks, int shift, short mask, long count) {
        super(count);
        this.children = blocks;
        this.shift = shift;
        this.mask = mask;
    }

    private static <T extends ACell> long computeCount(Ref<AHashSet<T>>[] children) {
        long n = 0L;
        for (Ref<AHashSet<AHashSet<T>>> ref : children) {
            if (ref == null) continue;
            ASet m = ref.getValue();
            n += m.count();
        }
        return n;
    }

    public static <V extends ACell> SetTree<V> create(Ref<V>[] newEntries, int shift) {
        int n = newEntries.length;
        if (n <= 16) {
            throw new IllegalArgumentException("Insufficient distinct entries for TreeMap construction: " + newEntries.length);
        }
        Ref[] children = new Ref[16];
        for (int i = 0; i < n; ++i) {
            Ref<V> e = newEntries[i];
            int ix = e.getHash().getHexDigit(shift);
            Ref ref = children[ix];
            if (ref == null) {
                children[ix] = SetLeaf.create(new Ref[]{e}).getRef();
                continue;
            }
            AHashSet<V> newChild = ((AHashSet)ref.getValue()).includeRef(e, shift + 1);
            children[ix] = newChild.getRef();
        }
        return (SetTree)SetTree.createFull(children, shift);
    }

    private static <T extends ACell> AHashSet<T> createFull(Ref<AHashSet<T>>[] children, int shift, long count) {
        if (children.length != 16) {
            throw new IllegalArgumentException("16 children required!");
        }
        Ref[] newChildren = Utils.filterArray(children, a -> {
            if (a == null) {
                return false;
            }
            AHashSet m = (AHashSet)a.getValue();
            return m != null && !m.isEmpty();
        });
        if (children != newChildren) {
            return SetTree.create(newChildren, shift, Utils.computeMask(children, newChildren), count);
        }
        return SetTree.create(children, shift, (short)-1, count);
    }

    private static <T extends ACell> AHashSet<T> createFull(Ref<AHashSet<T>>[] newChildren, int shift) {
        long count = SetTree.computeCount(newChildren);
        return SetTree.createFull(newChildren, shift, count);
    }

    private static <V extends ACell> AHashSet<V> create(Ref<AHashSet<V>>[] children, int shift, short mask, long count) {
        int cLen = children.length;
        if (Integer.bitCount(mask & 0xFFFF) != cLen) {
            throw new IllegalArgumentException("Invalid child array length " + cLen + " for bit mask " + Utils.toHexString(mask));
        }
        if (count <= 16L) {
            Ref[] entries = new Ref[Utils.checkedInt(count)];
            int ix = 0;
            for (Ref<AHashSet<V>> childRef : children) {
                AHashSet<V> child = childRef.getValue();
                long cc = child.count();
                for (long i = 0L; i < cc; ++i) {
                    entries[ix++] = child.getElementRef(i);
                }
            }
            assert ((long)ix == count);
            return SetLeaf.create(entries);
        }
        int sel = (1 << cLen) - 1;
        short newMask = mask;
        for (int i = 0; i < cLen; ++i) {
            AHashSet<V> child = children[i].getValue();
            if (!child.isEmpty()) continue;
            newMask = (short)(newMask & ~(1 << SetTree.digitForIndex(i, mask)));
            sel &= ~(1 << i);
        }
        if (mask != newMask) {
            return new SetTree<V>(Utils.filterSmallArray(children, sel), shift, newMask, count);
        }
        return new SetTree<V>(children, shift, mask, count);
    }

    @Override
    public Ref<T> getElementRef(long i) {
        long pos = i;
        for (Ref<AHashSet<AHashSet<T>>> ref : this.children) {
            AHashSet<T> child = ref.getValue();
            long cc = child.count();
            if (pos < cc) {
                return child.getElementRef(pos);
            }
            pos -= cc;
        }
        throw new IndexOutOfBoundsException("Entry index: " + i);
    }

    @Override
    protected Ref<T> getRefByHash(Hash hash) {
        int digit = Utils.extractDigit(hash, this.shift);
        int i = Bits.indexForDigit(digit, this.mask);
        if (i < 0) {
            return null;
        }
        return this.children[i].getValue().getRefByHash(hash);
    }

    @Override
    public AHashSet<T> exclude(ACell key) {
        return this.excludeRef(Ref.get(key));
    }

    @Override
    public AHashSet<T> excludeRef(Ref<T> keyRef) {
        AHashSet<T> newChild;
        int digit = Utils.extractDigit(keyRef.getHash(), this.shift);
        int i = Bits.indexForDigit(digit, this.mask);
        if (i < 0) {
            return this;
        }
        AHashSet<T> child = this.children[i].getValue();
        if (child == (newChild = child.excludeRef(keyRef))) {
            return this;
        }
        AHashSet<T> result = newChild.isEmpty() ? this.dissocChild(i) : this.replaceChild(i, newChild.getRef());
        return result.toCanonical();
    }

    @Override
    public AHashSet<T> toCanonical() {
        if (this.count > 16L) {
            return this;
        }
        int n = Utils.checkedInt(this.count);
        Ref[] newEntries = new Ref[n];
        for (int i = 0; i < n; ++i) {
            newEntries[i] = this.getElementRef(i);
        }
        return new SetLeaf(newEntries);
    }

    private AHashSet<T> dissocChild(int i) {
        int bsize = this.children.length;
        AHashSet<T> child = this.children[i].getValue();
        Ref[] newBlocks = new Ref[bsize - 1];
        System.arraycopy(this.children, 0, newBlocks, 0, i);
        System.arraycopy(this.children, i + 1, newBlocks, i, bsize - i - 1);
        short newMask = (short)(this.mask & ~(1 << SetTree.digitForIndex(i, this.mask)));
        long newCount = this.count - child.count();
        return SetTree.create(newBlocks, this.shift, newMask, newCount);
    }

    private SetTree<T> insertChild(int digit, Ref<AHashSet<T>> newChild) {
        int bsize = this.children.length;
        int i = Bits.positionForDigit(digit, this.mask);
        short newMask = (short)(this.mask | 1 << digit);
        if (this.mask == newMask) {
            throw new Error("Digit already present!");
        }
        Ref[] newChildren = new Ref[bsize + 1];
        System.arraycopy(this.children, 0, newChildren, 0, i);
        System.arraycopy(this.children, i, newChildren, i + 1, bsize - i);
        newChildren[i] = newChild;
        long newCount = this.count + newChild.getValue().count();
        return (SetTree)SetTree.create(newChildren, this.shift, newMask, newCount);
    }

    private AHashSet<T> replaceChild(int i, Ref<AHashSet<T>> newChild) {
        if (this.children[i] == newChild) {
            return this;
        }
        AHashSet<T> oldChild = this.children[i].getValue();
        Ref[] newChildren = (Ref[])this.children.clone();
        newChildren[i] = newChild;
        long newCount = this.count + newChild.getValue().count() - oldChild.count();
        return SetTree.create(newChildren, this.shift, this.mask, newCount);
    }

    public static int digitForIndex(int index, short mask) {
        int found = 0;
        for (int i = 0; i < 16; ++i) {
            if ((mask & 1 << i) == 0 || found++ != index) continue;
            return i;
        }
        throw new IllegalArgumentException("Index " + index + " not available in mask map: " + Utils.toHexString(mask));
    }

    @Override
    public SetTree<T> include(ACell value) {
        Ref<ACell> keyRef = Ref.get(value);
        return this.includeRef((Ref)keyRef, this.shift);
    }

    @Override
    protected SetTree<T> includeRef(Ref<T> e, int shift) {
        AHashSet<T> newChild;
        if (this.shift != shift) {
            throw new Error("Invalid shift!");
        }
        Ref<T> keyRef = e;
        int digit = Utils.extractDigit(keyRef.getHash(), shift);
        int i = Bits.indexForDigit(digit, this.mask);
        if (i < 0) {
            SetLeaf<Ref[]> newChild2 = SetLeaf.create(new Ref[]{e});
            return this.insertChild(digit, newChild2.getRef());
        }
        AHashSet<T> child = this.children[i].getValue();
        if (child == (newChild = child.includeRef(e, shift + 1))) {
            return this;
        }
        return (SetTree)this.replaceChild(i, newChild.getRef());
    }

    @Override
    public AHashSet<T> includeRef(Ref<T> ref) {
        return this.includeRef((Ref)ref, this.shift);
    }

    @Override
    public int encode(byte[] bs, int pos) {
        bs[pos++] = -125;
        return this.encodeRaw(bs, pos);
    }

    @Override
    public int encodeRaw(byte[] bs, int pos) {
        pos = Format.writeVLCLong(bs, pos, this.count);
        bs[pos++] = (byte)this.shift;
        pos = Utils.writeShort(bs, pos, this.mask);
        int ilength = this.children.length;
        for (int i = 0; i < ilength; ++i) {
            pos = this.children[i].encode(bs, pos);
        }
        return pos;
    }

    @Override
    public int estimatedEncodingSize() {
        return 4 + 140 * this.children.length;
    }

    public static <V extends ACell> SetTree<V> read(ByteBuffer bb, long count) throws BadFormatException {
        byte shift = bb.get();
        short mask = bb.getShort();
        int ilength = Integer.bitCount(mask & 0xFFFF);
        Ref[] blocks = new Ref[ilength];
        for (int i = 0; i < ilength; ++i) {
            Ref ref;
            blocks[i] = ref = Format.readRef(bb);
        }
        SetTree result = new SetTree(blocks, shift, mask, count);
        if (!result.isValidStructure()) {
            throw new BadFormatException("Problem with TreeMap invariants");
        }
        return result;
    }

    @Override
    public boolean isCanonical() {
        return this.count > 8L;
    }

    @Override
    public final boolean isCVMValue() {
        return this.shift == 0;
    }

    @Override
    public int getRefCount() {
        return this.children.length;
    }

    @Override
    public <R extends ACell> Ref<R> getRef(int i) {
        return this.children[i];
    }

    @Override
    public SetTree<T> updateRefs(IRefFunction func) {
        int n = this.children.length;
        if (n == 0) {
            return this;
        }
        Ref<AHashSet<T>>[] newChildren = this.children;
        for (int i = 0; i < n; ++i) {
            Ref<AHashSet<AHashSet<T>>> child = this.children[i];
            Ref<?> newChild = func.apply(child);
            if (child == newChild) continue;
            if (this.children == newChildren) {
                newChildren = (Ref[])this.children.clone();
            }
            newChildren[i] = newChild;
        }
        if (newChildren == this.children) {
            return this;
        }
        return new SetTree<T>(newChildren, this.shift, this.mask, this.count);
    }

    @Override
    public AHashSet<T> mergeWith(AHashSet<T> b, int setOp) {
        return this.mergeWith(b, setOp, this.shift);
    }

    @Override
    protected AHashSet<T> mergeWith(AHashSet<T> b, int setOp, int shift) {
        if (b instanceof SetTree) {
            SetTree bt = (SetTree)b;
            if (this.shift != bt.shift) {
                throw new Error("Misaligned shifts!");
            }
            return this.mergeWith(bt, setOp, shift);
        }
        if (b instanceof SetLeaf) {
            return this.mergeWith((SetLeaf)b, setOp, shift);
        }
        throw new Error("Unrecognised map type: " + b.getClass());
    }

    @Override
    private AHashSet<T> mergeWith(SetTree<T> b, int setOp, int shift) {
        assert (b.shift == shift);
        int fullMask = this.mask | b.mask;
        Ref[] newChildren = null;
        for (int digit = 0; digit < 16; ++digit) {
            AHashSet<T> bc;
            AHashSet<T> rc;
            int bitMask = 1 << digit;
            if ((fullMask & bitMask) == 0) continue;
            AHashSet<T> ac = this.childForDigit(digit).getValue();
            if (ac != (rc = ac.mergeWith(bc = b.childForDigit(digit).getValue(), setOp, shift + 1)) && newChildren == null) {
                newChildren = new Ref[16];
                for (int ii = 0; ii < digit; ++ii) {
                    int chi = Bits.indexForDigit(ii, this.mask);
                    if (chi < 0) continue;
                    newChildren[ii] = this.children[chi];
                }
            }
            if (newChildren == null) continue;
            newChildren[digit] = rc.getRef();
        }
        if (newChildren == null) {
            return this;
        }
        return SetTree.createFull(newChildren, shift);
    }

    @Override
    private AHashSet<T> mergeWith(SetLeaf<T> b, int setOp, int shift) {
        Ref[] newChildren = null;
        int ix = 0;
        for (int i = 0; i < 16; ++i) {
            SetLeaf<T> bSubset;
            AHashSet<T> newChild;
            Ref<AHashSet<T>> cref;
            AHashSet<T> child;
            int imask = 1 << i;
            if ((this.mask & imask) == 0) continue;
            if ((child = (cref = this.children[ix++]).getValue()) != (newChild = child.mergeWith(bSubset = b.filterHexDigits(shift, imask), setOp, shift + 1)) && newChildren == null) {
                newChildren = new Ref[16];
                for (int ii = 0; ii < this.children.length; ++ii) {
                    int chi = SetTree.digitForIndex(ii, this.mask);
                    newChildren[chi] = this.children[ii];
                }
            }
            if (newChildren == null) continue;
            newChildren[i] = newChild.getRef();
        }
        assert (ix == this.children.length);
        AHashSet result = newChildren == null ? this : SetTree.createFull(newChildren, shift);
        SetLeaf<T> extras = b.filterHexDigits(shift, ~this.mask);
        int en = extras.size();
        for (int i = 0; i < en; ++i) {
            Ref<T> e = extras.getRef(i);
            Ref<T> newE = this.applyOp(setOp, null, e);
            if (newE == null) continue;
            result = ((AHashSet)result).includeRef(newE, shift);
        }
        return result;
    }

    private Ref<AHashSet<T>> childForDigit(int digit) {
        int ix = Bits.indexForDigit(digit, this.mask);
        if (ix < 0) {
            return Sets.emptyRef();
        }
        return this.children[ix];
    }

    @Override
    public boolean equals(ASet<T> a) {
        if (!(a instanceof SetTree)) {
            return false;
        }
        return this.equals((SetTree)a);
    }

    @Override
    boolean equals(SetTree<T> b) {
        if (this == b) {
            return true;
        }
        long n = this.count;
        if (n != b.count) {
            return false;
        }
        if (this.mask != b.mask) {
            return false;
        }
        if (this.shift != b.shift) {
            return false;
        }
        return this.getHash().equals(b.getHash());
    }

    @Override
    public void validate() throws InvalidDataException {
        Hash firstHash;
        super.validate();
        if (this.mask == 0) {
            throw new InvalidDataException("TreeMap must have children!", this);
        }
        if (this.shift < 0 || this.shift > 63) {
            throw new InvalidDataException("Invalid shift for SetTree", this);
        }
        if (this.count <= 16L) {
            throw new InvalidDataException("Count too small [" + this.count + "] for SetTree", this);
        }
        try {
            firstHash = this.getElementRef(0L).getHash();
        }
        catch (ClassCastException e) {
            throw new InvalidDataException("Bad child type:" + e.getMessage(), this);
        }
        int bsize = this.children.length;
        long childCount = 0L;
        for (int i = 0; i < bsize; ++i) {
            Hash childHash;
            long pmatch;
            if (this.children[i] == null) {
                throw new InvalidDataException("Null child ref at index " + i, this);
            }
            AHashSet<T> o = this.children[i].getValue();
            if (!(o instanceof AHashMap)) {
                throw new InvalidDataException("Expected AHashSet child at index " + i + " but got " + Utils.getClassName(o), this);
            }
            AHashSet<T> child = o;
            if (child.isEmpty()) {
                throw new InvalidDataException("Empty child at index " + i, this);
            }
            if (child instanceof SetTree) {
                SetTree childTree = (SetTree)child;
                int expectedShift = this.shift + 1;
                if (childTree.shift != expectedShift) {
                    throw new InvalidDataException("Wrong child shift [" + childTree.shift + "], expected [" + expectedShift + "]", this);
                }
            }
            if ((pmatch = firstHash.commonHexPrefixLength(childHash = child.getElementRef(0L).getHash())) < (long)this.shift) {
                throw new InvalidDataException("Mismatched child hash [" + childHash + "] with this [" + firstHash + "]", this);
            }
            int d = SetTree.digitForIndex(i, this.mask);
            child.validateWithPrefix(firstHash, d, this.shift + 1);
            childCount += child.count();
        }
        if (this.count != childCount) {
            throw new InvalidDataException("Bad child count, expected " + this.count + " but children had: " + childCount, this);
        }
    }

    @Override
    protected void validateWithPrefix(Hash base, int digit, int shift) {
    }

    private boolean isValidStructure() {
        if (this.count <= 8L) {
            return false;
        }
        if (this.children.length != Integer.bitCount(this.mask & 0xFFFF)) {
            return false;
        }
        for (int i = 0; i < this.children.length; ++i) {
            if (this.children[i] != null) continue;
            return false;
        }
        return true;
    }

    @Override
    public void validateCell() throws InvalidDataException {
        if (!this.isValidStructure()) {
            throw new InvalidDataException("Bad structure", this);
        }
    }

    @Override
    public boolean containsAll(ASet<T> b) {
        if (b instanceof SetTree) {
            return this.containsAll((SetTree)b);
        }
        long n = b.count;
        for (long i = 0L; i < n; ++i) {
            Ref me = b.getElementRef(i);
            if (this.containsHash(me.getHash())) continue;
            return false;
        }
        return true;
    }

    @Override
    protected boolean containsAll(SetTree<T> map) {
        if ((this.mask | map.mask) != this.mask) {
            return false;
        }
        for (int i = 0; i < 16; ++i) {
            Ref<AHashSet<T>> mchild;
            Ref<AHashSet<T>> child = this.childForDigit(i);
            if (child == null || (mchild = map.childForDigit(i)) == null || child.getValue().containsAll(mchild.getValue())) continue;
            return false;
        }
        return true;
    }

    @Override
    public Ref<T> getValueRef(ACell k) {
        return this.getRefByHash(Hash.compute(k));
    }

    @Override
    protected <R> void copyToArray(R[] arr, int offset) {
        for (int i = 0; i < this.children.length; ++i) {
            AHashSet<T> child = this.children[i].getValue();
            child.copyToArray(arr, offset);
            offset = Utils.checkedInt((long)offset + child.count());
        }
    }

    @Override
    public boolean containsHash(Hash hash) {
        return this.getRefByHash(hash) != null;
    }
}

