/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.factory;

import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.LayerFactory;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.BasePretrainNetwork;
import org.deeplearning4j.nn.conf.layers.ConvolutionDownSampleLayer;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.RecursiveAutoEncoder;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.layers.factory.ConvolutionLayerFactory;
import org.deeplearning4j.nn.layers.factory.DefaultLayerFactory;
import org.deeplearning4j.nn.layers.factory.LSTMLayerFactory;
import org.deeplearning4j.nn.layers.factory.PretrainLayerFactory;
import org.deeplearning4j.nn.layers.factory.RecursiveAutoEncoderLayerFactory;
import org.deeplearning4j.nn.layers.factory.SubsampleLayerFactory;

public class LayerFactories {
    public static LayerFactory getFactory(NeuralNetConfiguration conf) {
        return LayerFactories.getFactory(conf.getLayer());
    }

    public static LayerFactory getFactory(Layer layer) {
        Class<?> clazz = layer.getClass();
        if (clazz.equals(ConvolutionDownSampleLayer.class)) {
            return new ConvolutionLayerFactory(clazz);
        }
        if (clazz.equals(LSTM.class)) {
            return new LSTMLayerFactory(LSTM.class);
        }
        if (RecursiveAutoEncoder.class.isAssignableFrom(clazz)) {
            return new RecursiveAutoEncoderLayerFactory(RecursiveAutoEncoder.class);
        }
        if (BasePretrainNetwork.class.isAssignableFrom(clazz)) {
            return new PretrainLayerFactory(clazz);
        }
        if (ConvolutionLayer.class.isAssignableFrom(clazz)) {
            return new ConvolutionLayerFactory(clazz);
        }
        if (SubsamplingLayer.class.isAssignableFrom(clazz)) {
            return new SubsampleLayerFactory(clazz);
        }
        return new DefaultLayerFactory(clazz);
    }

    public static Layer.Type typeForFactory(NeuralNetConfiguration conf) {
        LayerFactory layerFactory = LayerFactories.getFactory(conf);
        if (layerFactory instanceof ConvolutionLayerFactory || layerFactory instanceof SubsampleLayerFactory) {
            return Layer.Type.CONVOLUTIONAL;
        }
        if (layerFactory instanceof LSTMLayerFactory) {
            return Layer.Type.RECURRENT;
        }
        if (layerFactory instanceof RecursiveAutoEncoderLayerFactory) {
            return Layer.Type.RECURSIVE;
        }
        if (layerFactory instanceof DefaultLayerFactory || layerFactory instanceof PretrainLayerFactory) {
            return Layer.Type.FEED_FORWARD;
        }
        throw new IllegalArgumentException("Unknown layer type");
    }
}

