/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.inputs;

import java.io.Serializable;
import java.util.Arrays;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.DataFormat;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.layers.Convolution3D;
import org.nd4j.common.util.OneTimeLogger;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.jackson.annotation.JsonIgnore;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@JsonInclude(value=JsonInclude.Include.NON_NULL)
@JsonTypeInfo(use=JsonTypeInfo.Id.CLASS, include=JsonTypeInfo.As.PROPERTY, property="@class")
public abstract class InputType
implements Serializable {
    private static final Logger log = LoggerFactory.getLogger(InputType.class);
    private static CNN2DFormat defaultCNN2DFormat = CNN2DFormat.NCHW;

    public static CNN2DFormat getDefaultCNN2DFormat() {
        return defaultCNN2DFormat;
    }

    public static void setDefaultCNN2DFormat(CNN2DFormat defaultCNN2DFormat) {
        InputType.defaultCNN2DFormat = defaultCNN2DFormat;
    }

    @JsonIgnore
    public abstract Type getType();

    public abstract String toString();

    @JsonIgnore
    public abstract long arrayElementsPerExample();

    @JsonIgnore
    public abstract long[] getShape(boolean var1);

    public long[] getShape() {
        return this.getShape(false);
    }

    public static InputType feedForward(long size) {
        return new InputTypeFeedForward(size, null);
    }

    public static InputType feedForward(long size, DataFormat timeDistributedFormat) {
        return new InputTypeFeedForward(size, timeDistributedFormat);
    }

    public static InputType recurrent(long size) {
        return new InputTypeRecurrent(size);
    }

    public static InputType recurrent(long size, long timeSeriesLength) {
        return new InputTypeRecurrent(size, timeSeriesLength, RNNFormat.NCW);
    }

    public static InputType recurrent(long size, RNNFormat format) {
        return new InputTypeRecurrent(size, format);
    }

    public static InputType recurrent(long size, long timeSeriesLength, RNNFormat format) {
        return new InputTypeRecurrent(size, timeSeriesLength, format);
    }

    public static InputType convolutional(long height, long width, long depth) {
        return InputType.convolutional(height, width, depth, InputType.getDefaultCNN2DFormat());
    }

    public static InputType convolutional(long height, long width, long depth, CNN2DFormat format) {
        return new InputTypeConvolutional(height, width, depth, format);
    }

    @Deprecated
    public static InputType convolutional3D(long depth, long height, long width, long channels) {
        return InputType.convolutional3D(Convolution3D.DataFormat.NDHWC, depth, height, width, channels);
    }

    public static InputType convolutional3D(Convolution3D.DataFormat dataFormat, long depth, long height, long width, long channels) {
        return new InputTypeConvolutional3D(dataFormat, depth, height, width, channels);
    }

    public static InputType convolutionalFlat(long height, long width, long depth) {
        return new InputTypeConvolutionalFlat(height, width, depth);
    }

    public static InputType inferInputType(INDArray inputArray) {
        switch (inputArray.rank()) {
            case 2: {
                return InputType.feedForward(inputArray.size(1));
            }
            case 3: {
                return InputType.recurrent(inputArray.size(1), (int)inputArray.size(2));
            }
            case 4: {
                return InputType.convolutional(inputArray.size(2), (int)inputArray.size(3), (int)inputArray.size(1));
            }
            case 5: {
                return InputType.convolutional3D(inputArray.size(2), (int)inputArray.size(3), (int)inputArray.size(4), (int)inputArray.size(1));
            }
        }
        throw new IllegalArgumentException("Cannot infer input type for array with shape: " + Arrays.toString(inputArray.shape()));
    }

    public static InputType[] inferInputTypes(INDArray ... inputArrays) {
        InputType[] out = new InputType[inputArrays.length];
        for (int i = 0; i < inputArrays.length; ++i) {
            out[i] = InputType.inferInputType(inputArrays[i]);
        }
        return out;
    }

    public static class InputTypeConvolutionalFlat
    extends InputType {
        private long height;
        private long width;
        private long depth;

        public InputTypeConvolutionalFlat(@JsonProperty(value="height") long height, @JsonProperty(value="width") long width, @JsonProperty(value="depth") long depth) {
            this.height = height;
            this.width = width;
            this.depth = depth;
        }

        @Override
        public Type getType() {
            return Type.CNNFlat;
        }

        public long getFlattenedSize() {
            return this.height * this.width * this.depth;
        }

        public InputType getUnflattenedType() {
            return InputType.convolutional(this.height, this.width, this.depth);
        }

        @Override
        public String toString() {
            return "InputTypeConvolutionalFlat(h=" + this.height + ",w=" + this.width + ",d=" + this.depth + ")";
        }

        @Override
        public long arrayElementsPerExample() {
            return this.height * this.width * this.depth;
        }

        @Override
        public long[] getShape(boolean includeBatchDim) {
            if (includeBatchDim) {
                return new long[]{-1L, this.depth, this.height, this.width};
            }
            return new long[]{this.depth, this.height, this.width};
        }

        public InputTypeConvolutionalFlat() {
        }

        public long getHeight() {
            return this.height;
        }

        public long getWidth() {
            return this.width;
        }

        public long getDepth() {
            return this.depth;
        }

        public void setHeight(long height) {
            this.height = height;
        }

        public void setWidth(long width) {
            this.width = width;
        }

        public void setDepth(long depth) {
            this.depth = depth;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof InputTypeConvolutionalFlat)) {
                return false;
            }
            InputTypeConvolutionalFlat other = (InputTypeConvolutionalFlat)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (this.getHeight() != other.getHeight()) {
                return false;
            }
            if (this.getWidth() != other.getWidth()) {
                return false;
            }
            return this.getDepth() == other.getDepth();
        }

        protected boolean canEqual(Object other) {
            return other instanceof InputTypeConvolutionalFlat;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            long $height = this.getHeight();
            result = result * 59 + (int)($height >>> 32 ^ $height);
            long $width = this.getWidth();
            result = result * 59 + (int)($width >>> 32 ^ $width);
            long $depth = this.getDepth();
            result = result * 59 + (int)($depth >>> 32 ^ $depth);
            return result;
        }
    }

    public static class InputTypeConvolutional3D
    extends InputType {
        private Convolution3D.DataFormat dataFormat;
        private long depth;
        private long height;
        private long width;
        private long channels;

        public InputTypeConvolutional3D(@JsonProperty(value="dataFormat") Convolution3D.DataFormat dataFormat, @JsonProperty(value="depth") long depth, @JsonProperty(value="height") long height, @JsonProperty(value="width") long width, @JsonProperty(value="channels") long channels) {
            this.dataFormat = dataFormat;
            this.depth = depth;
            this.height = height;
            this.width = width;
            this.channels = channels;
        }

        @Override
        public Type getType() {
            return Type.CNN3D;
        }

        @Override
        public String toString() {
            return "InputTypeConvolutional3D(format=" + this.dataFormat + ",d=" + this.depth + ",h=" + this.height + ",w=" + this.width + ",c=" + this.channels + ")";
        }

        @Override
        public long arrayElementsPerExample() {
            return this.height * this.width * this.depth * this.channels;
        }

        @Override
        public long[] getShape(boolean includeBatchDim) {
            if (this.dataFormat == Convolution3D.DataFormat.NDHWC) {
                if (includeBatchDim) {
                    return new long[]{-1L, this.depth, this.height, this.width, this.channels};
                }
                return new long[]{this.depth, this.height, this.width, this.channels};
            }
            if (includeBatchDim) {
                return new long[]{-1L, this.channels, this.depth, this.height, this.width};
            }
            return new long[]{this.channels, this.depth, this.height, this.width};
        }

        public InputTypeConvolutional3D() {
        }

        public Convolution3D.DataFormat getDataFormat() {
            return this.dataFormat;
        }

        public long getDepth() {
            return this.depth;
        }

        public long getHeight() {
            return this.height;
        }

        public long getWidth() {
            return this.width;
        }

        public long getChannels() {
            return this.channels;
        }

        public void setDataFormat(Convolution3D.DataFormat dataFormat) {
            this.dataFormat = dataFormat;
        }

        public void setDepth(long depth) {
            this.depth = depth;
        }

        public void setHeight(long height) {
            this.height = height;
        }

        public void setWidth(long width) {
            this.width = width;
        }

        public void setChannels(long channels) {
            this.channels = channels;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof InputTypeConvolutional3D)) {
                return false;
            }
            InputTypeConvolutional3D other = (InputTypeConvolutional3D)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (this.getDepth() != other.getDepth()) {
                return false;
            }
            if (this.getHeight() != other.getHeight()) {
                return false;
            }
            if (this.getWidth() != other.getWidth()) {
                return false;
            }
            if (this.getChannels() != other.getChannels()) {
                return false;
            }
            Convolution3D.DataFormat this$dataFormat = this.getDataFormat();
            Convolution3D.DataFormat other$dataFormat = other.getDataFormat();
            return !(this$dataFormat == null ? other$dataFormat != null : !this$dataFormat.equals(other$dataFormat));
        }

        protected boolean canEqual(Object other) {
            return other instanceof InputTypeConvolutional3D;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            long $depth = this.getDepth();
            result = result * 59 + (int)($depth >>> 32 ^ $depth);
            long $height = this.getHeight();
            result = result * 59 + (int)($height >>> 32 ^ $height);
            long $width = this.getWidth();
            result = result * 59 + (int)($width >>> 32 ^ $width);
            long $channels = this.getChannels();
            result = result * 59 + (int)($channels >>> 32 ^ $channels);
            Convolution3D.DataFormat $dataFormat = this.getDataFormat();
            result = result * 59 + ($dataFormat == null ? 43 : $dataFormat.hashCode());
            return result;
        }
    }

    public static class InputTypeConvolutional
    extends InputType {
        private long height;
        private long width;
        private long channels;
        private CNN2DFormat format = CNN2DFormat.NCHW;

        public InputTypeConvolutional(@JsonProperty(value="height") long height, @JsonProperty(value="width") long width, @JsonProperty(value="channels") long channels, @JsonProperty(value="format") CNN2DFormat format) {
            if (height <= 0L) {
                OneTimeLogger.warn((Logger)log, (String)"Assigning height of 0. Normally this is not valid. Exceptions for this are generally relatedto model import and unknown dimensions", (Object[])new Object[0]);
            }
            if (width <= 0L) {
                OneTimeLogger.warn((Logger)log, (String)"Assigning height of 0. Normally this is not valid. Exceptions for this are generally relatedto model import and unknown dimensions", (Object[])new Object[0]);
            }
            if (width <= 0L) {
                OneTimeLogger.warn((Logger)log, (String)"Assigning width of 0. Normally this is not valid. Exceptions for this are generally relatedto model import and unknown dimensions", (Object[])new Object[0]);
            }
            if (channels <= 0L) {
                OneTimeLogger.warn((Logger)log, (String)"Assigning width of 0. Normally this is not valid. Exceptions for this are generally relatedto model import and unknown dimensions", (Object[])new Object[0]);
            }
            this.height = height;
            this.width = width;
            this.channels = channels;
            if (format != null) {
                this.format = format;
            }
        }

        public InputTypeConvolutional(long height, long width, long channels) {
            this(height, width, channels, CNN2DFormat.NCHW);
        }

        @Deprecated
        public long getDepth() {
            return this.channels;
        }

        @Deprecated
        public void setDepth(long depth) {
            this.channels = depth;
        }

        @Override
        public Type getType() {
            return Type.CNN;
        }

        @Override
        public String toString() {
            return "InputTypeConvolutional(h=" + this.height + ",w=" + this.width + ",c=" + this.channels + "," + this.format + ")";
        }

        @Override
        public long arrayElementsPerExample() {
            return this.height * this.width * this.channels;
        }

        @Override
        public long[] getShape(boolean includeBatchDim) {
            if (this.format == CNN2DFormat.NCHW) {
                if (includeBatchDim) {
                    return new long[]{-1L, this.channels, this.height, this.width};
                }
                return new long[]{this.channels, this.height, this.width};
            }
            if (includeBatchDim) {
                return new long[]{-1L, this.height, this.width, this.channels};
            }
            return new long[]{this.height, this.width, this.channels};
        }

        public InputTypeConvolutional() {
        }

        public long getHeight() {
            return this.height;
        }

        public long getWidth() {
            return this.width;
        }

        public long getChannels() {
            return this.channels;
        }

        public CNN2DFormat getFormat() {
            return this.format;
        }

        public void setHeight(long height) {
            this.height = height;
        }

        public void setWidth(long width) {
            this.width = width;
        }

        public void setChannels(long channels) {
            this.channels = channels;
        }

        public void setFormat(CNN2DFormat format) {
            this.format = format;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof InputTypeConvolutional)) {
                return false;
            }
            InputTypeConvolutional other = (InputTypeConvolutional)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (this.getHeight() != other.getHeight()) {
                return false;
            }
            if (this.getWidth() != other.getWidth()) {
                return false;
            }
            if (this.getChannels() != other.getChannels()) {
                return false;
            }
            CNN2DFormat this$format = this.getFormat();
            CNN2DFormat other$format = other.getFormat();
            return !(this$format == null ? other$format != null : !this$format.equals(other$format));
        }

        protected boolean canEqual(Object other) {
            return other instanceof InputTypeConvolutional;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            long $height = this.getHeight();
            result = result * 59 + (int)($height >>> 32 ^ $height);
            long $width = this.getWidth();
            result = result * 59 + (int)($width >>> 32 ^ $width);
            long $channels = this.getChannels();
            result = result * 59 + (int)($channels >>> 32 ^ $channels);
            CNN2DFormat $format = this.getFormat();
            result = result * 59 + ($format == null ? 43 : $format.hashCode());
            return result;
        }
    }

    public static class InputTypeRecurrent
    extends InputType {
        private long size;
        private long timeSeriesLength;
        private RNNFormat format = RNNFormat.NCW;

        public InputTypeRecurrent(long size) {
            this(size, -1L);
        }

        public InputTypeRecurrent(long size, long timeSeriesLength) {
            this(size, timeSeriesLength, RNNFormat.NCW);
        }

        public InputTypeRecurrent(long size, RNNFormat format) {
            this(size, -1L, format);
        }

        public InputTypeRecurrent(@JsonProperty(value="size") long size, @JsonProperty(value="timeSeriesLength") long timeSeriesLength, @JsonProperty(value="format") RNNFormat format) {
            this.size = size;
            this.timeSeriesLength = timeSeriesLength;
            this.format = format;
        }

        @Override
        public Type getType() {
            return Type.RNN;
        }

        @Override
        public String toString() {
            if (this.timeSeriesLength > 0L) {
                return "InputTypeRecurrent(" + this.size + ",timeSeriesLength=" + this.timeSeriesLength + ",format=" + this.format + ")";
            }
            return "InputTypeRecurrent(" + this.size + ",format=" + this.format + ")";
        }

        @Override
        public long arrayElementsPerExample() {
            if (this.timeSeriesLength <= 0L) {
                throw new IllegalStateException("Cannot calculate number of array elements per example: time series length is not set. Use InputType.recurrent(int size, int timeSeriesLength) instead?");
            }
            return this.timeSeriesLength * this.size;
        }

        @Override
        public long[] getShape(boolean includeBatchDim) {
            if (includeBatchDim) {
                if (this.format == RNNFormat.NCW) {
                    return new long[]{-1L, this.size, this.timeSeriesLength};
                }
                return new long[]{-1L, this.timeSeriesLength, this.size};
            }
            if (this.format == RNNFormat.NCW) {
                return new long[]{this.size, this.timeSeriesLength};
            }
            return new long[]{this.timeSeriesLength, this.size};
        }

        public InputTypeRecurrent() {
        }

        public long getSize() {
            return this.size;
        }

        public long getTimeSeriesLength() {
            return this.timeSeriesLength;
        }

        public RNNFormat getFormat() {
            return this.format;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof InputTypeRecurrent)) {
                return false;
            }
            InputTypeRecurrent other = (InputTypeRecurrent)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (this.getSize() != other.getSize()) {
                return false;
            }
            if (this.getTimeSeriesLength() != other.getTimeSeriesLength()) {
                return false;
            }
            RNNFormat this$format = this.getFormat();
            RNNFormat other$format = other.getFormat();
            return !(this$format == null ? other$format != null : !this$format.equals(other$format));
        }

        protected boolean canEqual(Object other) {
            return other instanceof InputTypeRecurrent;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            long $size = this.getSize();
            result = result * 59 + (int)($size >>> 32 ^ $size);
            long $timeSeriesLength = this.getTimeSeriesLength();
            result = result * 59 + (int)($timeSeriesLength >>> 32 ^ $timeSeriesLength);
            RNNFormat $format = this.getFormat();
            result = result * 59 + ($format == null ? 43 : $format.hashCode());
            return result;
        }
    }

    public static class InputTypeFeedForward
    extends InputType {
        private long size;
        private DataFormat timeDistributedFormat;

        public InputTypeFeedForward(@JsonProperty(value="size") long size, @JsonProperty(value="timeDistributedFormat") DataFormat timeDistributedFormat) {
            if (size <= 0L) {
                OneTimeLogger.warn((Logger)log, (String)"Assigning a size of zero. This is normally only valid in model import cases with unknown dimensions.", (Object[])new Object[0]);
            }
            this.size = size;
            this.timeDistributedFormat = timeDistributedFormat;
        }

        @Override
        public Type getType() {
            return Type.FF;
        }

        @Override
        public String toString() {
            return "InputTypeFeedForward(" + this.size + (this.timeDistributedFormat != null ? "," + this.timeDistributedFormat : "") + ")";
        }

        @Override
        public long arrayElementsPerExample() {
            return this.size;
        }

        @Override
        public long[] getShape(boolean includeBatchDim) {
            if (includeBatchDim) {
                return new long[]{-1L, this.size};
            }
            return new long[]{this.size};
        }

        public InputTypeFeedForward() {
        }

        public long getSize() {
            return this.size;
        }

        public DataFormat getTimeDistributedFormat() {
            return this.timeDistributedFormat;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof InputTypeFeedForward)) {
                return false;
            }
            InputTypeFeedForward other = (InputTypeFeedForward)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (this.getSize() != other.getSize()) {
                return false;
            }
            DataFormat this$timeDistributedFormat = this.getTimeDistributedFormat();
            DataFormat other$timeDistributedFormat = other.getTimeDistributedFormat();
            return !(this$timeDistributedFormat == null ? other$timeDistributedFormat != null : !this$timeDistributedFormat.equals(other$timeDistributedFormat));
        }

        protected boolean canEqual(Object other) {
            return other instanceof InputTypeFeedForward;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            long $size = this.getSize();
            result = result * 59 + (int)($size >>> 32 ^ $size);
            DataFormat $timeDistributedFormat = this.getTimeDistributedFormat();
            result = result * 59 + ($timeDistributedFormat == null ? 43 : $timeDistributedFormat.hashCode());
            return result;
        }
    }

    public static enum Type {
        FF,
        RNN,
        CNN,
        CNNFlat,
        CNN3D;

    }
}

