package org.deeplearning4j.nn.layers.normalization;

import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.AbstractLayer;
import org.deeplearning4j.nn.layers.HelperUtils;
import org.deeplearning4j.nn.layers.LayerHelper;
import org.deeplearning4j.nn.layers.mkldnn.MKLDNNLocalResponseNormalizationHelper;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.primitives.Triple;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
import org.nd4j.linalg.exception.ND4JOpProfilerException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/layers/normalization/LocalResponseNormalization.class */
public class LocalResponseNormalization extends AbstractLayer<org.deeplearning4j.nn.conf.layers.LocalResponseNormalization> {
    private static final Logger log = LoggerFactory.getLogger(LocalResponseNormalization.class);
    protected LocalResponseNormalizationHelper helper;
    protected int helperCountFail;
    public static final String LOCAL_RESPONSE_NORM_CUDNN_HELPER_CLASS_NAME = "org.deeplearning4j.cuda.normalization.CudnnLocalResponseNormalizationHelper";

    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public Layer m151clone() {
        return new LocalResponseNormalization(this.conf.m39clone(), this.dataType);
    }

    public LocalResponseNormalization(NeuralNetConfiguration neuralNetConfiguration, DataType dataType) {
        super(neuralNetConfiguration, dataType);
        this.helper = null;
        this.helperCountFail = 0;
        initializeHelper();
    }

