/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.normalization;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.nn.layers.LayerHelper;
import org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.util.OneTimeLogger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BatchNormalization
extends BaseLayer<org.deeplearning4j.nn.conf.layers.BatchNormalization> {
    private static final Logger log = LoggerFactory.getLogger(BatchNormalization.class);
    BatchNormalizationHelper helper = null;
    protected int index = 0;
    protected List<TrainingListener> listeners = new ArrayList<TrainingListener>();
    protected INDArray std;
    protected INDArray xMu;
    protected INDArray xHat;

    public BatchNormalization(NeuralNetConfiguration conf) {
        super(conf);
        this.initializeHelper();
    }

    void initializeHelper() {
        String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
        if ("CUDA".equalsIgnoreCase(backend)) {
            try {
                this.helper = Class.forName("org.deeplearning4j.nn.layers.normalization.CudnnBatchNormalizationHelper").asSubclass(BatchNormalizationHelper.class).newInstance();
                log.debug("CudnnBatchNormalizationHelper successfully initialized");
                if (!this.helper.checkSupported(((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getEps())) {
                    this.helper = null;
                }
            }
            catch (Throwable t) {
                if (!(t instanceof ClassNotFoundException)) {
                    log.warn("Could not initialize CudnnBatchNormalizationHelper", t);
                }
                OneTimeLogger.info((Logger)log, (String)"cuDNN not found: use cuDNN for better GPU performance by including the deeplearning4j-cuda module. For more information, please refer to: https://deeplearning4j.org/cudnn", (Object[])new Object[]{t});
            }
        }
    }

    @Override
    public double calcL2(boolean backpropParamsOnly) {
        return 0.0;
    }

    @Override
    public double calcL1(boolean backpropParamsOnly) {
        return 0.0;
    }

    @Override
    public Layer.Type type() {
        return Layer.Type.NORMALIZATION;
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
        INDArray nextEpsilon;
        INDArray dxhat;
        INDArray dGamma;
        INDArray dBeta;
        INDArray dBetaView;
        INDArray dGammaView;
        this.assertInputSet(true);
        long[] shape = this.getShape(epsilon);
        long batchSize = epsilon.size(0);
        org.deeplearning4j.nn.conf.layers.BatchNormalization layerConf = (org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf();
        INDArray gamma = null;
        INDArray dGlobalMeanView = (INDArray)this.gradientViews.get("mean");
        INDArray dGlobalVarView = (INDArray)this.gradientViews.get("var");
        if (layerConf.isLockGammaBeta()) {
            long[] tempShape = new long[]{1L, shape[1]};
            dGammaView = Nd4j.createUninitialized((long[])tempShape, (char)'c');
            dBetaView = Nd4j.createUninitialized((long[])tempShape, (char)'c');
        } else {
            gamma = this.getParam("gamma");
            dGammaView = (INDArray)this.gradientViews.get("gamma");
            dBetaView = (INDArray)this.gradientViews.get("beta");
        }
        DefaultGradient retGradient = new DefaultGradient();
        if (this.helper != null) {
            INDArray eps;
            INDArray in;
            if (layerConf.isLockGammaBeta()) {
                gamma = Nd4j.valueArrayOf((long[])new long[]{1L, shape[1]}, (double)layerConf.getGamma());
            }
            if (this.input.rank() == 2) {
                in = this.input.reshape(this.input.ordering(), new long[]{this.input.size(0), this.input.size(1), 1L, 1L});
                eps = epsilon.reshape(epsilon.ordering(), new long[]{epsilon.size(0), epsilon.size(1), 1L, 1L});
            } else {
                in = this.input;
                eps = epsilon;
            }
            Pair<Gradient, INDArray> ret = this.helper.backpropGradient(in, eps, ArrayUtil.toInts((long[])shape), gamma, dGammaView, dBetaView, layerConf.getEps(), workspaceMgr);
            if (ret != null) {
                ((Gradient)ret.getFirst()).setGradientFor("mean", dGlobalMeanView);
                ((Gradient)ret.getFirst()).setGradientFor("var", dGlobalVarView);
                if (this.input.rank() == 2) {
                    INDArray e = (INDArray)ret.getSecond();
                    ret.setSecond((Object)e.reshape(e.ordering(), new long[]{e.size(0), e.size(1)}));
                }
                return ret;
            }
        }
        if (epsilon.rank() == 2) {
            dBeta = epsilon.sum(new int[]{0});
            dGamma = epsilon.mul(this.xHat).sum(new int[]{0});
            dxhat = layerConf.isLockGammaBeta() ? epsilon.mul((Number)layerConf.getGamma()) : epsilon.mulRowVector(gamma);
            INDArray dLdVar = dxhat.mul(this.xMu).sum(new int[]{0}).muli((Number)-0.5).muli(Transforms.pow((INDArray)this.std, (Number)-3.0, (boolean)true));
            INDArray dxmu1 = dxhat.sum(new int[]{0}).divi(this.std).negi();
            INDArray dxmu2 = this.xMu.sum(new int[]{0}).muli((Number)(-2.0 / (double)batchSize)).muli(dLdVar);
            INDArray dLdmu = dxmu1.addi(dxmu2);
            INDArray dLdx = dxhat.diviRowVector(this.std).addi(this.xMu.muliRowVector(dLdVar.muli((Number)(2.0 / (double)batchSize)))).addiRowVector(dLdmu.muli((Number)(1.0 / (double)batchSize)));
            dGammaView.assign(dGamma);
            dBetaView.assign(dBeta);
            retGradient.setGradientFor("gamma", dGammaView);
            retGradient.setGradientFor("beta", dBetaView);
            dGlobalMeanView.assign((Number)0);
            dGlobalVarView.assign((Number)0);
            retGradient.setGradientFor("mean", dGlobalMeanView);
            retGradient.setGradientFor("var", dGlobalVarView);
            nextEpsilon = dLdx;
        } else if (epsilon.rank() == 4) {
            dBeta = epsilon.sum(new int[]{0, 2, 3});
            dGamma = epsilon.mul(this.xHat).sum(new int[]{0, 2, 3});
            dxhat = layerConf.isLockGammaBeta() ? epsilon.mul((Number)layerConf.getGamma()) : Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastMulOp(epsilon, gamma, Nd4j.createUninitialized((long[])epsilon.shape(), (char)epsilon.ordering()), new int[]{1}));
            INDArray dLdVar = dxhat.mul(this.xMu).sum(new int[]{0, 2, 3}).muli((Number)-0.5).muli(Transforms.pow((INDArray)this.std, (Number)-3.0, (boolean)true));
            long effectiveBatchSize = this.input.size(0) * this.input.size(2) * this.input.size(3);
            INDArray dxmu1 = dxhat.sum(new int[]{0, 2, 3}).divi(this.std).negi();
            INDArray dxmu2 = this.xMu.sum(new int[]{0, 2, 3}).muli((Number)(-2.0 / (double)effectiveBatchSize)).muli(dLdVar);
            INDArray dLdmu = dxmu1.addi(dxmu2);
            INDArray dLdx = Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastDivOp(dxhat, this.std, dxhat, new int[]{1})).addi(Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastMulOp(this.xMu, dLdVar.muli((Number)(2.0 / (double)effectiveBatchSize)), this.xMu, new int[]{1})));
            Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastAddOp(dLdx, dLdmu.muli((Number)(1.0 / (double)effectiveBatchSize)), dLdx, new int[]{1}));
            dGammaView.assign(dGamma);
            dBetaView.assign(dBeta);
            retGradient.setGradientFor("gamma", dGammaView);
            retGradient.setGradientFor("beta", dBetaView);
            dGlobalMeanView.assign((Number)0);
            dGlobalVarView.assign((Number)0);
            retGradient.setGradientFor("mean", dGlobalMeanView);
            retGradient.setGradientFor("var", dGlobalVarView);
            nextEpsilon = dLdx;
        } else {
            throw new IllegalStateException("The layer prior to BatchNorm in the configuration is not currently supported. " + this.layerId());
        }
        nextEpsilon = workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, nextEpsilon);
        return new Pair((Object)retGradient, (Object)nextEpsilon);
    }

    @Override
    public void fit(INDArray input, LayerWorkspaceMgr workspaceMgr) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
        this.assertInputSet(false);
        return this.preOutput(this.input, training ? Layer.TrainingMode.TRAIN : Layer.TrainingMode.TEST, workspaceMgr);
    }

    @Override
    public Gradient gradient() {
        return this.gradient;
    }

    public INDArray preOutput(INDArray x, Layer.TrainingMode training, LayerWorkspaceMgr workspaceMgr) {
        INDArray activations;
        double g;
        INDArray var;
        INDArray mean;
        org.deeplearning4j.nn.conf.layers.BatchNormalization layerConf = (org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf();
        long[] shape = this.getShape(x);
        INDArray gamma = null;
        INDArray beta = null;
        INDArray globalMeanView = this.getParam("mean");
        INDArray globalVarView = this.getParam("var");
        if (layerConf.isLockGammaBeta()) {
            if (this.helper != null && this.input.rank() == 4) {
                long[] gammaBetaShape = new long[]{1L, ((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getNOut()};
                gamma = Nd4j.valueArrayOf((long[])gammaBetaShape, (double)((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getGamma());
                beta = Nd4j.valueArrayOf((long[])gammaBetaShape, (double)((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getBeta());
            }
        } else {
            gamma = this.getParam("gamma");
            beta = this.getParam("beta");
        }
        if (this.helper != null) {
            INDArray in = x;
            if (x.rank() == 2) {
                in = x.reshape(x.ordering(), new long[]{in.size(0), in.size(1), 1L, 1L});
            }
            double decay = layerConf.getDecay();
            INDArray ret = this.helper.preOutput(in, training == Layer.TrainingMode.TRAIN, ArrayUtil.toInts((long[])shape), gamma, beta, globalMeanView, globalVarView, decay, layerConf.getEps(), workspaceMgr);
            if (ret != null) {
                if (this.input.rank() == 2) {
                    return ret.reshape(ret.ordering(), new long[]{ret.size(0), ret.size(1)});
                }
                return ret;
            }
        }
        if (training == Layer.TrainingMode.TRAIN) {
            switch (x.rank()) {
                case 2: {
                    mean = x.mean(new int[]{0});
                    var = x.var(false, new int[]{0});
                    break;
                }
                case 4: {
                    mean = x.mean(new int[]{0, 2, 3});
                    var = x.var(false, new int[]{0, 2, 3});
                    break;
                }
                default: {
                    throw new IllegalStateException("Batch normalization on activations of rank " + x.rank() + " not supported " + this.layerId());
                }
            }
            this.std = Transforms.sqrt((INDArray)workspaceMgr.dup(ArrayType.INPUT, var).addi((Number)((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getEps()), (boolean)false);
        } else {
            mean = this.getParam("mean");
            var = this.getParam("var");
            this.std = Transforms.sqrt((INDArray)workspaceMgr.dup(ArrayType.INPUT, var).addi((Number)((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getEps()), (boolean)false);
        }
        if (x.rank() == 2) {
            this.xMu = workspaceMgr.leverageTo(ArrayType.INPUT, x.subRowVector(mean));
            this.xHat = workspaceMgr.leverageTo(ArrayType.INPUT, this.xMu.divRowVector(this.std));
            if (layerConf.isLockGammaBeta()) {
                g = layerConf.getGamma();
                double b = layerConf.getBeta();
                activations = g != 1.0 && b != 0.0 ? this.xHat.mul((Number)g).addi((Number)b) : this.xHat;
            } else {
                activations = this.xHat.mulRowVector(gamma).addiRowVector(beta);
            }
        } else if (x.rank() == 4) {
            if (!Shape.strideDescendingCAscendingF((INDArray)x)) {
                x = x.dup();
            }
            this.xMu = workspaceMgr.createUninitialized(ArrayType.INPUT, x.shape(), x.ordering());
            this.xMu = Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastSubOp(x, mean, this.xMu, new int[]{1}));
            this.xHat = workspaceMgr.createUninitialized(ArrayType.INPUT, x.shape(), x.ordering());
            this.xHat = Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastDivOp(this.xMu, this.std, this.xHat, new int[]{1}));
            if (layerConf.isLockGammaBeta()) {
                g = layerConf.getGamma();
                double b = layerConf.getBeta();
                activations = g != 1.0 && b != 0.0 ? this.xHat.mul((Number)g).addi((Number)b) : this.xHat;
            } else {
                activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, x.shape(), x.ordering());
                activations = Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastMulOp(this.xHat, gamma, activations, new int[]{1}));
                activations = Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastAddOp(activations, beta, activations, new int[]{1}));
            }
        } else {
            throw new IllegalStateException("The layer prior to BatchNorm in the configuration is not currently supported. " + this.layerId());
        }
        if (training == Layer.TrainingMode.TRAIN) {
            if (layerConf.isMinibatch()) {
                double decay = layerConf.getDecay();
                globalMeanView.muli((Number)decay).addi(mean.muli((Number)(1.0 - decay)));
                globalVarView.muli((Number)decay).addi(var.muli((Number)(1.0 - decay)));
            } else {
                globalMeanView.assign(mean);
                globalVarView.assign(var);
            }
        }
        activations = workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, activations);
        return activations;
    }

    @Override
    public Collection<TrainingListener> getListeners() {
        return this.listeners;
    }

    @Override
    public void setListeners(TrainingListener ... listeners) {
        this.listeners = new ArrayList<TrainingListener>(Arrays.asList(listeners));
    }

    @Override
    public void setIndex(int index) {
        this.index = index;
    }

    @Override
    public int getIndex() {
        return this.index;
    }

    @Override
    public boolean isPretrainLayer() {
        return false;
    }

    @Override
    public LayerHelper getHelper() {
        return this.helper;
    }

    public long[] getShape(INDArray x) {
        if (x.rank() == 2 || x.rank() == 4) {
            return new long[]{1L, x.size(1)};
        }
        if (x.rank() == 3) {
            long wDim = x.size(1);
            long hdim = x.size(2);
            if (x.size(0) > 1L && wDim * hdim == x.length()) {
                throw new IllegalArgumentException("Illegal input for batch size " + this.layerId());
            }
            return new long[]{1L, wDim * hdim};
        }
        throw new IllegalStateException("Unable to process input of rank " + x.rank() + " " + this.layerId());
    }
}

