/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset.api.preprocessor;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp;
import org.nd4j.linalg.dataset.DistributionStats;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Conditions;

abstract class AbstractNormalizerStandardize {
    AbstractNormalizerStandardize() {
    }

    void assertIsFit() {
        if (!this.isFit()) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
    }

    protected void preProcess(INDArray theFeatures, DistributionStats stats) {
        if (theFeatures.rank() == 2) {
            theFeatures.subiRowVector(stats.getMean());
            theFeatures.diviRowVector(this.filteredStd(stats));
        } else {
            Nd4j.getExecutioner().execAndReturn(new BroadcastSubOp(theFeatures, stats.getMean(), theFeatures, 1));
            Nd4j.getExecutioner().execAndReturn(new BroadcastDivOp(theFeatures, this.filteredStd(stats), theFeatures, 1));
        }
    }

    protected abstract boolean isFit();

    void revert(INDArray data, DistributionStats distribution) {
        if (data.rank() == 2) {
            data.muliRowVector(this.filteredStd(distribution));
            data.addiRowVector(distribution.getMean());
        } else {
            Nd4j.getExecutioner().execAndReturn(new BroadcastMulOp(data, this.filteredStd(distribution), data, 1));
            Nd4j.getExecutioner().execAndReturn(new BroadcastAddOp(data, distribution.getMean(), data, 1));
        }
    }

    private INDArray filteredStd(DistributionStats stats) {
        INDArray stdCopy = stats.getStd();
        BooleanIndexing.replaceWhere(stdCopy, 1.0, Conditions.equals(0));
        return stdCopy;
    }
}