    void initializeHelper() {
        if ("CUDA".equalsIgnoreCase(Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"))) {
            this.helper = (LocalResponseNormalizationHelper) HelperUtils.createHelper(LOCAL_RESPONSE_NORM_CUDNN_HELPER_CLASS_NAME, MKLDNNLocalResponseNormalizationHelper.class.getName(), LocalResponseNormalizationHelper.class, layerConf().getLayerName(), this.dataType);
        }
        if (this.helper == null || this.helper.checkSupported(layerConf().getK(), layerConf().getN(), layerConf().getAlpha(), layerConf().getBeta())) {
            return;
        }
        log.debug("Removed helper {} as not supported (k={}, n={}, alpha={}, beta={})", new Object[]{this.helper.getClass(), Double.valueOf(layerConf().getK()), Double.valueOf(layerConf().getN()), Double.valueOf(layerConf().getAlpha()), Double.valueOf(layerConf().getBeta())});
        this.helper = null;
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public double calcRegularizationScore(boolean z) {
        return EvaluationBinary.DEFAULT_EDGE_VALUE;
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public Layer.Type type() {
        return Layer.Type.NORMALIZATION;
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public void fit(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        assertInputSet(true);
        double k = layerConf().getK();
        double n = layerConf().getN();
        double alpha = layerConf().getAlpha();
        double beta = layerConf().getBeta();
        int i = ((int) n) / 2;
        if (this.helper != null && (this.helperCountFail == 0 || !layerConf().isCudnnAllowFallback())) {
            Pair<Gradient, INDArray> pair = null;
            try {
                pair = this.helper.backpropGradient(this.input, iNDArray, k, n, alpha, beta, layerWorkspaceMgr);
            } catch (ND4JOpProfilerException e) {
                throw e;
            } catch (Throwable th) {
                if (th.getMessage() != null && th.getMessage().contains("Failed to allocate")) {
                    throw th;
                }
                if (!layerConf().isCudnnAllowFallback()) {
                    throw new RuntimeException("Error during LocalResponseNormalization CuDNN helper backprop - isCudnnAllowFallback() is set to false", th);
                }
                this.helperCountFail++;
                log.warn("CuDNN LocalResponseNormalization backprop execution failed - falling back on built-in implementation", th);
            }
            if (pair != null) {
                return pair;
            }
        }
        boolean z = layerConf().getDataFormat() == CNN2DFormat.NCHW;
        int i2 = z ? 1 : 3;
        char c = z ? (char) 2 : (char) 1;
        char c2 = z ? (char) 3 : (char) 2;
        Triple<INDArray, INDArray, INDArray> activateHelper = activateHelper(true, layerWorkspaceMgr, true);
        INDArray iNDArray2 = (INDArray) activateHelper.getFirst();
        INDArray iNDArray3 = (INDArray) activateHelper.getSecond();
        INDArray iNDArray4 = (INDArray) activateHelper.getThird();
        long size = this.input.size(i2);
        DefaultGradient defaultGradient = new DefaultGradient();
        INDArray mul = iNDArray2.mul(iNDArray);
        INDArray dup = mul.dup();
        for (int i3 = 1; i3 < i + 1; i3++) {
            if (z) {
                dup.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(i3, size), NDArrayIndex.all(), NDArrayIndex.all()}, dup.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(i3, size), NDArrayIndex.all(), NDArrayIndex.all()}).addi(mul.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0L, size - i3), NDArrayIndex.all(), NDArrayIndex.all()})));
                dup.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0L, size - i3), NDArrayIndex.all(), NDArrayIndex.all()}, dup.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0L, size - i3), NDArrayIndex.all(), NDArrayIndex.all()}).addi(mul.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(i3, size), NDArrayIndex.all(), NDArrayIndex.all()})));
            } else {
                dup.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(i3, size)}, dup.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(i3, size)}).addi(mul.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0L, size - i3)})));
                dup.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0L, size - i3)}, dup.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0L, size - i3)}).addi(mul.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(i3, size)})));
            }
        }
        INDArray createUninitialized = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, iNDArray.dataType(), iNDArray.shape(), iNDArray.ordering());
        Nd4j.getExecutioner().exec(new MulOp(iNDArray, iNDArray4, createUninitialized));
        createUninitialized.subi(dup.muli(this.input).divi(iNDArray3).muli(Double.valueOf(2.0d * alpha * beta)));
        return new Pair<>(defaultGradient, createUninitialized);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        return (INDArray) activateHelper(z, layerWorkspaceMgr, false).getFirst();
    }

    private Triple<INDArray, INDArray, INDArray> activateHelper(boolean z, LayerWorkspaceMgr layerWorkspaceMgr, boolean z2) {
        assertInputSet(false);
        double k = layerConf().getK();
        double n = layerConf().getN();
        double alpha = layerConf().getAlpha();
        double beta = layerConf().getBeta();
        int i = ((int) n) / 2;
        if (this.helper != null && (this.helperCountFail == 0 || !layerConf().isCudnnAllowFallback())) {
            INDArray iNDArray = null;
            try {
                iNDArray = this.helper.activate(this.input, z, k, n, alpha, beta, layerWorkspaceMgr);
            } catch (ND4JOpProfilerException e) {
                throw e;
            } catch (Throwable th) {
                if (th.getMessage() != null && th.getMessage().contains("Failed to allocate")) {
                    throw th;
                }
                if (!layerConf().isCudnnAllowFallback()) {
                    throw new RuntimeException("Error during LocalRsponseNormalization CuDNN helper backprop - isCudnnAllowFallback() is set to false", th);
                }
                this.helperCountFail++;
                log.warn("CuDNN LocalResponseNormalization backprop execution failed - falling back on built-in implementation", th);
            }
            if (iNDArray != null) {
                return new Triple<>(iNDArray, (Object) null, (Object) null);
            }
        }
        boolean z3 = layerConf().getDataFormat() == CNN2DFormat.NCHW;
        long size = this.input.size(z3 ? 1 : 3);
        INDArray mul = this.input.mul(this.input);
        INDArray dup = mul.dup();
        for (int i2 = 1; i2 < i + 1; i2++) {
            if (z3) {
                dup.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(i2, size), NDArrayIndex.all(), NDArrayIndex.all()}, dup.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(i2, size), NDArrayIndex.all(), NDArrayIndex.all()}).addi(mul.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0L, size - i2), NDArrayIndex.all(), NDArrayIndex.all()})));
                dup.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0L, size - i2), NDArrayIndex.all(), NDArrayIndex.all()}, dup.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0L, size - i2), NDArrayIndex.all(), NDArrayIndex.all()}).addi(mul.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(i2, size), NDArrayIndex.all(), NDArrayIndex.all()})));
            } else {
                dup.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(i2, size)}, dup.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(i2, size)}).addi(mul.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0L, size - i2)})));
                dup.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0L, size - i2)}, dup.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0L, size - i2)}).addi(mul.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(i2, size)})));
            }
        }
        INDArray iNDArray2 = null;
        INDArray iNDArray3 = null;
        INDArray createUninitialized = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, this.input.dataType(), this.input.shape(), this.input.ordering());
        if (z2) {
            iNDArray2 = dup.mul(Double.valueOf(alpha)).addi(Double.valueOf(k));
            iNDArray3 = Transforms.pow(iNDArray2, Double.valueOf(-beta), true);
            Nd4j.getExecutioner().exec(new MulOp(this.input, iNDArray3, createUninitialized));
        } else {
            dup.muli(Double.valueOf(alpha), createUninitialized).addi(Double.valueOf(k));
            Transforms.pow(createUninitialized, Double.valueOf(-beta), false);
            createUninitialized.muli(this.input);
        }
        return z2 ? new Triple<>(createUninitialized, iNDArray2, iNDArray3) : new Triple<>(createUninitialized, (Object) null, (Object) null);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public boolean isPretrainLayer() {
        return false;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void clearNoiseWeightParams() {
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public LayerHelper getHelper() {
        return this.helper;
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model, org.deeplearning4j.nn.api.NeuralNetwork
    public INDArray params() {
        return null;
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public INDArray getParam(String str) {
        return params();
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public void setParams(INDArray iNDArray) {
    }
}
