/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.parallelism;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;

public class MagicQueue
implements Queue<DataSet> {
    protected final List<LinkedBlockingQueue<DataSet>> backingQueues;
    protected final AtomicInteger nextBucket = new AtomicInteger(0);
    protected final int numberOfBuckets;
    protected final List<QueueHandler> handlers;

    protected MagicQueue(int numberOfFlows) {
        this.backingQueues = new ArrayList<LinkedBlockingQueue<DataSet>>();
        this.handlers = new ArrayList<QueueHandler>();
        if (numberOfFlows > 1) {
            for (int i = 0; i < numberOfFlows; ++i) {
                LinkedBlockingQueue<DataSet> queue = new LinkedBlockingQueue<DataSet>();
                this.backingQueues.add(queue);
                QueueHandler handler = new QueueHandler(queue);
                Nd4j.getAffinityManager().attachThreadToDevice((Thread)handler, Integer.valueOf(i));
                handler.start();
                this.handlers.add(handler);
            }
        } else {
            LinkedBlockingQueue queue = new LinkedBlockingQueue();
            this.backingQueues.add(queue);
        }
        this.numberOfBuckets = numberOfFlows;
    }

    @Override
    public int size() {
        if (this.numberOfBuckets > 1) {
            long cnt = 0L;
            for (int i = 0; i < this.numberOfBuckets; ++i) {
                cnt += (long)this.backingQueues.get(i).size();
            }
            return (int)Math.floor(cnt / (long)this.numberOfBuckets);
        }
        return this.backingQueues.get(0).size();
    }

    protected int size(int deviceId) {
        if (deviceId >= this.backingQueues.size()) {
            throw new RuntimeException("DeviceID exceeds number of actual backing queues");
        }
        return this.backingQueues.get(deviceId).size();
    }

    @Override
    public boolean isEmpty() {
        return this.size() < 1;
    }

    @Override
    public boolean contains(Object o) {
        throw new UnsupportedOperationException();
    }

    @Override
    public Iterator<DataSet> iterator() {
        throw new UnsupportedOperationException();
    }

    @Override
    public Object[] toArray() {
        throw new UnsupportedOperationException();
    }

    @Override
    public <T> T[] toArray(T[] a) {
        throw new UnsupportedOperationException();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public boolean add(DataSet dataSet) {
        if (this.numberOfBuckets > 1) {
            MagicQueue magicQueue = this;
            synchronized (magicQueue) {
                if (this.nextBucket.get() >= this.backingQueues.size()) {
                    this.nextBucket.set(0);
                }
            }
            this.handlers.get(this.nextBucket.getAndIncrement()).put(dataSet);
            return true;
        }
        this.backingQueues.get(0).add(dataSet);
        return true;
    }

    @Override
    public boolean remove(Object o) {
        throw new UnsupportedOperationException();
    }

    @Override
    public boolean containsAll(Collection<?> c) {
        throw new UnsupportedOperationException();
    }

    @Override
    public boolean addAll(Collection<? extends DataSet> c) {
        return false;
    }

    @Override
    public boolean removeAll(Collection<?> c) {
        throw new UnsupportedOperationException();
    }

    @Override
    public boolean retainAll(Collection<?> c) {
        throw new UnsupportedOperationException();
    }

    @Override
    public void clear() {
        for (Queue queue : this.backingQueues) {
            queue.clear();
        }
    }

    @Override
    public boolean offer(DataSet dataSet) {
        return false;
    }

    @Override
    public DataSet remove() {
        return null;
    }

    public DataSet poll(long time, TimeUnit timeUnit) throws InterruptedException {
        if (this.numberOfBuckets > 1) {
            int deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
            return this.backingQueues.get(deviceId).poll(time, timeUnit);
        }
        return this.backingQueues.get(0).poll(time, timeUnit);
    }

    @Override
    public DataSet poll() {
        if (this.numberOfBuckets > 1) {
            int deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
            return this.backingQueues.get(deviceId).poll();
        }
        return this.backingQueues.get(0).poll();
    }

    @Override
    public DataSet element() {
        return null;
    }

    @Override
    public DataSet peek() {
        return null;
    }

    private static class QueueHandler
    extends Thread
    implements Runnable {
        private final Queue<DataSet> targetQueue;
        private final LinkedBlockingQueue<DataSet> bufferQueue;

        public QueueHandler(Queue<DataSet> queue) {
            this.targetQueue = queue;
            this.bufferQueue = new LinkedBlockingQueue();
            this.setDaemon(true);
        }

        public void put(DataSet dataSet) {
            this.bufferQueue.add(dataSet);
        }

        @Override
        public void run() {
            while (true) {
                try {
                    while (true) {
                        DataSet ds;
                        if ((ds = this.bufferQueue.poll(1L, TimeUnit.SECONDS)) == null) {
                            continue;
                        }
                        if (ds.getFeaturesMaskArray() != null) {
                            Nd4j.getAffinityManager().touch(ds.getFeaturesMaskArray());
                        }
                        if (ds.getLabelsMaskArray() != null) {
                            Nd4j.getAffinityManager().touch(ds.getLabelsMaskArray());
                        }
                        Nd4j.getAffinityManager().touch(ds.getFeatures());
                        Nd4j.getAffinityManager().touch(ds.getLabels());
                        this.targetQueue.add(ds);
                    }
                }
                catch (Exception exception) {
                    continue;
                }
                break;
            }
        }
    }

    public static class Builder {
        private int numberOfBuckets = -1;

        public Builder setNumberOfBuckets(int number) {
            this.numberOfBuckets = number;
            return this;
        }

        public MagicQueue build() {
            if (this.numberOfBuckets < 1) {
                this.numberOfBuckets = Nd4j.getAffinityManager().getNumberOfDevices();
            }
            MagicQueue queue = new MagicQueue(this.numberOfBuckets);
            return queue;
        }
    }
}

