/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.datasets.canova;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.NoSuchElementException;
import org.canova.api.records.reader.SequenceRecordReader;
import org.canova.api.writable.Writable;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;

public class SequenceRecordReaderDataSetIterator
implements DataSetIterator {
    private SequenceRecordReader recordReader;
    private SequenceRecordReader labelsReader;
    private int miniBatchSize = 10;
    private final boolean regression;
    private int labelIndex = -1;
    private final int numPossibleLabels;
    private int cursor = 0;
    private int inputColumns = -1;
    private int totalOutcomes = -1;
    private boolean useStored = false;
    private org.nd4j.linalg.dataset.DataSet stored = null;
    private DataSetPreProcessor preProcessor;

    public SequenceRecordReaderDataSetIterator(SequenceRecordReader featuresReader, SequenceRecordReader labels, int miniBatchSize, int numPossibleLabels, boolean regression) {
        this.recordReader = featuresReader;
        this.labelsReader = labels;
        this.miniBatchSize = miniBatchSize;
        this.numPossibleLabels = numPossibleLabels;
        this.regression = regression;
    }

    @Override
    public boolean hasNext() {
        return this.recordReader.hasNext();
    }

    @Override
    public org.nd4j.linalg.dataset.DataSet next() {
        return this.next(this.miniBatchSize);
    }

    public org.nd4j.linalg.dataset.DataSet next(int num) {
        if (this.useStored) {
            this.useStored = false;
            org.nd4j.linalg.dataset.DataSet temp = this.stored;
            this.stored = null;
            if (this.preProcessor != null) {
                this.preProcessor.preProcess((DataSet)temp);
            }
            return temp;
        }
        if (!this.hasNext()) {
            throw new NoSuchElementException();
        }
        ArrayList<INDArray> featureList = new ArrayList<INDArray>(num);
        ArrayList<INDArray> labelList = new ArrayList<INDArray>(num);
        for (int i = 0; i < num && this.hasNext(); ++i) {
            Collection featureSequence = this.recordReader.sequenceRecord();
            Collection labelSequence = this.labelsReader.sequenceRecord();
            INDArray features = this.getFeatures(featureSequence);
            INDArray labels = this.getLabels(labelSequence);
            featureList.add(features);
            labelList.add(labels);
        }
        int[] featureShape = new int[]{featureList.size(), ((INDArray)featureList.get(0)).size(1), ((INDArray)featureList.get(0)).size(0)};
        int[] labelShape = new int[]{labelList.size(), ((INDArray)labelList.get(0)).size(1), ((INDArray)labelList.get(0)).size(0)};
        INDArray featuresOut = Nd4j.create((int[])featureShape);
        INDArray labelsOut = Nd4j.create((int[])labelShape);
        for (int i = 0; i < featureList.size(); ++i) {
            featuresOut.tensorAlongDimension(i, new int[]{1, 2}).assign((INDArray)featureList.get(i));
            labelsOut.tensorAlongDimension(i, new int[]{1, 2}).assign((INDArray)labelList.get(i));
        }
        this.cursor += featureList.size();
        if (this.inputColumns == -1) {
            this.inputColumns = featuresOut.size(1);
        }
        if (this.totalOutcomes == -1) {
            this.totalOutcomes = labelsOut.size(1);
        }
        org.nd4j.linalg.dataset.DataSet ds = new org.nd4j.linalg.dataset.DataSet(featuresOut, labelsOut);
        if (this.preProcessor != null) {
            this.preProcessor.preProcess((DataSet)ds);
        }
        return ds;
    }

    public int totalExamples() {
        throw new UnsupportedOperationException("Not supported");
    }

    public int inputColumns() {
        if (this.inputColumns != -1) {
            return this.inputColumns;
        }
        this.preLoad();
        return this.inputColumns;
    }

    public int totalOutcomes() {
        if (this.totalOutcomes != -1) {
            return this.totalOutcomes;
        }
        this.preLoad();
        return this.totalOutcomes;
    }

    private void preLoad() {
        this.stored = this.next();
        this.useStored = true;
        this.inputColumns = this.stored.getFeatureMatrix().size(1);
        this.totalOutcomes = this.stored.getLabels().size(1);
    }

    public void reset() {
    }

    public int batch() {
        return this.miniBatchSize;
    }

    public int cursor() {
        return this.cursor;
    }

    public int numExamples() {
        throw new UnsupportedOperationException("Not supported");
    }

    public void setPreProcessor(DataSetPreProcessor preProcessor) {
        this.preProcessor = preProcessor;
    }

    @Override
    public void remove() {
        throw new UnsupportedOperationException("Remove not supported for this iterator");
    }

    private INDArray getFeatures(Collection<Collection<Writable>> features) {
        int[] shape = new int[2];
        shape[0] = features.size();
        Iterator<Collection<Writable>> iter = features.iterator();
        int i = 0;
        INDArray out = null;
        while (iter.hasNext()) {
            Collection<Writable> step = iter.next();
            if (i == 0) {
                shape[1] = step.size();
                out = Nd4j.create((int[])shape);
            }
            Iterator<Writable> timeStepIter = step.iterator();
            int f = 0;
            while (timeStepIter.hasNext()) {
                Writable current = timeStepIter.next();
                out.put(i, f++, (Number)current.toDouble());
            }
            ++i;
        }
        return out;
    }

    private INDArray getLabels(Collection<Collection<Writable>> labels) {
        int[] shape = new int[2];
        shape[0] = labels.size();
        Iterator<Collection<Writable>> iter = labels.iterator();
        int i = 0;
        INDArray out = null;
        while (iter.hasNext()) {
            Collection<Writable> step = iter.next();
            if (i == 0) {
                shape[1] = this.regression ? step.size() : this.numPossibleLabels;
                out = Nd4j.create((int[])shape);
            }
            Iterator<Writable> timeStepIter = step.iterator();
            int f = 0;
            if (this.regression) {
                while (timeStepIter.hasNext()) {
                    Writable current = timeStepIter.next();
                    out.put(f++, i, (Number)current.toDouble());
                }
            } else {
                Writable value = timeStepIter.next();
                int idx = value.toInt();
                INDArray line = FeatureUtil.toOutcomeVector((int)idx, (int)this.numPossibleLabels);
                out.getRow(i).assign(line);
            }
            ++i;
        }
        return out;
    }
}

