/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.datasets.iterator;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
import java.util.List;
import org.deeplearning4j.datasets.iterator.AsyncDataSetIterator;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MultipleEpochsIterator
implements DataSetIterator {
    @VisibleForTesting
    protected int epochs = 0;
    protected int numEpochs;
    protected int batch;
    protected int lastBatch = this.batch = 0;
    protected DataSetIterator iter;
    protected org.nd4j.linalg.dataset.DataSet ds;
    protected List<org.nd4j.linalg.dataset.DataSet> batchedDS = Lists.newArrayList();
    protected static final Logger log = LoggerFactory.getLogger(MultipleEpochsIterator.class);
    protected DataSetPreProcessor preProcessor;
    protected boolean newEpoch = false;
    protected int queueSize = 1;
    protected boolean async = false;

    public MultipleEpochsIterator(int numEpochs, DataSetIterator iter) {
        this.numEpochs = numEpochs;
        this.iter = iter;
    }

    public MultipleEpochsIterator(int numEpochs, DataSetIterator iter, int queueSize) {
        this.numEpochs = numEpochs;
        this.iter = iter;
        this.queueSize = queueSize;
        this.async = queueSize != 1;
    }

    public MultipleEpochsIterator(int numEpochs, org.nd4j.linalg.dataset.DataSet ds) {
        this.numEpochs = numEpochs;
        this.ds = ds;
    }

    public org.nd4j.linalg.dataset.DataSet next(int num) {
        org.nd4j.linalg.dataset.DataSet next;
        ++this.batch;
        if (this.iter == null) {
            if (num == -1) {
                next = this.ds;
                if (this.epochs < this.numEpochs) {
                    this.trackEpochs();
                }
            } else {
                if (this.batchedDS.isEmpty() && num > 0) {
                    this.batchedDS = this.ds.batchBy(num);
                }
                next = this.batchedDS.get(this.batch);
                if (this.batch + 1 == this.batchedDS.size()) {
                    this.trackEpochs();
                    if (this.epochs < this.numEpochs) {
                        this.batch = -1;
                    }
                }
            }
        } else {
            this.iter = this.async ? new AsyncDataSetIterator(this.iter, this.queueSize) : this.iter;
            org.nd4j.linalg.dataset.DataSet dataSet = next = num == -1 ? (org.nd4j.linalg.dataset.DataSet)this.iter.next() : this.iter.next(num);
            if (!this.iter.hasNext()) {
                this.trackEpochs();
                if (this.epochs < this.numEpochs) {
                    this.iter.reset();
                    this.lastBatch = this.batch;
                    this.batch = 0;
                }
            }
        }
        if (this.preProcessor != null) {
            this.preProcessor.preProcess((DataSet)next);
        }
        return next;
    }

    public void trackEpochs() {
        ++this.epochs;
        this.newEpoch = true;
    }

    public org.nd4j.linalg.dataset.DataSet next() {
        return this.next(-1);
    }

    public int totalExamples() {
        return this.iter.totalExamples();
    }

    public int inputColumns() {
        return this.iter.inputColumns();
    }

    public int totalOutcomes() {
        return this.iter.totalOutcomes();
    }

    public void reset() {
        this.epochs = 0;
        this.lastBatch = this.batch;
        this.batch = 0;
        this.iter.reset();
    }

    public int batch() {
        return this.iter.batch();
    }

    public int cursor() {
        return this.iter.cursor();
    }

    public int numExamples() {
        return this.iter.numExamples();
    }

    public void setPreProcessor(DataSetPreProcessor preProcessor) {
        this.preProcessor = preProcessor;
    }

    public List<String> getLabels() {
        return this.iter.getLabels();
    }

    public boolean hasNext() {
        if (this.newEpoch) {
            log.info("Epoch " + this.epochs + ", number of batches completed " + this.lastBatch);
            this.newEpoch = false;
        }
        if (this.iter == null) {
            return this.epochs < this.numEpochs && (!this.batchedDS.isEmpty() && this.batchedDS.size() > this.batch || this.batchedDS.isEmpty());
        }
        return this.epochs < this.numEpochs || this.iter.hasNext() && (this.epochs == 0 || this.epochs == this.numEpochs);
    }

    public void remove() {
        this.iter.remove();
    }

    public DataSetPreProcessor getPreProcessor() {
        return this.preProcessor;
    }
}

