/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.parallelism;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MagicQueue
implements BlockingQueue<DataSet> {
    private static final Logger log = LoggerFactory.getLogger(MagicQueue.class);
    protected final List<LinkedBlockingQueue<DataSet>> backingQueues;
    protected final AtomicInteger nextBucket = new AtomicInteger(0);
    protected final int numberOfBuckets;
    protected final List<QueueHandler> handlers;
    protected int capacity = 10;
    protected Mode mode = Mode.THREADED;
    protected AtomicInteger interleavedCounter = new AtomicInteger(0);
    protected AtomicInteger interleavedPutter = new AtomicInteger(0);
    protected AtomicLong cntPut = new AtomicLong(0L);
    protected AtomicLong cntGet = new AtomicLong(0L);

    protected MagicQueue(int numberOfFlows, int capacity) {
        this.backingQueues = new ArrayList<LinkedBlockingQueue<DataSet>>();
        this.capacity = capacity;
        this.handlers = new ArrayList<QueueHandler>();
        if (numberOfFlows > 1) {
            for (int i = 0; i < numberOfFlows; ++i) {
                LinkedBlockingQueue<DataSet> queue = new LinkedBlockingQueue<DataSet>(capacity);
                this.backingQueues.add(queue);
                QueueHandler handler = new QueueHandler(queue, capacity);
                Nd4j.getAffinityManager().attachThreadToDevice((Thread)handler, Integer.valueOf(i));
                handler.start();
                this.handlers.add(handler);
            }
        } else {
            LinkedBlockingQueue queue = new LinkedBlockingQueue();
            this.backingQueues.add(queue);
        }
        this.numberOfBuckets = numberOfFlows;
    }

    @Override
    public int size() {
        if (this.mode == Mode.THREADED) {
            if (this.numberOfBuckets > 1) {
                long cnt = 0L;
                for (int i = 0; i < this.numberOfBuckets; ++i) {
                    cnt += (long)this.backingQueues.get(i).size();
                }
                return (int)Math.floor(cnt / (long)this.numberOfBuckets);
            }
            return this.backingQueues.get(0).size();
        }
        return (int)(this.cntPut.get() - this.cntGet.get());
    }

    protected int size(int deviceId) {
        if (deviceId >= this.backingQueues.size()) {
            throw new RuntimeException("DeviceID exceeds number of actual backing queues");
        }
        return this.backingQueues.get(deviceId).size();
    }

    @Override
    public boolean isEmpty() {
        return this.size() < 1;
    }

    @Override
    public boolean contains(Object o) {
        throw new UnsupportedOperationException();
    }

    @Override
    public int drainTo(Collection<? super DataSet> c) {
        throw new UnsupportedOperationException();
    }

    @Override
    public int drainTo(Collection<? super DataSet> c, int maxElements) {
        throw new UnsupportedOperationException();
    }

    @Override
    public Iterator<DataSet> iterator() {
        throw new UnsupportedOperationException();
    }

    @Override
    public Object[] toArray() {
        throw new UnsupportedOperationException();
    }

    @Override
    public <T> T[] toArray(T[] a) {
        throw new UnsupportedOperationException();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public boolean add(DataSet dataSet) {
        this.cntPut.incrementAndGet();
        if (this.numberOfBuckets > 1) {
            MagicQueue magicQueue = this;
            synchronized (magicQueue) {
                if (this.nextBucket.get() >= this.backingQueues.size()) {
                    this.nextBucket.set(0);
                }
            }
            this.handlers.get(this.nextBucket.getAndIncrement()).put(dataSet);
            return true;
        }
        this.backingQueues.get(0).add(dataSet);
        return true;
    }

    @Override
    public boolean remove(Object o) {
        throw new UnsupportedOperationException();
    }

    @Override
    public boolean containsAll(Collection<?> c) {
        throw new UnsupportedOperationException();
    }

    @Override
    public boolean addAll(Collection<? extends DataSet> c) {
        for (DataSet dataSet : c) {
            boolean result = this.add(dataSet);
            if (result) continue;
            return result;
        }
        return true;
    }

    @Override
    public boolean removeAll(Collection<?> c) {
        throw new UnsupportedOperationException();
    }

    @Override
    public boolean retainAll(Collection<?> c) {
        throw new UnsupportedOperationException();
    }

    @Override
    public void clear() {
        for (Queue queue : this.backingQueues) {
            queue.clear();
        }
        this.cntPut.set(0L);
        this.cntGet.set(0L);
    }

    @Override
    public boolean offer(DataSet dataSet) {
        if (this.numberOfBuckets > 1) {
            int deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
            boolean res = this.backingQueues.get(deviceId).offer(dataSet);
            if (res) {
                this.cntPut.incrementAndGet();
            }
            return res;
        }
        boolean result = this.backingQueues.get(0).offer(dataSet);
        if (result) {
            this.cntPut.incrementAndGet();
        }
        return result;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void put(DataSet dataSet) throws InterruptedException {
        if (this.numberOfBuckets > 1) {
            MagicQueue magicQueue = this;
            synchronized (magicQueue) {
                if (this.nextBucket.get() >= this.backingQueues.size()) {
                    this.nextBucket.set(0);
                }
            }
            this.handlers.get(this.nextBucket.getAndIncrement()).put(dataSet);
        } else {
            this.backingQueues.get(0).add(dataSet);
        }
        this.cntPut.incrementAndGet();
    }

    @Override
    public boolean offer(DataSet dataSet, long timeout, TimeUnit unit) throws InterruptedException {
        if (this.numberOfBuckets > 1) {
            int deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
            boolean res = this.backingQueues.get(deviceId).offer(dataSet, timeout, unit);
            if (res) {
                this.cntPut.incrementAndGet();
            }
            return res;
        }
        boolean res = this.backingQueues.get(0).offer(dataSet, timeout, unit);
        if (res) {
            this.cntPut.incrementAndGet();
        }
        return res;
    }

    @Override
    public DataSet take() throws InterruptedException {
        try {
            if (this.mode == Mode.THREADED) {
                if (this.numberOfBuckets > 1) {
                    int deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
                    DataSet dataSet = this.backingQueues.get(deviceId).take();
                    return dataSet;
                }
                DataSet deviceId = this.backingQueues.get(0).take();
                return deviceId;
            }
            DataSet ds = this.backingQueues.get(this.interleavedCounter.getAndIncrement()).take();
            if (this.interleavedCounter.get() >= this.backingQueues.size()) {
                this.interleavedCounter.set(0);
            }
            DataSet dataSet = ds;
            return dataSet;
        }
        catch (InterruptedException e) {
            throw e;
        }
        finally {
            this.cntGet.incrementAndGet();
        }
    }

    @Override
    public DataSet remove() {
        return null;
    }

    @Override
    public DataSet poll(long time, TimeUnit timeUnit) throws InterruptedException {
        if (this.mode == Mode.THREADED) {
            if (this.numberOfBuckets > 1) {
                int deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
                DataSet ds = this.backingQueues.get(deviceId).poll(time, timeUnit);
                if (ds != null) {
                    this.cntGet.incrementAndGet();
                }
                return ds;
            }
            DataSet ds = this.backingQueues.get(0).poll(time, timeUnit);
            if (ds != null) {
                this.cntGet.incrementAndGet();
            }
            return ds;
        }
        DataSet ds = this.backingQueues.get(this.interleavedCounter.getAndIncrement()).poll(time, timeUnit);
        if (this.interleavedCounter.get() >= this.backingQueues.size()) {
            this.interleavedCounter.set(0);
        }
        if (ds != null) {
            this.cntGet.incrementAndGet();
        }
        return ds;
    }

    @Override
    public int remainingCapacity() {
        return 0;
    }

    @Override
    public DataSet poll() {
        if (this.mode == Mode.THREADED) {
            if (this.numberOfBuckets > 1) {
                int deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
                DataSet ds = this.backingQueues.get(deviceId).poll();
                if (ds != null) {
                    this.cntGet.incrementAndGet();
                }
                return ds;
            }
            DataSet ds = this.backingQueues.get(0).poll();
            if (ds != null) {
                this.cntGet.incrementAndGet();
            }
            return ds;
        }
        DataSet ds = this.backingQueues.get(this.interleavedCounter.getAndIncrement()).poll();
        if (this.interleavedCounter.get() >= this.backingQueues.size()) {
            this.interleavedCounter.set(0);
        }
        if (ds != null) {
            this.cntGet.incrementAndGet();
        }
        return ds;
    }

    @Override
    public DataSet element() {
        return null;
    }

    @Override
    public DataSet peek() {
        return null;
    }

    private static class QueueHandler
    extends Thread
    implements Runnable {
        private final BlockingQueue<DataSet> targetQueue;
        private final LinkedBlockingQueue<DataSet> bufferQueue;

        public QueueHandler(BlockingQueue<DataSet> queue, int capacity) {
            this.targetQueue = queue;
            this.bufferQueue = new LinkedBlockingQueue(capacity);
            this.setDaemon(true);
        }

        public void put(DataSet dataSet) {
            try {
                this.bufferQueue.put(dataSet);
            }
            catch (InterruptedException interruptedException) {
                // empty catch block
            }
        }

        @Override
        public void run() {
            try {
                while (true) {
                    DataSet ds;
                    if ((ds = this.bufferQueue.poll(1L, TimeUnit.SECONDS)) == null) {
                        continue;
                    }
                    if (ds.getFeaturesMaskArray() != null) {
                        Nd4j.getAffinityManager().touch(ds.getFeaturesMaskArray());
                    }
                    if (ds.getLabelsMaskArray() != null) {
                        Nd4j.getAffinityManager().touch(ds.getLabelsMaskArray());
                    }
                    Nd4j.getAffinityManager().touch(ds.getFeatures());
                    Nd4j.getAffinityManager().touch(ds.getLabels());
                    this.targetQueue.put(ds);
                }
            }
            catch (InterruptedException e) {
                log.warn("Got InterruptedException...");
                return;
            }
        }
    }

    public static class Builder {
        private int numberOfBuckets = Nd4j.getAffinityManager().getNumberOfDevices();
        private int capacity = 16;
        private Mode mode = Mode.THREADED;

        public Builder setNumberOfBuckets(int number) {
            this.numberOfBuckets = number;
            return this;
        }

        public Builder setMode(@NonNull Mode mode) {
            if (mode == null) {
                throw new NullPointerException("mode");
            }
            this.mode = mode;
            return this;
        }

        public Builder setCapacityPerFlow(int capacityPerFlow) {
            if (capacityPerFlow <= 0) {
                throw new ND4JIllegalStateException("Capacity per flow value should be positive value");
            }
            this.capacity = capacityPerFlow;
            return this;
        }

        public MagicQueue build() {
            if (this.numberOfBuckets < 1) {
                this.numberOfBuckets = Nd4j.getAffinityManager().getNumberOfDevices();
            }
            MagicQueue queue = new MagicQueue(this.numberOfBuckets, this.capacity);
            queue.mode = this.mode;
            return queue;
        }
    }

    public static enum Mode {
        THREADED,
        SEQUENTIAL;

    }
}

