/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset.api.preprocessor;

import java.io.File;
import java.io.IOException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

public class NormalizerStandardize
implements DataSetPreProcessor {
    private INDArray mean;
    private INDArray std;
    private int runningTotal = 0;

    public void fit(DataSet dataSet) {
        this.mean = dataSet.getFeatureMatrix().mean(0);
        this.std = dataSet.getFeatureMatrix().std(0);
        this.std.addi(Nd4j.scalar(Nd4j.EPS_THRESHOLD));
    }

    public void fit(DataSetIterator iterator) {
        while (iterator.hasNext()) {
            DataSet next = (DataSet)iterator.next();
            this.runningTotal += next.numExamples();
            if (this.mean == null) {
                this.mean = next.getFeatureMatrix().mean(0);
                this.std = iterator.batch() == 1 ? Nd4j.zeros(this.mean.shape()) : Transforms.pow(next.getFeatureMatrix().std(0), 2);
                this.std.muli(iterator.batch());
                continue;
            }
            INDArray xMinusMean = next.getFeatureMatrix().subRowVector(this.mean);
            INDArray newMean = this.mean.add(xMinusMean.sum(0).divi(this.runningTotal));
            INDArray meanB = next.getFeatureMatrix().mean(0);
            INDArray deltaSq = Transforms.pow(meanB.subRowVector(this.mean), 2);
            INDArray deltaSqScaled = deltaSq.mul(Float.valueOf(((float)this.runningTotal - (float)iterator.batch()) * (float)iterator.batch() / (float)iterator.totalExamples()));
            INDArray mtwoB = Transforms.pow(next.getFeatureMatrix().std(0), 2);
            mtwoB.muli(iterator.batch());
            this.std = this.std.add(mtwoB);
            this.std = this.std.add(deltaSqScaled);
            this.mean = newMean;
        }
        this.std.divi(this.runningTotal);
        this.std = Transforms.sqrt(this.std);
        this.std.addi(Nd4j.scalar(Nd4j.EPS_THRESHOLD));
        iterator.reset();
    }

    @Override
    public void preProcess(DataSet toPreProcess) {
        if (this.mean == null || this.std == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        toPreProcess.setFeatures(toPreProcess.getFeatures().subRowVector(this.mean));
        toPreProcess.setFeatures(toPreProcess.getFeatures().divRowVector(this.std));
    }

    public void transform(DataSet toPreProcess) {
        this.preProcess(toPreProcess);
    }

    public void transform(DataSetIterator toPreProcessIter) {
        while (toPreProcessIter.hasNext()) {
            this.preProcess((DataSet)toPreProcessIter.next());
        }
        toPreProcessIter.reset();
    }

    public void revertPreProcess(DataSet toPreProcess) {
        if (this.mean == null || this.std == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        toPreProcess.setFeatures(toPreProcess.getFeatures().mulRowVector(this.std));
        toPreProcess.setFeatures(toPreProcess.getFeatures().addRowVector(this.mean));
    }

    public void revert(DataSet toPreProcess) {
        this.revertPreProcess(toPreProcess);
    }

    public void revert(DataSetIterator toPreProcessIter) {
        while (toPreProcessIter.hasNext()) {
            this.revertPreProcess((DataSet)toPreProcessIter.next());
        }
        toPreProcessIter.reset();
    }

    public INDArray getMean() {
        if (this.mean == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        return this.mean;
    }

    public INDArray getStd() {
        if (this.std == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        return this.std;
    }

    public void load(File mean, File std) throws IOException {
        this.mean = Nd4j.readBinary(mean);
        this.std = Nd4j.readBinary(std);
    }

    public void save(File mean, File std) throws IOException {
        Nd4j.saveBinary(this.mean, mean);
        Nd4j.saveBinary(this.std, std);
    }
}

