/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.datasets.fetchers;

import au.com.bytecode.opencsv.CSV;
import au.com.bytecode.opencsv.CSVReadProc;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.concurrent.atomic.AtomicInteger;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.datasets.fetchers.BaseDataFetcher;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;

public class CSVDataFetcher
extends BaseDataFetcher {
    private CSV csv;
    private InputStream is;
    private int labelColumn;
    private DataSet all;

    public CSVDataFetcher(InputStream is, int labelColumn) {
        this.is = is;
        this.labelColumn = labelColumn;
        this.csv = CSV.skipLines((int)0).separator(',').quote('\"').create();
        this.init();
    }

    public CSVDataFetcher(File f, int labelColumn) throws IOException {
        this(new BufferedInputStream(new FileInputStream(f)), labelColumn, 0);
    }

    public CSVDataFetcher(InputStream is, int labelColumn, int skipLines) {
        this.is = is;
        this.labelColumn = labelColumn;
        this.csv = CSV.skipLines((int)skipLines).separator(',').noQuote().create();
        this.init();
    }

    public CSVDataFetcher(File f, int labelColumn, int skipLines) throws IOException {
        this(new BufferedInputStream(new FileInputStream(f)), labelColumn, skipLines);
    }

    private void init() {
        final HashSet labels = new HashSet();
        final ArrayList rowLabels = new ArrayList();
        final ArrayList features = new ArrayList();
        final AtomicInteger i1 = new AtomicInteger(-1);
        this.csv.read(this.is, new CSVReadProc(){

            public void procRow(int rowIndex, String ... values) {
                if (values.length < 1) {
                    return;
                }
                if (i1.get() < 1) {
                    i1.set(values.length - 1);
                    CSVDataFetcher.this.inputColumns = values.length - 1;
                } else if (values.length - 1 != i1.get()) {
                    return;
                }
                Pair row = CSVDataFetcher.this.processRow(values);
                rowLabels.add(row.getSecond());
                labels.add(row.getSecond());
                features.add(row.getFirst());
            }
        });
        ArrayList<DataSet> l = new ArrayList<DataSet>();
        ArrayList labelIndices = new ArrayList(labels);
        for (int i = 0; i < rowLabels.size(); ++i) {
            l.add(new DataSet((INDArray)features.get(i), FeatureUtil.toOutcomeVector((int)labelIndices.indexOf(rowLabels.get(i)), (int)labels.size())));
        }
        this.numOutcomes = labels.size();
        this.totalExamples = l.size();
        this.all = DataSet.merge(l);
    }

    private Pair<INDArray, String> processRow(String[] data) {
        String label = data[this.labelColumn].replaceAll(".\".", "");
        double[] d = new double[data.length - 1];
        int index = 0;
        for (int i = 0; i < data.length; ++i) {
            if (i == this.labelColumn) continue;
            d[index] = Double.parseDouble(data[i]);
            ++index;
        }
        INDArray d1 = Nd4j.create((double[])d).reshape(1, d.length);
        return new Pair<INDArray, String>(d1, label);
    }

    @Override
    public void fetch(int numExamples) {
        int end = this.cursor + numExamples;
        if (end >= this.all.numExamples()) {
            end = this.all.numExamples();
        }
        this.initializeCurrFromList(this.all.asList().subList(this.cursor, end));
        this.cursor += numExamples;
    }
}

