/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize.solvers.accumulation;

import java.util.ArrayList;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import lombok.NonNull;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.AtomicBoolean;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class IndexedTail {
    private static final Logger log = LoggerFactory.getLogger(IndexedTail.class);
    protected ConcurrentHashMap<Long, AtomicLong> positions = new ConcurrentHashMap();
    protected Map<Long, INDArray> updates = new ConcurrentHashMap<Long, INDArray>();
    protected AtomicLong updatesCounter = new AtomicLong(0L);
    protected AtomicLong lastDeletedIndex = new AtomicLong(-1L);
    protected final int expectedConsumers;
    protected AtomicBoolean dead = new AtomicBoolean(false);
    protected ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
    protected final boolean allowCollapse;
    protected final long[] shape;
    protected final int collapseThreshold = 32;
    protected AtomicBoolean collapsedMode = new AtomicBoolean(false);
    protected AtomicLong collapsedIndex = new AtomicLong(-1L);

    public IndexedTail(int expectedConsumers) {
        this(expectedConsumers, false, null);
    }

    public IndexedTail(int expectedConsumers, boolean allowCollapse, long[] shape) {
        this.expectedConsumers = expectedConsumers;
        this.allowCollapse = allowCollapse;
        if (allowCollapse) {
            Preconditions.checkArgument((shape != null ? 1 : 0) != 0, (String)"shape can't be null if collapse is allowed");
        }
        this.shape = shape;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void put(@NonNull INDArray update) {
        if (update == null) {
            throw new NullPointerException("update is marked @NonNull but is null");
        }
        try {
            this.lock.writeLock().lock();
            if (this.collapsedMode.get()) {
                long lastUpdateIndex = this.collapsedIndex.get();
                INDArray lastUpdate = this.updates.get(lastUpdateIndex);
                Preconditions.checkArgument((!lastUpdate.isCompressed() ? 1 : 0) != 0, (String)"lastUpdate should NOT be compressed during collapse mode");
                this.smartDecompress(update, lastUpdate);
            } else if (this.allowCollapse && this.positions.size() >= this.expectedConsumers) {
                long lastUpdateIndex = this.updatesCounter.get();
                long maxIdx = this.firstNotAppliedIndexEverywhere();
                INDArray array = Nd4j.create((long[])this.shape);
                long delta = lastUpdateIndex - maxIdx;
                if (delta >= 32L) {
                    log.info("Max delta to collapse: {}; Range: <{}...{}>", new Object[]{delta, maxIdx, lastUpdateIndex});
                    for (long e = maxIdx; e < lastUpdateIndex; ++e) {
                        INDArray u = this.updates.get(e);
                        if (u == null) {
                            log.error("Failed on index {}", (Object)e);
                        }
                        this.smartDecompress(u, array);
                        this.updates.remove(e);
                    }
                    this.smartDecompress(update, array);
                    this.updates.put(lastUpdateIndex, array);
                    this.collapsedIndex.set(lastUpdateIndex);
                    this.updatesCounter.getAndIncrement();
                    this.collapsedMode.set(true);
                } else {
                    this.updates.put(this.updatesCounter.getAndIncrement(), update);
                }
            } else {
                this.updates.put(this.updatesCounter.getAndIncrement(), update);
            }
        }
        finally {
            this.lock.writeLock().unlock();
        }
    }

    protected long firstNotAppliedIndexEverywhere() {
        long maxIdx = -1L;
        if (this.updatesCounter.get() == 0L) {
            return maxIdx;
        }
        for (AtomicLong v : this.positions.values()) {
            if (v.get() <= maxIdx) continue;
            maxIdx = v.get();
        }
        return maxIdx + 1L;
    }

    protected long maxAppliedIndexEverywhere() {
        long maxIdx = Long.MAX_VALUE;
        for (AtomicLong v : this.positions.values()) {
            if (v.get() >= maxIdx) continue;
            maxIdx = v.get();
        }
        return maxIdx;
    }

    public boolean hasAnything() {
        return this.hasAnything(Thread.currentThread().getId());
    }

    public boolean hasAnything(long threadId) {
        long threadPosition = this.getLocalPosition(threadId);
        boolean r = threadPosition < this.updatesCounter.get();
        log.info("hasAnything({}): {}; position: {}; updates: {}", new Object[]{threadId, r, threadPosition, this.updatesCounter.get()});
        return r;
    }

    public boolean drainTo(@NonNull INDArray array) {
        if (array == null) {
            throw new NullPointerException("array is marked @NonNull but is null");
        }
        return this.drainTo(Thread.currentThread().getId(), array);
    }

    protected long getGlobalPosition() {
        try {
            this.lock.readLock().lock();
            long l = this.updatesCounter.get();
            return l;
        }
        finally {
            this.lock.readLock().unlock();
        }
    }

    protected long getLocalPosition() {
        return this.getLocalPosition(Thread.currentThread().getId());
    }

    protected long getDelta() {
        return this.getDelta(Thread.currentThread().getId());
    }

    protected long getDelta(long threadId) {
        return this.getGlobalPosition() - this.getLocalPosition(threadId);
    }

    protected long getLocalPosition(long threadId) {
        AtomicLong threadPosition = this.positions.get(threadId);
        if (threadPosition == null) {
            threadPosition = new AtomicLong(-1L);
            this.positions.put(threadId, threadPosition);
        }
        return threadPosition.get() < 0L ? 0L : threadPosition.get();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public boolean drainTo(long threadId, @NonNull INDArray array) {
        if (array == null) {
            throw new NullPointerException("array is marked @NonNull but is null");
        }
        AtomicLong threadPosition = this.positions.get(threadId);
        if (threadPosition == null) {
            threadPosition = new AtomicLong(-1L);
            this.positions.put(threadId, threadPosition);
        }
        long globalPos = 0L;
        long localPos = 0L;
        long delta = 0L;
        ArrayList<INDArray> sessionUpdates = new ArrayList<INDArray>();
        try {
            this.lock.readLock().lock();
            this.collapsedMode.set(false);
            globalPos = this.updatesCounter.get();
            localPos = this.getLocalPosition(threadId);
            delta = this.getDelta(threadId);
            for (long e = localPos; e < localPos + delta; ++e) {
                INDArray update = this.updates.get(e);
                if (this.allowCollapse && update == null) continue;
                if (update == null) {
                    log.info("Global: [{}]; Local: [{}]", (Object)globalPos, (Object)localPos);
                    throw new RuntimeException("Element [" + e + "] is absent");
                }
                sessionUpdates.add(update);
            }
            threadPosition.set(globalPos);
        }
        finally {
            this.lock.readLock().unlock();
        }
        for (INDArray u : sessionUpdates) {
            this.smartDecompress(u.unsafeDuplication(true), array);
        }
        this.maintenance();
        return delta > 0L;
    }

    protected synchronized void maintenance() {
        if (this.positions.size() < this.expectedConsumers) {
            log.info("Skipping maintanance due to not all expected consumers shown up: [{}] vs [{}]", (Object)this.positions.size(), (Object)this.expectedConsumers);
            return;
        }
        long minIdx = this.maxAppliedIndexEverywhere();
        long[] allPositions = new long[this.positions.size()];
        int cnt = 0;
        for (AtomicLong p : this.positions.values()) {
            allPositions[cnt++] = p.get();
        }
        log.info("Min idx: {}; last deleted index: {}; stored updates: {}; positions: {}", new Object[]{minIdx, this.lastDeletedIndex.get(), this.updates.size(), allPositions});
        if (minIdx > this.lastDeletedIndex.get()) {
            for (long e = this.lastDeletedIndex.get(); e < minIdx; ++e) {
                this.updates.remove(e);
            }
            this.lastDeletedIndex.set(minIdx);
        }
    }

    protected int updatesSize() {
        return this.updates.size();
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    protected INDArray smartDecompress(INDArray encoded, @NonNull INDArray target) {
        if (target == null) {
            throw new NullPointerException("target is marked @NonNull but is null");
        }
        INDArray result = target;
        if (encoded.isCompressed() || encoded.data().dataType() == DataBuffer.Type.INT) {
            int encoding = encoded.data().getInt(3L);
            if (encoding == 0) {
                Nd4j.getExecutioner().thresholdDecode(encoded, result);
                return result;
            } else {
                if (encoding != 1) throw new ND4JIllegalStateException("Unknown encoding mode: [" + encoding + "]");
                Nd4j.getExecutioner().bitmapDecode(encoded, result);
            }
            return result;
        } else {
            result.addi(encoded);
        }
        return result;
    }

    protected boolean isDead() {
        return this.dead.get();
    }

    protected void notifyDead() {
        this.dead.set(true);
    }

    public void purge() {
        this.positions.clear();
        this.updates.clear();
        this.updatesCounter.set(0L);
        this.lastDeletedIndex.set(-1L);
        this.collapsedMode.set(false);
        this.collapsedIndex.set(-1L);
    }
}

