/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.mkldnn;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Map;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.mkldnn.BaseMKLDNNHelper;
import org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm;
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;

public class MKLDNNBatchNormHelper
implements BatchNormalizationHelper {
    private static final int[] RANK2_DIMS = new int[]{0};
    private static final int[] RANK4_DIMS = new int[]{0, 2, 3};
    protected OpContext context;
    private INDArray meanCache;
    private INDArray varCache;

    public MKLDNNBatchNormHelper(DataType dataType) {
    }

    @Override
    public boolean checkSupported(double eps, boolean fixedGammaBeta) {
        return !fixedGammaBeta && BaseMKLDNNHelper.mklDnnEnabled();
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, INDArray beta, INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr) {
        if (input.dataType() != DataType.FLOAT) {
            return null;
        }
        ArrayList<INDArray> args = new ArrayList<INDArray>();
        args.add(input);
        args.add(this.meanCache);
        args.add(this.varCache);
        if (gamma != null) {
            args.add(gamma.reshape(new long[]{gamma.length()}));
        }
        if (beta != null) {
            args.add(beta.reshape(new long[]{beta.length()}));
        }
        args.add(epsilon);
        DynamicCustomOp op = DynamicCustomOp.builder((String)"batchnorm_bp").addInputs(args.toArray(new INDArray[0])).addIntegerArguments(new int[]{gamma == null ? 0 : 1, beta == null ? 0 : 1, 1}).addFloatingPointArguments(new Double[]{eps}).build();
        INDArray epsAtInput = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape());
        INDArray dLdm = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, this.meanCache.dataType(), this.meanCache.shape());
        INDArray dLdv = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, this.meanCache.dataType(), this.meanCache.shape());
        op.setOutputArgument(0, epsAtInput);
        op.setOutputArgument(1, dLdm);
        op.setOutputArgument(2, dLdv);
        if (dGammaView != null) {
            op.setOutputArgument(3, dGammaView.reshape(new long[]{dGammaView.length()}));
            op.setOutputArgument(4, dBetaView.reshape(new long[]{dBetaView.length()}));
        }
        Nd4j.exec((CustomOp)op);
        DefaultGradient g = new DefaultGradient();
        g.setGradientFor("gamma", dGammaView);
        g.setGradientFor("beta", dBetaView);
        return new Pair((Object)g, (Object)epsAtInput);
    }

    @Override
    public INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean, INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr) {
        INDArray v;
        INDArray m;
        if (x.dataType() != DataType.FLOAT) {
            return null;
        }
        if (this.context == null) {
            this.context = Nd4j.getExecutioner().buildContext();
            this.context.setIArguments(new long[]{ArrayUtil.fromBoolean((gamma != null ? 1 : 0) != 0), ArrayUtil.fromBoolean((beta != null ? 1 : 0) != 0), 1L});
            this.context.setTArguments(new double[]{eps});
        }
        if (training) {
            if (this.meanCache == null) {
                try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                    this.meanCache = Nd4j.createUninitialized((DataType)x.dataType(), (long[])new long[]{x.size(1)});
                    this.varCache = Nd4j.createUninitialized((DataType)x.dataType(), (long[])new long[]{x.size(1)});
                }
            }
            x.mean(this.meanCache, x.rank() == 2 ? RANK2_DIMS : RANK4_DIMS);
            Nd4j.exec((Op)new Variance(x, this.varCache, false, x.rank() == 2 ? RANK2_DIMS : RANK4_DIMS));
            m = this.meanCache;
            v = this.varCache;
        } else {
            m = mean.reshape(new long[]{mean.length()});
            v = var.reshape(new long[]{var.length()});
        }
        this.context.getInputArrays().clear();
        this.context.getOutputArrays().clear();
        this.context.setInputArray(0, x);
        this.context.setInputArray(1, m);
        this.context.setInputArray(2, v);
        if (gamma != null && beta != null) {
            this.context.setInputArray(3, gamma.rank() == 2 ? gamma.reshape(new long[]{gamma.length()}) : gamma);
            this.context.setInputArray(4, beta.rank() == 2 ? beta.reshape(new long[]{beta.length()}) : beta);
        }
        INDArray out = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, x.dataType(), x.shape());
        this.context.setOutputArray(0, out);
        BatchNorm bn = new BatchNorm();
        Nd4j.exec((CustomOp)bn, (OpContext)this.context);
        return out;
    }

    @Override
    public INDArray getMeanCache(DataType dataType) {
        return this.meanCache;
    }

    @Override
    public INDArray getVarCache(DataType dataType) {
        return this.varCache;
    }

    @Override
    public Map<String, Long> helperMemoryUse() {
        return Collections.emptyMap();
    }
}

