package org.deeplearning4j.nn.conf.layers;

import java.util.Arrays;
import java.util.Map;
import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.deeplearning4j.util.CapsuleUtils;
import org.deeplearning4j.util.ValidationUtils;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/nn/conf/layers/PrimaryCapsules.class */
public class PrimaryCapsules extends SameDiffLayer {
    private int[] kernelSize;
    private int[] stride;
    private int[] padding;
    private int[] dilation;
    private int inputChannels;
    private int channels;
    private boolean hasBias;
    private int capsules;
    private int capsuleDimensions;
    private ConvolutionMode convolutionMode;
    private boolean useRelu;
    private double leak;
    private static final String WEIGHT_PARAM = "weight";
    private static final String BIAS_PARAM = "bias";

    /* loaded from: input_file:org/deeplearning4j/nn/conf/layers/PrimaryCapsules$Builder.class */
    public static class Builder extends SameDiffLayer.Builder<Builder> {
        private int[] kernelSize;
        private int[] stride;
        private int[] padding;
        private int[] dilation;
        private int channels;
        private boolean hasBias;
        private int capsules;
        private int capsuleDimensions;
        private ConvolutionMode convolutionMode;
        private boolean useRelu;
        private double leak;

        public void setKernelSize(int... iArr) {
            this.kernelSize = ValidationUtils.validate2NonNegative(iArr, true, "kernelSize");
        }

        public void setStride(int... iArr) {
            this.stride = ValidationUtils.validate2NonNegative(iArr, true, "stride");
        }

        public void setPadding(int... iArr) {
            this.padding = ValidationUtils.validate2NonNegative(iArr, true, "padding");
        }

        public void setDilation(int... iArr) {
            this.dilation = ValidationUtils.validate2NonNegative(iArr, true, "dilation");
        }

        public Builder(int i, int i2, int[] iArr, int[] iArr2, int[] iArr3, int[] iArr4, ConvolutionMode convolutionMode) {
            this.kernelSize = new int[]{9, 9};
            this.stride = new int[]{2, 2};
            this.padding = new int[]{0, 0};
            this.dilation = new int[]{1, 1};
            this.channels = 32;
            this.hasBias = true;
            this.convolutionMode = ConvolutionMode.Truncate;
            this.useRelu = false;
            this.leak = EvaluationBinary.DEFAULT_EDGE_VALUE;
            this.capsuleDimensions = i;
            this.channels = i2;
            setKernelSize(iArr);
            setStride(iArr2);
            setPadding(iArr3);
            setDilation(iArr4);
            this.convolutionMode = convolutionMode;
        }

        public Builder(int i, int i2, int[] iArr, int[] iArr2, int[] iArr3, int[] iArr4) {
            this(i, i2, iArr, iArr2, iArr3, iArr4, ConvolutionMode.Truncate);
        }

        public Builder(int i, int i2, int[] iArr, int[] iArr2, int[] iArr3) {
            this(i, i2, iArr, iArr2, iArr3, new int[]{1, 1}, ConvolutionMode.Truncate);
        }

        public Builder(int i, int i2, int[] iArr, int[] iArr2) {
            this(i, i2, iArr, iArr2, new int[]{0, 0}, new int[]{1, 1}, ConvolutionMode.Truncate);
        }

        public Builder(int i, int i2, int[] iArr) {
            this(i, i2, iArr, new int[]{2, 2}, new int[]{0, 0}, new int[]{1, 1}, ConvolutionMode.Truncate);
        }

        public Builder(int i, int i2) {
            this(i, i2, new int[]{9, 9}, new int[]{2, 2}, new int[]{0, 0}, new int[]{1, 1}, ConvolutionMode.Truncate);
        }

        public Builder kernelSize(int... iArr) {
            setKernelSize(iArr);
            return this;
        }

        public Builder stride(int... iArr) {
            setStride(iArr);
            return this;
        }

        public Builder padding(int... iArr) {
            setPadding(iArr);
            return this;
        }

        public Builder dilation(int... iArr) {
            setDilation(iArr);
            return this;
        }

        public Builder channels(int i) {
            this.channels = i;
            return this;
        }

        public Builder nOut(int i) {
            return channels(i);
        }

        public Builder capsuleDimensions(int i) {
            this.capsuleDimensions = i;
            return this;
        }

        public Builder capsules(int i) {
            this.capsules = i;
            return this;
        }

