/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.datasets.canova;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import org.canova.api.io.WritableConverter;
import org.canova.api.io.converters.SelfWritableConverter;
import org.canova.api.io.converters.WritableConverterException;
import org.canova.api.records.reader.RecordReader;
import org.canova.api.records.reader.SequenceRecordReader;
import org.canova.api.writable.Writable;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;

public class RecordReaderDataSetIterator
implements DataSetIterator {
    private RecordReader recordReader;
    private WritableConverter converter;
    private int batchSize = 10;
    private int labelIndex = -1;
    private int numPossibleLabels = -1;
    private boolean overshot = false;
    private Iterator<Collection<Writable>> sequenceIter;
    private org.nd4j.linalg.dataset.DataSet last;
    private boolean useCurrent = false;
    private boolean regression = false;
    private DataSetPreProcessor preProcessor;

    public RecordReaderDataSetIterator(RecordReader recordReader, int batchSize) {
        this(recordReader, (WritableConverter)new SelfWritableConverter(), batchSize, -1, -1);
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, int batchSize, int labelIndex, int numPossibleLabels) {
        this(recordReader, (WritableConverter)new SelfWritableConverter(), batchSize, labelIndex, numPossibleLabels);
    }

    public RecordReaderDataSetIterator(RecordReader recordReader) {
        this(recordReader, (WritableConverter)new SelfWritableConverter());
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, int labelIndex, int numPossibleLabels) {
        this(recordReader, (WritableConverter)new SelfWritableConverter(), 10, labelIndex, numPossibleLabels);
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, WritableConverter converter, int batchSize, int labelIndex, int numPossibleLabels, boolean regression) {
        this.recordReader = recordReader;
        this.converter = converter;
        this.batchSize = batchSize;
        this.labelIndex = labelIndex;
        this.numPossibleLabels = numPossibleLabels;
        this.regression = regression;
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, WritableConverter converter, int batchSize, int labelIndex, int numPossibleLabels) {
        this(recordReader, converter, batchSize, labelIndex, numPossibleLabels, false);
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, WritableConverter converter) {
        this(recordReader, converter, 10, -1, -1);
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, WritableConverter converter, int labelIndex, int numPossibleLabels) {
        this(recordReader, converter, 10, labelIndex, numPossibleLabels);
    }

    public org.nd4j.linalg.dataset.DataSet next(int num) {
        org.nd4j.linalg.dataset.DataSet ret;
        if (this.useCurrent) {
            this.useCurrent = false;
            if (this.preProcessor != null) {
                this.preProcessor.preProcess((DataSet)this.last);
            }
            return this.last;
        }
        ArrayList<org.nd4j.linalg.dataset.DataSet> dataSets = new ArrayList<org.nd4j.linalg.dataset.DataSet>();
        for (int i = 0; i < num && this.hasNext(); ++i) {
            Collection<Writable> record;
            if (this.recordReader instanceof SequenceRecordReader) {
                if (this.sequenceIter == null || !this.sequenceIter.hasNext()) {
                    Collection sequenceRecord = ((SequenceRecordReader)this.recordReader).sequenceRecord();
                    this.sequenceIter = sequenceRecord.iterator();
                }
                record = this.sequenceIter.next();
                dataSets.add(this.getDataSet(record));
                continue;
            }
            record = this.recordReader.next();
            dataSets.add(this.getDataSet(record));
        }
        ArrayList<INDArray> inputs = new ArrayList<INDArray>();
        ArrayList<INDArray> labels = new ArrayList<INDArray>();
        for (org.nd4j.linalg.dataset.DataSet data : dataSets) {
            inputs.add(data.getFeatureMatrix());
            labels.add(data.getLabels());
        }
        if (inputs.isEmpty()) {
            this.overshot = true;
            return this.last;
        }
        this.last = ret = new org.nd4j.linalg.dataset.DataSet(Nd4j.vstack((INDArray[])inputs.toArray(new INDArray[0])), Nd4j.vstack((INDArray[])labels.toArray(new INDArray[0])));
        if (this.preProcessor != null) {
            this.preProcessor.preProcess((DataSet)ret);
        }
        return ret;
    }

    private org.nd4j.linalg.dataset.DataSet getDataSet(Collection<Writable> record) {
        ArrayList<Writable> currList = record instanceof List ? (ArrayList<Writable>)record : new ArrayList<Writable>(record);
        if (this.numPossibleLabels >= 1 && this.labelIndex < 0) {
            this.labelIndex = record.size() - 1;
        }
        INDArray label = null;
        INDArray featureVector = Nd4j.create((int)(this.labelIndex >= 0 ? currList.size() - 1 : currList.size()));
        int featureCount = 0;
        for (int j = 0; j < currList.size(); ++j) {
            Writable current = (Writable)currList.get(j);
            if (current.toString().isEmpty()) continue;
            if (this.labelIndex >= 0 && j == this.labelIndex) {
                if (this.converter != null) {
                    try {
                        current = this.converter.convert(current);
                    }
                    catch (WritableConverterException e) {
                        e.printStackTrace();
                    }
                }
                if (this.numPossibleLabels < 1) {
                    throw new IllegalStateException("Number of possible labels invalid, must be >= 1");
                }
                if (this.regression) {
                    label = Nd4j.scalar((double)current.toDouble());
                    continue;
                }
                int curr = current.toInt();
                if (curr >= this.numPossibleLabels) {
                    --curr;
                }
                label = FeatureUtil.toOutcomeVector((int)curr, (int)this.numPossibleLabels);
                continue;
            }
            featureVector.putScalar(featureCount++, current.toDouble());
        }
        return new org.nd4j.linalg.dataset.DataSet(featureVector, this.labelIndex >= 0 ? label : featureVector);
    }

    public int totalExamples() {
        throw new UnsupportedOperationException();
    }

    public int inputColumns() {
        if (this.last == null) {
            org.nd4j.linalg.dataset.DataSet next;
            this.last = next = this.next();
            this.useCurrent = true;
            return next.numInputs();
        }
        return this.last.numInputs();
    }

    public int totalOutcomes() {
        if (this.last == null) {
            org.nd4j.linalg.dataset.DataSet next;
            this.last = next = this.next();
            this.useCurrent = true;
            return next.numOutcomes();
        }
        return this.last.numOutcomes();
    }

    public void reset() {
        if (this.recordReader instanceof RecordReader) {
            this.recordReader.reset();
        } else if (this.recordReader instanceof SequenceRecordReader) {
            throw new UnsupportedOperationException("Reset not supported for SequenceRecordReader type.");
        }
    }

    public int batch() {
        return this.batchSize;
    }

    public int cursor() {
        throw new UnsupportedOperationException();
    }

    public int numExamples() {
        throw new UnsupportedOperationException();
    }

    public void setPreProcessor(DataSetPreProcessor preProcessor) {
        this.preProcessor = preProcessor;
    }

    @Override
    public boolean hasNext() {
        return this.recordReader.hasNext() || this.overshot;
    }

    @Override
    public org.nd4j.linalg.dataset.DataSet next() {
        return this.next(this.batchSize);
    }

    @Override
    public void remove() {
        throw new UnsupportedOperationException();
    }
}

