/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras;

import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasActivation;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasBatchNormalization;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasConvolution;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasDense;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasDropout;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasEmbedding;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasFlatten;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasGlobalPooling;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasInput;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasLstm;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasMerge;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasPooling;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasZeroPadding;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationHardSigmoid;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.activations.impl.ActivationReLU;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.activations.impl.ActivationSoftPlus;
import org.nd4j.linalg.activations.impl.ActivationSoftSign;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.activations.impl.ActivationTanH;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.util.ArrayUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KerasLayer {
    private static final Logger log = LoggerFactory.getLogger(KerasLayer.class);
    public static final String LAYER_FIELD_CLASS_NAME = "class_name";
    public static final String LAYER_CLASS_NAME_ACTIVATION = "Activation";
    public static final String LAYER_CLASS_NAME_INPUT = "InputLayer";
    public static final String LAYER_CLASS_NAME_DROPOUT = "Dropout";
    public static final String LAYER_CLASS_NAME_DENSE = "Dense";
    public static final String LAYER_CLASS_NAME_TIME_DISTRIBUTED_DENSE = "TimeDistributedDense";
    public static final String LAYER_CLASS_NAME_LSTM = "LSTM";
    public static final String LAYER_CLASS_NAME_CONVOLUTION_1D = "Convolution1D";
    public static final String LAYER_CLASS_NAME_CONVOLUTION_2D = "Convolution2D";
    public static final String LAYER_CLASS_NAME_MAX_POOLING_1D = "MaxPooling1D";
    public static final String LAYER_CLASS_NAME_MAX_POOLING_2D = "MaxPooling2D";
    public static final String LAYER_CLASS_NAME_AVERAGE_POOLING_1D = "AveragePooling1D";
    public static final String LAYER_CLASS_NAME_AVERAGE_POOLING_2D = "AveragePooling2D";
    public static final String LAYER_CLASS_NAME_ZERO_PADDING_1D = "ZeroPadding1D";
    public static final String LAYER_CLASS_NAME_ZERO_PADDING_2D = "ZeroPadding2D";
    public static final String LAYER_CLASS_NAME_FLATTEN = "Flatten";
    public static final String LAYER_CLASS_NAME_MERGE = "Merge";
    public static final String LAYER_CLASS_NAME_BATCHNORMALIZATION = "BatchNormalization";
    public static final String LAYER_CLASS_NAME_TIME_DISTRIBUTED = "TimeDistributed";
    public static final String LAYER_CLASS_NAME_EMBEDDING = "Embedding";
    public static final String LAYER_CLASS_NAME_GLOBAL_MAX_POOLING_1D = "GlobalMaxPooling1D";
    public static final String LAYER_CLASS_NAME_GLOBAL_MAX_POOLING_2D = "GlobalMaxPooling2D";
    public static final String LAYER_CLASS_NAME_GLOBAL_AVERAGE_POOLING_1D = "GlobalAveragePooling1D";
    public static final String LAYER_CLASS_NAME_GLOBAL_AVERAGE_POOLING_2D = "GlobalAveragePooling2D";
    public static final String LAYER_FIELD_CONFIG = "config";
    public static final String LAYER_FIELD_NAME = "name";
    public static final String LAYER_FIELD_BATCH_INPUT_SHAPE = "batch_input_shape";
    public static final String LAYER_FIELD_INBOUND_NODES = "inbound_nodes";
    public static final String LAYER_FIELD_DROPOUT = "dropout";
    public static final String LAYER_FIELD_DROPOUT_W = "dropout_W";
    public static final String LAYER_FIELD_OUTPUT_DIM = "output_dim";
    public static final String LAYER_FIELD_NB_FILTER = "nb_filter";
    public static final String LAYER_FIELD_NB_ROW = "nb_row";
    public static final String LAYER_FIELD_NB_COL = "nb_col";
    public static final String LAYER_FIELD_POOL_SIZE = "pool_size";
    public static final String LAYER_FIELD_SUBSAMPLE = "subsample";
    public static final String LAYER_FIELD_STRIDES = "strides";
    public static final String LAYER_FIELD_BORDER_MODE = "border_mode";
    public static final String LAYER_BORDER_MODE_SAME = "same";
    public static final String LAYER_BORDER_MODE_VALID = "valid";
    public static final String LAYER_BORDER_MODE_FULL = "full";
    public static final String LAYER_FIELD_W_REGULARIZER = "W_regularizer";
    public static final String LAYER_FIELD_B_REGULARIZER = "b_regularizer";
    public static final String REGULARIZATION_TYPE_L1 = "l1";
    public static final String REGULARIZATION_TYPE_L2 = "l2";
    public static final String LAYER_FIELD_INIT = "init";
    public static final String INIT_UNIFORM = "uniform";
    public static final String INIT_ZERO = "zero";
    public static final String INIT_GLOROT_NORMAL = "glorot_normal";
    public static final String INIT_GLOROT_UNIFORM = "glorot_uniform";
    public static final String INIT_HE_NORMAL = "he_normal";
    public static final String INIT_HE_UNIFORM = "he_uniform";
    public static final String INIT_LECUN_UNIFORM = "lecun_uniform";
    public static final String INIT_NORMAL = "normal";
    public static final String INIT_ORTHOGONAL = "orthogonal";
    public static final String INIT_IDENTITY = "identity";
    public static final String LAYER_FIELD_ACTIVATION = "activation";
    public static final String KERAS_ACTIVATION_SOFTMAX = "softmax";
    public static final String KERAS_ACTIVATION_SOFTPLUS = "softplus";
    public static final String KERAS_ACTIVATION_SOFTSIGN = "softsign";
    public static final String KERAS_ACTIVATION_RELU = "relu";
    public static final String KERAS_ACTIVATION_TANH = "tanh";
    public static final String KERAS_ACTIVATION_SIGMOID = "sigmoid";
    public static final String KERAS_ACTIVATION_HARD_SIGMOID = "hard_sigmoid";
    public static final String KERAS_ACTIVATION_LINEAR = "linear";
    public static final String LAYER_FIELD_DIM_ORDERING = "dim_ordering";
    public static final String DIM_ORDERING_THEANO = "th";
    public static final String DIM_ORDERING_TENSORFLOW = "tf";
    public static final String KERAS_LOSS_MEAN_SQUARED_ERROR = "mean_squared_error";
    public static final String KERAS_LOSS_MSE = "mse";
    public static final String KERAS_LOSS_MEAN_ABSOLUTE_ERROR = "mean_absolute_error";
    public static final String KERAS_LOSS_MAE = "mae";
    public static final String KERAS_LOSS_MEAN_ABSOLUTE_PERCENTAGE_ERROR = "mean_absolute_percentage_error";
    public static final String KERAS_LOSS_MAPE = "mape";
    public static final String KERAS_LOSS_MEAN_SQUARED_LOGARITHMIC_ERROR = "mean_squared_logarithmic_error";
    public static final String KERAS_LOSS_MSLE = "msle";
    public static final String KERAS_LOSS_SQUARED_HINGE = "squared_hinge";
    public static final String KERAS_LOSS_HINGE = "hinge";
    public static final String KERAS_LOSS_BINARY_CROSSENTROPY = "binary_crossentropy";
    public static final String KERAS_LOSS_CATEGORICAL_CROSSENTROPY = "categorical_crossentropy";
    public static final String KERAS_LOSS_SPARSE_CATEGORICAL_CROSSENTROPY = "sparse_categorical_crossentropy";
    public static final String KERAS_LOSS_KULLBACK_LEIBLER_DIVERGENCE = "kullback_leibler_divergence";
    public static final String KERAS_LOSS_KLD = "kld";
    public static final String KERAS_LOSS_POISSON = "poisson";
    public static final String KERAS_LOSS_COSINE_PROXIMITY = "cosine_proximity";
    public static final String LAYER_FIELD_LAYER = "layer";
    public static final Map<String, Class<? extends KerasLayer>> customLayers = new HashMap<String, Class<? extends KerasLayer>>();
    protected String className;
    protected String layerName;
    protected int[] inputShape;
    protected DimOrder dimOrder;
    protected List<String> inboundLayerNames;
    protected Layer layer;
    protected GraphVertex vertex;
    protected Map<String, INDArray> weights;
    protected double weightL1Regularization = 0.0;
    protected double weightL2Regularization = 0.0;
    protected double dropout = 1.0;

    public static KerasLayer getKerasLayerFromConfig(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        return KerasLayer.getKerasLayerFromConfig(layerConfig, false);
    }

    public static KerasLayer getKerasLayerFromConfig(Map<String, Object> layerConfig, boolean enforceTrainingConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        String layerClassName = KerasLayer.getClassNameFromConfig(layerConfig);
        if (layerClassName.equals(LAYER_CLASS_NAME_TIME_DISTRIBUTED)) {
            layerConfig = KerasLayer.getTimeDistributedLayerConfig(layerConfig);
            layerClassName = KerasLayer.getClassNameFromConfig(layerConfig);
        }
        KerasLayer layer = null;
        switch (layerClassName) {
            case "Activation": {
                layer = new KerasActivation(layerConfig, enforceTrainingConfig);
                break;
            }
            case "Dropout": {
                layer = new KerasDropout(layerConfig, enforceTrainingConfig);
                break;
            }
            case "Dense": 
            case "TimeDistributedDense": {
                layer = new KerasDense(layerConfig, enforceTrainingConfig);
                break;
            }
            case "LSTM": {
                layer = new KerasLstm(layerConfig, enforceTrainingConfig);
                break;
            }
            case "Convolution2D": {
                layer = new KerasConvolution(layerConfig, enforceTrainingConfig);
                break;
            }
            case "MaxPooling2D": 
            case "AveragePooling2D": {
                layer = new KerasPooling(layerConfig, enforceTrainingConfig);
                break;
            }
            case "GlobalAveragePooling1D": 
            case "GlobalAveragePooling2D": 
            case "GlobalMaxPooling1D": 
            case "GlobalMaxPooling2D": {
                layer = new KerasGlobalPooling(layerConfig, enforceTrainingConfig);
                break;
            }
            case "BatchNormalization": {
                layer = new KerasBatchNormalization(layerConfig, enforceTrainingConfig);
                break;
            }
            case "Embedding": {
                layer = new KerasEmbedding(layerConfig, enforceTrainingConfig);
                break;
            }
            case "InputLayer": {
                layer = new KerasInput(layerConfig, enforceTrainingConfig);
                break;
            }
            case "Merge": {
                layer = new KerasMerge(layerConfig, enforceTrainingConfig);
                break;
            }
            case "Flatten": {
                layer = new KerasFlatten(layerConfig, enforceTrainingConfig);
                break;
            }
            case "ZeroPadding2D": {
                layer = new KerasZeroPadding(layerConfig, enforceTrainingConfig);
                break;
            }
            default: {
                Class<? extends KerasLayer> customConfig = customLayers.get(layerClassName);
                if (customConfig == null) {
                    throw new UnsupportedKerasConfigurationException("Unsupported keras layer type " + layerClassName);
                }
                try {
                    Constructor<? extends KerasLayer> constructor = customConfig.getConstructor(Map.class);
                    layer = constructor.newInstance(layerConfig);
                    break;
                }
                catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
        }
        return layer;
    }

    public static void registerCustomLayer(String layerName, Class<? extends KerasLayer> configClass) {
        customLayers.put(layerName, configClass);
    }

    protected KerasLayer() {
        this.className = null;
        this.layerName = null;
        this.inputShape = null;
        this.dimOrder = DimOrder.NONE;
        this.inboundLayerNames = new ArrayList<String>();
        this.layer = null;
        this.vertex = null;
        this.weights = null;
    }

    protected KerasLayer(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this(layerConfig, true);
    }

    protected KerasLayer(Map<String, Object> layerConfig, boolean enforceTrainingConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this.className = KerasLayer.getClassNameFromConfig(layerConfig);
        if (this.className == null) {
            throw new InvalidKerasConfigurationException("Keras layer class name is missing");
        }
        this.layerName = this.getLayerNameFromConfig(layerConfig);
        if (this.layerName == null) {
            throw new InvalidKerasConfigurationException("Keras layer class name is missing");
        }
        this.inputShape = this.getInputShapeFromConfig(layerConfig);
        this.dimOrder = this.getDimOrderFromConfig(layerConfig);
        this.inboundLayerNames = KerasLayer.getInboundLayerNamesFromConfig(layerConfig);
        this.layer = null;
        this.vertex = null;
        this.weights = null;
        this.weightL1Regularization = KerasLayer.getWeightL1RegularizationFromConfig(layerConfig, enforceTrainingConfig);
        this.weightL2Regularization = KerasLayer.getWeightL2RegularizationFromConfig(layerConfig, enforceTrainingConfig);
        this.dropout = this.getDropoutFromConfig(layerConfig);
        KerasLayer.checkForUnsupportedConfigurations(layerConfig, enforceTrainingConfig);
    }

    public static void checkForUnsupportedConfigurations(Map<String, Object> layerConfig, boolean enforceTrainingConfig) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        KerasLayer.getBiasL1RegularizationFromConfig(layerConfig, enforceTrainingConfig);
        KerasLayer.getBiasL2RegularizationFromConfig(layerConfig, enforceTrainingConfig);
        Map<String, Object> innerConfig = KerasLayer.getInnerLayerConfigFromConfig(layerConfig);
        if (innerConfig.containsKey(LAYER_FIELD_W_REGULARIZER)) {
            KerasLayer.checkForUnknownRegularizer((Map)innerConfig.get(LAYER_FIELD_W_REGULARIZER), enforceTrainingConfig);
        }
        if (innerConfig.containsKey(LAYER_FIELD_B_REGULARIZER)) {
            KerasLayer.checkForUnknownRegularizer((Map)innerConfig.get(LAYER_FIELD_B_REGULARIZER), enforceTrainingConfig);
        }
    }

    public String getClassName() {
        return this.className;
    }

    public String getLayerName() {
        return this.layerName;
    }

    public int[] getInputShape() {
        if (this.inputShape == null) {
            return null;
        }
        return (int[])this.inputShape.clone();
    }

    public DimOrder getDimOrder() {
        return this.dimOrder;
    }

    public void setDimOrder(DimOrder dimOrder) {
        this.dimOrder = dimOrder;
    }

    public List<String> getInboundLayerNames() {
        if (this.inboundLayerNames == null) {
            this.inboundLayerNames = new ArrayList<String>();
        }
        return this.inboundLayerNames;
    }

    public void setInboundLayerNames(List<String> inboundLayerNames) {
        this.inboundLayerNames = new ArrayList<String>(inboundLayerNames);
    }

    public int getNumParams() {
        return 0;
    }

    public boolean usesRegularization() {
        return this.weightL1Regularization > 0.0 || this.weightL2Regularization > 0.0 || this.dropout < 1.0;
    }

    public void setWeights(Map<String, INDArray> weights) throws InvalidKerasConfigurationException {
    }

    public void copyWeightsToLayer(org.deeplearning4j.nn.api.Layer layer) throws InvalidKerasConfigurationException {
        if (this.getNumParams() > 0) {
            String dl4jLayerName = layer.conf().getLayer().getLayerName();
            String kerasLayerName = this.getLayerName();
            String msg = "Error when attempting to copy weights from Keras layer " + kerasLayerName + " to DL4J layer " + dl4jLayerName;
            if (this.weights == null) {
                throw new InvalidKerasConfigurationException(msg + "(weights is null)");
            }
            HashSet paramsInLayer = new HashSet(layer.paramTable().keySet());
            HashSet<String> paramsInKerasLayer = new HashSet<String>(this.weights.keySet());
            paramsInLayer.removeAll(paramsInKerasLayer);
            Iterator<Object> iterator = paramsInLayer.iterator();
            if (iterator.hasNext()) {
                String paramName = (String)iterator.next();
                throw new InvalidKerasConfigurationException(msg + "(no stored weights for parameter " + paramName + ")");
            }
            paramsInKerasLayer.removeAll(layer.paramTable().keySet());
            iterator = paramsInKerasLayer.iterator();
            if (iterator.hasNext()) {
                String paramName = (String)iterator.next();
                throw new InvalidKerasConfigurationException(msg + "(found no parameter named " + paramName + ")");
            }
            for (String paramName : layer.paramTable().keySet()) {
                layer.setParam(paramName, this.weights.get(paramName));
            }
        }
    }

    public boolean isLayer() {
        return this.layer != null;
    }

    public Layer getLayer() {
        return this.layer;
    }

    public boolean isVertex() {
        return this.vertex != null;
    }

    public GraphVertex getVertex() {
        return this.vertex;
    }

    public boolean isInputPreProcessor() {
        return false;
    }

    public InputPreProcessor getInputPreprocessor(InputType ... inputType) throws InvalidKerasConfigurationException {
        InputPreProcessor preprocessor = null;
        if (this.layer != null) {
            if (inputType.length > 1) {
                throw new InvalidKerasConfigurationException("Keras layer of type \"" + this.className + "\" accepts only one input");
            }
            preprocessor = this.layer.getPreProcessorForInputType(inputType[0]);
        }
        return preprocessor;
    }

    public InputType getOutputType(InputType ... inputType) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        throw new UnsupportedOperationException("Cannot determine output type for Keras layer of type " + this.className);
    }

    public boolean isValidInboundLayer() throws InvalidKerasConfigurationException {
        return this.getLayer() != null || this.getVertex() != null || this.getInputPreprocessor(new InputType[0]) != null || this.className.equals(LAYER_CLASS_NAME_INPUT);
    }

    public static IActivation mapActivation(String kerasActivation) throws UnsupportedKerasConfigurationException {
        ActivationSoftmax dl4jActivation = null;
        switch (kerasActivation) {
            case "softmax": {
                dl4jActivation = new ActivationSoftmax();
                break;
            }
            case "softplus": {
                dl4jActivation = new ActivationSoftPlus();
                break;
            }
            case "softsign": {
                dl4jActivation = new ActivationSoftSign();
                break;
            }
            case "relu": {
                dl4jActivation = new ActivationReLU();
                break;
            }
            case "tanh": {
                dl4jActivation = new ActivationTanH();
                break;
            }
            case "sigmoid": {
                dl4jActivation = new ActivationSigmoid();
                break;
            }
            case "hard_sigmoid": {
                dl4jActivation = new ActivationHardSigmoid();
                break;
            }
            case "linear": {
                dl4jActivation = new ActivationIdentity();
                break;
            }
            default: {
                throw new UnsupportedKerasConfigurationException("Unknown Keras activation function " + kerasActivation);
            }
        }
        return dl4jActivation;
    }

    public static WeightInit mapWeightInitialization(String kerasInit) throws UnsupportedKerasConfigurationException {
        WeightInit init = WeightInit.XAVIER;
        if (kerasInit != null) {
            switch (kerasInit) {
                case "glorot_normal": {
                    init = WeightInit.XAVIER;
                    break;
                }
                case "glorot_uniform": {
                    init = WeightInit.XAVIER_UNIFORM;
                    break;
                }
                case "he_normal": {
                    init = WeightInit.RELU;
                    break;
                }
                case "he_uniform": {
                    init = WeightInit.RELU_UNIFORM;
                    break;
                }
                case "zero": {
                    init = WeightInit.ZERO;
                    break;
                }
                default: {
                    throw new UnsupportedKerasConfigurationException("Unknown keras weight initializer " + kerasInit);
                }
            }
        }
        return init;
    }

    public static LossFunctions.LossFunction mapLossFunction(String kerasLoss) throws UnsupportedKerasConfigurationException {
        LossFunctions.LossFunction dl4jLoss = LossFunctions.LossFunction.SQUARED_LOSS;
        switch (kerasLoss) {
            case "mean_squared_error": 
            case "mse": {
                dl4jLoss = LossFunctions.LossFunction.SQUARED_LOSS;
                break;
            }
            case "mean_absolute_error": 
            case "mae": {
                dl4jLoss = LossFunctions.LossFunction.MEAN_ABSOLUTE_ERROR;
                break;
            }
            case "mean_absolute_percentage_error": 
            case "mape": {
                dl4jLoss = LossFunctions.LossFunction.MEAN_ABSOLUTE_PERCENTAGE_ERROR;
                break;
            }
            case "mean_squared_logarithmic_error": 
            case "msle": {
                dl4jLoss = LossFunctions.LossFunction.MEAN_SQUARED_LOGARITHMIC_ERROR;
                break;
            }
            case "squared_hinge": {
                dl4jLoss = LossFunctions.LossFunction.SQUARED_HINGE;
                break;
            }
            case "hinge": {
                dl4jLoss = LossFunctions.LossFunction.HINGE;
                break;
            }
            case "binary_crossentropy": {
                dl4jLoss = LossFunctions.LossFunction.XENT;
                break;
            }
            case "sparse_categorical_crossentropy": {
                log.warn("Sparse cross entropy not implemented, using multiclass cross entropy instead.");
            }
            case "categorical_crossentropy": {
                dl4jLoss = LossFunctions.LossFunction.MCXENT;
                break;
            }
            case "kullback_leibler_divergence": 
            case "kld": {
                dl4jLoss = LossFunctions.LossFunction.KL_DIVERGENCE;
                break;
            }
            case "poisson": {
                dl4jLoss = LossFunctions.LossFunction.POISSON;
                break;
            }
            case "cosine_proximity": {
                dl4jLoss = LossFunctions.LossFunction.COSINE_PROXIMITY;
                break;
            }
            default: {
                throw new UnsupportedKerasConfigurationException("Unknown Keras loss function " + kerasLoss);
            }
        }
        return dl4jLoss;
    }

    public static PoolingType mapPoolingType(String className) throws UnsupportedKerasConfigurationException {
        PoolingType poolingType;
        switch (className) {
            case "MaxPooling2D": 
            case "GlobalMaxPooling1D": 
            case "GlobalMaxPooling2D": {
                poolingType = PoolingType.MAX;
                break;
            }
            case "AveragePooling2D": 
            case "GlobalAveragePooling1D": 
            case "GlobalAveragePooling2D": {
                poolingType = PoolingType.AVG;
                break;
            }
            default: {
                throw new UnsupportedKerasConfigurationException("Unsupported Keras pooling layer " + className);
            }
        }
        return poolingType;
    }

    public static int[] mapPoolingDimensions(String className) throws UnsupportedKerasConfigurationException {
        int[] dimensions;
        switch (className) {
            case "GlobalMaxPooling1D": 
            case "GlobalAveragePooling1D": {
                dimensions = new int[]{2};
                break;
            }
            case "GlobalMaxPooling2D": 
            case "GlobalAveragePooling2D": {
                dimensions = new int[]{2, 3};
                break;
            }
            default: {
                throw new UnsupportedKerasConfigurationException("Unsupported Keras pooling layer " + className);
            }
        }
        return dimensions;
    }

    public static String getClassNameFromConfig(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException {
        if (!layerConfig.containsKey(LAYER_FIELD_CLASS_NAME)) {
            throw new InvalidKerasConfigurationException("Field class_name missing from layer config");
        }
        return (String)layerConfig.get(LAYER_FIELD_CLASS_NAME);
    }

    public static Map<String, Object> getTimeDistributedLayerConfig(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException {
        if (!layerConfig.containsKey(LAYER_FIELD_CLASS_NAME)) {
            throw new InvalidKerasConfigurationException("Field class_name missing from layer config");
        }
        if (!layerConfig.get(LAYER_FIELD_CLASS_NAME).equals(LAYER_CLASS_NAME_TIME_DISTRIBUTED)) {
            throw new InvalidKerasConfigurationException("Expected TimeDistributed layer, found " + (String)layerConfig.get(LAYER_FIELD_CLASS_NAME));
        }
        if (!layerConfig.containsKey(LAYER_FIELD_CONFIG)) {
            throw new InvalidKerasConfigurationException("Field config missing from layer config");
        }
        Map<String, Object> outerConfig = KerasLayer.getInnerLayerConfigFromConfig(layerConfig);
        Map innerLayer = (Map)outerConfig.get(LAYER_FIELD_LAYER);
        layerConfig.put(LAYER_FIELD_CLASS_NAME, innerLayer.get(LAYER_FIELD_CLASS_NAME));
        layerConfig.put(LAYER_FIELD_NAME, innerLayer.get(LAYER_FIELD_CLASS_NAME));
        Map<String, Object> innerConfig = KerasLayer.getInnerLayerConfigFromConfig(innerLayer);
        outerConfig.putAll(innerConfig);
        outerConfig.remove(LAYER_FIELD_LAYER);
        return layerConfig;
    }

    public static Map<String, Object> getInnerLayerConfigFromConfig(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException {
        if (!layerConfig.containsKey(LAYER_FIELD_CONFIG)) {
            throw new InvalidKerasConfigurationException("Field config missing from layer config");
        }
        return (Map)layerConfig.get(LAYER_FIELD_CONFIG);
    }

    protected String getLayerNameFromConfig(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayer.getInnerLayerConfigFromConfig(layerConfig);
        if (!innerConfig.containsKey(LAYER_FIELD_NAME)) {
            throw new InvalidKerasConfigurationException("Field name missing from layer config");
        }
        return (String)innerConfig.get(LAYER_FIELD_NAME);
    }

    private int[] getInputShapeFromConfig(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayer.getInnerLayerConfigFromConfig(layerConfig);
        if (!innerConfig.containsKey(LAYER_FIELD_BATCH_INPUT_SHAPE)) {
            return null;
        }
        List batchInputShape = (List)innerConfig.get(LAYER_FIELD_BATCH_INPUT_SHAPE);
        int[] inputShape = new int[batchInputShape.size() - 1];
        for (int i = 1; i < batchInputShape.size(); ++i) {
            inputShape[i - 1] = batchInputShape.get(i) != null ? (Integer)batchInputShape.get(i) : 0;
        }
        return inputShape;
    }

    private DimOrder getDimOrderFromConfig(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayer.getInnerLayerConfigFromConfig(layerConfig);
        DimOrder dimOrder = DimOrder.NONE;
        if (innerConfig.containsKey(LAYER_FIELD_DIM_ORDERING)) {
            String dimOrderStr;
            switch (dimOrderStr = (String)innerConfig.get(LAYER_FIELD_DIM_ORDERING)) {
                case "tf": {
                    dimOrder = DimOrder.TENSORFLOW;
                    break;
                }
                case "th": {
                    dimOrder = DimOrder.THEANO;
                    break;
                }
                default: {
                    log.warn("Keras layer has unknown Keras dimension order: " + (Object)((Object)dimOrder));
                }
            }
        }
        return dimOrder;
    }

    public static List<String> getInboundLayerNamesFromConfig(Map<String, Object> layerConfig) {
        List inboundNodes;
        ArrayList<String> inboundLayerNames = new ArrayList<String>();
        if (layerConfig.containsKey(LAYER_FIELD_INBOUND_NODES) && (inboundNodes = (List)layerConfig.get(LAYER_FIELD_INBOUND_NODES)).size() > 0) {
            inboundNodes = (List)inboundNodes.get(0);
            for (Object o : inboundNodes) {
                String nodeName = (String)((List)o).get(0);
                inboundLayerNames.add(nodeName);
            }
        }
        return inboundLayerNames;
    }

    public static int getNOutFromConfig(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException {
        int nOut;
        Map<String, Object> innerConfig = KerasLayer.getInnerLayerConfigFromConfig(layerConfig);
        if (innerConfig.containsKey(LAYER_FIELD_OUTPUT_DIM)) {
            nOut = (Integer)innerConfig.get(LAYER_FIELD_OUTPUT_DIM);
        } else if (innerConfig.containsKey(LAYER_FIELD_NB_FILTER)) {
            nOut = (Integer)innerConfig.get(LAYER_FIELD_NB_FILTER);
        } else {
            throw new InvalidKerasConfigurationException("Could not determine number of outputs for layer: no output_dim or nb_filter field found");
        }
        return nOut;
    }

    protected double getDropoutFromConfig(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayer.getInnerLayerConfigFromConfig(layerConfig);
        double dropout = 1.0;
        if (innerConfig.containsKey(LAYER_FIELD_DROPOUT)) {
            dropout = 1.0 - (Double)innerConfig.get(LAYER_FIELD_DROPOUT);
        } else if (innerConfig.containsKey(LAYER_FIELD_DROPOUT_W)) {
            dropout = 1.0 - (Double)innerConfig.get(LAYER_FIELD_DROPOUT_W);
        }
        return dropout;
    }

    protected IActivation getActivationFromConfig(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayer.getInnerLayerConfigFromConfig(layerConfig);
        if (!innerConfig.containsKey(LAYER_FIELD_ACTIVATION)) {
            throw new InvalidKerasConfigurationException("Keras layer is missing activation field");
        }
        return KerasLayer.mapActivation((String)innerConfig.get(LAYER_FIELD_ACTIVATION));
    }

    protected WeightInit getWeightInitFromConfig(Map<String, Object> layerConfig, boolean enforceTrainingConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        WeightInit init;
        Map<String, Object> innerConfig = KerasLayer.getInnerLayerConfigFromConfig(layerConfig);
        if (!innerConfig.containsKey(LAYER_FIELD_INIT)) {
            throw new InvalidKerasConfigurationException("Keras layer is missing init field");
        }
        String kerasInit = (String)innerConfig.get(LAYER_FIELD_INIT);
        try {
            init = KerasLayer.mapWeightInitialization(kerasInit);
        }
        catch (UnsupportedKerasConfigurationException e) {
            if (enforceTrainingConfig) {
                throw e;
            }
            init = WeightInit.XAVIER;
            log.warn("Unknown weight initializer " + kerasInit + " (Using XAVIER instead).");
        }
        return init;
    }

    public static double getWeightL1RegularizationFromConfig(Map<String, Object> layerConfig, boolean willBeTrained) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        Map regularizerConfig;
        Map<String, Object> innerConfig = KerasLayer.getInnerLayerConfigFromConfig(layerConfig);
        if (innerConfig.containsKey(LAYER_FIELD_W_REGULARIZER) && (regularizerConfig = (Map)innerConfig.get(LAYER_FIELD_W_REGULARIZER)) != null && regularizerConfig.containsKey(REGULARIZATION_TYPE_L1)) {
            return (Double)regularizerConfig.get(REGULARIZATION_TYPE_L1);
        }
        return 0.0;
    }

    public static double getWeightL2RegularizationFromConfig(Map<String, Object> layerConfig, boolean willBeTrained) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        Map regularizerConfig;
        Map<String, Object> innerConfig = KerasLayer.getInnerLayerConfigFromConfig(layerConfig);
        if (innerConfig.containsKey(LAYER_FIELD_W_REGULARIZER) && (regularizerConfig = (Map)innerConfig.get(LAYER_FIELD_W_REGULARIZER)) != null && regularizerConfig.containsKey(REGULARIZATION_TYPE_L2)) {
            return (Double)regularizerConfig.get(REGULARIZATION_TYPE_L2);
        }
        return 0.0;
    }

    public static double getBiasL1RegularizationFromConfig(Map<String, Object> layerConfig, boolean willBeTrained) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        Map regularizerConfig;
        Map<String, Object> innerConfig = KerasLayer.getInnerLayerConfigFromConfig(layerConfig);
        if (innerConfig.containsKey(LAYER_FIELD_B_REGULARIZER) && (regularizerConfig = (Map)innerConfig.get(LAYER_FIELD_B_REGULARIZER)) != null && regularizerConfig.containsKey(REGULARIZATION_TYPE_L1)) {
            throw new UnsupportedKerasConfigurationException("L1 regularization for bias parameter not supported");
        }
        return 0.0;
    }

    private static double getBiasL2RegularizationFromConfig(Map<String, Object> layerConfig, boolean willBeTrained) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        Map regularizerConfig;
        Map<String, Object> innerConfig = KerasLayer.getInnerLayerConfigFromConfig(layerConfig);
        if (innerConfig.containsKey(LAYER_FIELD_B_REGULARIZER) && (regularizerConfig = (Map)innerConfig.get(LAYER_FIELD_B_REGULARIZER)) != null && regularizerConfig.containsKey(REGULARIZATION_TYPE_L2)) {
            throw new UnsupportedKerasConfigurationException("L2 regularization for bias parameter not supported");
        }
        return 0.0;
    }

    private static void checkForUnknownRegularizer(Map<String, Object> regularizerConfig, boolean enforceTrainingConfig) throws UnsupportedKerasConfigurationException {
        if (regularizerConfig != null) {
            for (String field : regularizerConfig.keySet()) {
                if (field.equals(REGULARIZATION_TYPE_L1) || field.equals(REGULARIZATION_TYPE_L2) || field.equals(LAYER_FIELD_NAME)) continue;
                if (enforceTrainingConfig) {
                    throw new UnsupportedKerasConfigurationException("Unknown regularization field " + field);
                }
                log.warn("Ignoring unknown regularization field " + field);
            }
        }
    }

    public static int[] getStrideFromConfig(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayer.getInnerLayerConfigFromConfig(layerConfig);
        int[] strides = null;
        if (innerConfig.containsKey(LAYER_FIELD_SUBSAMPLE)) {
            List stridesList = (List)innerConfig.get(LAYER_FIELD_SUBSAMPLE);
            strides = ArrayUtil.toArray((List)stridesList);
        } else if (innerConfig.containsKey(LAYER_FIELD_STRIDES)) {
            List stridesList = (List)innerConfig.get(LAYER_FIELD_STRIDES);
            strides = ArrayUtil.toArray((List)stridesList);
        } else {
            throw new InvalidKerasConfigurationException("Could not determine layer stride: no subsample or strides field found");
        }
        return strides;
    }

    public static int[] getKernelSizeFromConfig(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayer.getInnerLayerConfigFromConfig(layerConfig);
        int[] kernelSize = null;
        if (innerConfig.containsKey(LAYER_FIELD_NB_ROW) && innerConfig.containsKey(LAYER_FIELD_NB_COL)) {
            ArrayList<Integer> kernelSizeList = new ArrayList<Integer>();
            kernelSizeList.add((Integer)innerConfig.get(LAYER_FIELD_NB_ROW));
            kernelSizeList.add((Integer)innerConfig.get(LAYER_FIELD_NB_COL));
            kernelSize = ArrayUtil.toArray(kernelSizeList);
        } else if (innerConfig.containsKey(LAYER_FIELD_POOL_SIZE)) {
            List kernelSizeList = (List)innerConfig.get(LAYER_FIELD_POOL_SIZE);
            kernelSize = ArrayUtil.toArray((List)kernelSizeList);
        } else {
            throw new InvalidKerasConfigurationException("Could not determine kernel size: no nb_row, nb_col, or pool_size field found");
        }
        return kernelSize;
    }

    public static ConvolutionMode getConvolutionModeFromConfig(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayer.getInnerLayerConfigFromConfig(layerConfig);
        if (!innerConfig.containsKey(LAYER_FIELD_BORDER_MODE)) {
            throw new InvalidKerasConfigurationException("Could not determine convolution border mode: no border_mode field found");
        }
        String borderMode = (String)innerConfig.get(LAYER_FIELD_BORDER_MODE);
        ConvolutionMode convolutionMode = null;
        switch (borderMode) {
            case "same": {
                convolutionMode = ConvolutionMode.Same;
                break;
            }
            case "valid": 
            case "full": {
                convolutionMode = ConvolutionMode.Truncate;
                break;
            }
            default: {
                throw new UnsupportedKerasConfigurationException("Unsupported convolution border mode: " + borderMode);
            }
        }
        return convolutionMode;
    }

    public int[] getPaddingFromBorderModeConfig(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayer.getInnerLayerConfigFromConfig(layerConfig);
        int[] padding = null;
        if (!innerConfig.containsKey(LAYER_FIELD_BORDER_MODE)) {
            throw new InvalidKerasConfigurationException("Could not determine convolution border mode: no border_mode field found");
        }
        String borderMode = (String)innerConfig.get(LAYER_FIELD_BORDER_MODE);
        if (borderMode == LAYER_FIELD_BORDER_MODE) {
            padding = KerasLayer.getKernelSizeFromConfig(layerConfig);
            int i = 0;
            while (i < padding.length) {
                int n = i++;
                padding[n] = padding[n] - 1;
            }
        }
        return padding;
    }

    public static enum DimOrder {
        NONE,
        THEANO,
        TENSORFLOW;

    }
}