        public Builder hasBias(boolean z) {
            this.hasBias = z;
            return this;
        }

        public Builder convolutionMode(ConvolutionMode convolutionMode) {
            this.convolutionMode = convolutionMode;
            return this;
        }

        public Builder useReLU(boolean z) {
            this.useRelu = z;
            return this;
        }

        public Builder useReLU() {
            return useReLU(true);
        }

        public Builder useLeakyReLU(double d) {
            this.useRelu = true;
            this.leak = d;
            return this;
        }

        @Override // org.deeplearning4j.nn.conf.layers.Layer.Builder
        public <E extends Layer> E build() {
            return new PrimaryCapsules(this);
        }

        public int[] getKernelSize() {
            return this.kernelSize;
        }

        public int[] getStride() {
            return this.stride;
        }

        public int[] getPadding() {
            return this.padding;
        }

        public int[] getDilation() {
            return this.dilation;
        }

        public int getChannels() {
            return this.channels;
        }

        public boolean isHasBias() {
            return this.hasBias;
        }

        public int getCapsules() {
            return this.capsules;
        }

        public int getCapsuleDimensions() {
            return this.capsuleDimensions;
        }

        public ConvolutionMode getConvolutionMode() {
            return this.convolutionMode;
        }

        public boolean isUseRelu() {
            return this.useRelu;
        }

        public double getLeak() {
            return this.leak;
        }

        public void setChannels(int i) {
            this.channels = i;
        }

        public void setHasBias(boolean z) {
            this.hasBias = z;
        }

        public void setCapsules(int i) {
            this.capsules = i;
        }

        public void setCapsuleDimensions(int i) {
            this.capsuleDimensions = i;
        }

        public void setConvolutionMode(ConvolutionMode convolutionMode) {
            this.convolutionMode = convolutionMode;
        }

        public void setUseRelu(boolean z) {
            this.useRelu = z;
        }

        public void setLeak(double d) {
            this.leak = d;
        }
    }

