/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.datasets.datavec;

import java.beans.ConstructorProperties;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import org.apache.commons.lang3.ArrayUtils;
import org.datavec.api.records.Record;
import org.datavec.api.records.SequenceRecord;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.metadata.RecordMetaDataComposableMap;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.writable.Writable;
import org.datavec.common.data.NDArrayWritable;
import org.deeplearning4j.berkeley.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class RecordReaderMultiDataSetIterator
implements MultiDataSetIterator {
    private int batchSize;
    private AlignmentMode alignmentMode;
    private Map<String, RecordReader> recordReaders = new HashMap<String, RecordReader>();
    private Map<String, SequenceRecordReader> sequenceRecordReaders = new HashMap<String, SequenceRecordReader>();
    private List<SubsetDetails> inputs = new ArrayList<SubsetDetails>();
    private List<SubsetDetails> outputs = new ArrayList<SubsetDetails>();
    private boolean collectMetaData = false;
    private MultiDataSetPreProcessor preProcessor;

    private RecordReaderMultiDataSetIterator(Builder builder) {
        this.batchSize = builder.batchSize;
        this.alignmentMode = builder.alignmentMode;
        this.recordReaders = builder.recordReaders;
        this.sequenceRecordReaders = builder.sequenceRecordReaders;
        this.inputs.addAll(builder.inputs);
        this.outputs.addAll(builder.outputs);
    }

    public org.nd4j.linalg.dataset.api.MultiDataSet next() {
        return this.next(this.batchSize);
    }

    public void remove() {
        throw new UnsupportedOperationException("Remove not supported");
    }

    public org.nd4j.linalg.dataset.api.MultiDataSet next(int num) {
        RecordMetaDataComposableMap map;
        Record r;
        int i;
        ArrayList<List> writables;
        RecordReader rr;
        if (!this.hasNext()) {
            throw new NoSuchElementException("No next elements");
        }
        HashMap<String, List<List<Writable>>> nextRRVals = new HashMap<String, List<List<Writable>>>();
        HashMap<String, List<List<List<Writable>>>> nextSeqRRVals = new HashMap<String, List<List<List<Writable>>>>();
        ArrayList<RecordMetaDataComposableMap> nextMetas = this.collectMetaData ? new ArrayList<RecordMetaDataComposableMap>() : null;
        for (Map.Entry<String, RecordReader> entry : this.recordReaders.entrySet()) {
            rr = entry.getValue();
            writables = new ArrayList<List>(num);
            for (i = 0; i < num && rr.hasNext(); ++i) {
                List record;
                if (this.collectMetaData) {
                    r = rr.nextRecord();
                    record = r.getRecord();
                    if (nextMetas.size() <= i) {
                        nextMetas.add(new RecordMetaDataComposableMap(new HashMap()));
                    }
                    map = (RecordMetaDataComposableMap)nextMetas.get(i);
                    map.getMeta().put(entry.getKey(), r.getMetaData());
                } else {
                    record = rr.next();
                }
                writables.add(record);
            }
            nextRRVals.put(entry.getKey(), writables);
        }
        for (Map.Entry<String, RecordReader> entry : this.sequenceRecordReaders.entrySet()) {
            rr = (SequenceRecordReader)entry.getValue();
            writables = new ArrayList(num);
            for (i = 0; i < num && rr.hasNext(); ++i) {
                List sequence;
                if (this.collectMetaData) {
                    r = rr.nextSequence();
                    sequence = r.getSequenceRecord();
                    if (nextMetas.size() <= i) {
                        nextMetas.add(new RecordMetaDataComposableMap(new HashMap()));
                    }
                    map = (RecordMetaDataComposableMap)nextMetas.get(i);
                    map.getMeta().put(entry.getKey(), r.getMetaData());
                } else {
                    sequence = rr.sequenceRecord();
                }
                writables.add(sequence);
            }
            nextSeqRRVals.put(entry.getKey(), writables);
        }
        return this.nextMultiDataSet(nextRRVals, nextSeqRRVals, nextMetas);
    }

    /*
     * WARNING - void declaration
     */
    private org.nd4j.linalg.dataset.api.MultiDataSet nextMultiDataSet(Map<String, List<List<Writable>>> nextRRVals, Map<String, List<List<List<Writable>>>> nextSeqRRVals, List<RecordMetaDataComposableMap> nextMetas) {
        void var8_22;
        int n;
        int minExamples = Integer.MAX_VALUE;
        for (List<List<Writable>> list : nextRRVals.values()) {
            minExamples = Math.min(minExamples, list.size());
        }
        for (List<List<Object>> list : nextSeqRRVals.values()) {
            minExamples = Math.min(minExamples, list.size());
        }
        if (minExamples == Integer.MAX_VALUE) {
            throw new RuntimeException("Error occurred during data set generation: no readers?");
        }
        int[] longestSequence = null;
        if (this.alignmentMode == AlignmentMode.ALIGN_END) {
            longestSequence = new int[minExamples];
            for (Map.Entry<String, List<List<List<Writable>>>> entry : nextSeqRRVals.entrySet()) {
                List<List<List<Writable>>> list = entry.getValue();
                for (int i = 0; i < list.size() && i < minExamples; ++i) {
                    longestSequence[i] = Math.max(longestSequence[i], list.get(i).size());
                }
            }
        }
        int n2 = -1;
        if (this.alignmentMode != AlignmentMode.EQUAL_LENGTH) {
            for (Map.Entry<String, List<List<List<Writable>>>> entry : nextSeqRRVals.entrySet()) {
                List<List<List<Writable>>> list = entry.getValue();
                for (Object c : list) {
                    n = Math.max(n, c.size());
                }
            }
        }
        INDArray[] iNDArrayArray = new INDArray[this.inputs.size()];
        INDArray[] iNDArrayArray2 = new INDArray[this.inputs.size()];
        boolean inputMasks = false;
        int i = 0;
        for (SubsetDetails d : this.inputs) {
            List<List<List<Writable>>> list;
            if (nextRRVals.containsKey(d.readerName)) {
                list = nextRRVals.get(d.readerName);
                iNDArrayArray[i] = this.convertWritables(list, minExamples, d);
            } else {
                list = nextSeqRRVals.get(d.readerName);
                Iterator<SubsetDetails> p = this.convertWritablesSequence(list, minExamples, n, d, longestSequence);
                iNDArrayArray[i] = (INDArray)p.getFirst();
                iNDArrayArray2[i] = (INDArray)p.getSecond();
                if (iNDArrayArray2[i] != null) {
                    inputMasks = true;
                }
            }
            ++i;
        }
        if (!inputMasks) {
            Object var8_21 = null;
        }
        INDArray[] outputArrs = new INDArray[this.outputs.size()];
        INDArray[] outputArrMasks = new INDArray[this.outputs.size()];
        boolean outputMasks = false;
        i = 0;
        for (SubsetDetails d : this.outputs) {
            List<List<List<Writable>>> list;
            if (nextRRVals.containsKey(d.readerName)) {
                list = nextRRVals.get(d.readerName);
                outputArrs[i] = this.convertWritables(list, minExamples, d);
            } else {
                list = nextSeqRRVals.get(d.readerName);
                Pair<INDArray, INDArray> p = this.convertWritablesSequence(list, minExamples, n, d, longestSequence);
                outputArrs[i] = (INDArray)p.getFirst();
                outputArrMasks[i] = (INDArray)p.getSecond();
                if (outputArrMasks[i] != null) {
                    outputMasks = true;
                }
            }
            ++i;
        }
        if (!outputMasks) {
            outputArrMasks = null;
        }
        MultiDataSet mds = new MultiDataSet(iNDArrayArray, outputArrs, (INDArray[])var8_22, outputArrMasks);
        if (this.collectMetaData) {
            mds.setExampleMetaData(nextMetas);
        }
        if (this.preProcessor != null) {
            this.preProcessor.preProcess((org.nd4j.linalg.dataset.api.MultiDataSet)mds);
        }
        return mds;
    }

    private INDArray convertWritables(List<List<Writable>> list, int minValues, SubsetDetails details) {
        INDArray arr;
        int[] shape;
        INDArray temp;
        if (details.entireReader) {
            if (list.get(0).size() == 1 && list.get(0).get(0) instanceof NDArrayWritable) {
                temp = ((NDArrayWritable)list.get(0).get(0)).get();
                shape = ArrayUtils.clone((int[])temp.shape());
                shape[0] = minValues;
                arr = Nd4j.create((int[])shape);
            } else {
                arr = Nd4j.create((int)minValues, (int)list.get(0).size());
            }
        } else if (details.oneHot) {
            arr = Nd4j.zeros((int)minValues, (int)details.oneHotNumClasses);
        } else if (details.subsetStart == details.subsetEndInclusive && list.get(0).get(details.subsetStart) instanceof NDArrayWritable) {
            temp = ((NDArrayWritable)list.get(0).get(details.subsetStart)).get();
            shape = ArrayUtils.clone((int[])temp.shape());
            shape[0] = minValues;
            arr = Nd4j.create((int[])shape);
        } else {
            arr = Nd4j.create((int)minValues, (int)(details.subsetEndInclusive - details.subsetStart + 1));
        }
        for (int i = 0; i < minValues; ++i) {
            List<Writable> c = list.get(i);
            if (details.entireReader) {
                int j = 0;
                for (Writable w : c) {
                    try {
                        arr.putScalar(i, j, w.toDouble());
                    }
                    catch (UnsupportedOperationException e) {
                        if (w instanceof NDArrayWritable) {
                            this.putExample(arr, ((NDArrayWritable)w).get(), i);
                        }
                        throw e;
                    }
                    ++j;
                }
                continue;
            }
            if (details.oneHot) {
                Writable w = c.get(details.subsetStart);
                arr.putScalar(i, w.toInt(), 1.0);
                continue;
            }
            if (details.subsetStart == details.subsetEndInclusive && c.get(details.subsetStart) instanceof NDArrayWritable) {
                this.putExample(arr, ((NDArrayWritable)c.get(details.subsetStart)).get(), i);
                continue;
            }
            Iterator<Writable> iter = c.iterator();
            for (int j = 0; j < details.subsetStart; ++j) {
                iter.next();
            }
            int k = 0;
            for (int j = details.subsetStart; j <= details.subsetEndInclusive; ++j) {
                Writable w = iter.next();
                try {
                    arr.putScalar(i, k, w.toDouble());
                }
                catch (UnsupportedOperationException e) {
                    if (w instanceof NDArrayWritable) {
                        this.putExample(arr, ((NDArrayWritable)w).get(), i);
                    }
                    throw e;
                }
                ++k;
            }
        }
        return arr;
    }

    private void putExample(INDArray arr, INDArray singleExample, int exampleIdx) {
        switch (arr.rank()) {
            case 2: {
                arr.put(new INDArrayIndex[]{NDArrayIndex.point((int)exampleIdx), NDArrayIndex.all()}, singleExample);
                break;
            }
            case 3: {
                arr.put(new INDArrayIndex[]{NDArrayIndex.point((int)exampleIdx), NDArrayIndex.all(), NDArrayIndex.all()}, singleExample);
                break;
            }
            case 4: {
                arr.put(new INDArrayIndex[]{NDArrayIndex.point((int)exampleIdx), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()}, singleExample);
                break;
            }
            default: {
                throw new RuntimeException("Unexpected rank: " + arr.rank());
            }
        }
    }

    private Pair<INDArray, INDArray> convertWritablesSequence(List<List<List<Writable>>> list, int minValues, int maxTSLength, SubsetDetails details, int[] longestSequence) {
        INDArray arr;
        if (maxTSLength == -1) {
            maxTSLength = list.get(0).size();
        }
        if (details.entireReader) {
            int size = list.get(0).iterator().next().size();
            arr = Nd4j.create((int[])new int[]{minValues, size, maxTSLength}, (char)'f');
        } else {
            arr = details.oneHot ? Nd4j.create((int[])new int[]{minValues, details.oneHotNumClasses, maxTSLength}, (char)'f') : Nd4j.create((int[])new int[]{minValues, details.subsetEndInclusive - details.subsetStart + 1, maxTSLength}, (char)'f');
        }
        boolean needMaskArray = false;
        for (List<List<Writable>> c : list) {
            if (c.size() >= maxTSLength) continue;
            needMaskArray = true;
        }
        INDArray maskArray = needMaskArray ? Nd4j.ones((int)minValues, (int)maxTSLength) : null;
        for (int i = 0; i < minValues; ++i) {
            List<List<Writable>> sequence = list.get(i);
            int startOffset = this.alignmentMode == AlignmentMode.ALIGN_START || this.alignmentMode == AlignmentMode.EQUAL_LENGTH ? 0 : longestSequence[i] - sequence.size();
            int t = 0;
            for (List<Writable> timeStep : sequence) {
                int j;
                Iterator<Writable> iter;
                int k = startOffset + t++;
                if (details.entireReader) {
                    iter = timeStep.iterator();
                    j = 0;
                    while (iter.hasNext()) {
                        Writable w = iter.next();
                        try {
                            arr.putScalar(i, j, k, w.toDouble());
                        }
                        catch (UnsupportedOperationException e) {
                            if (w instanceof NDArrayWritable) {
                                arr.get(new INDArrayIndex[]{NDArrayIndex.point((int)i), NDArrayIndex.all(), NDArrayIndex.point((int)k)}).putRow(0, ((NDArrayWritable)w).get());
                            }
                            throw e;
                        }
                        ++j;
                    }
                    continue;
                }
                if (details.oneHot) {
                    Writable w = null;
                    if (timeStep instanceof List) {
                        w = timeStep.get(details.subsetStart);
                    } else {
                        Iterator<Writable> iter2 = timeStep.iterator();
                        for (int x = 0; x <= details.subsetStart; ++x) {
                            w = iter2.next();
                        }
                    }
                    int classIdx = w.toInt();
                    arr.putScalar(i, classIdx, k, 1.0);
                    continue;
                }
                iter = timeStep.iterator();
                for (j = 0; j < details.subsetStart; ++j) {
                    iter.next();
                }
                int l = 0;
                for (int j2 = details.subsetStart; j2 <= details.subsetEndInclusive; ++j2) {
                    Writable w = iter.next();
                    try {
                        arr.putScalar(i, l++, k, w.toDouble());
                        continue;
                    }
                    catch (UnsupportedOperationException e) {
                        if (w instanceof NDArrayWritable) {
                            arr.get(new INDArrayIndex[]{NDArrayIndex.point((int)i), NDArrayIndex.all(), NDArrayIndex.point((int)k)}).putRow(0, ((NDArrayWritable)w).get().get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)details.subsetStart, (int)(details.subsetEndInclusive + 1))}));
                            continue;
                        }
                        throw e;
                    }
                }
            }
            if (!needMaskArray) continue;
            if (this.alignmentMode == AlignmentMode.ALIGN_END) {
                for (int t2 = 0; t2 < startOffset; ++t2) {
                    maskArray.putScalar(i, t2, 0.0);
                }
            }
            if (this.alignmentMode != AlignmentMode.ALIGN_START) continue;
            for (int t2 = t; t2 < maxTSLength; ++t2) {
                maskArray.putScalar(i, t2, 0.0);
            }
        }
        return new Pair((Object)arr, (Object)maskArray);
    }

    public void setPreProcessor(MultiDataSetPreProcessor preProcessor) {
        this.preProcessor = preProcessor;
    }

    public boolean resetSupported() {
        return true;
    }

    public boolean asyncSupported() {
        return true;
    }

    public void reset() {
        for (RecordReader recordReader : this.recordReaders.values()) {
            recordReader.reset();
        }
        for (SequenceRecordReader sequenceRecordReader : this.sequenceRecordReaders.values()) {
            sequenceRecordReader.reset();
        }
    }

    public boolean hasNext() {
        for (RecordReader recordReader : this.recordReaders.values()) {
            if (recordReader.hasNext()) continue;
            return false;
        }
        for (SequenceRecordReader sequenceRecordReader : this.sequenceRecordReaders.values()) {
            if (sequenceRecordReader.hasNext()) continue;
            return false;
        }
        return true;
    }

    public org.nd4j.linalg.dataset.api.MultiDataSet loadFromMetaData(RecordMetaData recordMetaData) throws IOException {
        return this.loadFromMetaData(Collections.singletonList(recordMetaData));
    }

    public org.nd4j.linalg.dataset.api.MultiDataSet loadFromMetaData(List<RecordMetaData> list) throws IOException {
        Record r;
        ArrayList<List> writables;
        Object fromMeta;
        Object m2;
        ArrayList thisRRMeta;
        RecordReader rr;
        HashMap<String, List<List<Writable>>> nextRRVals = new HashMap<String, List<List<Writable>>>();
        HashMap<String, List<List<List<Writable>>>> nextSeqRRVals = new HashMap<String, List<List<List<Writable>>>>();
        ArrayList<RecordMetaDataComposableMap> nextMetas = this.collectMetaData ? new ArrayList<RecordMetaDataComposableMap>() : null;
        for (Map.Entry<String, RecordReader> entry : this.recordReaders.entrySet()) {
            rr = entry.getValue();
            thisRRMeta = new ArrayList();
            for (RecordMetaData m : list) {
                m2 = (RecordMetaDataComposableMap)m;
                thisRRMeta.add(m2.getMeta().get(entry.getKey()));
            }
            fromMeta = rr.loadFromMetaData(thisRRMeta);
            writables = new ArrayList<List>(list.size());
            m2 = fromMeta.iterator();
            while (m2.hasNext()) {
                r = (Record)m2.next();
                writables.add(r.getRecord());
            }
            nextRRVals.put(entry.getKey(), writables);
        }
        for (Map.Entry<String, RecordReader> entry : this.sequenceRecordReaders.entrySet()) {
            rr = (SequenceRecordReader)entry.getValue();
            thisRRMeta = new ArrayList();
            for (RecordMetaData m : list) {
                m2 = (RecordMetaDataComposableMap)m;
                thisRRMeta.add(m2.getMeta().get(entry.getKey()));
            }
            fromMeta = rr.loadSequenceFromMetaData(thisRRMeta);
            writables = new ArrayList(list.size());
            Iterator iterator = fromMeta.iterator();
            while (iterator.hasNext()) {
                r = (SequenceRecord)iterator.next();
                writables.add(r.getSequenceRecord());
            }
            nextSeqRRVals.put(entry.getKey(), writables);
        }
        return this.nextMultiDataSet(nextRRVals, nextSeqRRVals, nextMetas);
    }

    public boolean isCollectMetaData() {
        return this.collectMetaData;
    }

    public void setCollectMetaData(boolean collectMetaData) {
        this.collectMetaData = collectMetaData;
    }

    private static class SubsetDetails {
        private final String readerName;
        private final boolean entireReader;
        private final boolean oneHot;
        private final int oneHotNumClasses;
        private final int subsetStart;
        private final int subsetEndInclusive;

        @ConstructorProperties(value={"readerName", "entireReader", "oneHot", "oneHotNumClasses", "subsetStart", "subsetEndInclusive"})
        public SubsetDetails(String readerName, boolean entireReader, boolean oneHot, int oneHotNumClasses, int subsetStart, int subsetEndInclusive) {
            this.readerName = readerName;
            this.entireReader = entireReader;
            this.oneHot = oneHot;
            this.oneHotNumClasses = oneHotNumClasses;
            this.subsetStart = subsetStart;
            this.subsetEndInclusive = subsetEndInclusive;
        }
    }

    public static class Builder {
        private int batchSize;
        private AlignmentMode alignmentMode = AlignmentMode.EQUAL_LENGTH;
        private Map<String, RecordReader> recordReaders = new HashMap<String, RecordReader>();
        private Map<String, SequenceRecordReader> sequenceRecordReaders = new HashMap<String, SequenceRecordReader>();
        private List<SubsetDetails> inputs = new ArrayList<SubsetDetails>();
        private List<SubsetDetails> outputs = new ArrayList<SubsetDetails>();

        public Builder(int batchSize) {
            this.batchSize = batchSize;
        }

        public Builder addReader(String readerName, RecordReader recordReader) {
            this.recordReaders.put(readerName, recordReader);
            return this;
        }

        public Builder addSequenceReader(String seqReaderName, SequenceRecordReader seqRecordReader) {
            this.sequenceRecordReaders.put(seqReaderName, seqRecordReader);
            return this;
        }

        public Builder sequenceAlignmentMode(AlignmentMode alignmentMode) {
            this.alignmentMode = alignmentMode;
            return this;
        }

        public Builder addInput(String readerName) {
            this.inputs.add(new SubsetDetails(readerName, true, false, -1, -1, -1));
            return this;
        }

        public Builder addInput(String readerName, int columnFirst, int columnLast) {
            this.inputs.add(new SubsetDetails(readerName, false, false, -1, columnFirst, columnLast));
            return this;
        }

        public Builder addInputOneHot(String readerName, int column, int numClasses) {
            this.inputs.add(new SubsetDetails(readerName, false, true, numClasses, column, -1));
            return this;
        }

        public Builder addOutput(String readerName) {
            this.outputs.add(new SubsetDetails(readerName, true, false, -1, -1, -1));
            return this;
        }

        public Builder addOutput(String readerName, int columnFirst, int columnLast) {
            this.outputs.add(new SubsetDetails(readerName, false, false, -1, columnFirst, columnLast));
            return this;
        }

        public Builder addOutputOneHot(String readerName, int column, int numClasses) {
            this.outputs.add(new SubsetDetails(readerName, false, true, numClasses, column, -1));
            return this;
        }

        public RecordReaderMultiDataSetIterator build() {
            if (this.recordReaders.isEmpty() && this.sequenceRecordReaders.isEmpty()) {
                throw new IllegalStateException("Cannot construct RecordReaderMultiDataSetIterator with no readers");
            }
            if (this.batchSize <= 0) {
                throw new IllegalStateException("Cannot construct RecordReaderMultiDataSetIterator with batch size <= 0");
            }
            if (this.inputs.isEmpty() && this.outputs.isEmpty()) {
                throw new IllegalStateException("Cannot construct RecordReaderMultiDataSetIterator with no inputs/outputs");
            }
            for (SubsetDetails ssd : this.inputs) {
                if (this.recordReaders.containsKey(ssd.readerName) || this.sequenceRecordReaders.containsKey(ssd.readerName)) continue;
                throw new IllegalStateException("Invalid input name: \"" + ssd.readerName + "\" - no reader found with this name");
            }
            for (SubsetDetails ssd : this.outputs) {
                if (this.recordReaders.containsKey(ssd.readerName) || this.sequenceRecordReaders.containsKey(ssd.readerName)) continue;
                throw new IllegalStateException("Invalid output name: \"" + ssd.readerName + "\" - no reader found with this name");
            }
            return new RecordReaderMultiDataSetIterator(this);
        }
    }

    public static enum AlignmentMode {
        EQUAL_LENGTH,
        ALIGN_START,
        ALIGN_END;

    }
}

