/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.datasets.loader;

import java.io.File;
import java.net.URL;
import java.util.ArrayList;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.bagofwords.vectorizer.BagOfWordsVectorizer;
import org.deeplearning4j.bagofwords.vectorizer.TextVectorizer;
import org.deeplearning4j.bagofwords.vectorizer.TfidfVectorizer;
import org.deeplearning4j.datasets.fetchers.BaseDataFetcher;
import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareFileSentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizerfactory.UimaTokenizerFactory;
import org.deeplearning4j.util.ArchiveUtils;
import org.nd4j.linalg.dataset.DataSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ReutersNewsGroupsLoader
extends BaseDataFetcher {
    private TextVectorizer textVectorizer;
    private boolean tfidf;
    public static final String NEWSGROUP_URL = "http://qwone.com/~jason/20Newsgroups/20news-18828.tar.gz";
    private File reutersRootDir;
    private static Logger log = LoggerFactory.getLogger(ReutersNewsGroupsLoader.class);
    private DataSet load;

    public ReutersNewsGroupsLoader(boolean tfidf) throws Exception {
        this.getIfNotExists();
        LabelAwareFileSentenceIterator iter = new LabelAwareFileSentenceIterator(this.reutersRootDir);
        ArrayList<String> labels = new ArrayList<String>();
        for (File f : this.reutersRootDir.listFiles()) {
            if (!f.isDirectory()) continue;
            labels.add(f.getName());
        }
        UimaTokenizerFactory tokenizerFactory = new UimaTokenizerFactory();
        this.textVectorizer = tfidf ? new TfidfVectorizer.Builder().iterate(iter).labels(labels).tokenize(tokenizerFactory).build() : new BagOfWordsVectorizer.Builder().iterate(iter).labels(labels).tokenize(tokenizerFactory).build();
        this.load = this.textVectorizer.vectorize();
    }

    private void getIfNotExists() throws Exception {
        String home = System.getProperty("user.home");
        String rootDir = home + File.separator + "reuters";
        this.reutersRootDir = new File(rootDir);
        if (!this.reutersRootDir.exists()) {
            this.reutersRootDir.mkdir();
        } else if (this.reutersRootDir.exists()) {
            return;
        }
        File rootTarFile = new File(this.reutersRootDir, "20news-18828.tar.gz");
        if (rootTarFile.exists()) {
            rootTarFile.delete();
        }
        rootTarFile.createNewFile();
        FileUtils.copyURLToFile((URL)new URL(NEWSGROUP_URL), (File)rootTarFile);
        ArchiveUtils.unzipFileTo((String)rootTarFile.getAbsolutePath(), (String)this.reutersRootDir.getAbsolutePath());
        rootTarFile.delete();
        FileUtils.copyDirectory((File)new File(this.reutersRootDir, "20news-18828"), (File)this.reutersRootDir);
        FileUtils.deleteDirectory((File)new File(this.reutersRootDir, "20news-18828"));
        if (this.reutersRootDir.listFiles() == null) {
            throw new IllegalStateException("No files found!");
        }
    }

    public void fetch(int numExamples) {
        ArrayList<DataSet> newData = new ArrayList<DataSet>();
        for (int grabbed = 0; grabbed < numExamples && this.cursor < this.load.numExamples(); ++grabbed) {
            newData.add(this.load.get(this.cursor));
            ++this.cursor;
        }
        this.curr = DataSet.merge(newData);
    }
}

