package org.deeplearning4j.nn.conf.graph;

import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.layers.Convolution3D;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/nn/conf/graph/MergeVertex.class */
public class MergeVertex extends GraphVertex {
    protected int mergeAxis = 1;

    @Override // org.deeplearning4j.nn.conf.graph.GraphVertex
    /* renamed from: clone */
    public MergeVertex mo53clone() {
        return new MergeVertex();
    }

    @Override // org.deeplearning4j.nn.conf.graph.GraphVertex
    public boolean equals(Object obj) {
        return obj instanceof MergeVertex;
    }

    @Override // org.deeplearning4j.nn.conf.graph.GraphVertex
    public int hashCode() {
        return 433682566;
    }

    @Override // org.deeplearning4j.nn.conf.graph.GraphVertex
    public long numParams(boolean z) {
        return 0L;
    }

    @Override // org.deeplearning4j.nn.conf.graph.GraphVertex
    public int minVertexInputs() {
        return 2;
    }

    @Override // org.deeplearning4j.nn.conf.graph.GraphVertex
    public int maxVertexInputs() {
        return Integer.MAX_VALUE;
    }

    public String toString() {
        return "MergeVertex()";
    }

    @Override // org.deeplearning4j.nn.conf.graph.GraphVertex
    public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph computationGraph, String str, int i, INDArray iNDArray, boolean z, DataType dataType) {
        return new org.deeplearning4j.nn.graph.vertex.impl.MergeVertex(computationGraph, str, i, dataType, this.mergeAxis);
    }

    @Override // org.deeplearning4j.nn.conf.graph.GraphVertex
    public InputType getOutputType(int i, InputType... inputTypeArr) throws InvalidInputTypeException {
        long size;
        if (inputTypeArr.length == 1) {
            return inputTypeArr[0];
        }
        InputType inputType = inputTypeArr[0];
        if (inputType.getType() == InputType.Type.CNNFlat) {
            throw new InvalidInputTypeException("Invalid input: MergeVertex cannot currently merge CNN data in flattened format. Got: " + inputTypeArr);
        }
        if (inputType.getType() == InputType.Type.CNN3D) {
            InputType.InputTypeConvolutional3D inputTypeConvolutional3D = (InputType.InputTypeConvolutional3D) inputType;
            long depth = inputTypeConvolutional3D.getDepth();
            long width = inputTypeConvolutional3D.getWidth();
            long height = inputTypeConvolutional3D.getHeight();
            long channels = inputTypeConvolutional3D.getChannels();
            for (int i2 = 1; i2 < inputTypeArr.length; i2++) {
                if (inputTypeArr[i2].getType() != InputType.Type.CNN3D) {
                    throw new InvalidInputTypeException("Invalid input: MergeVertex cannot process activations of different types: first type = " + InputType.Type.CNN3D + ", input type " + (i2 + 1) + " = " + inputTypeArr[i2].getType());
                }
                InputType.InputTypeConvolutional3D inputTypeConvolutional3D2 = (InputType.InputTypeConvolutional3D) inputTypeArr[i2];
                long depth2 = inputTypeConvolutional3D2.getDepth();
                long width2 = inputTypeConvolutional3D2.getWidth();
                long height2 = inputTypeConvolutional3D2.getHeight();
                long channels2 = inputTypeConvolutional3D2.getChannels();
                if (depth != depth2 || width != width2 || height != height2) {
                    throw new InvalidInputTypeException("Invalid input: MergeVertex cannot merge CNN3D activations of different width/heights:first [channels,width,height] = [" + depth + "," + width + "," + height + "], input " + i2 + " = [" + depth2 + "," + width2 + "," + height2 + "]");
                }
                channels += channels2;
            }
            return InputType.convolutional3D(Convolution3D.DataFormat.NDHWC, depth, height, width, channels);
        }
        if (inputType.getType() == InputType.Type.CNN) {
            InputType.InputTypeConvolutional inputTypeConvolutional = (InputType.InputTypeConvolutional) inputType;
            CNN2DFormat format = inputTypeConvolutional.getFormat();
            long channels3 = inputTypeConvolutional.getChannels();
            long width3 = inputTypeConvolutional.getWidth();
            long height3 = inputTypeConvolutional.getHeight();
            long j = channels3;
            for (int i3 = 1; i3 < inputTypeArr.length; i3++) {
                if (inputTypeArr[i3].getType() != InputType.Type.CNN) {
                    throw new InvalidInputTypeException("Invalid input: MergeVertex cannot process activations of different types: first type = " + InputType.Type.CNN + ", input type " + (i3 + 1) + " = " + inputTypeArr[i3].getType());
                }
                InputType.InputTypeConvolutional inputTypeConvolutional2 = (InputType.InputTypeConvolutional) inputTypeArr[i3];
                long channels4 = inputTypeConvolutional2.getChannels();
                long width4 = inputTypeConvolutional2.getWidth();
                long height4 = inputTypeConvolutional2.getHeight();
                if (width3 != width4 || height3 != height4) {
                    throw new InvalidInputTypeException("Invalid input: MergeVertex cannot merge CNN activations of different width/heights:first [channels,width,height] = [" + channels3 + "," + width3 + "," + height3 + "], input " + i3 + " = [" + channels4 + "," + width4 + "," + height4 + "]");
                }
                j += channels4;
            }
            this.mergeAxis = format == CNN2DFormat.NCHW ? 1 : 3;
            return InputType.convolutional(height3, width3, j, format);
        }
        int i4 = 0;
        InputType.Type type = null;
        RNNFormat rNNFormat = null;
        for (int i5 = 0; i5 < inputTypeArr.length; i5++) {
            if (inputTypeArr[i5].getType() != inputType.getType()) {
                throw new InvalidInputTypeException("Invalid input: MergeVertex cannot merge activations of different types: first type = " + inputType.getType() + ", input type " + (i5 + 1) + " = " + inputTypeArr[i5].getType());
            }
            switch (inputTypeArr[i5].getType()) {
                case FF:
                    size = ((InputType.InputTypeFeedForward) inputTypeArr[i5]).getSize();
                    type = InputType.Type.FF;
                    break;
                case RNN:
                    size = ((InputType.InputTypeRecurrent) inputTypeArr[i5]).getSize();
                    rNNFormat = ((InputType.InputTypeRecurrent) inputTypeArr[i5]).getFormat();
                    this.mergeAxis = rNNFormat == RNNFormat.NCW ? 1 : 2;
                    type = InputType.Type.RNN;
                    break;
                default:
                    throw new IllegalStateException("Unknown input type: " + inputTypeArr[i5]);
            }
            i4 = size <= 0 ? -1 : (int) (i4 + size);
        }
        if (i4 <= 0) {
            return type == InputType.Type.FF ? InputType.feedForward(-1L) : InputType.recurrent(-1L, ((InputType.InputTypeRecurrent) inputTypeArr[0]).getTimeSeriesLength(), rNNFormat);
        }
        if (type == InputType.Type.FF) {
            return InputType.feedForward(i4);
        }
        return InputType.recurrent(i4, ((InputType.InputTypeRecurrent) inputTypeArr[0]).getTimeSeriesLength(), rNNFormat);
    }

    @Override // org.deeplearning4j.nn.conf.graph.GraphVertex
    public MemoryReport getMemoryReport(InputType... inputTypeArr) {
        return new LayerMemoryReport.Builder(null, MergeVertex.class, inputTypeArr[0], getOutputType(-1, inputTypeArr)).standardMemory(0L, 0L).workingMemory(0L, 0L, 0L, 0L).cacheMemory(0L, 0L).build();
    }
}
