/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.datasets.iterator;

import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DataSetIteratorSplitter {
    private static final Logger log = LoggerFactory.getLogger(DataSetIteratorSplitter.class);
    protected DataSetIterator backedIterator;
    protected final long totalExamples;
    protected final double ratio;
    protected final long numTrain;
    protected final long numTest;
    protected AtomicLong counter = new AtomicLong(0L);
    protected AtomicBoolean resetPending = new AtomicBoolean(false);
    protected DataSet firstTrain = null;

    public DataSetIteratorSplitter(@NonNull DataSetIterator baseIterator, long totalBatches, double ratio) {
        if (baseIterator == null) {
            throw new NullPointerException("baseIterator is marked @NonNull but is null");
        }
        if (!(ratio > 0.0) || !(ratio < 1.0)) {
            throw new ND4JIllegalStateException("Ratio value should be in range of 0.0 > X < 1.0");
        }
        if (totalBatches < 0L) {
            throw new ND4JIllegalStateException("totalExamples number should be positive value");
        }
        if (!baseIterator.resetSupported()) {
            throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split");
        }
        this.backedIterator = baseIterator;
        this.totalExamples = totalBatches;
        this.ratio = ratio;
        this.numTrain = (long)((double)this.totalExamples * ratio);
        this.numTest = this.totalExamples - this.numTrain;
        log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
    }

    public DataSetIterator getTrainIterator() {
        return new DataSetIterator(){

            public DataSet next(int i) {
                throw new UnsupportedOperationException();
            }

            public List<String> getLabels() {
                return DataSetIteratorSplitter.this.backedIterator.getLabels();
            }

            public int inputColumns() {
                return DataSetIteratorSplitter.this.backedIterator.inputColumns();
            }

            public void remove() {
                throw new UnsupportedOperationException();
            }

            public int totalOutcomes() {
                return DataSetIteratorSplitter.this.backedIterator.totalOutcomes();
            }

            public boolean resetSupported() {
                return DataSetIteratorSplitter.this.backedIterator.resetSupported();
            }

            public boolean asyncSupported() {
                return DataSetIteratorSplitter.this.backedIterator.asyncSupported();
            }

            public void reset() {
                DataSetIteratorSplitter.this.resetPending.set(true);
            }

            public int batch() {
                return DataSetIteratorSplitter.this.backedIterator.batch();
            }

            public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
                DataSetIteratorSplitter.this.backedIterator.setPreProcessor(dataSetPreProcessor);
            }

            public DataSetPreProcessor getPreProcessor() {
                return DataSetIteratorSplitter.this.backedIterator.getPreProcessor();
            }

            public boolean hasNext() {
                boolean state;
                if (DataSetIteratorSplitter.this.resetPending.get()) {
                    if (this.resetSupported()) {
                        DataSetIteratorSplitter.this.backedIterator.reset();
                        DataSetIteratorSplitter.this.counter.set(0L);
                        DataSetIteratorSplitter.this.resetPending.set(false);
                    } else {
                        throw new UnsupportedOperationException("Reset isn't supported by underlying iterator");
                    }
                }
                return (state = DataSetIteratorSplitter.this.backedIterator.hasNext()) && DataSetIteratorSplitter.this.counter.get() < DataSetIteratorSplitter.this.numTrain;
            }

            public DataSet next() {
                DataSetIteratorSplitter.this.counter.incrementAndGet();
                DataSet p = (DataSet)DataSetIteratorSplitter.this.backedIterator.next();
                if (DataSetIteratorSplitter.this.counter.get() == 1L && DataSetIteratorSplitter.this.firstTrain == null) {
                    DataSetIteratorSplitter.this.firstTrain = p.copy();
                    DataSetIteratorSplitter.this.firstTrain.detach();
                } else if (DataSetIteratorSplitter.this.counter.get() == 1L) {
                    boolean cnt = false;
                    if (!p.getFeatures().equalsWithEps((Object)DataSetIteratorSplitter.this.firstTrain.getFeatures(), 1.0E-5)) {
                        throw new ND4JIllegalStateException("First examples do not match. Randomization was used?");
                    }
                }
                return p;
            }
        };
    }

    public DataSetIterator getTestIterator() {
        return new DataSetIterator(){

            public DataSet next(int i) {
                throw new UnsupportedOperationException();
            }

            public List<String> getLabels() {
                return DataSetIteratorSplitter.this.backedIterator.getLabels();
            }

            public int inputColumns() {
                return DataSetIteratorSplitter.this.backedIterator.inputColumns();
            }

            public void remove() {
                throw new UnsupportedOperationException();
            }

            public int totalOutcomes() {
                return DataSetIteratorSplitter.this.backedIterator.totalOutcomes();
            }

            public boolean resetSupported() {
                return DataSetIteratorSplitter.this.backedIterator.resetSupported();
            }

            public boolean asyncSupported() {
                return DataSetIteratorSplitter.this.backedIterator.asyncSupported();
            }

            public void reset() {
                DataSetIteratorSplitter.this.resetPending.set(true);
            }

            public int batch() {
                return DataSetIteratorSplitter.this.backedIterator.batch();
            }

            public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
                DataSetIteratorSplitter.this.backedIterator.setPreProcessor(dataSetPreProcessor);
            }

            public DataSetPreProcessor getPreProcessor() {
                return DataSetIteratorSplitter.this.backedIterator.getPreProcessor();
            }

            public boolean hasNext() {
                boolean state = DataSetIteratorSplitter.this.backedIterator.hasNext();
                return state && DataSetIteratorSplitter.this.counter.get() < DataSetIteratorSplitter.this.numTrain + DataSetIteratorSplitter.this.numTest;
            }

            public DataSet next() {
                DataSetIteratorSplitter.this.counter.incrementAndGet();
                return (DataSet)DataSetIteratorSplitter.this.backedIterator.next();
            }
        };
    }
}

