/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.berkeley;

import java.io.Serializable;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.TreeSet;
import org.deeplearning4j.berkeley.Iterators;
import org.deeplearning4j.berkeley.MapFactory;
import org.deeplearning4j.berkeley.PriorityQueue;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;

@JsonIgnoreProperties(value={"mf"})
public class Counter<E>
implements Serializable {
    private static final long serialVersionUID = 1L;
    Map<E, Float> entries;
    boolean dirty = true;
    float cacheTotal = 0.0f;
    MapFactory<E, Float> mf;
    float deflt = 0.0f;

    public float getDeflt() {
        return this.deflt;
    }

    public void setDeflt(float deflt) {
        this.deflt = deflt;
    }

    public Set<E> keySet() {
        return this.entries.keySet();
    }

    public Set<Map.Entry<E, Float>> entrySet() {
        return this.entries.entrySet();
    }

    public int size() {
        return this.entries.size();
    }

    public boolean isEmpty() {
        return this.size() == 0;
    }

    public boolean containsKey(E key) {
        return this.entries.containsKey(key);
    }

    public float getCount(E key) {
        Float value = this.entries.get(key);
        if (value == null) {
            return this.deflt;
        }
        return value.floatValue();
    }

    public float getProbability(E key) {
        float count = this.getCount(key);
        float total = this.totalCount();
        if ((double)total < 0.0) {
            throw new RuntimeException("Can't call getProbability() with totalCount < 0.0");
        }
        return total > 0.0f ? count / total : 0.0f;
    }

    public void normalize() {
        float totalCount = this.totalCount();
        for (E key : this.keySet()) {
            this.setCount(key, this.getCount(key) / totalCount);
        }
        this.dirty = true;
    }

    public void setCount(E key, float count) {
        this.entries.put(key, Float.valueOf(count));
        this.dirty = true;
    }

    public void put(E key, float count, boolean keepHigher) {
        if (keepHigher && this.entries.containsKey(key)) {
            float oldCount = this.entries.get(key).floatValue();
            if (count > oldCount) {
                this.entries.put(key, Float.valueOf(count));
            }
        } else {
            this.entries.put(key, Float.valueOf(count));
        }
        this.dirty = true;
    }

    public E sample(Random rand) {
        float total = this.totalCount();
        if ((double)total <= 0.0) {
            throw new RuntimeException(String.format("Attempting to sample() with totalCount() %.3f%n", Float.valueOf(total)));
        }
        float sum = 0.0f;
        float r = rand.nextFloat();
        for (Map.Entry<E, Float> entry : this.entries.entrySet()) {
            float count = entry.getValue().floatValue();
            float frac = count / total;
            if (!(r < (sum += frac))) continue;
            return entry.getKey();
        }
        throw new IllegalStateException("Shoudl've have returned a sample by now....");
    }

    public E sample() {
        return this.sample(new Random());
    }

    public void removeKey(E key) {
        this.setCount(key, 0.0f);
        this.dirty = true;
        this.removeKeyFromEntries(key);
    }

    protected void removeKeyFromEntries(E key) {
        this.entries.remove(key);
    }

    public void setMaxCount(E key, float val) {
        Float value = this.entries.get(key);
        if (value == null || val > value.floatValue()) {
            this.setCount(key, val);
            this.dirty = true;
        }
    }

    public void setMinCount(E key, float val) {
        Float value = this.entries.get(key);
        if (value == null || val < value.floatValue()) {
            this.setCount(key, val);
            this.dirty = true;
        }
    }

    public float incrementCount(E key, float increment) {
        float newVal = this.getCount(key) + increment;
        this.setCount(key, newVal);
        this.dirty = true;
        return newVal;
    }

    public void incrementAll(Collection<? extends E> collection, float count) {
        for (E key : collection) {
            this.incrementCount(key, count);
        }
        this.dirty = true;
    }

    public <T extends E> void incrementAll(Counter<T> counter) {
        for (T key : counter.keySet()) {
            float count = counter.getCount(key);
            this.incrementCount(key, count);
        }
        this.dirty = true;
    }

    public float totalCount() {
        if (!this.dirty) {
            return this.cacheTotal;
        }
        float total = 0.0f;
        for (Map.Entry<E, Float> entry : this.entries.entrySet()) {
            total += entry.getValue().floatValue();
        }
        this.cacheTotal = total;
        this.dirty = false;
        return total;
    }

    public List<E> getSortedKeys() {
        PriorityQueue<E> pq = this.asPriorityQueue();
        ArrayList<E> keys = new ArrayList<E>();
        while (pq.hasNext()) {
            keys.add(pq.next());
        }
        return keys;
    }

    public E argMax() {
        float maxCount = Float.NEGATIVE_INFINITY;
        E maxKey = null;
        for (Map.Entry<E, Float> entry : this.entries.entrySet()) {
            if (!(entry.getValue().floatValue() > maxCount) && maxKey != null) continue;
            maxKey = entry.getKey();
            maxCount = entry.getValue().floatValue();
        }
        return maxKey;
    }

    public float min() {
        return this.maxMinHelp(false);
    }

    public float max() {
        return this.maxMinHelp(true);
    }

    private float maxMinHelp(boolean max) {
        float maxCount = max ? Float.NEGATIVE_INFINITY : Float.POSITIVE_INFINITY;
        for (Map.Entry<E, Float> entry : this.entries.entrySet()) {
            if (!(max && entry.getValue().floatValue() > maxCount) && (max || !(entry.getValue().floatValue() < maxCount))) continue;
            maxCount = entry.getValue().floatValue();
        }
        return maxCount;
    }

    public String toString() {
        return this.toString(this.keySet().size());
    }

    public String toStringSortedByKeys() {
        StringBuilder sb = new StringBuilder("[");
        NumberFormat f = NumberFormat.getInstance();
        f.setMaximumFractionDigits(5);
        int numKeysPrinted = 0;
        for (E element : new TreeSet<E>(this.keySet())) {
            sb.append(element.toString());
            sb.append(" : ");
            sb.append(f.format(this.getCount(element)));
            if (numKeysPrinted < this.size() - 1) {
                sb.append(", ");
            }
            ++numKeysPrinted;
        }
        if (numKeysPrinted < this.size()) {
            sb.append("...");
        }
        sb.append("]");
        return sb.toString();
    }

    public String toString(int maxKeysToPrint) {
        return this.asPriorityQueue().toString(maxKeysToPrint, false);
    }

    public String toString(int maxKeysToPrint, boolean multiline) {
        return this.asPriorityQueue().toString(maxKeysToPrint, multiline);
    }

    public PriorityQueue<E> asPriorityQueue() {
        PriorityQueue<E> pq = new PriorityQueue<E>(this.entries.size());
        for (Map.Entry<E, Float> entry : this.entries.entrySet()) {
            pq.add(entry.getKey(), entry.getValue().floatValue());
        }
        return pq;
    }

    public PriorityQueue<E> asMinPriorityQueue() {
        PriorityQueue<E> pq = new PriorityQueue<E>(this.entries.size());
        for (Map.Entry<E, Float> entry : this.entries.entrySet()) {
            pq.add(entry.getKey(), -entry.getValue().floatValue());
        }
        return pq;
    }

    public Counter() {
        this(false);
    }

    public Counter(boolean identityHashMap) {
        this(identityHashMap ? new MapFactory.IdentityHashMapFactory() : new MapFactory.HashMapFactory());
    }

    public Counter(MapFactory<E, Float> mf) {
        this.mf = mf;
        this.entries = mf.buildMap();
    }

    public Counter(Map<? extends E, Float> mapCounts) {
        this(false);
        this.entries = new HashMap<E, Float>();
        for (Map.Entry<E, Float> entry : mapCounts.entrySet()) {
            this.incrementCount(entry.getKey(), entry.getValue().floatValue());
        }
    }

    public Counter(Counter<? extends E> counter) {
        this();
        this.incrementAll(counter);
    }

    public Counter(Collection<? extends E> collection) {
        this();
        this.incrementAll(collection, 1.0f);
    }

    public void pruneKeysBelowThreshold(float cutoff) {
        Iterator<E> it = this.entries.keySet().iterator();
        while (it.hasNext()) {
            E key = it.next();
            float val = this.entries.get(key).floatValue();
            if (!(val < cutoff)) continue;
            it.remove();
        }
        this.dirty = true;
    }

    public Set<Map.Entry<E, Float>> getEntrySet() {
        return this.entries.entrySet();
    }

    public boolean isEqualTo(Counter<E> counter) {
        boolean tmp = true;
        Counter bigger = counter.size() > this.size() ? counter : this;
        for (E e : bigger.keySet()) {
            tmp &= counter.getCount(e) == this.getCount(e);
        }
        return tmp;
    }

    public static void main(String[] args) {
        Counter<String> counter = new Counter<String>();
        System.out.println(counter);
        counter.incrementCount("planets", 7.0f);
        System.out.println(counter);
        counter.incrementCount("planets", 1.0f);
        System.out.println(counter);
        counter.setCount("suns", 1.0f);
        System.out.println(counter);
        counter.setCount("aliens", 0.0f);
        System.out.println(counter);
        System.out.println(counter.toString(2));
        System.out.println("Total: " + counter.totalCount());
    }

    public void clear() {
        this.entries = this.mf.buildMap();
        this.dirty = true;
    }

    public void keepTopNKeys(int keepN) {
        this.keepKeysHelper(keepN, true);
    }

    public void keepBottomNKeys(int keepN) {
        this.keepKeysHelper(keepN, false);
    }

    private void keepKeysHelper(int keepN, boolean top) {
        Counter<E> tmp = new Counter<E>();
        int n = 0;
        for (E e : Iterators.able(top ? this.asPriorityQueue() : this.asMinPriorityQueue())) {
            if (n <= keepN) {
                tmp.setCount(e, this.getCount(e));
            }
            ++n;
        }
        this.clear();
        this.incrementAll(tmp);
        this.dirty = true;
    }

    public void setAllCounts(float val) {
        for (E e : this.keySet()) {
            this.setCount(e, val);
        }
    }

    public float dotProduct(Counter<E> other) {
        float sum = 0.0f;
        for (Map.Entry<E, Float> entry : this.getEntrySet()) {
            float value;
            float otherCount = other.getCount(entry.getKey());
            if ((double)otherCount == 0.0 || (double)(value = entry.getValue().floatValue()) == 0.0) continue;
            sum += value * otherCount;
        }
        return sum;
    }

    public void scale(float c) {
        for (Map.Entry<E, Float> entry : this.getEntrySet()) {
            entry.setValue(Float.valueOf(entry.getValue().floatValue() * c));
        }
    }

    public Counter<E> scaledClone(float c) {
        Counter<E> newCounter = new Counter<E>();
        for (Map.Entry<E, Float> entry : this.getEntrySet()) {
            newCounter.setCount(entry.getKey(), entry.getValue().floatValue() * c);
        }
        return newCounter;
    }

    public Counter<E> difference(Counter<E> counter) {
        Counter<E> clone = new Counter<E>(this);
        for (E key : counter.keySet()) {
            float count = counter.getCount(key);
            clone.incrementCount(key, -1.0f * count);
        }
        return clone;
    }

    public Counter<E> toLogSpace() {
        Counter<E> newCounter = new Counter<E>(this);
        for (E key : newCounter.keySet()) {
            newCounter.setCount(key, (float)Math.log(this.getCount(key)));
        }
        return newCounter;
    }

    public boolean approxEquals(Counter<E> other, float tol) {
        for (E key : this.keySet()) {
            if (!(Math.abs(this.getCount(key) - other.getCount(key)) > tol)) continue;
            return false;
        }
        for (E key : other.keySet()) {
            if (!(Math.abs(this.getCount(key) - other.getCount(key)) > tol)) continue;
            return false;
        }
        return true;
    }

    public void setDirty(boolean dirty) {
        this.dirty = dirty;
    }

    public String toStringTabSeparated() {
        StringBuilder sb = new StringBuilder();
        for (E key : this.getSortedKeys()) {
            sb.append(key.toString() + "\t" + this.getCount(key) + "\n");
        }
        return sb.toString();
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        Counter counter = (Counter)o;
        if (this.dirty != counter.dirty) {
            return false;
        }
        if (Float.compare(counter.cacheTotal, this.cacheTotal) != 0) {
            return false;
        }
        if (Float.compare(counter.deflt, this.deflt) != 0) {
            return false;
        }
        return !(this.entries == null ? counter.entries != null : !this.entries.equals(counter.entries));
    }

    public int hashCode() {
        int result = this.entries != null ? this.entries.hashCode() : 0;
        result = 31 * result + (this.dirty ? 1 : 0);
        long temp = Float.floatToIntBits(this.cacheTotal);
        result = 31 * result + (int)(temp ^ temp >>> 32);
        result = 31 * result + (this.mf != null ? this.mf.hashCode() : 0);
        temp = Float.floatToIntBits(this.deflt);
        result = 31 * result + (int)(temp ^ temp >>> 32);
        return result;
    }

    private Map<E, Float> getEntries() {
        return this.entries;
    }

    private boolean isDirty() {
        return this.dirty;
    }

    private float getCacheTotal() {
        return this.cacheTotal;
    }

    private MapFactory<E, Float> getMf() {
        return this.mf;
    }

    private void setEntries(Map<E, Float> entries) {
        this.entries = entries;
    }

    private void setCacheTotal(float cacheTotal) {
        this.cacheTotal = cacheTotal;
    }

    private void setMf(MapFactory<E, Float> mf) {
        this.mf = mf;
    }
}

