/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.UUID;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.util.FeatureUtil;
import org.nd4j.linalg.util.MathUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DataSet
implements org.nd4j.linalg.dataset.api.DataSet {
    private static final long serialVersionUID = 1935520764586513365L;
    private static Logger log = LoggerFactory.getLogger(DataSet.class);
    private List<String> columnNames = new ArrayList<String>();
    private List<String> labelNames = new ArrayList<String>();
    private INDArray features;
    private INDArray labels;
    private String id = UUID.randomUUID().toString();

    public DataSet() {
        this(Nd4j.zeros(new int[]{1, 1}), Nd4j.zeros(new int[]{1, 1}));
    }

    public DataSet(INDArray first, INDArray second) {
        if (first.size(0) != second.size(0)) {
            throw new IllegalStateException("Invalid data transform; first and second do not have equal rows. First was " + first.size(0) + " second was " + second.size(0));
        }
        this.features = first;
        this.labels = second;
    }

    public static DataSet empty() {
        return new DataSet(Nd4j.zeros(new int[]{1, 1}), Nd4j.zeros(new int[]{1, 1}));
    }

    public static DataSet merge(List<DataSet> data, boolean clone) {
        if (data.isEmpty()) {
            throw new IllegalArgumentException("Unable to merge empty dataset");
        }
        DataSet first = data.get(0);
        if (first.getFeatures().rank() == 3 && first.getLabels().rank() == 3) {
            return DataSet.mergeTimeSeries(data);
        }
        int numExamples = DataSet.totalExamples(data);
        INDArray in = Nd4j.create(numExamples, first.getFeatures().columns());
        INDArray out = Nd4j.create(numExamples, first.getLabels().columns());
        int count = 0;
        for (int i = 0; i < data.size(); ++i) {
            DataSet d1 = data.get(i);
            for (int j = 0; j < d1.numExamples(); ++j) {
                DataSet example = d1.get(j);
                in.putRow(count, clone ? example.getFeatures().dup() : example.getFeatures());
                out.putRow(count, clone ? example.getLabels().dup() : example.getLabels());
                ++count;
            }
        }
        return new DataSet(in, out);
    }

    private static DataSet mergeTimeSeries(List<DataSet> data) {
        DataSet first = data.get(0);
        int numExamples = DataSet.totalExamples(data);
        int nIn = first.getFeatureMatrix().size(1);
        int nOut = first.getLabels().size(1);
        int tsLength = first.getFeatureMatrix().size(2);
        INDArray in = Nd4j.create(numExamples, nIn, tsLength);
        INDArray out = Nd4j.create(numExamples, nOut, tsLength);
        int rowCount = 0;
        for (DataSet ds : data) {
            INDArray f = ds.getFeatures();
            INDArray l = ds.getLabels();
            int nEx = f.size(0);
            in.get(NDArrayIndex.interval(rowCount, rowCount + nEx), NDArrayIndex.all(), NDArrayIndex.all()).assign(f);
            out.get(NDArrayIndex.interval(rowCount, rowCount + nEx), NDArrayIndex.all(), NDArrayIndex.all()).assign(l);
            rowCount += nEx;
        }
        return new DataSet(in, out);
    }

    public static DataSet merge(List<DataSet> data) {
        if (data.isEmpty()) {
            throw new IllegalArgumentException("Unable to merge empty dataset");
        }
        return DataSet.merge(data, false);
    }

    private static int totalExamples(Collection<DataSet> coll) {
        int count = 0;
        for (DataSet d : coll) {
            count += d.numExamples();
        }
        return count;
    }

    @Override
    public org.nd4j.linalg.dataset.api.DataSet getRange(int from, int to) {
        return new DataSet(this.features.get(NDArrayIndex.interval(from, to)), this.labels.get(NDArrayIndex.interval(from, to)));
    }

    @Override
    public void load(File from) {
        try {
            BufferedInputStream bis = new BufferedInputStream(new FileInputStream(from));
            DataInputStream dis = new DataInputStream(bis);
            this.features = Nd4j.read(dis);
            this.labels = Nd4j.read(dis);
            dis.close();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override
    public void save(File to) {
        try {
            BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(to));
            DataOutputStream dis = new DataOutputStream(bos);
            Nd4j.write(this.getFeatureMatrix(), dis);
            Nd4j.write(this.getLabels(), dis);
            dis.flush();
            dis.close();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override
    public DataSetIterator iterateWithMiniBatches() {
        return null;
    }

    @Override
    public String id() {
        return this.id;
    }

    @Override
    public INDArray getFeatures() {
        return this.features;
    }

    @Override
    public void setFeatures(INDArray features) {
        this.features = features;
    }

    @Override
    public Map<Integer, Double> labelCounts() {
        HashMap<Integer, Double> ret = new HashMap<Integer, Double>();
        if (this.labels == null) {
            return ret;
        }
        for (int i = 0; i < this.labels.rows(); ++i) {
            INDArray label = this.labels.getRow(i);
            int maxIdx = Nd4j.getBlasWrapper().iamax(label);
            if (maxIdx < 0) {
                throw new IllegalStateException("Please check the iamax implementation for " + Nd4j.getBlasWrapper().getClass().getName());
            }
            if (ret.get(maxIdx) == null) {
                ret.put(maxIdx, 1.0);
                continue;
            }
            ret.put(maxIdx, (Double)ret.get(maxIdx) + 1.0);
        }
        return ret;
    }

    @Override
    public void apply(Condition condition, Function<Number, Number> function) {
        BooleanIndexing.applyWhere(this.getFeatureMatrix(), condition, function);
    }

    @Override
    public DataSet copy() {
        DataSet ret = new DataSet(this.getFeatures().dup(), this.getLabels().dup());
        ret.setColumnNames(this.getColumnNames());
        ret.setLabelNames(this.getLabelNames());
        return ret;
    }

    @Override
    public DataSet reshape(int rows, int cols) {
        DataSet ret = new DataSet(this.getFeatures().reshape(new int[]{rows, cols}), this.getLabels());
        return ret;
    }

    @Override
    public void multiplyBy(double num) {
        this.getFeatures().muli(Nd4j.scalar(num));
    }

    @Override
    public void divideBy(int num) {
        this.getFeatures().divi(Nd4j.scalar(num));
    }

    @Override
    public void shuffle() {
        long seed = System.currentTimeMillis();
        Nd4j.shuffle(this.getFeatureMatrix(), new java.util.Random(seed), 0);
        Nd4j.shuffle(this.getLabels(), new java.util.Random(seed), 0);
    }

    @Override
    public void squishToRange(double min, double max) {
        for (int i = 0; i < this.getFeatures().length(); ++i) {
            double curr = (Double)this.getFeatures().getScalar(i).element();
            if (curr < min) {
                this.getFeatures().put(i, Nd4j.scalar(min));
                continue;
            }
            if (!(curr > max)) continue;
            this.getFeatures().put(i, Nd4j.scalar(max));
        }
    }

    @Override
    public void scaleMinAndMax(double min, double max) {
        FeatureUtil.scaleMinMax(min, max, this.getFeatureMatrix());
    }

    @Override
    public void scale() {
        FeatureUtil.scaleByMax(this.getFeatures());
    }

    @Override
    public void addFeatureVector(INDArray toAdd) {
        this.setFeatures(Nd4j.hstack(this.getFeatureMatrix(), toAdd));
    }

    @Override
    public void addFeatureVector(INDArray feature, int example) {
        this.getFeatures().putRow(example, feature);
    }

    @Override
    public void normalize() {
        FeatureUtil.normalizeMatrix(this.getFeatures());
    }

    @Override
    public void binarize() {
        this.binarize(0.0);
    }

    @Override
    public void binarize(double cutoff) {
        INDArray linear = this.getFeatureMatrix().linearView();
        for (int i = 0; i < this.getFeatures().length(); ++i) {
            double curr = linear.getDouble(i);
            if (curr > cutoff) {
                this.getFeatures().putScalar(i, 1);
                continue;
            }
            this.getFeatures().putScalar(i, 0);
        }
    }

    @Override
    public void normalizeZeroMeanZeroUnitVariance() {
        INDArray columnMeans = this.getFeatures().mean(0);
        INDArray columnStds = this.getFeatureMatrix().std(0);
        this.setFeatures(this.getFeatures().subiRowVector(columnMeans));
        columnStds.addi(Nd4j.scalar(Nd4j.EPS_THRESHOLD));
        this.setFeatures(this.getFeatures().diviRowVector(columnStds));
    }

    @Override
    public int numInputs() {
        return this.getFeatures().columns();
    }

    @Override
    public void validate() {
        if (this.getFeatures().size(0) != this.getLabels().size(0)) {
            throw new IllegalStateException("Invalid dataset");
        }
    }

    @Override
    public int outcome() {
        if (this.numExamples() > 1) {
            throw new IllegalStateException("Unable to derive outcome for dataset greater than one row");
        }
        return Nd4j.getBlasWrapper().iamax(this.getLabels());
    }

    @Override
    public void setNewNumberOfLabels(int labels) {
        int examples = this.numExamples();
        INDArray newOutcomes = Nd4j.create(examples, labels);
        this.setLabels(newOutcomes);
    }

    @Override
    public void setOutcome(int example, int label) {
        if (example > this.numExamples()) {
            throw new IllegalArgumentException("No example at " + example);
        }
        if (label > this.numOutcomes() || label < 0) {
            throw new IllegalArgumentException("Illegal label");
        }
        INDArray outcome = FeatureUtil.toOutcomeVector(label, this.numOutcomes());
        this.getLabels().putRow(example, outcome);
    }

    @Override
    public DataSet get(int i) {
        if (i > this.numExamples() || i < 0) {
            throw new IllegalArgumentException("invalid example number");
        }
        if (i == 0 && this.numExamples() == 1) {
            return this;
        }
        return new DataSet(this.getFeatures().getRow(i), this.getLabels().getRow(i));
    }

    @Override
    public DataSet get(int[] i) {
        return new DataSet(this.getFeatures().getRows(i), this.getLabels().getRows(i));
    }

    @Override
    public List<DataSet> batchBy(int num) {
        ArrayList batched = Lists.newArrayList();
        for (List splitBatch : Lists.partition(this.asList(), (int)num)) {
            batched.add(DataSet.merge(splitBatch));
        }
        return batched;
    }

    @Override
    public DataSet filterBy(int[] labels) {
        List<DataSet> list = this.asList();
        ArrayList<DataSet> newList = new ArrayList<DataSet>();
        ArrayList<Integer> labelList = new ArrayList<Integer>();
        for (int i : labels) {
            labelList.add(i);
        }
        Object object = list.iterator();
        while (object.hasNext()) {
            DataSet d = (DataSet)object.next();
            int outcome = d.outcome();
            if (!labelList.contains(outcome)) continue;
            newList.add(d);
        }
        return DataSet.merge(newList);
    }

    @Override
    public void filterAndStrip(int[] labels) {
        int i;
        DataSet filtered = this.filterBy(labels);
        ArrayList<Integer> newLabels = new ArrayList<Integer>();
        HashMap<Integer, Integer> labelMap = new HashMap<Integer, Integer>();
        for (i = 0; i < labels.length; ++i) {
            labelMap.put(labels[i], i);
        }
        for (i = 0; i < filtered.numExamples(); ++i) {
            DataSet example = filtered.get(i);
            int o2 = example.outcome();
            Integer outcome = (Integer)labelMap.get(o2);
            newLabels.add(outcome);
        }
        INDArray newLabelMatrix = Nd4j.create(filtered.numExamples(), labels.length);
        if (newLabelMatrix.rows() != newLabels.size()) {
            throw new IllegalStateException("Inconsistent label sizes");
        }
        for (int i2 = 0; i2 < newLabelMatrix.rows(); ++i2) {
            Integer i22 = (Integer)newLabels.get(i2);
            if (i22 == null) {
                throw new IllegalStateException("Label not found on row " + i2);
            }
            INDArray newRow = FeatureUtil.toOutcomeVector(i22, labels.length);
            newLabelMatrix.putRow(i2, newRow);
        }
        this.setFeatures(filtered.getFeatures());
        this.setLabels(newLabelMatrix);
    }

    @Override
    public List<DataSet> dataSetBatches(int num) {
        List list = Lists.partition(this.asList(), (int)num);
        ArrayList<DataSet> ret = new ArrayList<DataSet>();
        for (List l : list) {
            ret.add(DataSet.merge(l));
        }
        return ret;
    }

    @Override
    public List<DataSet> sortAndBatchByNumLabels() {
        this.sortByLabel();
        return this.batchByNumLabels();
    }

    @Override
    public List<DataSet> batchByNumLabels() {
        return this.batchBy(this.numOutcomes());
    }

    @Override
    public List<DataSet> asList() {
        ArrayList<DataSet> list = new ArrayList<DataSet>(this.numExamples());
        for (int i = 0; i < this.numExamples(); ++i) {
            list.add(new DataSet(this.getFeatures().getRow(i), this.getLabels().getRow(i)));
        }
        return list;
    }

    @Override
    public SplitTestAndTrain splitTestAndTrain(int numHoldout, java.util.Random rng) {
        if (numHoldout >= this.numExamples()) {
            throw new IllegalArgumentException("Unable to split on size larger than the number of rows");
        }
        DataSet first = new DataSet(this.getFeatureMatrix().get(NDArrayIndex.interval(0, numHoldout)), this.getLabels().get(NDArrayIndex.interval(0, numHoldout)));
        DataSet second = new DataSet(this.getFeatureMatrix().get(NDArrayIndex.interval(numHoldout, this.numExamples())), this.getLabels().get(NDArrayIndex.interval(numHoldout, this.numExamples())));
        return new SplitTestAndTrain(first, second);
    }

    @Override
    public SplitTestAndTrain splitTestAndTrain(int numHoldout) {
        return this.splitTestAndTrain(numHoldout, new java.util.Random());
    }

    @Override
    public INDArray getLabels() {
        return this.labels;
    }

    @Override
    public void setLabels(INDArray labels) {
        this.labels = labels;
    }

    @Override
    public INDArray getFeatureMatrix() {
        return this.getFeatures();
    }

    /*
     * WARNING - void declaration
     */
    @Override
    public void sortByLabel() {
        void var6_11;
        Queue q;
        HashMap<Integer, ArrayDeque<DataSet>> map = new HashMap<Integer, ArrayDeque<DataSet>>();
        List<DataSet> data = this.asList();
        int numLabels = this.numOutcomes();
        int examples = this.numExamples();
        for (DataSet dataSet : data) {
            int label = dataSet.outcome();
            q = (ArrayDeque<DataSet>)map.get(label);
            if (q == null) {
                q = new ArrayDeque<DataSet>();
                map.put(label, (ArrayDeque<DataSet>)q);
            }
            q.add(dataSet);
        }
        for (Map.Entry entry : map.entrySet()) {
            log.info("Label " + entry + " has " + ((Queue)entry.getValue()).size() + " elements");
        }
        boolean optimal = true;
        boolean bl = false;
        while (var6_11 < examples) {
            if (optimal) {
                for (int j = 0; j < numLabels; ++j) {
                    q = (Queue)map.get(j);
                    if (q == null) {
                        optimal = false;
                    } else {
                        DataSet next = (DataSet)q.poll();
                        if (next != null) {
                            this.addRow(next, (int)var6_11);
                            ++var6_11;
                            continue;
                        }
                        optimal = false;
                    }
                    break;
                }
            } else {
                DataSet add = null;
                for (Queue q2 : map.values()) {
                    if (q2.isEmpty()) continue;
                    add = (DataSet)q2.poll();
                    break;
                }
                this.addRow(add, (int)var6_11);
            }
            ++var6_11;
        }
    }

    @Override
    public void addRow(DataSet d, int i) {
        if (i > this.numExamples() || d == null) {
            throw new IllegalArgumentException("Invalid index for adding a row");
        }
        this.getFeatures().putRow(i, d.getFeatures());
        this.getLabels().putRow(i, d.getLabels());
    }

    private int getLabel(DataSet data) {
        Float f = Float.valueOf(data.getLabels().maxNumber().floatValue());
        return f.intValue();
    }

    @Override
    public INDArray exampleSums() {
        return this.getFeatures().sum(1);
    }

    @Override
    public INDArray exampleMaxs() {
        return this.getFeatures().max(1);
    }

    @Override
    public INDArray exampleMeans() {
        return this.getFeatures().mean(1);
    }

    @Override
    public DataSet sample(int numSamples) {
        return this.sample(numSamples, Nd4j.getRandom());
    }

    @Override
    public DataSet sample(int numSamples, Random rng) {
        return this.sample(numSamples, rng, false);
    }

    @Override
    public DataSet sample(int numSamples, boolean withReplacement) {
        return this.sample(numSamples, Nd4j.getRandom(), withReplacement);
    }

    @Override
    public DataSet sample(int numSamples, Random rng, boolean withReplacement) {
        INDArray examples = Nd4j.create(numSamples, this.getFeatures().columns());
        INDArray outcomes = Nd4j.create(numSamples, this.numOutcomes());
        HashSet added = new HashSet();
        for (int i = 0; i < numSamples; ++i) {
            int picked = rng.nextInt(this.numExamples());
            if (!withReplacement) {
                while (added.contains(picked)) {
                    picked = rng.nextInt(this.numExamples());
                }
            }
            examples.putRow(i, this.get(picked).getFeatures());
            outcomes.putRow(i, this.get(picked).getLabels());
        }
        return new DataSet(examples, outcomes);
    }

    @Override
    public void roundToTheNearest(int roundTo) {
        for (int i = 0; i < this.getFeatures().length(); ++i) {
            double curr = (Double)this.getFeatures().getScalar(i).element();
            this.getFeatures().put(i, Nd4j.scalar(MathUtils.roundDouble(curr, roundTo)));
        }
    }

    @Override
    public int numOutcomes() {
        return this.getLabels().columns();
    }

    @Override
    public int numExamples() {
        return this.getFeatures().size(0);
    }

    public String toString() {
        StringBuilder builder = new StringBuilder();
        builder.append("===========INPUT===================\n").append(this.getFeatures().toString().replaceAll(";", "\n")).append("\n=================OUTPUT==================\n").append(this.getLabels().toString().replaceAll(";", "\n"));
        return builder.toString();
    }

    @Override
    public List<String> getLabelNames() {
        return this.labelNames;
    }

    @Override
    public void setLabelNames(List<String> labelNames) {
        this.labelNames = labelNames;
    }

    @Override
    public List<String> getColumnNames() {
        return this.columnNames;
    }

    @Override
    public void setColumnNames(List<String> columnNames) {
        this.columnNames = columnNames;
    }

    @Override
    public SplitTestAndTrain splitTestAndTrain(double percentTrain) {
        int numPercent = (int)(percentTrain * (double)this.numExamples());
        return this.splitTestAndTrain(numPercent);
    }

    @Override
    public Iterator<DataSet> iterator() {
        return this.asList().iterator();
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (!(o instanceof DataSet)) {
            return false;
        }
        DataSet dataSet = (DataSet)o;
        if (this.getFeatures() != null ? !this.getFeatures().equals(dataSet.getFeatures()) : dataSet.getFeatures() != null) {
            return false;
        }
        return !(this.getLabels() == null ? dataSet.getLabels() != null : !this.getLabels().equals(dataSet.getLabels()));
    }

    public int hashCode() {
        int result = this.getFeatures() != null ? this.getFeatures().hashCode() : 0;
        result = 31 * result + (this.getLabels() != null ? this.getLabels().hashCode() : 0);
        return result;
    }
}

