/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.factory;

import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import org.deeplearning4j.nn.api.LayerFactory;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.AutoEncoder;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.EmbeddingLayer;
import org.deeplearning4j.nn.conf.layers.GRU;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.LocalResponseNormalization;
import org.deeplearning4j.nn.layers.ActivationLayer;
import org.deeplearning4j.nn.layers.OutputLayer;
import org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingLayer;
import org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer;
import org.deeplearning4j.nn.layers.feedforward.rbm.RBM;
import org.deeplearning4j.nn.layers.normalization.BatchNormalization;
import org.deeplearning4j.nn.layers.recurrent.GravesBidirectionalLSTM;
import org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.optimize.api.IterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;

public class DefaultLayerFactory
implements LayerFactory {
    protected Layer layerConfig;

    public DefaultLayerFactory(Class<? extends Layer> layerConfig) {
        try {
            this.layerConfig = layerConfig.newInstance();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public <E extends org.deeplearning4j.nn.api.Layer> E create(NeuralNetConfiguration conf, Collection<IterationListener> iterationListeners, int index, INDArray layerParamsView, boolean initializeParams) {
        org.deeplearning4j.nn.api.Layer ret = this.getInstance(conf);
        ret.setListeners(iterationListeners);
        ret.setIndex(index);
        ret.setParamsViewArray(layerParamsView);
        ret.setParamTable(this.getParams(conf, layerParamsView, initializeParams));
        ret.setConf(conf);
        return (E)ret;
    }

    protected org.deeplearning4j.nn.api.Layer getInstance(NeuralNetConfiguration conf) {
        if (this.layerConfig instanceof org.deeplearning4j.nn.conf.layers.DenseLayer) {
            return new DenseLayer(conf);
        }
        if (this.layerConfig instanceof AutoEncoder) {
            return new org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder(conf);
        }
        if (this.layerConfig instanceof org.deeplearning4j.nn.conf.layers.RBM) {
            return new RBM(conf);
        }
        if (this.layerConfig instanceof GravesLSTM) {
            return new org.deeplearning4j.nn.layers.recurrent.GravesLSTM(conf);
        }
        if (this.layerConfig instanceof org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM) {
            return new GravesBidirectionalLSTM(conf);
        }
        if (this.layerConfig instanceof GRU) {
            return new org.deeplearning4j.nn.layers.recurrent.GRU(conf);
        }
        if (this.layerConfig instanceof org.deeplearning4j.nn.conf.layers.OutputLayer) {
            return new OutputLayer(conf);
        }
        if (this.layerConfig instanceof org.deeplearning4j.nn.conf.layers.RnnOutputLayer) {
            return new RnnOutputLayer(conf);
        }
        if (this.layerConfig instanceof ConvolutionLayer) {
            return new org.deeplearning4j.nn.layers.convolution.ConvolutionLayer(conf);
        }
        if (this.layerConfig instanceof org.deeplearning4j.nn.conf.layers.SubsamplingLayer) {
            return new SubsamplingLayer(conf);
        }
        if (this.layerConfig instanceof org.deeplearning4j.nn.conf.layers.BatchNormalization) {
            return new BatchNormalization(conf);
        }
        if (this.layerConfig instanceof LocalResponseNormalization) {
            return new org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization(conf);
        }
        if (this.layerConfig instanceof EmbeddingLayer) {
            return new org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingLayer(conf);
        }
        if (this.layerConfig instanceof org.deeplearning4j.nn.conf.layers.ActivationLayer) {
            return new ActivationLayer(conf);
        }
        throw new RuntimeException("unknown layer type: " + this.layerConfig);
    }

    protected Map<String, INDArray> getParams(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) {
        ParamInitializer init = this.initializer();
        Map<String, INDArray> params = Collections.synchronizedMap(new LinkedHashMap());
        init.init(params, conf, paramsView, initializeParams);
        return params;
    }

    @Override
    public ParamInitializer initializer() {
        return new DefaultParamInitializer();
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (!(o instanceof DefaultLayerFactory)) {
            return false;
        }
        DefaultLayerFactory that = (DefaultLayerFactory)o;
        return !(this.layerConfig == null ? that.layerConfig != null : !this.layerConfig.equals(that.layerConfig));
    }

    public int hashCode() {
        return this.layerConfig != null ? this.layerConfig.hashCode() : 0;
    }
}

