package org.deeplearning4j.nn.conf.preprocessor;

import java.util.Arrays;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Convolution3D;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.shade.jackson.annotation.JsonCreator;
import org.nd4j.shade.jackson.annotation.JsonProperty;

/* loaded from: input_file:org/deeplearning4j/nn/conf/preprocessor/Cnn3DToFeedForwardPreProcessor.class */
public class Cnn3DToFeedForwardPreProcessor implements InputPreProcessor {
    protected long inputDepth;
    protected long inputHeight;
    protected long inputWidth;
    protected long numChannels;
    protected boolean isNCDHW;

    @JsonCreator
    public Cnn3DToFeedForwardPreProcessor(@JsonProperty("inputDepth") long j, @JsonProperty("inputHeight") long j2, @JsonProperty("inputWidth") long j3, @JsonProperty("numChannels") long j4, @JsonProperty("isNCDHW") boolean z) {
        this.isNCDHW = true;
        this.inputDepth = j;
        this.inputHeight = j2;
        this.inputWidth = j3;
        this.numChannels = j4;
        this.isNCDHW = z;
    }

    public Cnn3DToFeedForwardPreProcessor(int i, int i2, int i3) {
        this.isNCDHW = true;
        this.inputDepth = i;
        this.inputHeight = i2;
        this.inputWidth = i3;
        this.numChannels = 1L;
    }

    public Cnn3DToFeedForwardPreProcessor(int i, int i2, int i3, int i4, Convolution3D.DataFormat dataFormat) {
        this(i, i2, i3, i4, dataFormat == Convolution3D.DataFormat.NCDHW);
    }

    public Cnn3DToFeedForwardPreProcessor() {
        this.isNCDHW = true;
    }

    @Override // org.deeplearning4j.nn.conf.InputPreProcessor
    public INDArray preProcess(INDArray iNDArray, int i, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (iNDArray.rank() == 2) {
            return iNDArray;
        }
        if ((this.isNCDHW && iNDArray.size(1) != this.numChannels) || (!this.isNCDHW && iNDArray.size(4) != this.numChannels)) {
            throw new IllegalStateException("Invalid input array: expected shape in format [minibatch, channels, channels, height, width] or [minibatch, channels, height, width, channels] for numChannels: " + this.numChannels + ", inputDepth " + this.inputDepth + ", inputHeight " + this.inputHeight + " and inputWidth " + this.inputWidth + ", but got " + Arrays.toString(iNDArray.shape()));
        }
        if (!Shape.hasDefaultStridesForShape(iNDArray)) {
            iNDArray = layerWorkspaceMgr.dup(ArrayType.ACTIVATIONS, iNDArray, 'c');
        }
        long[] shape = iNDArray.shape();
        return layerWorkspaceMgr.leverageTo(ArrayType.ACTIVATIONS, iNDArray.reshape('c', new long[]{shape[0], shape[1] * shape[2] * shape[3] * shape[4]}));
    }

