/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.multilayer;

import java.io.Serializable;
import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.Classifier;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseOutputLayer;
import org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingLayer;
import org.deeplearning4j.nn.layers.factory.LayerFactories;
import org.deeplearning4j.nn.layers.recurrent.BaseRecurrentLayer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.Solver;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.util.MultiLayerUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.util.FeatureUtil;
import org.nd4j.linalg.util.LinAlgExceptions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MultiLayerNetwork
implements Serializable,
Classifier,
Layer {
    private static final Logger log = LoggerFactory.getLogger(MultiLayerNetwork.class);
    protected Layer[] layers;
    protected INDArray input;
    protected INDArray labels;
    protected boolean initCalled = false;
    private Collection<IterationListener> listeners = new ArrayList<IterationListener>();
    protected NeuralNetConfiguration defaultConfiguration;
    protected MultiLayerConfiguration layerWiseConfigurations;
    protected Gradient gradient;
    protected double score;
    private INDArray params;
    protected INDArray mask;
    protected int layerIndex;
    protected transient Solver solver;

    public MultiLayerNetwork(MultiLayerConfiguration conf) {
        this.layerWiseConfigurations = conf;
        this.defaultConfiguration = conf.getConf(0).clone();
    }

    public MultiLayerNetwork(String conf, INDArray params) {
        this(MultiLayerConfiguration.fromJson(conf));
        this.init();
        this.setParameters(params);
    }

    public MultiLayerNetwork(MultiLayerConfiguration conf, INDArray params) {
        this(conf);
        this.init();
        this.setParameters(params);
    }

    protected void intializeConfigurations() {
        if (this.layerWiseConfigurations == null) {
            this.layerWiseConfigurations = new MultiLayerConfiguration.Builder().build();
        }
        if (this.layers == null) {
            this.layers = new Layer[this.getnLayers()];
        }
        if (this.defaultConfiguration == null) {
            this.defaultConfiguration = new NeuralNetConfiguration.Builder().build();
        }
    }

    public void pretrain(DataSetIterator iter) {
        if (!this.layerWiseConfigurations.isPretrain()) {
            return;
        }
        for (int i = 0; i < this.getnLayers(); ++i) {
            INDArray layerInput;
            org.nd4j.linalg.dataset.DataSet next;
            if (i == 0) {
                while (iter.hasNext()) {
                    next = (org.nd4j.linalg.dataset.DataSet)iter.next();
                    layerInput = this.getLayerWiseConfigurations().getInputPreProcess(i) != null ? this.getLayerWiseConfigurations().getInputPreProcess(i).preProcess(next.getFeatureMatrix(), this.layers[i]) : next.getFeatureMatrix();
                    this.setInput(layerInput);
                    if (this.getInput() == null || this.getLayers() == null) {
                        this.initializeLayers(this.input());
                    }
                    this.getLayers()[i].fit(this.input());
                    log.info("Training on layer " + (i + 1) + " with " + this.input().slices() + " examples");
                }
            } else {
                while (iter.hasNext()) {
                    next = (org.nd4j.linalg.dataset.DataSet)iter.next();
                    layerInput = next.getFeatureMatrix();
                    for (int j = 1; j <= i; ++j) {
                        layerInput = this.activationFromPrevLayer(j - 1, layerInput, true);
                    }
                    log.info("Training on layer " + (i + 1) + " with " + layerInput.slices() + " examples");
                    this.getLayers()[i].fit(layerInput);
                }
            }
            iter.reset();
        }
    }

    public void pretrain(INDArray input) {
        if (!this.layerWiseConfigurations.isPretrain()) {
            return;
        }
        INDArray layerInput = null;
        for (int i = 0; i < this.getnLayers() - 1; ++i) {
            layerInput = i == 0 ? (this.getLayerWiseConfigurations().getInputPreProcess(i) != null ? this.getLayerWiseConfigurations().getInputPreProcess(i).preProcess(input, this.layers[i]) : input) : this.activationFromPrevLayer(i - 1, layerInput, true);
            log.info("Training on layer " + (i + 1) + " with " + layerInput.slices() + " examples");
            this.getLayers()[i].fit(layerInput);
        }
    }

    @Override
    public int batchSize() {
        return this.input.slices();
    }

    @Override
    public NeuralNetConfiguration conf() {
        return this.defaultConfiguration;
    }

    @Override
    public void setConf(NeuralNetConfiguration conf) {
        throw new UnsupportedOperationException();
    }

    @Override
    public INDArray input() {
        return this.input;
    }

    @Override
    public void validateInput() {
    }

    @Override
    public ConvexOptimizer getOptimizer() {
        throw new UnsupportedOperationException();
    }

    @Override
    public INDArray getParam(String param) {
        int idx = param.indexOf("_");
        if (idx == -1) {
            throw new IllegalStateException("Invalid param key: not have layer separator: \"" + param + "\"");
        }
        int layerIdx = Integer.parseInt(param.substring(0, idx));
        String newKey = param.substring(idx + 1);
        return this.layers[layerIdx].getParam(newKey);
    }

    @Override
    public void initParams() {
        throw new UnsupportedOperationException();
    }

    @Override
    public Map<String, INDArray> paramTable() {
        LinkedHashMap<String, INDArray> allParams = new LinkedHashMap<String, INDArray>();
        for (int i = 0; i < this.layers.length; ++i) {
            Map<String, INDArray> paramMap = this.layers[i].paramTable();
            for (Map.Entry<String, INDArray> entry : paramMap.entrySet()) {
                String newKey = i + "_" + entry.getKey();
                allParams.put(newKey, entry.getValue());
            }
        }
        return allParams;
    }

    @Override
    public void setParamTable(Map<String, INDArray> paramTable) {
        throw new UnsupportedOperationException();
    }

    @Override
    public void setParam(String key, INDArray val) {
        int idx = key.indexOf("_");
        if (idx == -1) {
            throw new IllegalStateException("Invalid param key: not have layer separator: \"" + key + "\"");
        }
        int layerIdx = Integer.parseInt(key.substring(0, idx));
        String newKey = key.substring(idx + 1);
        this.layers[layerIdx].setParam(newKey, val);
    }

    public MultiLayerConfiguration getLayerWiseConfigurations() {
        return this.layerWiseConfigurations;
    }

    public void setLayerWiseConfigurations(MultiLayerConfiguration layerWiseConfigurations) {
        this.layerWiseConfigurations = layerWiseConfigurations;
    }

    public void initializeLayers(INDArray input) {
        if (input == null) {
            throw new IllegalArgumentException("Unable to initialize neuralNets with empty input");
        }
        this.input = input;
        if (input != null) {
            this.setInputMiniBatchSize(input.size(0));
        }
        if (!this.initCalled) {
            this.init();
        }
    }

    public void init() {
        int i;
        if (this.layerWiseConfigurations == null || this.layers == null) {
            this.intializeConfigurations();
        }
        if (this.initCalled) {
            return;
        }
        if (this.getnLayers() < 1) {
            throw new IllegalStateException("Unable to createComplex network neuralNets; number specified is less than 1");
        }
        if (this.layers == null || this.layers[0] == null) {
            if (this.layers == null) {
                this.layers = new Layer[this.getnLayers()];
            }
            for (i = 0; i < this.getnLayers(); ++i) {
                NeuralNetConfiguration conf = this.layerWiseConfigurations.getConf(i);
                this.layers[i] = LayerFactories.getFactory(conf).create(conf, this.listeners, i);
            }
            this.initCalled = true;
            this.initMask();
        }
        this.defaultConfiguration.clearVariables();
        for (i = 0; i < this.layers.length; ++i) {
            for (String s : this.layers[i].conf().variables()) {
                this.defaultConfiguration.addVariable(i + "_" + s);
            }
        }
        if (this.getLayerWiseConfigurations().isRedistributeParams()) {
            this.reDistributeParams();
        }
    }

    @Override
    public INDArray activate() {
        return this.getLayers()[this.getLayers().length - 1].activate();
    }

    public INDArray activate(int layer) {
        return this.getLayers()[layer].activate();
    }

    @Override
    public INDArray activate(INDArray input) {
        throw new UnsupportedOperationException();
    }

    public INDArray activate(int layer, INDArray input) {
        return this.getLayers()[layer].activate(input);
    }

    @Override
    public INDArray activationMean() {
        throw new UnsupportedOperationException();
    }

    public void reDistributeParams() {
        ArrayList<INDArray> params = new ArrayList<INDArray>();
        for (Layer l : this.layers) {
            INDArray paramsForL = l.params();
            params.add(paramsForL);
        }
        this.params = Nd4j.toFlattened((char)'f', params);
        int idx = 0;
        for (Layer l : this.layers) {
            INDArray paramsForL = l.params();
            params.add(paramsForL);
            int range = l.numParams();
            INDArray get = this.params.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)idx, (int)(range + idx))});
            if (get.length() < 1) continue;
            l.setParams(get);
            idx += range;
        }
    }

    public void initialize(org.nd4j.linalg.dataset.DataSet data) {
        this.setInput(data.getFeatureMatrix());
        this.feedForward(this.getInput());
        this.labels = data.getLabels();
        if (this.getOutputLayer() instanceof BaseOutputLayer) {
            BaseOutputLayer o = (BaseOutputLayer)this.getOutputLayer();
            o.setLabels(this.labels);
        }
    }

    public INDArray zFromPrevLayer(int curr, INDArray input, boolean training) {
        if (this.getLayerWiseConfigurations().getInputPreProcess(curr) != null) {
            input = this.getLayerWiseConfigurations().getInputPreProcess(curr).preProcess(input, this.layers[curr]);
        }
        INDArray ret = this.layers[curr].preOutput(input, training);
        return ret;
    }

    public INDArray activationFromPrevLayer(int curr, INDArray input, boolean training) {
        if (this.getLayerWiseConfigurations().getInputPreProcess(curr) != null) {
            input = this.getLayerWiseConfigurations().getInputPreProcess(curr).preProcess(input, this.layers[curr]);
        }
        INDArray ret = this.layers[curr].activate(input, training);
        return ret;
    }

    public List<INDArray> computeZ(boolean training) {
        INDArray currInput = this.input;
        ArrayList<INDArray> activations = new ArrayList<INDArray>();
        activations.add(currInput);
        for (int i = 0; i < this.layers.length; ++i) {
            currInput = this.zFromPrevLayer(i, currInput, training);
            activations.add(currInput);
        }
        return activations;
    }

    public List<INDArray> computeZ(INDArray input, boolean training) {
        if (input == null) {
            throw new IllegalStateException("Unable to perform feed forward; no input found");
        }
        if (this.getLayerWiseConfigurations().getInputPreProcess(0) != null) {
            this.setInput(this.getLayerWiseConfigurations().getInputPreProcess(0).preProcess(input, this.layers[0]));
        } else {
            this.setInput(input);
        }
        return this.computeZ(training);
    }

    public List<INDArray> feedForward(INDArray input, boolean train) {
        this.setInput(input);
        return this.feedForward(train);
    }

    public List<INDArray> feedForward(boolean train) {
        INDArray currInput = this.input;
        ArrayList<INDArray> activations = new ArrayList<INDArray>();
        activations.add(currInput);
        for (int i = 0; i < this.layers.length; ++i) {
            currInput = this.activationFromPrevLayer(i, currInput, train);
            activations.add(currInput);
        }
        return activations;
    }

    public List<INDArray> feedForward() {
        return this.feedForward(false);
    }

    public Pair<List<INDArray>, List<INDArray>> feedForwardActivationsAndDerivatives(boolean training) {
        int i;
        INDArray currInput = this.input;
        ArrayList<INDArray> activations = new ArrayList<INDArray>();
        ArrayList<Object> derivatives = new ArrayList<Object>();
        activations.add(currInput);
        for (i = 0; i < this.layers.length; ++i) {
            currInput = this.zFromPrevLayer(i, currInput, training);
            if (this.layers[i].conf().getLayer().getActivationFunction().equals("softmax")) {
                activations.add(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", currInput.dup()), new int[]{1}));
                continue;
            }
            activations.add(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.layerWiseConfigurations.getConf(i).getLayer().getActivationFunction(), currInput)));
        }
        currInput = this.input;
        for (i = 0; i < this.layers.length; ++i) {
            currInput = this.zFromPrevLayer(i, currInput, training);
            INDArray dup = currInput.dup();
            if (this.layers[i].conf().getLayer().getActivationFunction().equals("softmax")) {
                derivatives.add(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.layerWiseConfigurations.getConf(i).getLayer().getActivationFunction(), dup).derivative(), new int[]{1}));
                continue;
            }
            derivatives.add(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.layerWiseConfigurations.getConf(i).getLayer().getActivationFunction(), dup).derivative()));
        }
        derivatives.add(derivatives.get(this.layers.length - 1));
        return new Pair<List<INDArray>, List<INDArray>>(activations, derivatives);
    }

    public List<INDArray> feedForward(INDArray input) {
        if (input == null) {
            throw new IllegalStateException("Unable to perform feed forward; no input found");
        }
        if (this.getLayerWiseConfigurations().getInputPreProcess(0) != null) {
            this.setInput(this.getLayerWiseConfigurations().getInputPreProcess(0).preProcess(input, this.layers[0]));
        } else {
            this.setInput(input);
        }
        return this.feedForward();
    }

    @Override
    public Gradient gradient() {
        return this.gradient;
    }

    @Override
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair<Gradient, Double>(this.gradient(), this.score());
    }

    protected List<INDArray> computeDeltasR(INDArray v) {
        int i;
        ArrayList<INDArray> deltaRet = new ArrayList<INDArray>();
        INDArray[] deltas = new INDArray[this.getnLayers() + 1];
        List<INDArray> activations = this.feedForward();
        List<INDArray> rActivations = this.feedForwardR(activations, v);
        ArrayList<INDArray> weights = new ArrayList<INDArray>();
        ArrayList<INDArray> biases = new ArrayList<INDArray>();
        ArrayList<String> activationFunctions = new ArrayList<String>();
        for (int j = 0; j < this.getLayers().length; ++j) {
            weights.add(this.getLayers()[j].getParam("W"));
            biases.add(this.getLayers()[j].getParam("b"));
            activationFunctions.add(this.getLayers()[j].conf().getLayer().getActivationFunction());
        }
        INDArray rix = rActivations.get(rActivations.size() - 1).divi((Number)this.input.slices());
        LinAlgExceptions.assertValidNum((INDArray)rix);
        for (i = this.getnLayers() - 1; i >= 0; --i) {
            deltas[i] = activations.get(i).transpose().mmul(rix);
            if (i <= 0) continue;
            rix = rix.mmul(((INDArray)weights.get(i)).addRowVector((INDArray)biases.get(i)).transpose()).muli(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform((String)activationFunctions.get(i - 1), activations.get(i)).derivative()));
        }
        for (i = 0; i < deltas.length - 1; ++i) {
            if (this.defaultConfiguration.isConstrainGradientToUnitNorm()) {
                int[] nArray = new int[]{Integer.MAX_VALUE};
                double sum = deltas[i].sum(nArray).getDouble(0);
                if (sum > 0.0) {
                    deltaRet.add(deltas[i].div(deltas[i].norm2(new int[]{Integer.MAX_VALUE})));
                } else {
                    deltaRet.add(deltas[i]);
                }
            } else {
                deltaRet.add(deltas[i]);
            }
            LinAlgExceptions.assertValidNum((INDArray)((INDArray)deltaRet.get(i)));
        }
        return deltaRet;
    }

    public void dampingUpdate(double rho, double boost, double decrease) {
        if (rho < 0.25 || Double.isNaN(rho)) {
            this.layerWiseConfigurations.setDampingFactor(this.getLayerWiseConfigurations().getDampingFactor() * boost);
        } else if (rho > 0.75) {
            this.layerWiseConfigurations.setDampingFactor(this.getLayerWiseConfigurations().getDampingFactor() * decrease);
        }
    }

    public double reductionRatio(INDArray p, double currScore, double score, INDArray gradient) {
        double currentDamp = this.layerWiseConfigurations.getDampingFactor();
        this.layerWiseConfigurations.setDampingFactor(0.0);
        INDArray denom = this.getBackPropRGradient(p);
        denom.muli((Number)0.5).muli(p.mul(denom)).sum(new int[]{0});
        denom.subi(gradient.mul(p).sum(new int[]{0}));
        double rho = (currScore - score) / (Double)denom.getScalar(0).element();
        this.layerWiseConfigurations.setDampingFactor(currentDamp);
        if (score - currScore > 0.0) {
            return Double.NEGATIVE_INFINITY;
        }
        return rho;
    }

    protected List<Pair<INDArray, INDArray>> computeDeltas2() {
        int i;
        ArrayList<Pair<INDArray, INDArray>> deltaRet = new ArrayList<Pair<INDArray, INDArray>>();
        List<INDArray> activations = this.feedForward();
        INDArray[] deltas = new INDArray[activations.size() - 1];
        INDArray[] preCons = new INDArray[activations.size() - 1];
        INDArray ix = activations.get(activations.size() - 1).sub(this.labels).div((Number)this.labels.slices());
        ArrayList<INDArray> weights = new ArrayList<INDArray>();
        ArrayList<INDArray> biases = new ArrayList<INDArray>();
        ArrayList<String> activationFunctions = new ArrayList<String>();
        for (int j = 0; j < this.getLayers().length; ++j) {
            weights.add(this.getLayers()[j].getParam("W"));
            biases.add(this.getLayers()[j].getParam("b"));
            activationFunctions.add(this.getLayers()[j].conf().getLayer().getActivationFunction());
        }
        for (i = weights.size() - 1; i >= 0; --i) {
            deltas[i] = activations.get(i).transpose().mmul(ix);
            preCons[i] = Transforms.pow((INDArray)activations.get(i).transpose(), (Number)2).mmul(Transforms.pow((INDArray)ix, (Number)2)).muli((Number)this.labels.slices());
            if (i <= 0) continue;
            ix = ix.mmul(((INDArray)weights.get(i)).transpose()).muli(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform((String)activationFunctions.get(i - 1), activations.get(i)).derivative()));
        }
        for (i = 0; i < deltas.length; ++i) {
            if (this.defaultConfiguration.isConstrainGradientToUnitNorm()) {
                deltaRet.add(new Pair<INDArray, INDArray>(deltas[i].divi(deltas[i].norm2(new int[]{Integer.MAX_VALUE})), preCons[i]));
                continue;
            }
            deltaRet.add(new Pair<INDArray, INDArray>(deltas[i], preCons[i]));
        }
        return deltaRet;
    }

    public INDArray getBackPropRGradient(INDArray v) {
        return this.pack(this.backPropGradientR(v));
    }

    public Pair<INDArray, INDArray> getBackPropGradient2() {
        List<Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>>> deltas = this.backPropGradient2();
        ArrayList<Pair<INDArray, INDArray>> deltaNormal = new ArrayList<Pair<INDArray, INDArray>>();
        ArrayList<Pair<INDArray, INDArray>> deltasPreCon = new ArrayList<Pair<INDArray, INDArray>>();
        for (int i = 0; i < deltas.size(); ++i) {
            deltaNormal.add(deltas.get(i).getFirst());
            deltasPreCon.add(deltas.get(i).getSecond());
        }
        return new Pair<INDArray, INDArray>(this.pack(deltaNormal), this.pack(deltasPreCon));
    }

    @Override
    public MultiLayerNetwork clone() {
        MultiLayerNetwork ret;
        try {
            Constructor<?> constructor = this.getClass().getDeclaredConstructor(MultiLayerConfiguration.class);
            ret = (MultiLayerNetwork)constructor.newInstance(this.getLayerWiseConfigurations());
            ret.update(this);
        }
        catch (Exception e) {
            throw new IllegalStateException("Unable to cloe network");
        }
        return ret;
    }

    @Override
    public INDArray params() {
        if (this.params != null) {
            return this.params;
        }
        ArrayList<INDArray> params = new ArrayList<INDArray>();
        for (Layer layer : this.getLayers()) {
            params.add(layer.params());
        }
        return Nd4j.toFlattened((char)'f', params);
    }

    @Override
    public void setParams(INDArray params) {
        if (this.params != null) {
            this.params = params;
        }
        int idx = 0;
        for (int i = 0; i < this.getLayers().length; ++i) {
            Layer layer = this.getLayer(i);
            int range = layer.numParams();
            INDArray get = params.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)idx, (int)(range + idx))});
            layer.setParams(get);
            idx += range;
        }
    }

    @Override
    public int numParams() {
        int length = 0;
        for (int i = 0; i < this.layers.length; ++i) {
            length += this.layers[i].numParams();
        }
        return length;
    }

    public INDArray pack() {
        return this.params();
    }

    public INDArray pack(List<Pair<INDArray, INDArray>> layers) {
        ArrayList<INDArray> list = new ArrayList<INDArray>();
        for (Pair<INDArray, INDArray> layer : layers) {
            list.add(layer.getFirst());
            list.add(layer.getSecond());
        }
        return Nd4j.toFlattened(list);
    }

    @Override
    public double f1Score(DataSet data) {
        return this.f1Score(data.getFeatureMatrix(), data.getLabels());
    }

    public List<Pair<INDArray, INDArray>> unPack(INDArray param) {
        if (param.slices() != 1) {
            param = param.reshape(1, param.length());
        }
        ArrayList<Pair<INDArray, INDArray>> ret = new ArrayList<Pair<INDArray, INDArray>>();
        int curr = 0;
        for (int i = 0; i < this.layers.length; ++i) {
            int layerLength = this.layers[i].getParam("W").length() + this.layers[i].getParam("b").length();
            INDArray subMatrix = param.get(new INDArrayIndex[]{NDArrayIndex.interval((int)curr, (int)(curr + layerLength))});
            INDArray weightPortion = subMatrix.get(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)this.layers[i].getParam("W").length())});
            int beginHBias = this.layers[i].getParam("W").length();
            int endHbias = subMatrix.length();
            INDArray hBiasPortion = subMatrix.get(new INDArrayIndex[]{NDArrayIndex.interval((int)beginHBias, (int)endHbias)});
            int layerLengthSum = weightPortion.length() + hBiasPortion.length();
            if (layerLengthSum != layerLength) {
                if (hBiasPortion.length() != this.layers[i].getParam("b").length()) {
                    throw new IllegalStateException("Hidden bias on layer " + i + " was off");
                }
                if (weightPortion.length() != this.layers[i].getParam("W").length()) {
                    throw new IllegalStateException("Weight portion on layer " + i + " was off");
                }
            }
            ret.add(new Pair<INDArray, INDArray>(weightPortion.reshape(this.layers[i].getParam("W").slices(), this.layers[i].getParam("W").columns()), hBiasPortion.reshape(this.layers[i].getParam("b").slices(), this.layers[i].getParam("b").columns())));
            curr += layerLength;
        }
        return ret;
    }

    protected List<Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>>> backPropGradient2() {
        List<Pair<INDArray, INDArray>> deltas = this.computeDeltas2();
        ArrayList<Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>>> list = new ArrayList<Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>>>();
        ArrayList<Pair<INDArray, INDArray>> grad = new ArrayList<Pair<INDArray, INDArray>>();
        ArrayList<Pair<INDArray, INDArray>> preCon = new ArrayList<Pair<INDArray, INDArray>>();
        for (int l = 0; l < deltas.size(); ++l) {
            INDArray gradientChange = deltas.get(l).getFirst();
            INDArray preConGradientChange = deltas.get(l).getSecond();
            if (l < this.layers.length && gradientChange.length() != this.layers[l].getParam("W").length()) {
                throw new IllegalStateException("Gradient change not equal to weight change");
            }
            INDArray deltaColumnSums = deltas.get(l).getFirst().mean(new int[]{0});
            INDArray preConColumnSums = deltas.get(l).getSecond().mean(new int[]{0});
            grad.add(new Pair<INDArray, INDArray>(gradientChange, deltaColumnSums));
            preCon.add(new Pair<INDArray, INDArray>(preConGradientChange, preConColumnSums));
            if (l < this.layers.length && deltaColumnSums.length() != this.layers[l].getParam("b").length()) {
                throw new IllegalStateException("Bias change not equal to weight change");
            }
            if (l != this.getLayers().length || deltaColumnSums.length() == this.getOutputLayer().getParam("b").length()) continue;
            throw new IllegalStateException("Bias change not equal to weight change");
        }
        INDArray g = this.pack(grad);
        INDArray con = this.pack(preCon);
        INDArray theta = this.params();
        if (this.getOutputLayer().conf().isUseDropConnect() || this.getOutputLayer().conf().getLayer().getDropOut() > 0.0) {
            if (this.mask == null) {
                this.initMask();
            }
            g.addi(theta.mul((Number)this.defaultConfiguration.getL2()).muli(this.mask));
            INDArray conAdd = Transforms.pow((INDArray)this.mask.mul((Number)this.defaultConfiguration.getL2()).add(Nd4j.valueArrayOf((int)g.slices(), (int)g.columns(), (double)this.layerWiseConfigurations.getDampingFactor())), (Number)0.75);
            con.addi(conAdd);
        }
        List<Pair<INDArray, INDArray>> gUnpacked = this.unPack(g);
        List<Pair<INDArray, INDArray>> conUnpacked = this.unPack(con);
        for (int i = 0; i < gUnpacked.size(); ++i) {
            list.add(new Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>>(gUnpacked.get(i), conUnpacked.get(i)));
        }
        return list;
    }

    @Override
    public void fit(DataSetIterator iter) {
        org.nd4j.linalg.dataset.DataSet next;
        if (this.layerWiseConfigurations.isPretrain()) {
            this.pretrain(iter);
            iter.reset();
            while (iter.hasNext() && (next = (org.nd4j.linalg.dataset.DataSet)iter.next()).getFeatureMatrix() != null && next.getLabels() != null) {
                this.setInput(next.getFeatureMatrix());
                this.setLabels(next.getLabels());
                this.finetune();
            }
        }
        if (this.layerWiseConfigurations.isBackprop()) {
            iter.reset();
            while (iter.hasNext() && (next = (org.nd4j.linalg.dataset.DataSet)iter.next()).getFeatureMatrix() != null && next.getLabels() != null) {
                if (this.layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) {
                    this.doTruncatedBPTT(next.getFeatureMatrix(), next.getLabels());
                    continue;
                }
                this.setInput(next.getFeatureMatrix());
                this.setLabels(next.getLabels());
                if (this.solver == null) {
                    this.solver = new Solver.Builder().configure(this.conf()).listeners(this.getListeners()).model(this).build();
                }
                this.solver.optimize();
            }
        }
    }

    protected void backprop() {
        String multiGradientKey;
        this.gradient = new DefaultGradient();
        if (!(this.getOutputLayer() instanceof BaseOutputLayer)) {
            log.warn("Warning: final layer isn't output layer. You cannot use backprop without an output layer.");
            return;
        }
        BaseOutputLayer outputLayer = (BaseOutputLayer)this.getOutputLayer();
        if (this.labels == null) {
            throw new IllegalStateException("No labels found");
        }
        if (outputLayer.conf().getLayer().getWeightInit() == WeightInit.ZERO) {
            throw new IllegalStateException("Output layer weights cannot be initialized to zero when using backprop.");
        }
        outputLayer.setLabels(this.labels);
        int numLayers = this.getnLayers();
        LinkedList<Pair> gradientList = new LinkedList<Pair>();
        Pair<Gradient, INDArray> currPair = outputLayer.backpropGradient(null);
        for (Map.Entry<String, INDArray> entry : currPair.getFirst().gradientForVariable().entrySet()) {
            multiGradientKey = String.valueOf(numLayers - 1) + "_" + entry.getKey();
            gradientList.addLast(new Pair<String, INDArray>(multiGradientKey, entry.getValue()));
        }
        if (this.getLayerWiseConfigurations().getInputPreProcess(numLayers - 1) != null) {
            currPair = new Pair<Gradient, INDArray>(currPair.getFirst(), this.layerWiseConfigurations.getInputPreProcess(numLayers - 1).backprop(currPair.getSecond(), this.layers[numLayers - 1]));
        }
        for (int j = numLayers - 2; j >= 0; --j) {
            Layer currLayer = this.getLayer(j);
            currPair = currLayer.backpropGradient(currPair.getSecond());
            LinkedList<Pair<String, INDArray>> tempList = new LinkedList<Pair<String, INDArray>>();
            for (Map.Entry<String, INDArray> entry : currPair.getFirst().gradientForVariable().entrySet()) {
                multiGradientKey = String.valueOf(j) + "_" + entry.getKey();
                tempList.addFirst(new Pair<String, INDArray>(multiGradientKey, entry.getValue()));
            }
            for (Pair pair : tempList) {
                gradientList.addFirst(pair);
            }
            if (this.getLayerWiseConfigurations().getInputPreProcess(j) == null) continue;
            currPair = new Pair<Gradient, INDArray>(currPair.getFirst(), this.getLayerWiseConfigurations().getInputPreProcess(j).backprop(currPair.getSecond(), this.layers[j]));
        }
        for (Pair pair : gradientList) {
            this.gradient.setGradientFor((String)pair.getFirst(), (INDArray)pair.getSecond());
        }
    }

    protected void doTruncatedBPTT(INDArray input, INDArray labels) {
        if (input.rank() != 3 || labels.rank() != 3) {
            log.warn("Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength]");
            return;
        }
        if (input.size(2) != labels.size(2)) {
            log.warn("Input and label time series have different lengths: {} input length, {} label length", (Object)input.size(2), (Object)labels.size(2));
            return;
        }
        int fwdLen = this.layerWiseConfigurations.getTbpttFwdLength();
        int timeSeriesLength = input.size(2);
        int nSubsets = timeSeriesLength / fwdLen;
        if (fwdLen > timeSeriesLength) {
            log.warn("Cannot do TBPTT: Truncated BPTT forward length > input time series length.");
            return;
        }
        this.rnnClearPreviousState();
        for (int i = 0; i < nSubsets; ++i) {
            int startTimeIdx = i * fwdLen;
            int endTimeIdx = startTimeIdx + fwdLen;
            INDArray inputSubset = input.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval((int)startTimeIdx, (int)endTimeIdx)});
            INDArray labelSubset = labels.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval((int)startTimeIdx, (int)endTimeIdx)});
            this.setInput(inputSubset);
            this.setLabels(labelSubset);
            if (this.solver == null) {
                this.solver = new Solver.Builder().configure(this.conf()).listeners(this.getListeners()).model(this).build();
            }
            this.solver.optimize();
            this.updateRnnStateWithTBPTTState();
        }
        this.rnnClearPreviousState();
    }

    protected void updateRnnStateWithTBPTTState() {
        for (int i = 0; i < this.layers.length; ++i) {
            if (this.layers[i] instanceof BaseRecurrentLayer) {
                BaseRecurrentLayer l = (BaseRecurrentLayer)this.layers[i];
                l.rnnSetPreviousState(l.rnnGetTBPTTState());
                continue;
            }
            if (!(this.layers[i] instanceof MultiLayerNetwork)) continue;
            ((MultiLayerNetwork)this.layers[i]).updateRnnStateWithTBPTTState();
        }
    }

    protected void truncatedBPTTGradient() {
        String multiGradientKey;
        this.gradient = new DefaultGradient();
        if (!(this.getOutputLayer() instanceof BaseOutputLayer)) {
            log.warn("Warning: final layer isn't output layer. You cannot use backprop (truncated BPTT) without an output layer.");
            return;
        }
        BaseOutputLayer outputLayer = (BaseOutputLayer)this.getOutputLayer();
        if (this.labels == null) {
            throw new IllegalStateException("No labels found");
        }
        if (outputLayer.conf().getLayer().getWeightInit() == WeightInit.ZERO) {
            throw new IllegalStateException("Output layer weights cannot be initialized to zero when using backprop.");
        }
        outputLayer.setLabels(this.labels);
        int numLayers = this.getnLayers();
        LinkedList<Pair> gradientList = new LinkedList<Pair>();
        Pair<Gradient, INDArray> currPair = outputLayer.backpropGradient(null);
        for (Map.Entry<String, INDArray> entry : currPair.getFirst().gradientForVariable().entrySet()) {
            multiGradientKey = String.valueOf(numLayers - 1) + "_" + entry.getKey();
            gradientList.addLast(new Pair<String, INDArray>(multiGradientKey, entry.getValue()));
        }
        if (this.getLayerWiseConfigurations().getInputPreProcess(numLayers - 1) != null) {
            currPair = new Pair<Gradient, INDArray>(currPair.getFirst(), this.layerWiseConfigurations.getInputPreProcess(numLayers - 1).backprop(currPair.getSecond(), this.layers[numLayers - 1]));
        }
        for (int j = numLayers - 2; j >= 0; --j) {
            Layer currLayer = this.getLayer(j);
            currPair = currLayer instanceof BaseRecurrentLayer ? ((BaseRecurrentLayer)currLayer).tbpttBackpropGradient(currPair.getSecond(), this.layerWiseConfigurations.getTbpttBackLength()) : currLayer.backpropGradient(currPair.getSecond());
            LinkedList<Pair<String, INDArray>> tempList = new LinkedList<Pair<String, INDArray>>();
            for (Map.Entry<String, INDArray> entry : currPair.getFirst().gradientForVariable().entrySet()) {
                multiGradientKey = String.valueOf(j) + "_" + entry.getKey();
                tempList.addFirst(new Pair<String, INDArray>(multiGradientKey, entry.getValue()));
            }
            for (Pair pair : tempList) {
                gradientList.addFirst(pair);
            }
            if (this.getLayerWiseConfigurations().getInputPreProcess(j) == null) continue;
            currPair = new Pair<Gradient, INDArray>(currPair.getFirst(), this.getLayerWiseConfigurations().getInputPreProcess(j).backprop(currPair.getSecond(), this.layers[j]));
        }
        for (Pair pair : gradientList) {
            this.gradient.setGradientFor((String)pair.getFirst(), (INDArray)pair.getSecond());
        }
    }

    @Override
    public Collection<IterationListener> getListeners() {
        return this.listeners;
    }

    @Override
    public void setListeners(Collection<IterationListener> listeners) {
        this.listeners = listeners;
        if (this.layers == null) {
            this.init();
        }
        for (Layer layer : this.layers) {
            layer.setListeners(listeners);
        }
    }

    @Override
    public void setListeners(IterationListener ... listeners) {
        ArrayList<IterationListener> cListeners = new ArrayList<IterationListener>();
        for (IterationListener listener : listeners) {
            cListeners.add(listener);
        }
        this.setListeners(cListeners);
    }

    public void finetune() {
        if (!(this.getOutputLayer() instanceof BaseOutputLayer)) {
            log.warn("Output layer not instance of output layer returning.");
            return;
        }
        if (this.labels == null) {
            throw new IllegalStateException("No labels found");
        }
        log.info("Finetune phase");
        BaseOutputLayer output = (BaseOutputLayer)this.getOutputLayer();
        if (output.conf().getOptimizationAlgo() == OptimizationAlgorithm.HESSIAN_FREE) {
            throw new UnsupportedOperationException();
        }
        this.feedForward();
        output.fit(output.input(), this.labels);
    }

    @Override
    public int[] predict(INDArray d) {
        INDArray output = this.output(d);
        int[] ret = new int[d.slices()];
        if (d.isRowVector()) {
            ret[0] = Nd4j.getBlasWrapper().iamax(output);
        } else {
            for (int i = 0; i < ret.length; ++i) {
                ret[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
            }
        }
        return ret;
    }

    @Override
    public INDArray labelProbabilities(INDArray examples) {
        List<INDArray> feed = this.feedForward(examples);
        BaseOutputLayer o = (BaseOutputLayer)this.getOutputLayer();
        return o.labelProbabilities(feed.get(feed.size() - 1));
    }

    @Override
    public void fit(INDArray data, INDArray labels) {
        this.setInput(data.dup());
        this.setLabels(labels.dup());
        if (this.layerWiseConfigurations.isPretrain()) {
            this.pretrain(data);
            this.finetune();
        }
        if (this.layerWiseConfigurations.isBackprop()) {
            if (this.layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) {
                this.doTruncatedBPTT(data, labels);
            } else {
                if (this.solver == null) {
                    this.solver = new Solver.Builder().configure(this.conf()).listeners(this.getListeners()).model(this).build();
                }
                this.solver.optimize();
            }
        }
    }

    @Override
    public void fit(INDArray data) {
        this.setInput(data.dup());
        this.pretrain(data);
    }

    @Override
    public void iterate(INDArray input) {
        this.pretrain(input);
    }

    @Override
    public void fit(DataSet data) {
        this.fit(data.getFeatureMatrix(), data.getLabels());
    }

    @Override
    public void fit(INDArray examples, int[] labels) {
        OutputLayer layerConf = (OutputLayer)this.getOutputLayer().conf().getLayer();
        this.fit(examples, FeatureUtil.toOutcomeMatrix((int[])labels, (int)layerConf.getNOut()));
    }

    public INDArray output(INDArray input, Layer.TrainingMode train) {
        return this.output(input, train == Layer.TrainingMode.TRAIN);
    }

    public INDArray output(INDArray input, boolean train) {
        List<INDArray> activations = this.feedForward(input, train);
        return activations.get(activations.size() - 1);
    }

    public INDArray output(INDArray input) {
        return this.output(input, Layer.TrainingMode.TRAIN);
    }

    public INDArray reconstruct(INDArray x, int layerNum) {
        List<INDArray> forward = this.feedForward(x);
        return forward.get(layerNum - 1);
    }

    public void printConfiguration() {
        StringBuilder sb = new StringBuilder();
        int count = 0;
        for (NeuralNetConfiguration conf : this.getLayerWiseConfigurations().getConfs()) {
            sb.append(" Layer " + count++ + " conf " + conf);
        }
        log.info(sb.toString());
    }

    public void update(MultiLayerNetwork network) {
        this.defaultConfiguration = network.defaultConfiguration;
        this.setInput(network.input);
        this.labels = network.labels;
        this.layers = (Layer[])ArrayUtils.clone((Object[])network.layers);
    }

    @Override
    public double f1Score(INDArray input, INDArray labels) {
        this.feedForward(input);
        this.setLabels(labels);
        Evaluation eval = new Evaluation();
        eval.eval(labels, this.labelProbabilities(input));
        return eval.f1();
    }

    @Override
    public int numLabels() {
        return this.labels.columns();
    }

    public double score(org.nd4j.linalg.dataset.DataSet data) {
        this.feedForward(data.getFeatureMatrix());
        this.setLabels(data.getLabels());
        if (!(this.getOutputLayer() instanceof BaseOutputLayer)) {
            log.warn("Cannot calculate score wrt labels without an OutputLayer");
            return 0.0;
        }
        BaseOutputLayer ol = (BaseOutputLayer)this.getOutputLayer();
        ol.setLabels(data.getLabels());
        ol.computeScore(this.calcL1(), this.calcL2());
        this.score = ol.score();
        return this.score();
    }

    @Override
    public void fit() {
        this.fit(this.input, this.labels);
    }

    @Override
    public void update(INDArray gradient, String paramType) {
    }

    @Override
    public double score() {
        return this.score;
    }

    @Override
    public void computeGradientAndScore() {
        if (this.layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) {
            this.rnnActivateUsingStoredState(this.getInput(), true, true);
            this.truncatedBPTTGradient();
        } else {
            this.feedForward();
            this.backprop();
        }
        this.score = ((BaseOutputLayer)this.getOutputLayer()).computeScore(this.calcL1(), this.calcL2());
    }

    @Override
    public void accumulateScore(double accum) {
    }

    @Override
    public void clear() {
        for (Layer layer : this.layers) {
            layer.clear();
        }
        this.input = null;
        this.labels = null;
        this.solver = null;
    }

    public double score(INDArray param) {
        INDArray params = this.params();
        this.setParameters(param);
        double ret = this.score();
        double regCost = 0.5 * this.defaultConfiguration.getL2() * (Double)Transforms.pow((INDArray)this.mask.mul(param), (Number)2).sum(new int[]{Integer.MAX_VALUE}).element();
        this.setParameters(params);
        return ret + regCost;
    }

    @Override
    public void merge(Layer layer, int batchSize) {
        throw new UnsupportedOperationException();
    }

    public void merge(MultiLayerNetwork network, int batchSize) {
        if (network.layers.length != this.layers.length) {
            throw new IllegalArgumentException("Unable to merge networks that are not of equal length");
        }
        for (int i = 0; i < this.getnLayers(); ++i) {
            Layer n = this.layers[i];
            Layer otherNetwork = network.layers[i];
            n.merge(otherNetwork, batchSize);
        }
        this.getOutputLayer().merge(network.getOutputLayer(), batchSize);
    }

    @Override
    public void setInput(INDArray input) {
        this.input = input;
        if (this.layers == null) {
            this.initializeLayers(this.getInput());
        }
        if (input != null) {
            this.setInputMiniBatchSize(input.size(0));
        }
    }

    private void initMask() {
        this.setMask(Nd4j.ones((int)1, (int)this.pack().length()));
    }

    public Layer getOutputLayer() {
        return this.getLayers()[this.getLayers().length - 1];
    }

    public void setParameters(INDArray params) {
        int idx = 0;
        for (int i = 0; i < this.getLayers().length; ++i) {
            Layer layer = this.getLayer(i);
            if (layer instanceof SubsamplingLayer) continue;
            int range = layer.numParams();
            INDArray get = params.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)idx, (int)(range + idx))});
            if (get.length() < 1) {
                throw new IllegalStateException("Unable to retrieve layer. No params found (length was 0");
            }
            layer.setParams(get);
            idx += range;
        }
    }

    public List<INDArray> feedForwardR(List<INDArray> acts, INDArray v) {
        ArrayList<INDArray> R = new ArrayList<INDArray>();
        R.add(Nd4j.zeros((int)this.input.slices(), (int)this.input.columns()));
        List<Pair<INDArray, INDArray>> vWvB = this.unPack(v);
        List<INDArray> W = MultiLayerUtil.weightMatrices(this);
        for (int i = 0; i < this.layers.length; ++i) {
            String derivative = this.getLayers()[i].conf().getLayer().getActivationFunction();
            R.add(((INDArray)R.get(i)).mmul(W.get(i)).addi(acts.get(i).mmul(vWvB.get(i).getFirst().addiRowVector(vWvB.get(i).getSecond()))).muli(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(derivative, acts.get(i + 1)).derivative())));
        }
        return R;
    }

    protected List<Pair<INDArray, INDArray>> backPropGradientR(INDArray v) {
        if (this.mask == null) {
            this.initMask();
        }
        List<INDArray> deltas = this.computeDeltasR(v);
        ArrayList<Pair<INDArray, INDArray>> list = new ArrayList<Pair<INDArray, INDArray>>();
        for (int l = 0; l < this.getnLayers(); ++l) {
            INDArray gradientChange = deltas.get(l);
            if (gradientChange.length() != this.getLayers()[l].getParam("W").length()) {
                throw new IllegalStateException("Gradient change not equal to weight change");
            }
            INDArray deltaColumnSums = deltas.get(l).mean(new int[]{0});
            if (deltaColumnSums.length() != this.layers[l].getParam("b").length()) {
                throw new IllegalStateException("Bias change not equal to weight change");
            }
            list.add(new Pair<INDArray, INDArray>(gradientChange, deltaColumnSums));
        }
        INDArray pack = this.pack(list).addi(this.mask.mul((Number)this.defaultConfiguration.getL2()).muli(v)).addi(v.mul((Number)this.layerWiseConfigurations.getDampingFactor()));
        return this.unPack(pack);
    }

    public INDArray getLabels() {
        return this.labels;
    }

    public INDArray getInput() {
        return this.input;
    }

    public void setLabels(INDArray labels) {
        this.labels = labels;
    }

    public int getnLayers() {
        return this.layerWiseConfigurations.getConfs().size();
    }

    public Layer[] getLayers() {
        return this.layers;
    }

    public Layer getLayer(int i) {
        return this.layers[i];
    }

    public void setLayers(Layer[] layers) {
        this.layers = layers;
    }

    public INDArray getMask() {
        return this.mask;
    }

    public void setMask(INDArray mask) {
        this.mask = mask;
    }

    @Override
    public Gradient error(INDArray errorSignal) {
        throw new UnsupportedOperationException();
    }

    @Override
    public Layer.Type type() {
        return Layer.Type.MULTILAYER;
    }

    @Override
    public INDArray derivativeActivation(INDArray input) {
        throw new UnsupportedOperationException();
    }

    @Override
    public Gradient calcGradient(Gradient layerError, INDArray activation) {
        throw new UnsupportedOperationException();
    }

    @Override
    public INDArray preOutput(INDArray x) {
        INDArray lastLayerActivation = x;
        for (int i = 0; i < this.layers.length - 1; ++i) {
            if (this.getLayerWiseConfigurations().getInputPreProcess(i) != null) {
                lastLayerActivation = this.getLayerWiseConfigurations().getInputPreProcess(i).preProcess(lastLayerActivation, this.layers[i]);
            }
            lastLayerActivation = this.layers[i].activate(lastLayerActivation);
        }
        if (this.getLayerWiseConfigurations().getInputPreProcess(this.layers.length - 1) != null) {
            lastLayerActivation = this.getLayerWiseConfigurations().getInputPreProcess(this.layers.length - 1).preProcess(lastLayerActivation, this.layers[this.layers.length - 1]);
        }
        return this.layers[this.layers.length - 1].preOutput(lastLayerActivation);
    }

    @Override
    public INDArray preOutput(INDArray x, Layer.TrainingMode training) {
        return this.preOutput(x, training == Layer.TrainingMode.TRAIN);
    }

    @Override
    public INDArray activate(Layer.TrainingMode training) {
        return this.activate(training == Layer.TrainingMode.TRAIN);
    }

    @Override
    public INDArray activate(INDArray input, Layer.TrainingMode training) {
        return this.activate(input, training == Layer.TrainingMode.TRAIN);
    }

    @Override
    public Layer transpose() {
        throw new UnsupportedOperationException();
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public void setIndex(int index) {
        this.layerIndex = index;
    }

    @Override
    public int getIndex() {
        return this.layerIndex;
    }

    @Override
    public double calcL2() {
        double l2 = 0.0;
        for (int i = 0; i < this.layers.length; ++i) {
            l2 += this.layers[i].calcL2();
        }
        return l2;
    }

    @Override
    public double calcL1() {
        double l1 = 0.0;
        for (int i = 0; i < this.layers.length; ++i) {
            l1 += this.layers[i].calcL1();
        }
        return l1;
    }

    @Override
    public void update(Gradient gradient) {
        throw new UnsupportedOperationException();
    }

    @Override
    public INDArray preOutput(INDArray x, boolean training) {
        throw new UnsupportedOperationException();
    }

    @Override
    public INDArray activate(boolean training) {
        throw new UnsupportedOperationException();
    }

    @Override
    public INDArray activate(INDArray input, boolean training) {
        throw new UnsupportedOperationException();
    }

    @Override
    public void setInputMiniBatchSize(int size) {
        if (this.layers != null) {
            for (Layer l : this.layers) {
                l.setInputMiniBatchSize(size);
            }
        }
    }

    @Override
    public int getInputMiniBatchSize() {
        return this.layers[0].getInputMiniBatchSize();
    }

    public INDArray rnnTimeStep(INDArray input) {
        this.setInputMiniBatchSize(input.size(0));
        boolean inputIs2d = input.rank() == 2;
        for (int i = 0; i < this.layers.length; ++i) {
            if (this.getLayerWiseConfigurations().getInputPreProcess(i) != null) {
                input = this.getLayerWiseConfigurations().getInputPreProcess(i).preProcess(input, this.layers[i]);
            }
            input = this.layers[i] instanceof BaseRecurrentLayer ? ((BaseRecurrentLayer)this.layers[i]).rnnTimeStep(input) : (this.layers[i] instanceof MultiLayerNetwork ? ((MultiLayerNetwork)this.layers[i]).rnnTimeStep(input) : this.layers[i].activate(input, false));
        }
        if (inputIs2d && input.rank() == 3 && this.layers[this.layers.length - 1].type() == Layer.Type.RECURRENT) {
            return input.tensorAlongDimension(0, new int[]{1, 0});
        }
        return input;
    }

    public Map<String, INDArray> rnnGetPreviousState(int layer) {
        if (layer < 0 || layer >= this.layers.length) {
            throw new IllegalArgumentException("Invalid layer number");
        }
        if (!(this.layers[layer] instanceof BaseRecurrentLayer)) {
            throw new IllegalArgumentException("Layer is not an RNN layer");
        }
        return ((BaseRecurrentLayer)this.layers[layer]).rnnGetPreviousState();
    }

    public void rnnSetPreviousState(int layer, Map<String, INDArray> state) {
        if (layer < 0 || layer >= this.layers.length) {
            throw new IllegalArgumentException("Invalid layer number");
        }
        if (!(this.layers[layer] instanceof BaseRecurrentLayer)) {
            throw new IllegalArgumentException("Layer is not an RNN layer");
        }
        BaseRecurrentLayer r = (BaseRecurrentLayer)this.layers[layer];
        r.rnnSetPreviousState(state);
    }

    public void rnnClearPreviousState() {
        if (this.layers == null) {
            return;
        }
        for (int i = 0; i < this.layers.length; ++i) {
            if (this.layers[i] instanceof BaseRecurrentLayer) {
                ((BaseRecurrentLayer)this.layers[i]).rnnClearPreviousState();
                continue;
            }
            if (!(this.layers[i] instanceof MultiLayerNetwork)) continue;
            ((MultiLayerNetwork)this.layers[i]).rnnClearPreviousState();
        }
    }

    public List<INDArray> rnnActivateUsingStoredState(INDArray input, boolean training, boolean storeLastForTBPTT) {
        this.setInputMiniBatchSize(input.size(0));
        INDArray currInput = input;
        ArrayList<INDArray> activations = new ArrayList<INDArray>();
        activations.add(currInput);
        for (int i = 0; i < this.layers.length; ++i) {
            if (this.getLayerWiseConfigurations().getInputPreProcess(i) != null) {
                currInput = this.getLayerWiseConfigurations().getInputPreProcess(i).preProcess(currInput, this.layers[i]);
            }
            if (this.layers[i] instanceof BaseRecurrentLayer) {
                currInput = ((BaseRecurrentLayer)this.layers[i]).rnnActivateUsingStoredState(currInput, training, storeLastForTBPTT);
            } else if (this.layers[i] instanceof MultiLayerNetwork) {
                List<INDArray> temp = ((MultiLayerNetwork)this.layers[i]).rnnActivateUsingStoredState(currInput, training, storeLastForTBPTT);
                currInput = temp.get(temp.size() - 1);
            } else {
                currInput = this.layers[i].activate(currInput, training);
            }
            activations.add(currInput);
        }
        return activations;
    }
}

