/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.datasets.fetchers;

import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Random;
import java.util.zip.Adler32;
import java.util.zip.Checksum;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.base.MnistFetcher;
import org.deeplearning4j.common.resources.DL4JResources;
import org.deeplearning4j.common.resources.ResourceType;
import org.deeplearning4j.datasets.mnist.MnistManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.fetcher.BaseDataFetcher;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.MathUtils;

public class MnistDataFetcher
extends BaseDataFetcher {
    public static final int NUM_EXAMPLES = 60000;
    public static final int NUM_EXAMPLES_TEST = 10000;
    protected static final long CHECKSUM_TRAIN_FEATURES = 2094436111L;
    protected static final long CHECKSUM_TRAIN_LABELS = 4008842612L;
    protected static final long CHECKSUM_TEST_FEATURES = 2165396896L;
    protected static final long CHECKSUM_TEST_LABELS = 2212998611L;
    protected static final long[] CHECKSUMS_TRAIN = new long[]{2094436111L, 4008842612L};
    protected static final long[] CHECKSUMS_TEST = new long[]{2165396896L, 2212998611L};
    protected transient MnistManager man;
    protected boolean binarize = true;
    protected boolean train;
    protected int[] order;
    protected Random rng;
    protected boolean shuffle;
    protected boolean oneIndexed = false;
    protected boolean fOrder = false;
    protected boolean firstShuffle = true;
    protected final int numExamples;

    public MnistDataFetcher(boolean binarize) throws IOException {
        this(binarize, true, true, System.currentTimeMillis(), 60000);
    }

    public MnistDataFetcher(boolean binarize, boolean train, boolean shuffle, long rngSeed, int numExamples) throws IOException {
        long[] checksums;
        String labels;
        String images;
        if (!this.mnistExists()) {
            new MnistFetcher().downloadAndUntar();
        }
        String MNIST_ROOT = DL4JResources.getDirectory((ResourceType)ResourceType.DATASET, (String)"MNIST").getAbsolutePath();
        if (train) {
            images = FilenameUtils.concat((String)MNIST_ROOT, (String)"train-images-idx3-ubyte");
            labels = FilenameUtils.concat((String)MNIST_ROOT, (String)"train-labels-idx1-ubyte");
            this.totalExamples = 60000;
            checksums = CHECKSUMS_TRAIN;
        } else {
            images = FilenameUtils.concat((String)MNIST_ROOT, (String)"t10k-images-idx3-ubyte");
            labels = FilenameUtils.concat((String)MNIST_ROOT, (String)"t10k-labels-idx1-ubyte");
            this.totalExamples = 10000;
            checksums = CHECKSUMS_TEST;
        }
        String[] files = new String[]{images, labels};
        try {
            this.man = new MnistManager(images, labels, train);
            this.validateFiles(files, checksums);
        }
        catch (Exception e) {
            try {
                FileUtils.deleteDirectory((File)new File(MNIST_ROOT));
            }
            catch (Exception exception) {
                // empty catch block
            }
            new MnistFetcher().downloadAndUntar();
            this.man = new MnistManager(images, labels, train);
            this.validateFiles(files, checksums);
        }
        this.numOutcomes = 10;
        this.binarize = binarize;
        this.cursor = 0;
        this.inputColumns = this.man.getImages().getEntryLength();
        this.train = train;
        this.shuffle = shuffle;
        this.order = train ? new int[60000] : new int[10000];
        for (int i = 0; i < this.order.length; ++i) {
            this.order[i] = i;
        }
        this.rng = new Random(rngSeed);
        this.numExamples = numExamples;
        this.reset();
    }

    private boolean mnistExists() {
        String MNIST_ROOT = DL4JResources.getDirectory((ResourceType)ResourceType.DATASET, (String)"MNIST").getAbsolutePath();
        File f = new File(MNIST_ROOT, "train-images-idx3-ubyte");
        if (!f.exists()) {
            return false;
        }
        f = new File(MNIST_ROOT, "train-labels-idx1-ubyte");
        if (!f.exists()) {
            return false;
        }
        f = new File(MNIST_ROOT, "t10k-images-idx3-ubyte");
        if (!f.exists()) {
            return false;
        }
        f = new File(MNIST_ROOT, "t10k-labels-idx1-ubyte");
        return f.exists();
    }

    private void validateFiles(String[] files, long[] checksums) {
        try {
            for (int i = 0; i < files.length; ++i) {
                long checksum;
                File f = new File(files[i]);
                Adler32 adler = new Adler32();
                long l = checksum = f.exists() ? FileUtils.checksum((File)f, (Checksum)adler).getValue() : -1L;
                if (f.exists() && checksum == checksums[i]) continue;
                throw new IllegalStateException("Failed checksum: expected " + checksums[i] + ", got " + checksum + " for file: " + f);
            }
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public MnistDataFetcher() throws IOException {
        this(true);
    }

    public void fetch(int numExamples) {
        if (!this.hasMore()) {
            throw new IllegalStateException("Unable to get more; there are no more images");
        }
        float[][] featureData = new float[numExamples][0];
        float[][] labelData = new float[numExamples][0];
        int actualExamples = 0;
        byte[] working = null;
        int i = 0;
        while (i < numExamples && this.hasMore()) {
            byte[] img = this.man.readImageUnsafe(this.order[this.cursor]);
            if (this.fOrder) {
                if (working == null) {
                    working = new byte[784];
                }
                for (int j = 0; j < 784; ++j) {
                    working[j] = img[28 * (j % 28) + j / 28];
                }
                img = working;
            }
            int label = this.man.readLabel(this.order[this.cursor]);
            if (this.oneIndexed) {
                --label;
            }
            float[] featureVec = new float[img.length];
            featureData[actualExamples] = featureVec;
            labelData[actualExamples] = new float[this.numOutcomes];
            labelData[actualExamples][label] = 1.0f;
            for (int j = 0; j < img.length; ++j) {
                float v = img[j] & 0xFF;
                if (this.binarize) {
                    if (v > 30.0f) {
                        featureVec[j] = 1.0f;
                        continue;
                    }
                    featureVec[j] = 0.0f;
                    continue;
                }
                featureVec[j] = v / 255.0f;
            }
            ++actualExamples;
            ++i;
            ++this.cursor;
        }
        if (actualExamples < numExamples) {
            featureData = (float[][])Arrays.copyOfRange(featureData, 0, actualExamples);
            labelData = (float[][])Arrays.copyOfRange(labelData, 0, actualExamples);
        }
        INDArray features = Nd4j.create((float[][])featureData);
        INDArray labels = Nd4j.create((float[][])labelData);
        this.curr = new DataSet(features, labels);
    }

    public void reset() {
        this.cursor = 0;
        this.curr = null;
        if (this.shuffle) {
            if (this.train && this.numExamples < 60000 || !this.train && this.numExamples < 10000) {
                if (this.firstShuffle) {
                    MathUtils.shuffleArray((int[])this.order, (Random)this.rng);
                    this.firstShuffle = false;
                } else {
                    MathUtils.shuffleArraySubset((int[])this.order, (int)this.numExamples, (Random)this.rng);
                }
            } else {
                MathUtils.shuffleArray((int[])this.order, (Random)this.rng);
            }
        }
    }

    public DataSet next() {
        DataSet next = super.next();
        return next;
    }
}

