/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.util;

import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TestDataSetConsumer {
    private DataSetIterator iterator;
    private long delay;
    private AtomicLong count = new AtomicLong(0L);
    protected static final Logger logger = LoggerFactory.getLogger(TestDataSetConsumer.class);

    public TestDataSetConsumer(long delay) {
        this.delay = delay;
    }

    public TestDataSetConsumer(@NonNull DataSetIterator iterator, long delay) {
        if (iterator == null) {
            throw new NullPointerException("iterator");
        }
        this.iterator = iterator;
        this.delay = delay;
    }

    public long consumeWhileHasNext(boolean consumeWithSleep) {
        if (this.iterator == null) {
            throw new RuntimeException("Can't use consumeWhileHasNext() if iterator isn't set");
        }
        while (this.iterator.hasNext()) {
            DataSet ds = (DataSet)this.iterator.next();
            this.consumeOnce(ds, consumeWithSleep);
        }
        return this.count.get();
    }

    public long consumeOnce(@NonNull DataSet dataSet, boolean consumeWithSleep) {
        if (dataSet == null) {
            throw new NullPointerException("dataSet");
        }
        long timeMs = System.currentTimeMillis() + this.delay;
        while (System.currentTimeMillis() < timeMs) {
            if (!consumeWithSleep) continue;
            try {
                Thread.sleep(this.delay);
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        this.count.incrementAndGet();
        if (this.count.get() % 100L == 0L) {
            logger.info("Passed {} datasets...", (Object)this.count.get());
        }
        return this.count.get();
    }

    public long getCount() {
        return this.count.get();
    }
}

