/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.datasets.iterator;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Random;
import lombok.NonNull;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.RandomOp;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Triple;

public class RandomMultiDataSetIterator
implements MultiDataSetIterator {
    private final int numMiniBatches;
    private final List<Triple<long[], Character, Values>> features;
    private final List<Triple<long[], Character, Values>> labels;
    private MultiDataSetPreProcessor preProcessor;
    private int position;

    public RandomMultiDataSetIterator(int numMiniBatches, @NonNull List<Triple<long[], Character, Values>> features, @NonNull List<Triple<long[], Character, Values>> labels) {
        if (features == null) {
            throw new NullPointerException("features is marked @NonNull but is null");
        }
        if (labels == null) {
            throw new NullPointerException("labels is marked @NonNull but is null");
        }
        Preconditions.checkArgument((numMiniBatches > 0 ? 1 : 0) != 0, (String)"Number of minibatches must be positive: got %s", (int)numMiniBatches);
        Preconditions.checkArgument((features.size() > 0 ? 1 : 0) != 0, (String)"No features defined");
        Preconditions.checkArgument((labels.size() > 0 ? 1 : 0) != 0, (String)"No labels defined");
        this.numMiniBatches = numMiniBatches;
        this.features = features;
        this.labels = labels;
    }

    public org.nd4j.linalg.dataset.api.MultiDataSet next(int i) {
        return this.next();
    }

    public boolean resetSupported() {
        return true;
    }

    public boolean asyncSupported() {
        return true;
    }

    public void reset() {
        this.position = 0;
    }

    public boolean hasNext() {
        return this.position < this.numMiniBatches;
    }

    public org.nd4j.linalg.dataset.api.MultiDataSet next() {
        Triple<long[], Character, Values> t;
        int i;
        if (!this.hasNext()) {
            throw new NoSuchElementException("No next element");
        }
        INDArray[] f = new INDArray[this.features.size()];
        INDArray[] l = new INDArray[this.labels.size()];
        for (i = 0; i < f.length; ++i) {
            t = this.features.get(i);
            f[i] = RandomMultiDataSetIterator.generate((long[])t.getFirst(), ((Character)t.getSecond()).charValue(), (Values)((Object)t.getThird()));
        }
        for (i = 0; i < l.length; ++i) {
            t = this.labels.get(i);
            l[i] = RandomMultiDataSetIterator.generate((long[])t.getFirst(), ((Character)t.getSecond()).charValue(), (Values)((Object)t.getThird()));
        }
        ++this.position;
        MultiDataSet mds = new MultiDataSet(f, l);
        if (this.preProcessor != null) {
            this.preProcessor.preProcess((org.nd4j.linalg.dataset.api.MultiDataSet)mds);
        }
        return mds;
    }

    public void remove() {
        throw new UnsupportedOperationException("Not supported");
    }

    public static INDArray generate(long[] shape, Values values) {
        return RandomMultiDataSetIterator.generate(shape, Nd4j.order().charValue(), values);
    }

    public static INDArray generate(long[] shape, char order, Values values) {
        switch (values) {
            case RANDOM_UNIFORM: {
                return Nd4j.rand((INDArray)Nd4j.createUninitialized((long[])shape, (char)order));
            }
            case RANDOM_NORMAL: {
                return Nd4j.randn((INDArray)Nd4j.createUninitialized((long[])shape, (char)order));
            }
            case ONE_HOT: {
                Random r = new Random(Nd4j.getRandom().nextLong());
                INDArray out = Nd4j.create((long[])shape, (char)order);
                if (shape.length == 1) {
                    out.putScalar((long)r.nextInt((int)shape[0]), 1.0);
                } else if (shape.length == 2) {
                    int i = 0;
                    while ((long)i < shape[0]) {
                        out.putScalar((long)i, (long)r.nextInt((int)shape[1]), 1.0);
                        ++i;
                    }
                } else if (shape.length == 3) {
                    int i = 0;
                    while ((long)i < shape[0]) {
                        int j = 0;
                        while ((long)j < shape[2]) {
                            out.putScalar((long)i, (long)r.nextInt((int)shape[1]), (long)j, 1.0);
                            ++j;
                        }
                        ++i;
                    }
                } else if (shape.length == 4) {
                    int i = 0;
                    while ((long)i < shape[0]) {
                        int j = 0;
                        while ((long)j < shape[2]) {
                            int k = 0;
                            while ((long)k < shape[3]) {
                                out.putScalar((long)i, (long)r.nextInt((int)shape[1]), (long)j, (long)k, 1.0);
                                ++k;
                            }
                            ++j;
                        }
                        ++i;
                    }
                } else if (shape.length == 5) {
                    int i = 0;
                    while ((long)i < shape[0]) {
                        int j = 0;
                        while ((long)j < shape[2]) {
                            int k = 0;
                            while ((long)k < shape[3]) {
                                int l = 0;
                                while ((long)l < shape[4]) {
                                    out.putScalar(new int[]{i, r.nextInt((int)shape[1]), j, k, l++}, 1.0);
                                }
                                ++k;
                            }
                            ++j;
                        }
                        ++i;
                    }
                } else {
                    throw new RuntimeException("Not supported: rank 6+ arrays. Shape: " + Arrays.toString(shape));
                }
                return out;
            }
            case ZEROS: {
                return Nd4j.create((long[])shape, (char)order);
            }
            case ONES: {
                return Nd4j.createUninitialized((long[])shape, (char)order).assign((Number)1.0);
            }
            case BINARY: {
                return Nd4j.getExecutioner().exec((RandomOp)new BernoulliDistribution(Nd4j.createUninitialized((long[])shape, (char)order), 0.5));
            }
            case INTEGER_0_10: {
                return Transforms.floor((INDArray)Nd4j.rand((long[])shape).muli((Number)10), (boolean)false);
            }
            case INTEGER_0_100: {
                return Transforms.floor((INDArray)Nd4j.rand((long[])shape).muli((Number)100), (boolean)false);
            }
            case INTEGER_0_1000: {
                return Transforms.floor((INDArray)Nd4j.rand((long[])shape).muli((Number)1000), (boolean)false);
            }
            case INTEGER_0_10000: {
                return Transforms.floor((INDArray)Nd4j.rand((long[])shape).muli((Number)10000), (boolean)false);
            }
            case INTEGER_0_100000: {
                return Transforms.floor((INDArray)Nd4j.rand((long[])shape).muli((Number)100000), (boolean)false);
            }
        }
        throw new RuntimeException("Unknown enum value: " + (Object)((Object)values));
    }

    public MultiDataSetPreProcessor getPreProcessor() {
        return this.preProcessor;
    }

    public void setPreProcessor(MultiDataSetPreProcessor preProcessor) {
        this.preProcessor = preProcessor;
    }

    public static class Builder {
        private int numMiniBatches;
        private List<Triple<long[], Character, Values>> features = new ArrayList<Triple<long[], Character, Values>>();
        private List<Triple<long[], Character, Values>> labels = new ArrayList<Triple<long[], Character, Values>>();

        public Builder(int numMiniBatches) {
            this.numMiniBatches = numMiniBatches;
        }

        public Builder addFeatures(long[] shape, Values values) {
            return this.addFeatures(shape, 'c', values);
        }

        public Builder addFeatures(long[] shape, char order, Values values) {
            this.features.add((Triple<long[], Character, Values>)new Triple((Object)shape, (Object)Character.valueOf(order), (Object)values));
            return this;
        }

        public Builder addLabels(long[] shape, Values values) {
            return this.addLabels(shape, 'c', values);
        }

        public Builder addLabels(long[] shape, char order, Values values) {
            this.labels.add((Triple<long[], Character, Values>)new Triple((Object)shape, (Object)Character.valueOf(order), (Object)values));
            return this;
        }

        public RandomMultiDataSetIterator build() {
            return new RandomMultiDataSetIterator(this.numMiniBatches, this.features, this.labels);
        }
    }

    public static enum Values {
        RANDOM_UNIFORM,
        RANDOM_NORMAL,
        ONE_HOT,
        ZEROS,
        ONES,
        BINARY,
        INTEGER_0_10,
        INTEGER_0_100,
        INTEGER_0_1000,
        INTEGER_0_10000,
        INTEGER_0_100000;

    }
}