    @Override // org.deeplearning4j.nn.conf.InputPreProcessor
    public INDArray backprop(INDArray iNDArray, int i, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (!Shape.hasDefaultStridesForShape(iNDArray)) {
            iNDArray = layerWorkspaceMgr.dup(ArrayType.ACTIVATION_GRAD, iNDArray, 'c');
        }
        if (iNDArray.rank() == 5) {
            return layerWorkspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, iNDArray);
        }
        if (iNDArray.columns() != this.inputDepth * this.inputWidth * this.inputHeight * this.numChannels) {
            throw new IllegalArgumentException("Invalid input: expect output to have depth: " + this.inputDepth + ", height: " + this.inputHeight + ", width: " + this.inputWidth + " and channels: " + this.numChannels + ", i.e. [" + iNDArray.rows() + ", " + (this.inputDepth * this.inputHeight * this.inputWidth * this.numChannels) + "] but was instead " + Arrays.toString(iNDArray.shape()));
        }
        return layerWorkspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, this.isNCDHW ? iNDArray.reshape('c', new long[]{iNDArray.size(0), this.numChannels, this.inputDepth, this.inputHeight, this.inputWidth}) : iNDArray.reshape('c', new long[]{iNDArray.size(0), this.inputDepth, this.inputHeight, this.inputWidth, this.numChannels}));
    }

    @Override // org.deeplearning4j.nn.conf.InputPreProcessor
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public Cnn3DToFeedForwardPreProcessor m98clone() {
        try {
            return (Cnn3DToFeedForwardPreProcessor) super.clone();
        } catch (CloneNotSupportedException e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.deeplearning4j.nn.conf.InputPreProcessor
    public InputType getOutputType(InputType inputType) {
        if (inputType == null || inputType.getType() != InputType.Type.CNN3D) {
            throw new IllegalStateException("Invalid input type: Expected input of type CNN3D, got " + inputType);
        }
        InputType.InputTypeConvolutional3D inputTypeConvolutional3D = (InputType.InputTypeConvolutional3D) inputType;
        return InputType.feedForward(inputTypeConvolutional3D.getChannels() * inputTypeConvolutional3D.getDepth() * inputTypeConvolutional3D.getHeight() * inputTypeConvolutional3D.getWidth());
    }

    @Override // org.deeplearning4j.nn.conf.InputPreProcessor
    public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray iNDArray, MaskState maskState, int i) {
        return new Pair<>(iNDArray, maskState);
    }

    public long getInputDepth() {
        return this.inputDepth;
    }

    public long getInputHeight() {
        return this.inputHeight;
    }

    public long getInputWidth() {
        return this.inputWidth;
    }

    public long getNumChannels() {
        return this.numChannels;
    }

    public boolean isNCDHW() {
        return this.isNCDHW;
    }

    public void setInputDepth(long j) {
        this.inputDepth = j;
    }

    public void setInputHeight(long j) {
        this.inputHeight = j;
    }

    public void setInputWidth(long j) {
        this.inputWidth = j;
    }

    public void setNumChannels(long j) {
        this.numChannels = j;
    }

    public void setNCDHW(boolean z) {
        this.isNCDHW = z;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof Cnn3DToFeedForwardPreProcessor)) {
            return false;
        }
        Cnn3DToFeedForwardPreProcessor cnn3DToFeedForwardPreProcessor = (Cnn3DToFeedForwardPreProcessor) obj;
        return cnn3DToFeedForwardPreProcessor.canEqual(this) && getInputDepth() == cnn3DToFeedForwardPreProcessor.getInputDepth() && getInputHeight() == cnn3DToFeedForwardPreProcessor.getInputHeight() && getInputWidth() == cnn3DToFeedForwardPreProcessor.getInputWidth() && getNumChannels() == cnn3DToFeedForwardPreProcessor.getNumChannels() && isNCDHW() == cnn3DToFeedForwardPreProcessor.isNCDHW();
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof Cnn3DToFeedForwardPreProcessor;
    }

    public int hashCode() {
        long inputDepth = getInputDepth();
        int i = (1 * 59) + ((int) ((inputDepth >>> 32) ^ inputDepth));
        long inputHeight = getInputHeight();
        int i2 = (i * 59) + ((int) ((inputHeight >>> 32) ^ inputHeight));
        long inputWidth = getInputWidth();
        int i3 = (i2 * 59) + ((int) ((inputWidth >>> 32) ^ inputWidth));
        long numChannels = getNumChannels();
        return (((i3 * 59) + ((int) ((numChannels >>> 32) ^ numChannels))) * 59) + (isNCDHW() ? 79 : 97);
    }

    public String toString() {
        return "Cnn3DToFeedForwardPreProcessor(inputDepth=" + getInputDepth() + ", inputHeight=" + getInputHeight() + ", inputWidth=" + getInputWidth() + ", numChannels=" + getNumChannels() + ", isNCDHW=" + isNCDHW() + ")";
    }
}
