/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset.api.iterator;

import java.io.File;
import java.io.IOException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;

public class StandardScaler {
    private INDArray mean;
    private INDArray std;
    private int runningTotal = 0;

    public void fit(DataSet dataSet) {
        this.mean = dataSet.getFeatureMatrix().mean(0);
        this.std = dataSet.getFeatureMatrix().std(0);
    }

    public void fit(DataSetIterator iterator) {
        while (iterator.hasNext()) {
            DataSet next = (DataSet)iterator.next();
            if (this.mean == null) {
                this.mean = next.getFeatureMatrix().mean(0);
                this.std = Nd4j.zeros(this.mean.shape());
            } else {
                INDArray xMinusMean = next.getFeatureMatrix().subRowVector(this.mean);
                INDArray newMean = this.mean.add(xMinusMean.sum(0).divi(this.runningTotal));
                this.std.addi(xMinusMean.muli(next.getFeatureMatrix().subRowVector(newMean)).sum(0).divi(this.runningTotal));
                this.mean = newMean;
            }
            this.runningTotal += next.numExamples();
        }
        iterator.reset();
    }

    public void load(File mean, File std) throws IOException {
        this.mean = Nd4j.readBinary(mean);
        this.std = Nd4j.readBinary(std);
    }

    public void save(File mean, File std) throws IOException {
        Nd4j.saveBinary(this.mean, mean);
        Nd4j.saveBinary(this.std, std);
    }

    public void transform(DataSet dataSet) {
        dataSet.setFeatures(dataSet.getFeatures().subiRowVector(this.mean));
        dataSet.setFeatures(dataSet.getFeatures().diviRowVector(this.std));
    }

    public INDArray getMean() {
        return this.mean;
    }

    public INDArray getStd() {
        return this.std;
    }
}

