package org.deeplearning4j.nn.layers.recurrent;

import java.util.Arrays;
import lombok.NonNull;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/nn/layers/recurrent/MaskZeroLayer.class */
public class MaskZeroLayer extends BaseWrapperLayer {
    private static final long serialVersionUID = -7369482676002469854L;
    private double maskingValue;

    public MaskZeroLayer(@NonNull Layer layer, double d) {
        super(layer);
        if (layer == null) {
            throw new NullPointerException("underlying is marked non-null but is null");
        }
        this.maskingValue = d;
    }

    @Override // org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer, org.deeplearning4j.nn.api.Layer
    public Layer.Type type() {
        return Layer.Type.RECURRENT;
    }

    @Override // org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        return this.underlying.backpropGradient(iNDArray, layerWorkspaceMgr);
    }

    @Override // org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        setMaskFromInput(input());
        return this.underlying.activate(z, layerWorkspaceMgr);
    }

    @Override // org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray, boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        setMaskFromInput(iNDArray);
        return this.underlying.activate(iNDArray, z, layerWorkspaceMgr);
    }

    private void setMaskFromInput(INDArray iNDArray) {
        if (iNDArray.rank() != 3) {
            throw new IllegalArgumentException("Expected input of shape [batch_size, timestep_input_size, timestep], got shape " + Arrays.toString(iNDArray.shape()) + " instead");
        }
        if ((this.underlying instanceof BaseRecurrentLayer) && ((BaseRecurrentLayer) this.underlying).getDataFormat() == RNNFormat.NWC) {
            iNDArray = iNDArray.permute(new int[]{0, 2, 1});
        }
        this.underlying.setMaskArray(iNDArray.eq(Double.valueOf(this.maskingValue)).castTo(iNDArray.dataType()).sum(new int[]{1}).neq(Long.valueOf(iNDArray.shape()[1])).castTo(iNDArray.dataType()).detach());
    }

    @Override // org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer, org.deeplearning4j.nn.api.Model
    public long numParams() {
        return this.underlying.numParams();
    }

    @Override // org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer, org.deeplearning4j.nn.api.Layer
    public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray iNDArray, MaskState maskState, int i) {
        this.underlying.feedForwardMaskArray(iNDArray, maskState, i);
        return new Pair<>((Object) null, maskState);
    }
}
