/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.datasets.iterator;

import java.util.UUID;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import lombok.NonNull;
import org.deeplearning4j.datasets.iterator.callbacks.DataSetCallback;
import org.deeplearning4j.datasets.iterator.callbacks.DefaultCallback;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AsyncMultiDataSetIterator
implements MultiDataSetIterator {
    private static final Logger log = LoggerFactory.getLogger(AsyncMultiDataSetIterator.class);
    protected MultiDataSetIterator backedIterator;
    protected org.nd4j.linalg.dataset.api.MultiDataSet terminator = new MultiDataSet();
    protected org.nd4j.linalg.dataset.api.MultiDataSet nextElement = null;
    protected BlockingQueue<org.nd4j.linalg.dataset.api.MultiDataSet> buffer;
    protected AsyncPrefetchThread thread;
    protected AtomicBoolean shouldWork = new AtomicBoolean(true);
    protected volatile RuntimeException throwable = null;
    protected boolean useWorkspaces;
    protected int prefetchSize;
    protected String workspaceId;
    protected DataSetCallback callback;
    protected Integer deviceId;
    protected AtomicBoolean hasDepleted = new AtomicBoolean(false);

    protected AsyncMultiDataSetIterator() {
    }

    public AsyncMultiDataSetIterator(MultiDataSetIterator baseIterator) {
        this(baseIterator, 8);
    }

    public AsyncMultiDataSetIterator(MultiDataSetIterator iterator, int queueSize, BlockingQueue<org.nd4j.linalg.dataset.api.MultiDataSet> queue) {
        this(iterator, queueSize, queue, true);
    }

    public AsyncMultiDataSetIterator(MultiDataSetIterator baseIterator, int queueSize) {
        this(baseIterator, queueSize, new LinkedBlockingQueue<org.nd4j.linalg.dataset.api.MultiDataSet>(queueSize));
    }

    public AsyncMultiDataSetIterator(MultiDataSetIterator baseIterator, int queueSize, boolean useWorkspace) {
        this(baseIterator, queueSize, new LinkedBlockingQueue<org.nd4j.linalg.dataset.api.MultiDataSet>(queueSize), useWorkspace);
    }

    public AsyncMultiDataSetIterator(MultiDataSetIterator baseIterator, int queueSize, boolean useWorkspace, Integer deviceId) {
        this(baseIterator, queueSize, new LinkedBlockingQueue<org.nd4j.linalg.dataset.api.MultiDataSet>(queueSize), useWorkspace, null, deviceId);
    }

    public AsyncMultiDataSetIterator(MultiDataSetIterator iterator, int queueSize, BlockingQueue<org.nd4j.linalg.dataset.api.MultiDataSet> queue, boolean useWorkspace) {
        this(iterator, queueSize, queue, useWorkspace, new DefaultCallback());
    }

    public AsyncMultiDataSetIterator(MultiDataSetIterator iterator, int queueSize, BlockingQueue<org.nd4j.linalg.dataset.api.MultiDataSet> queue, boolean useWorkspace, DataSetCallback callback) {
        this(iterator, queueSize, queue, useWorkspace, callback, Nd4j.getAffinityManager().getDeviceForCurrentThread());
    }

    public AsyncMultiDataSetIterator(MultiDataSetIterator iterator, int queueSize, BlockingQueue<org.nd4j.linalg.dataset.api.MultiDataSet> queue, boolean useWorkspace, DataSetCallback callback, Integer deviceId) {
        if (queueSize < 2) {
            queueSize = 2;
        }
        this.callback = callback;
        this.buffer = queue;
        this.backedIterator = iterator;
        this.useWorkspaces = useWorkspace;
        this.prefetchSize = queueSize;
        this.workspaceId = "AMDSI_ITER-" + UUID.randomUUID().toString();
        this.deviceId = deviceId;
        if (iterator.resetSupported() && !iterator.hasNext()) {
            this.backedIterator.reset();
        }
        this.thread = new AsyncPrefetchThread(this.buffer, iterator, this.terminator);
        Nd4j.getAffinityManager().attachThreadToDevice((Thread)this.thread, deviceId);
        this.thread.setDaemon(true);
        this.thread.start();
    }

    public org.nd4j.linalg.dataset.api.MultiDataSet next(int num) {
        throw new UnsupportedOperationException();
    }

    public void setPreProcessor(MultiDataSetPreProcessor preProcessor) {
        this.backedIterator.setPreProcessor(preProcessor);
    }

    public MultiDataSetPreProcessor getPreProcessor() {
        return this.backedIterator.getPreProcessor();
    }

    public boolean resetSupported() {
        return this.backedIterator.resetSupported();
    }

    public boolean asyncSupported() {
        return false;
    }

    public void reset() {
        this.buffer.clear();
        if (this.thread != null) {
            this.thread.interrupt();
        }
        try {
            if (this.thread != null) {
                this.thread.join();
            }
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new RuntimeException(e);
        }
        this.thread.shutdown();
        this.buffer.clear();
        this.backedIterator.reset();
        this.shouldWork.set(true);
        this.thread = new AsyncPrefetchThread(this.buffer, this.backedIterator, this.terminator);
        Nd4j.getAffinityManager().attachThreadToDevice((Thread)this.thread, this.deviceId);
        this.thread.setDaemon(true);
        this.thread.start();
        this.hasDepleted.set(false);
        this.nextElement = null;
    }

    public void shutdown() {
        this.buffer.clear();
        if (this.thread != null) {
            this.thread.interrupt();
        }
        try {
            if (this.thread != null) {
                this.thread.join();
            }
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new RuntimeException(e);
        }
        this.thread.shutdown();
        this.buffer.clear();
    }

    public boolean hasNext() {
        if (this.throwable != null) {
            throw this.throwable;
        }
        try {
            if (this.hasDepleted.get()) {
                return false;
            }
            if (this.nextElement != null && this.nextElement != this.terminator) {
                return true;
            }
            if (this.nextElement == this.terminator) {
                return false;
            }
            this.nextElement = this.buffer.take();
            if (this.nextElement == this.terminator) {
                this.hasDepleted.set(true);
                return false;
            }
            return true;
        }
        catch (Exception e) {
            log.error("Premature end of loop!");
            throw new RuntimeException(e);
        }
    }

    public org.nd4j.linalg.dataset.api.MultiDataSet next() {
        if (this.throwable != null) {
            throw this.throwable;
        }
        if (this.hasDepleted.get()) {
            return null;
        }
        org.nd4j.linalg.dataset.api.MultiDataSet temp = this.nextElement;
        this.nextElement = null;
        return temp;
    }

    public void remove() {
    }

    protected void externalCall() {
    }

    protected class AsyncPrefetchThread
    extends Thread
    implements Runnable {
        private BlockingQueue<org.nd4j.linalg.dataset.api.MultiDataSet> queue;
        private MultiDataSetIterator iterator;
        private org.nd4j.linalg.dataset.api.MultiDataSet terminator;
        private boolean isShutdown = false;
        private WorkspaceConfiguration configuration;
        private MemoryWorkspace workspace;

        protected AsyncPrefetchThread(@NonNull BlockingQueue<org.nd4j.linalg.dataset.api.MultiDataSet> queue, @NonNull MultiDataSetIterator iterator, org.nd4j.linalg.dataset.api.MultiDataSet terminator) {
            this.configuration = WorkspaceConfiguration.builder().minSize(0xA00000L).overallocationLimit((double)(AsyncMultiDataSetIterator.this.prefetchSize + 1)).policyReset(ResetPolicy.ENDOFBUFFER_REACHED).policyLearning(LearningPolicy.FIRST_LOOP).policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE).build();
            if (queue == null) {
                throw new NullPointerException("queue is marked @NonNull but is null");
            }
            if (iterator == null) {
                throw new NullPointerException("iterator is marked @NonNull but is null");
            }
            if (terminator == null) {
                throw new NullPointerException("terminator is marked @NonNull but is null");
            }
            this.queue = queue;
            this.iterator = iterator;
            this.terminator = terminator;
            this.setDaemon(true);
            this.setName("AMDSI prefetch thread");
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         * Unable to fully structure code
         */
        @Override
        public void run() {
            AsyncMultiDataSetIterator.this.externalCall();
            try {
                if (AsyncMultiDataSetIterator.this.useWorkspaces) {
                    this.workspace = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(this.configuration, AsyncMultiDataSetIterator.this.workspaceId);
                }
                while (this.iterator.hasNext() && AsyncMultiDataSetIterator.this.shouldWork.get()) {
                    smth = null;
                    if (AsyncMultiDataSetIterator.this.useWorkspaces) {
                        ws = this.workspace.notifyScopeEntered();
                        var3_7 = null;
                        try {
                            smth = (org.nd4j.linalg.dataset.api.MultiDataSet)this.iterator.next();
                            if (AsyncMultiDataSetIterator.this.callback == null) ** GOTO lbl32
                            AsyncMultiDataSetIterator.this.callback.call((org.nd4j.linalg.dataset.api.MultiDataSet)smth);
                        }
                        catch (Throwable var4_9) {
                            var3_7 = var4_9;
                            throw var4_9;
                        }
                        finally {
                            if (ws != null) {
                                if (var3_7 != null) {
                                    try {
                                        ws.close();
                                    }
                                    catch (Throwable var4_8) {
                                        var3_7.addSuppressed(var4_8);
                                    }
                                } else {
                                    ws.close();
                                }
                            }
                        }
                    } else {
                        smth = (org.nd4j.linalg.dataset.api.MultiDataSet)this.iterator.next();
                        if (AsyncMultiDataSetIterator.this.callback != null) {
                            AsyncMultiDataSetIterator.this.callback.call((org.nd4j.linalg.dataset.api.MultiDataSet)smth);
                        }
                    }
lbl32:
                    // 5 sources

                    Nd4j.getExecutioner().commit();
                    if (smth == null) continue;
                    this.queue.put((org.nd4j.linalg.dataset.api.MultiDataSet)smth);
                }
                this.queue.put(this.terminator);
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                AsyncMultiDataSetIterator.this.shouldWork.set(false);
            }
            catch (RuntimeException e) {
                AsyncMultiDataSetIterator.this.throwable = e;
                throw new RuntimeException(e);
            }
            catch (Exception e) {
                AsyncMultiDataSetIterator.this.throwable = new RuntimeException(e);
                throw new RuntimeException(e);
            }
            finally {
                e = this;
                synchronized (e) {
                    this.isShutdown = true;
                    this.notifyAll();
                }
            }
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public void shutdown() {
            AsyncPrefetchThread asyncPrefetchThread = this;
            synchronized (asyncPrefetchThread) {
                while (!this.isShutdown) {
                    try {
                        this.wait();
                    }
                    catch (InterruptedException e) {
                        Thread.currentThread().interrupt();
                        throw new RuntimeException(e);
                    }
                }
            }
            if (this.workspace != null) {
                log.debug("Manually destroying AMDSI workspace");
                this.workspace.destroyWorkspace(true);
                this.workspace = null;
            }
        }
    }
}