    public PrimaryCapsules(Builder builder) {
        super(builder);
        this.convolutionMode = ConvolutionMode.Truncate;
        this.useRelu = false;
        this.leak = EvaluationBinary.DEFAULT_EDGE_VALUE;
        this.kernelSize = builder.kernelSize;
        this.stride = builder.stride;
        this.padding = builder.padding;
        this.dilation = builder.dilation;
        this.channels = builder.channels;
        this.hasBias = builder.hasBias;
        this.capsules = builder.capsules;
        this.capsuleDimensions = builder.capsuleDimensions;
        this.convolutionMode = builder.convolutionMode;
        this.useRelu = builder.useRelu;
        this.leak = builder.leak;
        if (this.capsuleDimensions <= 0 || this.channels <= 0) {
            throw new IllegalArgumentException("Invalid configuration for Primary Capsules (layer name = \"" + this.layerName + "\"): capsuleDimensions and channels must be > 0.  Got: " + this.capsuleDimensions + ", " + this.channels);
        }
        if (this.capsules < 0) {
            throw new IllegalArgumentException("Invalid configuration for Capsule Layer (layer name = \"" + this.layerName + "\"): capsules must be >= 0 if set.  Got: " + this.capsules);
        }
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer
    public SDVariable defineLayer(SameDiff sameDiff, SDVariable sDVariable, Map<String, SDVariable> map, SDVariable sDVariable2) {
        Conv2DConfig build = Conv2DConfig.builder().kH(this.kernelSize[0]).kW(this.kernelSize[1]).sH(this.stride[0]).sW(this.stride[1]).pH(this.padding[0]).pW(this.padding[1]).dH(this.dilation[0]).dW(this.dilation[1]).paddingMode(ConvolutionMode.mapToMode(this.convolutionMode)).build();
        SDVariable conv2d = this.hasBias ? sameDiff.cnn.conv2d(sDVariable, map.get(WEIGHT_PARAM), map.get(BIAS_PARAM), build) : sameDiff.cnn.conv2d(sDVariable, map.get(WEIGHT_PARAM), build);
        if (this.useRelu) {
            conv2d = this.leak == EvaluationBinary.DEFAULT_EDGE_VALUE ? sameDiff.nn.relu(conv2d, EvaluationBinary.DEFAULT_EDGE_VALUE) : sameDiff.nn.leakyRelu(conv2d, this.leak);
        }
        return CapsuleUtils.squash(sameDiff, conv2d.reshape(new int[]{-1, this.capsules, this.capsuleDimensions}), 2);
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer
    public void defineParameters(SDLayerParams sDLayerParams) {
        sDLayerParams.clear();
        sDLayerParams.addWeightParam(WEIGHT_PARAM, this.kernelSize[0], this.kernelSize[1], this.inputChannels, this.capsuleDimensions * this.channels);
        if (this.hasBias) {
            sDLayerParams.addBiasParam(BIAS_PARAM, this.capsuleDimensions * this.channels);
        }
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer
    public void initializeParameters(Map<String, INDArray> map) {
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                for (Map.Entry<String, INDArray> entry : map.entrySet()) {
                    if (BIAS_PARAM.equals(entry.getKey())) {
                        entry.getValue().assign(0);
                    } else if (WEIGHT_PARAM.equals(entry.getKey())) {
                        WeightInitUtil.initWeights(this.inputChannels * this.kernelSize[0] * this.kernelSize[1], (((this.capsuleDimensions * this.channels) * this.kernelSize[0]) * this.kernelSize[1]) / (this.stride[0] * this.stride[1]), entry.getValue().shape(), this.weightInit, (Distribution) null, 'c', entry.getValue());
                    }
                }
                if (scopeOutOfWorkspaces != null) {
                    if (0 == 0) {
                        scopeOutOfWorkspaces.close();
                        return;
                    }
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th4;
        }
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    public InputType getOutputType(int i, InputType inputType) {
        if (inputType == null || inputType.getType() != InputType.Type.CNN) {
            throw new IllegalStateException("Invalid input for Primary Capsules layer (layer name = \"" + this.layerName + "\"): expect CNN input.  Got: " + inputType);
        }
        if (this.capsules > 0) {
            return InputType.recurrent(this.capsules, this.capsuleDimensions);
        }
        InputType.InputTypeConvolutional inputTypeConvolutional = (InputType.InputTypeConvolutional) InputTypeUtil.getOutputTypeCnnLayers(inputType, this.kernelSize, this.stride, this.padding, this.dilation, this.convolutionMode, this.capsuleDimensions * this.channels, -1L, getLayerName(), PrimaryCapsules.class);
        return InputType.recurrent((int) (((inputTypeConvolutional.getChannels() * inputTypeConvolutional.getHeight()) * inputTypeConvolutional.getWidth()) / this.capsuleDimensions), this.capsuleDimensions);
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer, org.deeplearning4j.nn.conf.layers.Layer
    public void setNIn(InputType inputType, boolean z) {
        if (inputType == null || inputType.getType() != InputType.Type.CNN) {
            throw new IllegalStateException("Invalid input for Primary Capsules layer (layer name = \"" + this.layerName + "\"): expect CNN input.  Got: " + inputType);
        }
        this.inputChannels = (int) ((InputType.InputTypeConvolutional) inputType).getChannels();
        if (this.capsules <= 0 || z) {
            InputType.InputTypeConvolutional inputTypeConvolutional = (InputType.InputTypeConvolutional) InputTypeUtil.getOutputTypeCnnLayers(inputType, this.kernelSize, this.stride, this.padding, this.dilation, this.convolutionMode, this.capsuleDimensions * this.channels, -1L, getLayerName(), PrimaryCapsules.class);
            this.capsules = (int) (((inputTypeConvolutional.getChannels() * inputTypeConvolutional.getHeight()) * inputTypeConvolutional.getWidth()) / this.capsuleDimensions);
        }
    }

    public int[] getKernelSize() {
        return this.kernelSize;
    }

    public int[] getStride() {
        return this.stride;
    }

    public int[] getPadding() {
        return this.padding;
    }

    public int[] getDilation() {
        return this.dilation;
    }

    public int getInputChannels() {
        return this.inputChannels;
    }

    public int getChannels() {
        return this.channels;
    }

    public boolean isHasBias() {
        return this.hasBias;
    }

    public int getCapsules() {
        return this.capsules;
    }

    public int getCapsuleDimensions() {
        return this.capsuleDimensions;
    }

    public ConvolutionMode getConvolutionMode() {
        return this.convolutionMode;
    }

    public boolean isUseRelu() {
        return this.useRelu;
    }

    public double getLeak() {
        return this.leak;
    }

    public void setKernelSize(int[] iArr) {
        this.kernelSize = iArr;
    }

    public void setStride(int[] iArr) {
        this.stride = iArr;
    }

    public void setPadding(int[] iArr) {
        this.padding = iArr;
    }

    public void setDilation(int[] iArr) {
        this.dilation = iArr;
    }

    public void setInputChannels(int i) {
        this.inputChannels = i;
    }

    public void setChannels(int i) {
        this.channels = i;
    }

    public void setHasBias(boolean z) {
        this.hasBias = z;
    }

    public void setCapsules(int i) {
        this.capsules = i;
    }

    public void setCapsuleDimensions(int i) {
        this.capsuleDimensions = i;
    }

    public void setConvolutionMode(ConvolutionMode convolutionMode) {
        this.convolutionMode = convolutionMode;
    }

    public void setUseRelu(boolean z) {
        this.useRelu = z;
    }

    public void setLeak(double d) {
        this.leak = d;
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer, org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer, org.deeplearning4j.nn.conf.layers.Layer
    public String toString() {
        return "PrimaryCapsules(kernelSize=" + Arrays.toString(getKernelSize()) + ", stride=" + Arrays.toString(getStride()) + ", padding=" + Arrays.toString(getPadding()) + ", dilation=" + Arrays.toString(getDilation()) + ", inputChannels=" + getInputChannels() + ", channels=" + getChannels() + ", hasBias=" + isHasBias() + ", capsules=" + getCapsules() + ", capsuleDimensions=" + getCapsuleDimensions() + ", convolutionMode=" + getConvolutionMode() + ", useRelu=" + isUseRelu() + ", leak=" + getLeak() + ")";
    }

    public PrimaryCapsules() {
        this.convolutionMode = ConvolutionMode.Truncate;
        this.useRelu = false;
        this.leak = EvaluationBinary.DEFAULT_EDGE_VALUE;
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer, org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer, org.deeplearning4j.nn.conf.layers.Layer
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof PrimaryCapsules)) {
            return false;
        }
        PrimaryCapsules primaryCapsules = (PrimaryCapsules) obj;
        if (!primaryCapsules.canEqual(this) || !super.equals(obj) || getInputChannels() != primaryCapsules.getInputChannels() || getChannels() != primaryCapsules.getChannels() || isHasBias() != primaryCapsules.isHasBias() || getCapsules() != primaryCapsules.getCapsules() || getCapsuleDimensions() != primaryCapsules.getCapsuleDimensions() || isUseRelu() != primaryCapsules.isUseRelu() || Double.compare(getLeak(), primaryCapsules.getLeak()) != 0 || !Arrays.equals(getKernelSize(), primaryCapsules.getKernelSize()) || !Arrays.equals(getStride(), primaryCapsules.getStride()) || !Arrays.equals(getPadding(), primaryCapsules.getPadding()) || !Arrays.equals(getDilation(), primaryCapsules.getDilation())) {
            return false;
        }
        ConvolutionMode convolutionMode = getConvolutionMode();
        ConvolutionMode convolutionMode2 = primaryCapsules.getConvolutionMode();
        return convolutionMode == null ? convolutionMode2 == null : convolutionMode.equals(convolutionMode2);
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer, org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer, org.deeplearning4j.nn.conf.layers.Layer
    protected boolean canEqual(Object obj) {
        return obj instanceof PrimaryCapsules;
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer, org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer, org.deeplearning4j.nn.conf.layers.Layer
    public int hashCode() {
        int hashCode = (((((((((((super.hashCode() * 59) + getInputChannels()) * 59) + getChannels()) * 59) + (isHasBias() ? 79 : 97)) * 59) + getCapsules()) * 59) + getCapsuleDimensions()) * 59) + (isUseRelu() ? 79 : 97);
        long doubleToLongBits = Double.doubleToLongBits(getLeak());
        int hashCode2 = (((((((((hashCode * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits))) * 59) + Arrays.hashCode(getKernelSize())) * 59) + Arrays.hashCode(getStride())) * 59) + Arrays.hashCode(getPadding())) * 59) + Arrays.hashCode(getDilation());
        ConvolutionMode convolutionMode = getConvolutionMode();
        return (hashCode2 * 59) + (convolutionMode == null ? 43 : convolutionMode.hashCode());
    }
}
