package org.deeplearning4j.nn.layers.util;

import java.util.Arrays;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.AbstractLayer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Broadcast;

/* loaded from: input_file:org/deeplearning4j/nn/layers/util/MaskLayer.class */
public class MaskLayer extends AbstractLayer<org.deeplearning4j.nn.conf.layers.util.MaskLayer> {
    private Gradient emptyGradient;

    public MaskLayer(NeuralNetConfiguration neuralNetConfiguration, DataType dataType) {
        super(neuralNetConfiguration, dataType);
        this.emptyGradient = new DefaultGradient();
    }

    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public Layer m163clone() {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public boolean isPretrainLayer() {
        return false;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void clearNoiseWeightParams() {
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        return new Pair<>(this.emptyGradient, applyMask(iNDArray, this.maskArray, layerWorkspaceMgr, ArrayType.ACTIVATION_GRAD));
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        return applyMask(this.input, this.maskArray, layerWorkspaceMgr, ArrayType.ACTIVATIONS);
    }

    private static INDArray applyMask(INDArray iNDArray, INDArray iNDArray2, LayerWorkspaceMgr layerWorkspaceMgr, ArrayType arrayType) {
        if (iNDArray2 == null) {
            return layerWorkspaceMgr.leverageTo(arrayType, iNDArray);
        }
        switch (iNDArray.rank()) {
            case 2:
                if (iNDArray2.isColumnVectorOrScalar() && iNDArray2.size(0) == iNDArray.size(0)) {
                    return layerWorkspaceMgr.leverageTo(arrayType, iNDArray.mulColumnVector(iNDArray2));
                }
                throw new IllegalStateException("Expected column vector for mask with 2d input, with same size(0) as input. Got mask with shape: " + Arrays.toString(iNDArray2.shape()) + ", input shape = " + Arrays.toString(iNDArray.shape()));
            case 3:
                if (iNDArray2.rank() != 2 || iNDArray.size(0) != iNDArray2.size(0) || iNDArray.size(2) != iNDArray2.size(1)) {
                    throw new IllegalStateException("With 3d (time series) input with shape [minibatch, size, sequenceLength]=" + Arrays.toString(iNDArray.shape()) + ", expected 2d mask array with shape [minibatch, sequenceLength]. Got mask with shape: " + Arrays.toString(iNDArray2.shape()));
                }
                INDArray createUninitialized = layerWorkspaceMgr.createUninitialized(arrayType, iNDArray.dataType(), iNDArray.shape(), 'f');
                Broadcast.mul(iNDArray, iNDArray2, createUninitialized, new int[]{0, 2});
                return createUninitialized;
            case 4:
                int[] iArr = new int[4];
                int i = 0;
                for (int i2 = 0; i2 < 4; i2++) {
                    if (iNDArray.size(i2) == iNDArray2.size(i2)) {
                        int i3 = i;
                        i++;
                        iArr[i3] = i2;
                    }
                }
                if (i < 4) {
                    iArr = Arrays.copyOfRange(iArr, 0, i);
                }
                INDArray createUninitialized2 = layerWorkspaceMgr.createUninitialized(arrayType, iNDArray.dataType(), iNDArray.shape(), 'c');
                Broadcast.mul(iNDArray, iNDArray2, createUninitialized2, iArr);
                return createUninitialized2;
            default:
                throw new RuntimeException("Expected rank 2 to 4 input. Got rank " + iNDArray.rank() + " with shape " + Arrays.toString(iNDArray.shape()));
        }
    }
}
