public abstract class InputType extends Object implements Serializable
ComputationGraphConfiguration.GraphBuilder.setInputTypes(InputType...) and
ComputationGraphConfiguration.addPreProcessors(InputType...)| Modifier and Type | Class and Description |
|---|---|
static class |
InputType.InputTypeConvolutional |
static class |
InputType.InputTypeConvolutional3D |
static class |
InputType.InputTypeConvolutionalFlat |
static class |
InputType.InputTypeFeedForward |
static class |
InputType.InputTypeRecurrent |
static class |
InputType.Type
The type of activations in/out of a given GraphVertex
FF: Standard feed-foward (2d minibatch, 1d per example) data RNN: Recurrent neural network (3d minibatch) time series data CNN: 2D Convolutional neural network (4d minibatch, [miniBatchSize, channels, height, width]) CNNFlat: Flattened 2D conv net data (2d minibatch, [miniBatchSize, height * width * channels]) CNN3D: 3D convolutional neural network (5d minibatch, [miniBatchSize, channels, height, width, channels]) |
| Constructor and Description |
|---|
InputType() |
| Modifier and Type | Method and Description |
|---|---|
abstract int |
arrayElementsPerExample() |
static InputType |
convolutional(int height,
int width,
int depth)
Input type for convolutional (CNN) data, that is 4d with shape [miniBatchSize, channels, height, width].
|
static InputType |
convolutional3D(int depth,
int height,
int width,
int channels)
Input type for 3D convolutional (CNN3D) data, that is 5d with shape
[miniBatchSize, channels, height, width, channels].
|
static InputType |
convolutionalFlat(int height,
int width,
int depth)
Input type for convolutional (CNN) data, where the data is in flattened (row vector) format.
|
static InputType |
feedForward(int size)
InputType for feed forward network data
|
int[] |
getShape()
Returns the shape of this InputType without minibatch dimension in the returned array
|
abstract int[] |
getShape(boolean includeBatchDim)
Returns the shape of this InputType
|
abstract InputType.Type |
getType() |
static InputType |
inferInputType(org.nd4j.linalg.api.ndarray.INDArray inputArray) |
static InputType[] |
inferInputTypes(org.nd4j.linalg.api.ndarray.INDArray... inputArrays) |
static InputType |
recurrent(int size)
InputType for recurrent neural network (time series) data
|
static InputType |
recurrent(int size,
int timeSeriesLength)
InputType for recurrent neural network (time series) data
|
abstract String |
toString() |
public abstract InputType.Type getType()
public abstract int arrayElementsPerExample()
public abstract int[] getShape(boolean includeBatchDim)
includeBatchDim - Whether to include minibatch in the return shape arraypublic int[] getShape()
public static InputType feedForward(int size)
size - The size of the activationspublic static InputType recurrent(int size)
size - The size of the activationspublic static InputType recurrent(int size, int timeSeriesLength)
size - The size of the activationstimeSeriesLength - Length of the input time seriespublic static InputType convolutional(int height, int width, int depth)
convolutionalFlat(int, int, int)height - height of the inputwidth - Width of the inputdepth - Depth, or number of channelspublic static InputType convolutional3D(int depth, int height, int width, int channels)
height - height of the inputwidth - Width of the inputdepth - Depth of the inputchannels - Number of channels of the inputpublic static InputType convolutionalFlat(int height, int width, int depth)
convolutional(int, int, int)height - Height of the (unflattened) data represented by this input typewidth - Width of the (unflattened) data represented by this input typedepth - Depth of the (unflattened) data represented by this input typepublic static InputType inferInputType(org.nd4j.linalg.api.ndarray.INDArray inputArray)
public static InputType[] inferInputTypes(org.nd4j.linalg.api.ndarray.INDArray... inputArrays)
Copyright © 2018. All rights reserved.