/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.multilayer;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.berkeley.Triple;
import org.deeplearning4j.datasets.iterator.AsyncDataSetIterator;
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.IEvaluation;
import org.deeplearning4j.eval.ROC;
import org.deeplearning4j.eval.ROCMultiClass;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.Classifier;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.api.layers.RecurrentLayer;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.FrozenLayer;
import org.deeplearning4j.nn.updater.MultiLayerUpdater;
import org.deeplearning4j.nn.updater.UpdaterCreator;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.Solver;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator;
import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.util.OneTimeLogger;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.MirroringPolicy;
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.heartbeat.Heartbeat;
import org.nd4j.linalg.heartbeat.reports.Environment;
import org.nd4j.linalg.heartbeat.reports.Event;
import org.nd4j.linalg.heartbeat.reports.Task;
import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils;
import org.nd4j.linalg.heartbeat.utils.TaskUtils;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.memory.abstracts.DummyWorkspace;
import org.nd4j.linalg.util.FeatureUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MultiLayerNetwork
implements Serializable,
Classifier,
Layer,
NeuralNetwork {
    private static final Logger log = LoggerFactory.getLogger(MultiLayerNetwork.class);
    protected Layer[] layers;
    protected LinkedHashMap<String, Layer> layerMap = new LinkedHashMap();
    protected INDArray input;
    protected INDArray labels;
    protected boolean initCalled = false;
    private Collection<IterationListener> listeners = new ArrayList<IterationListener>();
    private Collection<TrainingListener> trainingListeners = new ArrayList<TrainingListener>();
    protected NeuralNetConfiguration defaultConfiguration;
    protected MultiLayerConfiguration layerWiseConfigurations;
    protected Gradient gradient;
    protected INDArray epsilon;
    protected double score;
    protected boolean initDone = false;
    protected INDArray flattenedParams;
    protected transient INDArray flattenedGradients;
    protected ThreadLocal<Long> lastEtlTime = new ThreadLocal();
    protected INDArray mask;
    protected int layerIndex;
    protected transient Solver solver;
    protected static final String workspaceExternal = "LOOP_EXTERNAL";
    protected static final String workspaceFeedForward = "LOOP_FF";
    protected static final String workspaceBackProp = "LOOP_BP";
    public static final String workspaceTBPTT = "LOOP_TBPTT";
    protected static final WorkspaceConfiguration workspaceConfigurationExternal = WorkspaceConfiguration.builder().initialSize(0L).overallocationLimit(0.3).policyLearning(LearningPolicy.FIRST_LOOP).policyReset(ResetPolicy.BLOCK_LEFT).policySpill(SpillPolicy.REALLOCATE).policyAllocation(AllocationPolicy.OVERALLOCATE).build();
    protected WorkspaceConfiguration workspaceConfigurationFeedForward = WorkspaceConfiguration.builder().initialSize(0L).overallocationLimit(0.2).policyReset(ResetPolicy.BLOCK_LEFT).policyLearning(LearningPolicy.OVER_TIME).policySpill(SpillPolicy.REALLOCATE).policyAllocation(AllocationPolicy.OVERALLOCATE).build();
    protected static final WorkspaceConfiguration workspaceConfigurationTBPTT = WorkspaceConfiguration.builder().initialSize(0L).overallocationLimit(0.2).policyReset(ResetPolicy.BLOCK_LEFT).policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE).policyLearning(LearningPolicy.OVER_TIME).build();

    public MultiLayerNetwork(MultiLayerConfiguration conf) {
        this.layerWiseConfigurations = conf;
        this.defaultConfiguration = conf.getConf(0).clone();
    }

    @Override
    public void setCacheMode(CacheMode mode) {
        if (mode == null) {
            mode = CacheMode.NONE;
        }
        for (Layer layer : this.layers) {
            layer.setCacheMode(mode);
        }
    }

    public void setLastEtlTime(long time) {
        this.lastEtlTime.set(time);
    }

    public long getLastEtlTime() {
        Long time = this.lastEtlTime.get();
        return time == null ? 0L : time;
    }

    public MultiLayerNetwork(String conf, INDArray params) {
        this(MultiLayerConfiguration.fromJson(conf));
        this.init();
        this.setParameters(params);
    }

    public MultiLayerNetwork(MultiLayerConfiguration conf, INDArray params) {
        this(conf);
        this.init();
        this.setParameters(params);
    }

    protected void intializeConfigurations() {
        if (this.layerWiseConfigurations == null) {
            this.layerWiseConfigurations = new MultiLayerConfiguration.Builder().build();
        }
        if (this.layers == null) {
            this.layers = new Layer[this.getnLayers()];
        }
        if (this.defaultConfiguration == null) {
            this.defaultConfiguration = new NeuralNetConfiguration.Builder().build();
        }
    }

    public void pretrain(DataSetIterator iter) {
        if (this.flattenedGradients == null) {
            this.initGradientsView();
        }
        if (!this.layerWiseConfigurations.isPretrain()) {
            return;
        }
        for (int i = 0; i < this.getnLayers(); ++i) {
            this.pretrainLayer(i, iter);
        }
    }

    public void pretrainLayer(int layerIdx, DataSetIterator iter) {
        if (this.flattenedGradients == null) {
            this.initGradientsView();
        }
        if (!this.layerWiseConfigurations.isPretrain()) {
            return;
        }
        if (layerIdx >= this.layers.length) {
            throw new IllegalArgumentException("Cannot pretrain layer: layerIdx (" + layerIdx + ") >= numLayers (" + this.layers.length + ")");
        }
        Layer layer = this.layers[layerIdx];
        if (!layer.isPretrainLayer()) {
            return;
        }
        if (!iter.hasNext() && iter.resetSupported()) {
            iter.reset();
        }
        DummyWorkspace workspace = this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceConfigurationExternal, workspaceExternal);
        DummyWorkspace cache = this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceConfigurationCache, "LOOP_CACHE");
        log.info("Starting unsupervised training on layer " + layerIdx);
        while (iter.hasNext()) {
            org.nd4j.linalg.dataset.DataSet next = (org.nd4j.linalg.dataset.DataSet)iter.next();
            MemoryWorkspace wsCache = cache.notifyScopeEntered();
            Throwable throwable = null;
            try {
                MemoryWorkspace ws = workspace.notifyScopeEntered();
                Throwable throwable2 = null;
                try {
                    this.input = next.getFeatureMatrix();
                    this.pretrainLayer(layerIdx, this.input);
                }
                catch (Throwable throwable3) {
                    throwable2 = throwable3;
                    throw throwable3;
                }
                finally {
                    if (ws == null) continue;
                    if (throwable2 != null) {
                        try {
                            ws.close();
                        }
                        catch (Throwable throwable4) {
                            throwable2.addSuppressed(throwable4);
                        }
                        continue;
                    }
                    ws.close();
                }
            }
            catch (Throwable throwable5) {
                throwable = throwable5;
                throw throwable5;
            }
            finally {
                if (wsCache == null) continue;
                if (throwable != null) {
                    try {
                        wsCache.close();
                    }
                    catch (Throwable throwable6) {
                        throwable.addSuppressed(throwable6);
                    }
                    continue;
                }
                wsCache.close();
            }
        }
    }

    public void pretrainLayer(int layerIdx, INDArray features) {
        DummyWorkspace workspace;
        Layer layer;
        if (this.flattenedGradients == null) {
            this.initGradientsView();
        }
        if (!this.layerWiseConfigurations.isPretrain()) {
            return;
        }
        if (layerIdx >= this.layers.length) {
            throw new IllegalArgumentException("Cannot pretrain layer: layerIdx (" + layerIdx + ") >= numLayers (" + this.layers.length + ")");
        }
        INDArray layerInput = features;
        if (layerIdx == 0 && this.getLayerWiseConfigurations().getInputPreProcess(0) != null) {
            layerInput = this.getLayerWiseConfigurations().getInputPreProcess(0).preProcess(this.input, this.input.size(0));
        }
        if (!(layer = this.layers[layerIdx]).isPretrainLayer()) {
            return;
        }
        layer.conf().setPretrain(true);
        DummyWorkspace dummyWorkspace = this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : (workspace = this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.SINGLE ? Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceExternal) : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(this.workspaceConfigurationFeedForward, workspaceFeedForward));
        DummyWorkspace pretrain = this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : (this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.SINGLE ? Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceExternal) : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(this.workspaceConfigurationFeedForward, "LOOP_PTR"));
        try (MemoryWorkspace wsP = pretrain.notifyScopeEntered();){
            for (int j = 0; j < layerIdx; ++j) {
                try (MemoryWorkspace wsFF = workspace.notifyScopeEntered();){
                    if (Nd4j.getWorkspaceManager().checkIfWorkspaceExists("LOOP_PTR")) {
                        layerInput = this.activationFromPrevLayer(j, layerInput, true).leverageTo("LOOP_PTR");
                        continue;
                    }
                    layerInput = this.activationFromPrevLayer(j, layerInput, true);
                    continue;
                }
            }
            layer.fit(layerInput);
        }
        layer.conf().setPretrain(false);
    }

    @Deprecated
    public void pretrain(INDArray input) {
        DummyWorkspace workspace;
        if (!this.layerWiseConfigurations.isPretrain()) {
            return;
        }
        if (this.flattenedGradients == null) {
            this.initGradientsView();
        }
        DummyWorkspace dummyWorkspace = this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : (workspace = this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.SINGLE ? Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceExternal) : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(this.workspaceConfigurationFeedForward, workspaceFeedForward));
        DummyWorkspace pretrain = this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : (this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.SINGLE ? Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceExternal) : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(this.workspaceConfigurationFeedForward, "LOOP_PTR"));
        int miniBatchSize = input.size(0);
        INDArray layerInput = null;
        int nPretrainLayers = this.getnLayers();
        if (this.getLayer(this.getnLayers() - 1) instanceof IOutputLayer) {
            --nPretrainLayers;
        }
        try (MemoryWorkspace wsP = pretrain.notifyScopeEntered();){
            for (int i = 0; i < nPretrainLayers; ++i) {
                try (MemoryWorkspace wsFF = workspace.notifyScopeEntered();){
                    Layer layer = this.getLayer(i);
                    layerInput = i == 0 ? (this.getLayerWiseConfigurations().getInputPreProcess(i) != null ? this.getLayerWiseConfigurations().getInputPreProcess(i).preProcess(input, miniBatchSize).leverageTo("LOOP_PTR") : input.leverageTo("LOOP_PTR")) : this.activationFromPrevLayer(i - 1, layerInput, true).leverageTo("LOOP_PTR");
                    layer.conf().setPretrain(true);
                    layer.fit(layerInput);
                    layer.conf().setPretrain(false);
                    continue;
                }
            }
        }
    }

    @Override
    public int batchSize() {
        return this.input.size(0);
    }

    @Override
    public NeuralNetConfiguration conf() {
        return this.defaultConfiguration;
    }

    @Override
    public void setConf(NeuralNetConfiguration conf) {
        throw new UnsupportedOperationException();
    }

    @Override
    public INDArray input() {
        return this.input;
    }

    @Override
    public void validateInput() {
    }

    @Override
    public ConvexOptimizer getOptimizer() {
        return this.solver.getOptimizer();
    }

    @Override
    public INDArray getParam(String param) {
        int idx = param.indexOf(95);
        if (idx == -1) {
            throw new IllegalStateException("Invalid param key: not have layer separator: \"" + param + "\"");
        }
        int layerIdx = Integer.parseInt(param.substring(0, idx));
        String newKey = param.substring(idx + 1);
        return this.layers[layerIdx].getParam(newKey);
    }

    @Override
    public void initParams() {
        throw new UnsupportedOperationException();
    }

    @Override
    public Map<String, INDArray> paramTable() {
        return this.paramTable(false);
    }

    @Override
    public Map<String, INDArray> paramTable(boolean backpropParamsOnly) {
        LinkedHashMap<String, INDArray> allParams = new LinkedHashMap<String, INDArray>();
        for (int i = 0; i < this.layers.length; ++i) {
            Map<String, INDArray> paramMap = this.layers[i].paramTable(backpropParamsOnly);
            for (Map.Entry<String, INDArray> entry : paramMap.entrySet()) {
                String newKey = i + "_" + entry.getKey();
                allParams.put(newKey, entry.getValue());
            }
        }
        return allParams;
    }

    @Override
    public void setParamTable(Map<String, INDArray> paramTable) {
        INDArray toSet;
        INDArray curr;
        Map<String, INDArray> currParamTable = this.paramTable();
        if (!currParamTable.keySet().equals(paramTable.keySet())) {
            throw new IllegalArgumentException("Cannot set param table: parameter keys do not match.\nCurrent: " + currParamTable.keySet() + "\nTo set: " + paramTable.keySet());
        }
        for (String s : paramTable.keySet()) {
            curr = currParamTable.get(s);
            toSet = paramTable.get(s);
            if (Arrays.equals(curr.shape(), toSet.shape())) continue;
            throw new IllegalArgumentException("Cannot set parameter table: parameter \"" + s + "\" shapes do not match. Current = " + Arrays.toString(curr.shape()) + ", to set = " + Arrays.toString(toSet.shape()));
        }
        for (String s : paramTable.keySet()) {
            curr = currParamTable.get(s);
            toSet = paramTable.get(s);
            curr.assign(toSet);
        }
    }

    @Override
    public void setParam(String key, INDArray val) {
        int idx = key.indexOf(95);
        if (idx == -1) {
            throw new IllegalStateException("Invalid param key: not have layer separator: \"" + key + "\"");
        }
        int layerIdx = Integer.parseInt(key.substring(0, idx));
        String newKey = key.substring(idx + 1);
        this.layers[layerIdx].setParam(newKey, val);
    }

    public MultiLayerConfiguration getLayerWiseConfigurations() {
        return this.layerWiseConfigurations;
    }

    public void setLayerWiseConfigurations(MultiLayerConfiguration layerWiseConfigurations) {
        this.layerWiseConfigurations = layerWiseConfigurations;
    }

    public void initializeLayers(INDArray input) {
        if (input == null) {
            throw new IllegalArgumentException("Unable to initialize neuralNets with empty input");
        }
        this.input = input;
        this.setInputMiniBatchSize(input.size(0));
        if (!this.initCalled) {
            this.init();
        }
    }

    @Override
    public void init() {
        this.init(null, false);
    }

    public void init(INDArray parameters, boolean cloneParametersArray) {
        int nLayers;
        if (this.layerWiseConfigurations == null || this.layers == null) {
            this.intializeConfigurations();
        }
        if (this.initCalled) {
            return;
        }
        OneTimeLogger.info(log, "Starting MultiLayerNetwork with WorkspaceModes set to [training: {}; inference: {}]", new Object[]{this.layerWiseConfigurations.getTrainingWorkspaceMode(), this.layerWiseConfigurations.getInferenceWorkspaceMode()});
        if (this.layerWiseConfigurations.getCacheMode() == CacheMode.HOST) {
            ComputationGraph.workspaceConfigurationCache.setPolicyMirroring(MirroringPolicy.HOST_ONLY);
        }
        if ((nLayers = this.getnLayers()) < 1) {
            throw new IllegalStateException("Unable to create network: number of layers is less than 1");
        }
        if (this.layers == null || this.layers[0] == null) {
            boolean initializeParams;
            if (this.layers == null) {
                this.layers = new Layer[nLayers];
            }
            int paramLength = 0;
            int[] nParamsPerLayer = new int[nLayers];
            for (int i = 0; i < nLayers; ++i) {
                NeuralNetConfiguration conf = this.layerWiseConfigurations.getConf(i);
                nParamsPerLayer[i] = conf.getLayer().initializer().numParams(conf);
                paramLength += nParamsPerLayer[i];
            }
            if (parameters != null) {
                if (!parameters.isRowVector()) {
                    throw new IllegalArgumentException("Invalid parameters: should be a row vector");
                }
                if (parameters.length() != paramLength) {
                    throw new IllegalArgumentException("Invalid parameters: expected length " + paramLength + ", got length " + parameters.length());
                }
                this.flattenedParams = cloneParametersArray ? parameters.dup() : parameters;
                initializeParams = false;
            } else {
                this.flattenedParams = Nd4j.create((int)1, (int)paramLength);
                initializeParams = true;
            }
            if (initializeParams) {
                Nd4j.getRandom().setSeed(this.getDefaultConfiguration().getSeed());
            }
            int paramCountSoFar = 0;
            for (int i = 0; i < nLayers; ++i) {
                INDArray paramsView = nParamsPerLayer[i] > 0 ? this.flattenedParams.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)paramCountSoFar, (int)(paramCountSoFar + nParamsPerLayer[i]))}) : null;
                paramCountSoFar += nParamsPerLayer[i];
                NeuralNetConfiguration conf = this.layerWiseConfigurations.getConf(i);
                this.layers[i] = conf.getLayer().instantiate(conf, this.listeners, i, paramsView, initializeParams);
                this.layerMap.put(conf.getLayer().getLayerName(), this.layers[i]);
            }
            this.initCalled = true;
        }
        this.defaultConfiguration.clearVariables();
        List<String> variables = this.defaultConfiguration.variables(false);
        for (int i = 0; i < this.layers.length; ++i) {
            for (String s : this.layers[i].conf().variables()) {
                variables.add(i + "_" + s);
            }
        }
        if (this.solver == null) {
            try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                this.solver = new Solver.Builder().configure(this.conf()).listeners(this.getListeners()).model(this).build();
                this.solver.initOptimizer();
            }
        }
    }

    public void setGradientsAccumulator(GradientsAccumulator accumulator) {
        if (!this.isInitCalled()) {
            this.init();
        }
        this.solver.getOptimizer().setGradientsAccumulator(accumulator);
    }

    public boolean isInitCalled() {
        return this.initCalled;
    }

    public void initGradientsView() {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            if (this.layers == null) {
                this.init();
            }
            int nLayers = this.layers.length;
            int paramLength = 0;
            int[] nParamsPerLayer = new int[nLayers];
            for (int i = 0; i < nLayers; ++i) {
                NeuralNetConfiguration conf = this.layerWiseConfigurations.getConf(i);
                nParamsPerLayer[i] = conf.getLayer().initializer().numParams(conf);
                paramLength += nParamsPerLayer[i];
            }
            this.flattenedGradients = Nd4j.zeros((int[])new int[]{1, paramLength}, (char)'f');
            int backpropParamsSoFar = 0;
            for (int i = 0; i < this.layers.length; ++i) {
                if (nParamsPerLayer[i] == 0) continue;
                INDArray thisLayerGradView = this.flattenedGradients.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)backpropParamsSoFar, (int)(backpropParamsSoFar + nParamsPerLayer[i]))});
                this.layers[i].setBackpropGradientsViewArray(thisLayerGradView);
                backpropParamsSoFar += nParamsPerLayer[i];
            }
        }
    }

    @Override
    public INDArray activate() {
        return this.getLayers()[this.getLayers().length - 1].activate();
    }

    public INDArray activate(int layer) {
        return this.getLayer(layer).activate();
    }

    @Override
    public INDArray activate(INDArray input) {
        throw new UnsupportedOperationException();
    }

    public INDArray activate(int layer, INDArray input) {
        return this.getLayer(layer).activate(input);
    }

    @Override
    public INDArray activationMean() {
        throw new UnsupportedOperationException();
    }

    public void initialize(org.nd4j.linalg.dataset.DataSet data) {
        this.setInput(data.getFeatureMatrix());
        this.feedForward(this.getInput());
        this.labels = data.getLabels();
        if (this.getOutputLayer() instanceof IOutputLayer) {
            IOutputLayer ol = (IOutputLayer)this.getOutputLayer();
            ol.setLabels(this.labels);
        }
    }

    public INDArray zFromPrevLayer(int curr, INDArray input, boolean training) {
        if (this.getLayerWiseConfigurations().getInputPreProcess(curr) != null) {
            input = this.getLayerWiseConfigurations().getInputPreProcess(curr).preProcess(input, input.size(0));
        }
        INDArray ret = this.layers[curr].preOutput(input, training);
        return ret;
    }

    public INDArray activationFromPrevLayer(int curr, INDArray input, boolean training) {
        if (this.getLayerWiseConfigurations().getInputPreProcess(curr) != null) {
            input = this.getLayerWiseConfigurations().getInputPreProcess(curr).preProcess(input, this.getInputMiniBatchSize());
        }
        INDArray ret = this.layers[curr].activate(input, training);
        return ret;
    }

    public INDArray activateSelectedLayers(int from, int to, INDArray input) {
        if (input == null) {
            throw new IllegalStateException("Unable to perform activation; no input found");
        }
        if (from < 0 || from >= this.layers.length || from >= to) {
            throw new IllegalStateException("Unable to perform activation; FROM is out of layer space");
        }
        if (to < 1 || to >= this.layers.length) {
            throw new IllegalStateException("Unable to perform activation; TO is out of layer space");
        }
        INDArray res = input;
        for (int l = from; l <= to; ++l) {
            res = this.activationFromPrevLayer(l, res, false);
        }
        return res;
    }

    public List<INDArray> computeZ(boolean training) {
        INDArray currentInput = this.input;
        ArrayList<INDArray> activations = new ArrayList<INDArray>();
        activations.add(currentInput);
        for (int i = 0; i < this.layers.length; ++i) {
            INDArray currentZ = this.zFromPrevLayer(i, currentInput, training);
            currentInput = this.activationFromPrevLayer(i, currentInput, training);
            activations.add(currentZ);
        }
        return activations;
    }

    public List<INDArray> computeZ(INDArray input, boolean training) {
        if (input == null) {
            throw new IllegalStateException("Unable to perform feed forward; no input found");
        }
        if (this.getLayerWiseConfigurations().getInputPreProcess(0) != null) {
            this.setInput(this.getLayerWiseConfigurations().getInputPreProcess(0).preProcess(input, this.getInputMiniBatchSize()));
        } else {
            this.setInput(input);
        }
        return this.computeZ(training);
    }

    public List<INDArray> feedForward(INDArray input, boolean train) {
        this.setInput(input);
        return this.feedForward(train);
    }

    public List<INDArray> feedForward(boolean train) {
        return this.feedForwardToLayer(this.layers.length - 1, train);
    }

    public List<INDArray> feedForwardToLayer(int layerNum, INDArray input) {
        return this.feedForwardToLayer(layerNum, input, false);
    }

    public List<INDArray> feedForwardToLayer(int layerNum, INDArray input, boolean train) {
        this.setInput(input);
        return this.feedForwardToLayer(layerNum, train);
    }

    public List<INDArray> feedForwardToLayer(int layerNum, boolean train) {
        INDArray currInput = this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE || !this.input.isAttached() ? this.input : this.input.migrate();
        ArrayList<INDArray> activations = new ArrayList<INDArray>();
        activations.add(currInput);
        DummyWorkspace workspace = this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : (this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.SINGLE ? Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceExternal) : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(this.workspaceConfigurationFeedForward, workspaceFeedForward));
        for (int i = 0; i <= layerNum; ++i) {
            try (MemoryWorkspace ws = workspace.notifyScopeEntered();){
                currInput = this.activationFromPrevLayer(i, currInput, train).leverageTo(workspaceExternal);
                activations.add(currInput);
                continue;
            }
        }
        if (!train && this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.SEPARATE) {
            Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceFeedForward).initializeWorkspace();
        }
        return activations;
    }

    public List<INDArray> feedForward() {
        return this.feedForward(false);
    }

    public List<INDArray> feedForward(INDArray input) {
        if (input == null) {
            throw new IllegalStateException("Unable to perform feed forward; no input found");
        }
        if (this.getLayerWiseConfigurations().getInputPreProcess(0) != null) {
            this.setInput(this.getLayerWiseConfigurations().getInputPreProcess(0).preProcess(input, input.size(0)));
        } else {
            this.setInput(input);
        }
        return this.feedForward();
    }

    public List<INDArray> feedForward(INDArray input, INDArray featuresMask, INDArray labelsMask) {
        this.setLayerMaskArrays(featuresMask, labelsMask);
        List<INDArray> list = this.feedForward(input);
        this.clearLayerMaskArrays();
        return list;
    }

    @Override
    public Gradient gradient() {
        return this.gradient;
    }

    public INDArray epsilon() {
        return this.epsilon;
    }

    @Override
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair<Gradient, Double>(this.gradient(), this.score());
    }

    @Override
    public MultiLayerNetwork clone() {
        Updater u;
        INDArray updaterState;
        MultiLayerConfiguration conf = this.layerWiseConfigurations.clone();
        MultiLayerNetwork ret = new MultiLayerNetwork(conf);
        ret.init(this.params().dup(), false);
        if (this.solver != null && (updaterState = (u = this.getUpdater()).getStateViewArray()) != null) {
            ret.getUpdater().setStateViewArray(ret, updaterState.dup(), false);
        }
        if (this.hasAFrozenLayer()) {
            Layer[] clonedLayers = ret.getLayers();
            for (int i = 0; i < this.layers.length; ++i) {
                if (!(this.layers[i] instanceof FrozenLayer)) continue;
                clonedLayers[i] = new FrozenLayer(ret.getLayer(i));
            }
            ret.setLayers(clonedLayers);
        }
        return ret;
    }

    private boolean hasAFrozenLayer() {
        for (int i = 0; i < this.layers.length - 1; ++i) {
            if (!(this.layers[i] instanceof FrozenLayer)) continue;
            return true;
        }
        return false;
    }

    public INDArray params(boolean backwardOnly) {
        if (backwardOnly) {
            return this.params();
        }
        ArrayList<INDArray> params = new ArrayList<INDArray>();
        for (Layer layer : this.getLayers()) {
            INDArray layerParams = layer.params();
            if (layerParams == null) continue;
            params.add(layerParams);
        }
        return Nd4j.toFlattened((char)'f', params);
    }

    @Override
    public INDArray params() {
        return this.flattenedParams;
    }

    @Override
    public void setParams(INDArray params) {
        if (this.flattenedParams == params) {
            return;
        }
        if (this.flattenedParams != null && params.length() == this.flattenedParams.length()) {
            if (params != this.flattenedParams) {
                this.flattenedParams.assign(params);
            }
        } else {
            if (this.flattenedParams == null) {
                this.flattenedParams = params.dup();
            }
            int idx = 0;
            for (int i = 0; i < this.getLayers().length; ++i) {
                Layer layer = this.getLayer(i);
                int range = layer.numParams();
                if (range <= 0) continue;
                INDArray get = params.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)idx, (int)(range + idx))});
                layer.setParams(get);
                idx += range;
            }
        }
    }

    @Override
    public void setParamsViewArray(INDArray params) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public INDArray getGradientsViewArray() {
        return this.flattenedGradients;
    }

    @Override
    public void setBackpropGradientsViewArray(INDArray gradients) {
        int paramsSoFar = 0;
        for (Layer layer : this.layers) {
            if (layer.numParams() == 0) continue;
            layer.setBackpropGradientsViewArray(gradients.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)paramsSoFar, (int)(paramsSoFar + layer.numParams()))}));
            paramsSoFar += layer.numParams();
        }
    }

    @Override
    public int numParams() {
        if (this.isInitCalled()) {
            return this.numParams(false);
        }
        log.info("Model is not initialized. Initialize net with init()");
        return 0;
    }

    @Override
    public int numParams(boolean backwards) {
        int length = 0;
        for (int i = 0; i < this.layers.length; ++i) {
            length += this.layers[i].numParams(backwards);
        }
        return length;
    }

    @Override
    public double f1Score(DataSet data) {
        return this.f1Score(data.getFeatures(), data.getLabels());
    }

    @Override
    public void fit(DataSetIterator iterator) {
        DummyWorkspace cache;
        DataSetIterator iter;
        boolean destructable = false;
        if (iterator.asyncSupported()) {
            iter = new AsyncDataSetIterator(iterator, Math.min(Nd4j.getAffinityManager().getNumberOfDevices() * 2, 2), this.layerWiseConfigurations.getTrainingWorkspaceMode() != WorkspaceMode.NONE);
            destructable = true;
        } else {
            iter = iterator;
        }
        for (TrainingListener tl : this.trainingListeners) {
            tl.onEpochStart(this);
        }
        if (this.layerWiseConfigurations.isPretrain()) {
            this.pretrain(iter);
            if (iter.resetSupported()) {
                iter.reset();
            }
        }
        DummyWorkspace workspace = this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationExternal, workspaceExternal);
        Object object = cache = this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceConfigurationCache, "LOOP_CACHE");
        if (this.layerWiseConfigurations.isBackprop()) {
            this.update(TaskUtils.buildTask((DataSetIterator)iter));
            if (!iter.hasNext() && iter.resetSupported()) {
                iter.reset();
            }
            long time1 = System.currentTimeMillis();
            while (iter.hasNext()) {
                org.nd4j.linalg.dataset.DataSet next = (org.nd4j.linalg.dataset.DataSet)iter.next();
                long time2 = System.currentTimeMillis();
                this.lastEtlTime.set(time2 - time1);
                if (next.getFeatureMatrix() == null || next.getLabels() == null) break;
                boolean hasMaskArrays = next.hasMaskArrays();
                if (this.layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) {
                    this.doTruncatedBPTT(next.getFeatureMatrix(), next.getLabels(), next.getFeaturesMaskArray(), next.getLabelsMaskArray());
                } else {
                    Throwable throwable;
                    if (hasMaskArrays) {
                        this.setLayerMaskArrays(next.getFeaturesMaskArray(), next.getLabelsMaskArray());
                    }
                    this.setInput(next.getFeatureMatrix());
                    this.setLabels(next.getLabels());
                    if (this.solver == null) {
                        throwable = null;
                        try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                            this.solver = new Solver.Builder().configure(this.conf()).listeners(this.getListeners()).model(this).build();
                        }
                        catch (Throwable throwable2) {
                            throwable = throwable2;
                            throw throwable2;
                        }
                    }
                    throwable = null;
                    try (MemoryWorkspace wsCache = cache.notifyScopeEntered();
                         MemoryWorkspace ws = workspace.notifyScopeEntered();){
                        this.solver.optimize();
                    }
                    catch (Throwable throwable3) {
                        throwable = throwable3;
                        throw throwable3;
                    }
                }
                if (hasMaskArrays) {
                    this.clearLayerMaskArrays();
                }
                time1 = System.currentTimeMillis();
            }
        } else if (this.layerWiseConfigurations.isPretrain()) {
            log.warn("Warning: finetune is not applied.");
        }
        if (this.trainingListeners.size() > 0) {
            for (TrainingListener tl : this.trainingListeners) {
                tl.onEpochEnd(this);
            }
        }
        this.clearLayersStates();
        if (destructable) {
            ((AsyncDataSetIterator)iter).shutdown();
        }
    }

    protected void backprop() {
        Pair<Gradient, INDArray> pair = this.calcBackpropGradients(null, true);
        this.gradient = pair == null ? null : pair.getFirst();
        this.epsilon = pair == null ? null : pair.getSecond();
    }

    protected Pair<Gradient, INDArray> calcBackpropGradients(INDArray epsilon, boolean withOutputLayer) {
        int layerFrom;
        String multiGradientKey;
        Pair<Object, INDArray> currPair;
        if (this.flattenedGradients == null) {
            this.initGradientsView();
        }
        DefaultGradient gradient = new DefaultGradient(this.flattenedGradients);
        int numLayers = this.getnLayers();
        LinkedList<Triple> gradientList = new LinkedList<Triple>();
        if (withOutputLayer) {
            if (!(this.getOutputLayer() instanceof IOutputLayer)) {
                log.warn("Warning: final layer isn't output layer. You cannot use backprop without an output layer.");
                return null;
            }
            IOutputLayer outputLayer = (IOutputLayer)this.getOutputLayer();
            if (this.labels == null) {
                throw new IllegalStateException("No labels found");
            }
            outputLayer.setLabels(this.labels);
            currPair = outputLayer.backpropGradient(null);
            for (Map.Entry<String, INDArray> entry : ((Gradient)currPair.getFirst()).gradientForVariable().entrySet()) {
                String origName = entry.getKey();
                multiGradientKey = String.valueOf(numLayers - 1) + "_" + origName;
                gradientList.addLast(new Triple<String, INDArray, Character>(multiGradientKey, entry.getValue(), ((Gradient)currPair.getFirst()).flatteningOrderForVariable(origName)));
            }
            if (this.getLayerWiseConfigurations().getInputPreProcess(numLayers - 1) != null) {
                currPair = new Pair<Object, INDArray>(currPair.getFirst(), this.layerWiseConfigurations.getInputPreProcess(numLayers - 1).backprop(currPair.getSecond(), this.getInputMiniBatchSize()));
            }
            layerFrom = numLayers - 2;
        } else {
            currPair = new Pair<Object, INDArray>(null, epsilon);
            layerFrom = numLayers - 1;
        }
        DummyWorkspace workspace = this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : (this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.SINGLE ? Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceExternal) : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(this.workspaceConfigurationFeedForward, workspaceFeedForward));
        for (int j = layerFrom; j >= 0; --j) {
            try (MemoryWorkspace ws = workspace.notifyScopeEntered();){
                Layer currLayer = this.getLayer(j);
                if (currLayer instanceof FrozenLayer) break;
                if ((currPair = currLayer.backpropGradient(currPair.getSecond())).getSecond() != null) {
                    currPair.setSecond(currPair.getSecond().leverageTo(workspaceExternal));
                }
                LinkedList<Triple<String, INDArray, Character>> tempList = new LinkedList<Triple<String, INDArray, Character>>();
                for (Map.Entry<String, INDArray> entry : ((Gradient)currPair.getFirst()).gradientForVariable().entrySet()) {
                    String origName = entry.getKey();
                    multiGradientKey = String.valueOf(j) + "_" + origName;
                    tempList.addFirst(new Triple<String, INDArray, Character>(multiGradientKey, entry.getValue(), ((Gradient)currPair.getFirst()).flatteningOrderForVariable(origName)));
                }
                for (Triple triple : tempList) {
                    gradientList.addFirst(triple);
                }
                if (this.getLayerWiseConfigurations().getInputPreProcess(j) == null) continue;
                currPair = new Pair<Object, INDArray>(currPair.getFirst(), this.getLayerWiseConfigurations().getInputPreProcess(j).backprop(currPair.getSecond(), this.getInputMiniBatchSize()));
                continue;
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        if (this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.SEPARATE) {
            Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceFeedForward).initializeWorkspace();
        }
        for (Triple triple : gradientList) {
            gradient.setGradientFor((String)triple.getFirst(), (INDArray)triple.getSecond(), (Character)triple.getThird());
        }
        return new Pair<Gradient, INDArray>(gradient, currPair.getSecond());
    }

    protected void doTruncatedBPTT(INDArray input, INDArray labels, INDArray featuresMaskArray, INDArray labelsMaskArray) {
        if (input.rank() != 3 || labels.rank() != 3) {
            log.warn("Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got " + Arrays.toString(input.shape()) + "\tand labels with shape " + Arrays.toString(labels.shape()));
            return;
        }
        if (input.size(2) != labels.size(2)) {
            log.warn("Input and label time series have different lengths: {} input length, {} label length", (Object)input.size(2), (Object)labels.size(2));
            return;
        }
        int fwdLen = this.layerWiseConfigurations.getTbpttFwdLength();
        this.update(TaskUtils.buildTask((INDArray)input, (INDArray)labels));
        int timeSeriesLength = input.size(2);
        int nSubsets = timeSeriesLength / fwdLen;
        if (timeSeriesLength % fwdLen != 0) {
            ++nSubsets;
        }
        this.rnnClearPreviousState();
        workspaceConfigurationExternal.setCyclesBeforeInitialization(0);
        workspaceConfigurationExternal.setPolicyLearning(LearningPolicy.OVER_TIME);
        DummyWorkspace workspaceT = this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationTBPTT, workspaceTBPTT);
        DummyWorkspace workspace = this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationExternal, workspaceExternal);
        try (MemoryWorkspace wsT = workspaceT.notifyScopeEntered();){
            for (int i = 0; i < nSubsets; ++i) {
                try (MemoryWorkspace wsE = workspace.notifyScopeEntered();){
                    int startTimeIdx = i * fwdLen;
                    int endTimeIdx = startTimeIdx + fwdLen;
                    if (endTimeIdx > timeSeriesLength) {
                        endTimeIdx = timeSeriesLength;
                    }
                    INDArray inputSubset = input.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval((int)startTimeIdx, (int)endTimeIdx)});
                    INDArray labelSubset = labels.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval((int)startTimeIdx, (int)endTimeIdx)});
                    this.setInput(inputSubset);
                    this.setLabels(labelSubset);
                    INDArray featuresMaskSubset = null;
                    INDArray labelsMaskSubset = null;
                    if (featuresMaskArray != null) {
                        featuresMaskSubset = featuresMaskArray.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)startTimeIdx, (int)endTimeIdx)});
                    }
                    if (labelsMaskArray != null) {
                        labelsMaskSubset = labelsMaskArray.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)startTimeIdx, (int)endTimeIdx)});
                    }
                    if (featuresMaskSubset != null || labelsMaskSubset != null) {
                        this.setLayerMaskArrays(featuresMaskSubset, labelsMaskSubset);
                    }
                    if (this.solver == null) {
                        try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                            this.solver = new Solver.Builder().configure(this.conf()).listeners(this.getListeners()).model(this).build();
                        }
                    }
                    this.solver.optimize();
                    this.updateRnnStateWithTBPTTState();
                    continue;
                }
            }
        }
        if (this.layerWiseConfigurations.getTrainingWorkspaceMode() != WorkspaceMode.NONE) {
            workspace.initializeWorkspace();
            workspaceT.initializeWorkspace();
        }
        this.rnnClearPreviousState();
        if (featuresMaskArray != null || labelsMaskArray != null) {
            this.clearLayerMaskArrays();
        }
    }

    public void updateRnnStateWithTBPTTState() {
        for (int i = 0; i < this.layers.length; ++i) {
            if (this.layers[i] instanceof RecurrentLayer) {
                RecurrentLayer l = (RecurrentLayer)this.layers[i];
                l.rnnSetPreviousState(l.rnnGetTBPTTState());
                continue;
            }
            if (!(this.layers[i] instanceof MultiLayerNetwork)) continue;
            ((MultiLayerNetwork)this.layers[i]).updateRnnStateWithTBPTTState();
        }
    }

    protected void truncatedBPTTGradient() {
        String multiGradientKey;
        if (this.flattenedGradients == null) {
            this.initGradientsView();
        }
        this.gradient = new DefaultGradient(this.flattenedGradients);
        if (!(this.getOutputLayer() instanceof IOutputLayer)) {
            log.warn("Warning: final layer isn't output layer. You cannot use backprop (truncated BPTT) without an output layer.");
            return;
        }
        IOutputLayer outputLayer = (IOutputLayer)this.getOutputLayer();
        if (this.labels == null) {
            throw new IllegalStateException("No labels found");
        }
        if (outputLayer instanceof BaseLayer && ((BaseLayer)outputLayer.conf().getLayer()).getWeightInit() == WeightInit.ZERO) {
            throw new IllegalStateException("Output layer weights cannot be initialized to zero when using backprop.");
        }
        outputLayer.setLabels(this.labels);
        int numLayers = this.getnLayers();
        LinkedList<Pair> gradientList = new LinkedList<Pair>();
        Pair<Gradient, INDArray> currPair = outputLayer.backpropGradient(null);
        for (Map.Entry<String, INDArray> entry : currPair.getFirst().gradientForVariable().entrySet()) {
            multiGradientKey = String.valueOf(numLayers - 1) + "_" + entry.getKey();
            gradientList.addLast(new Pair<String, INDArray>(multiGradientKey, entry.getValue()));
        }
        if (this.getLayerWiseConfigurations().getInputPreProcess(numLayers - 1) != null) {
            currPair = new Pair<Gradient, INDArray>(currPair.getFirst(), this.layerWiseConfigurations.getInputPreProcess(numLayers - 1).backprop(currPair.getSecond(), this.getInputMiniBatchSize()));
        }
        for (int j = numLayers - 2; j >= 0; --j) {
            Layer currLayer = this.getLayer(j);
            currPair = currLayer instanceof RecurrentLayer ? ((RecurrentLayer)currLayer).tbpttBackpropGradient(currPair.getSecond(), this.layerWiseConfigurations.getTbpttBackLength()) : currLayer.backpropGradient(currPair.getSecond());
            LinkedList<Pair<String, INDArray>> tempList = new LinkedList<Pair<String, INDArray>>();
            for (Map.Entry<String, INDArray> entry : currPair.getFirst().gradientForVariable().entrySet()) {
                multiGradientKey = String.valueOf(j) + "_" + entry.getKey();
                tempList.addFirst(new Pair<String, INDArray>(multiGradientKey, entry.getValue()));
            }
            for (Pair pair : tempList) {
                gradientList.addFirst(pair);
            }
            if (this.getLayerWiseConfigurations().getInputPreProcess(j) == null) continue;
            currPair = new Pair<Gradient, INDArray>(currPair.getFirst(), this.getLayerWiseConfigurations().getInputPreProcess(j).backprop(currPair.getSecond(), this.getInputMiniBatchSize()));
        }
        for (Pair pair : gradientList) {
            this.gradient.setGradientFor((String)pair.getFirst(), (INDArray)pair.getSecond());
        }
    }

    @Override
    public Collection<IterationListener> getListeners() {
        return this.listeners;
    }

    @Override
    public void setListeners(Collection<IterationListener> listeners) {
        this.listeners = listeners;
        if (this.layers == null) {
            this.init();
        }
        for (Layer layer : this.layers) {
            layer.setListeners(listeners);
        }
        if (this.solver != null) {
            this.solver.setListeners(listeners);
        }
        this.trainingListeners.clear();
        if (listeners != null) {
            for (IterationListener il : listeners) {
                if (!(il instanceof TrainingListener)) continue;
                this.trainingListeners.add((TrainingListener)il);
            }
        }
    }

    @Override
    public void addListeners(IterationListener ... listeners) {
        if (this.listeners == null) {
            this.setListeners(listeners);
            return;
        }
        for (IterationListener listener : listeners) {
            this.listeners.add(listener);
            if (!(listener instanceof TrainingListener)) continue;
            this.trainingListeners.add((TrainingListener)listener);
        }
        if (this.solver != null) {
            this.solver.setListeners(this.listeners);
        }
    }

    @Override
    public void setListeners(IterationListener ... listeners) {
        ArrayList<IterationListener> cListeners = new ArrayList<IterationListener>();
        if (listeners != null && listeners.length > 0) {
            for (IterationListener i : listeners) {
                if (i == null) continue;
                cListeners.add(i);
            }
        }
        this.setListeners(cListeners);
    }

    public void finetune() {
        if (!this.layerWiseConfigurations.isBackprop()) {
            log.warn("Warning: finetune is not applied.");
            return;
        }
        if (!(this.getOutputLayer() instanceof IOutputLayer)) {
            log.warn("Output layer not instance of output layer returning.");
            return;
        }
        if (this.flattenedGradients == null) {
            this.initGradientsView();
        }
        if (this.labels == null) {
            throw new IllegalStateException("No labels found");
        }
        log.info("Finetune phase");
        IOutputLayer output = (IOutputLayer)this.getOutputLayer();
        if (output.conf().getOptimizationAlgo() == OptimizationAlgorithm.HESSIAN_FREE) {
            throw new UnsupportedOperationException();
        }
        this.feedForward();
        output.fit(output.input(), this.labels);
    }

    @Override
    public int[] predict(INDArray d) {
        INDArray output = this.output(d, Layer.TrainingMode.TEST);
        int[] ret = new int[d.size(0)];
        if (d.isRowVector()) {
            ret[0] = Nd4j.getBlasWrapper().iamax(output);
        } else {
            for (int i = 0; i < ret.length; ++i) {
                ret[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
            }
        }
        return ret;
    }

    @Override
    public List<String> predict(DataSet dataSet) {
        int[] intRet = this.predict(dataSet.getFeatures());
        ArrayList<String> ret = new ArrayList<String>();
        for (int i = 0; i < intRet.length; ++i) {
            ret.add(i, dataSet.getLabelName(intRet[i]));
        }
        return ret;
    }

    @Override
    public INDArray labelProbabilities(INDArray examples) {
        List<INDArray> feed = this.feedForward(examples);
        IOutputLayer o = (IOutputLayer)this.getOutputLayer();
        return o.labelProbabilities(feed.get(feed.size() - 1));
    }

    @Override
    public void fit(INDArray data, INDArray labels) {
        this.fit(data, labels, null, null);
    }

    public void fit(INDArray features, INDArray labels, INDArray featuresMask, INDArray labelsMask) {
        Throwable throwable;
        MemoryWorkspace ws2;
        Throwable throwable2;
        MemoryWorkspace wsCache;
        DummyWorkspace cache;
        this.setInput(features);
        this.setLabels(labels);
        if (featuresMask != null || labelsMask != null) {
            this.setLayerMaskArrays(featuresMask, labelsMask);
        }
        this.update(TaskUtils.buildTask((INDArray)features, (INDArray)labels));
        DummyWorkspace workspace = this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationExternal, workspaceExternal);
        Object object = cache = this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceConfigurationCache, "LOOP_CACHE");
        if (this.layerWiseConfigurations.isPretrain()) {
            wsCache = cache.notifyScopeEntered();
            throwable2 = null;
            try {
                ws2 = workspace.notifyScopeEntered();
                throwable = null;
                try {
                    this.pretrain(features);
                }
                catch (Throwable throwable3) {
                    throwable = throwable3;
                    throw throwable3;
                }
                finally {
                    if (ws2 != null) {
                        if (throwable != null) {
                            try {
                                ws2.close();
                            }
                            catch (Throwable throwable4) {
                                throwable.addSuppressed(throwable4);
                            }
                        } else {
                            ws2.close();
                        }
                    }
                }
            }
            catch (Throwable ws2) {
                throwable2 = ws2;
                throw ws2;
            }
            finally {
                if (wsCache != null) {
                    if (throwable2 != null) {
                        try {
                            wsCache.close();
                        }
                        catch (Throwable ws2) {
                            throwable2.addSuppressed(ws2);
                        }
                    } else {
                        wsCache.close();
                    }
                }
            }
        }
        if (this.layerWiseConfigurations.isBackprop()) {
            if (this.layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) {
                this.doTruncatedBPTT(features, labels, featuresMask, labelsMask);
            } else {
                if (this.solver == null) {
                    throwable2 = null;
                    try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                        this.solver = new Solver.Builder().configure(this.conf()).listeners(this.getListeners()).model(this).build();
                    }
                    catch (Throwable ws2) {
                        throwable2 = ws2;
                        throw ws2;
                    }
                }
                wsCache = cache.notifyScopeEntered();
                throwable2 = null;
                try {
                    ws2 = workspace.notifyScopeEntered();
                    throwable = null;
                    try {
                        this.solver.optimize();
                    }
                    catch (Throwable throwable5) {
                        throwable = throwable5;
                        throw throwable5;
                    }
                    finally {
                        if (ws2 != null) {
                            if (throwable != null) {
                                try {
                                    ws2.close();
                                }
                                catch (Throwable throwable6) {
                                    throwable.addSuppressed(throwable6);
                                }
                            } else {
                                ws2.close();
                            }
                        }
                    }
                }
                catch (Throwable throwable7) {
                    throwable2 = throwable7;
                    throw throwable7;
                }
                finally {
                    if (wsCache != null) {
                        if (throwable2 != null) {
                            try {
                                wsCache.close();
                            }
                            catch (Throwable throwable8) {
                                throwable2.addSuppressed(throwable8);
                            }
                        } else {
                            wsCache.close();
                        }
                    }
                }
            }
        }
        if (featuresMask != null || labelsMask != null) {
            this.clearLayerMaskArrays();
        }
        this.clearLayersStates();
    }

    @Override
    public void fit(INDArray data) {
        this.setInput(data);
        if (!this.layerWiseConfigurations.isPretrain()) {
            throw new IllegalStateException("Set pretrain to true in the configuration in order to pretrain the model.");
        }
        this.update(TaskUtils.buildTask((INDArray)data));
        this.pretrain(data);
    }

    @Override
    public void iterate(INDArray input) {
        this.pretrain(input);
    }

    @Override
    public void fit(DataSet data) {
        if (this.layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) {
            this.doTruncatedBPTT(data.getFeatures(), data.getLabels(), data.getFeaturesMaskArray(), data.getLabelsMaskArray());
        } else {
            boolean hasMaskArrays = data.hasMaskArrays();
            if (hasMaskArrays) {
                this.setLayerMaskArrays(data.getFeaturesMaskArray(), data.getLabelsMaskArray());
            }
            this.fit(data.getFeatures(), data.getLabels());
            if (hasMaskArrays) {
                this.clearLayerMaskArrays();
            }
        }
        this.clearLayersStates();
    }

    @Override
    public void fit(INDArray examples, int[] labels) {
        OutputLayer layerConf = (OutputLayer)this.getOutputLayer().conf().getLayer();
        this.fit(examples, FeatureUtil.toOutcomeMatrix((int[])labels, (int)layerConf.getNOut()));
    }

    public INDArray output(INDArray input, Layer.TrainingMode train) {
        return this.output(input, train == Layer.TrainingMode.TRAIN);
    }

    public INDArray output(INDArray input, boolean train) {
        WorkspaceMode cMode = this.layerWiseConfigurations.getTrainingWorkspaceMode();
        this.layerWiseConfigurations.setTrainingWorkspaceMode(this.layerWiseConfigurations.getInferenceWorkspaceMode());
        DummyWorkspace workspace = this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationExternal, workspaceExternal);
        try (MemoryWorkspace wsE = workspace.notifyScopeEntered();){
            INDArray ret = this.silentOutput(input, train).detach();
            this.layerWiseConfigurations.setTrainingWorkspaceMode(cMode);
            INDArray iNDArray = ret;
            return iNDArray;
        }
    }

    protected INDArray silentOutput(INDArray input, boolean train) {
        List<INDArray> activations = this.feedForward(input, train);
        return activations.get(activations.size() - 1);
    }

    public INDArray output(INDArray input, boolean train, INDArray featuresMask, INDArray labelsMask) {
        WorkspaceMode cMode = this.layerWiseConfigurations.getTrainingWorkspaceMode();
        this.layerWiseConfigurations.setTrainingWorkspaceMode(this.layerWiseConfigurations.getInferenceWorkspaceMode());
        DummyWorkspace workspace = this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationExternal, workspaceExternal);
        try (MemoryWorkspace wsE = workspace.notifyScopeEntered();){
            INDArray ret = this.silentOutput(input, train, featuresMask, labelsMask).detach();
            this.layerWiseConfigurations.setTrainingWorkspaceMode(cMode);
            INDArray iNDArray = ret;
            return iNDArray;
        }
    }

    protected INDArray silentOutput(INDArray input, boolean train, INDArray featuresMask, INDArray labelsMask) {
        this.setLayerMaskArrays(featuresMask, labelsMask);
        INDArray out = this.silentOutput(input, train);
        this.clearLayerMaskArrays();
        return out;
    }

    public INDArray output(INDArray input) {
        return this.output(input, Layer.TrainingMode.TEST);
    }

    public INDArray output(DataSetIterator iterator, boolean train) {
        org.nd4j.linalg.dataset.DataSet next;
        ArrayList<INDArray> outList = new ArrayList<INDArray>();
        while (iterator.hasNext() && (next = (org.nd4j.linalg.dataset.DataSet)iterator.next()).getFeatureMatrix() != null && next.getLabels() != null) {
            INDArray features = next.getFeatures();
            if (next.hasMaskArrays()) {
                INDArray fMask = next.getFeaturesMaskArray();
                INDArray lMask = next.getLabelsMaskArray();
                outList.add(this.output(features, train, fMask, lMask));
                continue;
            }
            outList.add(this.output(features, train));
        }
        return Nd4j.vstack((INDArray[])outList.toArray(new INDArray[0]));
    }

    public INDArray output(DataSetIterator iterator) {
        return this.output(iterator, false);
    }

    public INDArray reconstruct(INDArray x, int layerNum) {
        List<INDArray> forward = this.feedForward(x);
        return forward.get(layerNum - 1);
    }

    public void printConfiguration() {
        StringBuilder sb = new StringBuilder();
        int count = 0;
        for (NeuralNetConfiguration conf : this.getLayerWiseConfigurations().getConfs()) {
            sb.append(" Layer " + count++ + " conf " + conf);
        }
        log.info(sb.toString());
    }

    public void update(MultiLayerNetwork network) {
        NeuralNetConfiguration neuralNetConfiguration = this.defaultConfiguration = network.defaultConfiguration != null ? network.defaultConfiguration.clone() : null;
        if (network.input != null) {
            this.setInput(network.input.dup());
        }
        this.labels = network.labels;
        if (network.layers != null) {
            this.layers = new Layer[network.layers.length];
            for (int i = 0; i < this.layers.length; ++i) {
                this.layers[i] = network.layers[i].clone();
            }
        } else {
            this.layers = null;
        }
        if (network.solver != null) {
            INDArray updaterView = network.getUpdater().getStateViewArray();
            if (updaterView != null) {
                MultiLayerUpdater newUpdater = new MultiLayerUpdater(this);
                newUpdater.setStateViewArray(this, updaterView.dup(), false);
                this.setUpdater(newUpdater);
            }
        } else {
            this.solver = null;
        }
    }

    @Override
    public double f1Score(INDArray input, INDArray labels) {
        this.feedForward(input);
        this.setLabels(labels);
        Evaluation eval = new Evaluation();
        eval.eval(labels, this.labelProbabilities(input));
        return eval.f1();
    }

    @Override
    public int numLabels() {
        return this.labels.columns();
    }

    public double score(org.nd4j.linalg.dataset.DataSet data) {
        return this.score(data, false);
    }

    public double score(org.nd4j.linalg.dataset.DataSet data, boolean training) {
        boolean hasMaskArray;
        block18: {
            hasMaskArray = data.hasMaskArrays();
            if (hasMaskArray) {
                this.setLayerMaskArrays(data.getFeaturesMaskArray(), data.getLabelsMaskArray());
            }
            DummyWorkspace workspace = this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationExternal, workspaceExternal);
            try (MemoryWorkspace ws = workspace.notifyScopeEntered();){
                List<INDArray> activations = this.feedForwardToLayer(this.layers.length - 2, data.getFeatureMatrix(), training);
                int n = activations.size();
                this.setLabels(data.getLabels());
                if (this.getOutputLayer() instanceof IOutputLayer) {
                    IOutputLayer ol = (IOutputLayer)this.getOutputLayer();
                    INDArray olInput = activations.get(n - 1);
                    if (this.getLayerWiseConfigurations().getInputPreProcess(n - 1) != null) {
                        olInput = this.getLayerWiseConfigurations().getInputPreProcess(n - 1).preProcess(olInput, this.input.size(0));
                    }
                    ol.setInput(olInput);
                    ol.setLabels(data.getLabels());
                    ol.computeScore(this.calcL1(true), this.calcL2(true), training);
                    this.score = ol.score();
                    break block18;
                }
                log.warn("Cannot calculate score wrt labels without an OutputLayer");
                double d = 0.0;
                return d;
            }
        }
        if (hasMaskArray) {
            this.clearLayerMaskArrays();
        }
        return this.score();
    }

    public INDArray scoreExamples(DataSetIterator iter, boolean addRegularizationTerms) {
        ArrayList<INDArray> out = new ArrayList<INDArray>();
        while (iter.hasNext()) {
            out.add(this.scoreExamples((org.nd4j.linalg.dataset.DataSet)iter.next(), addRegularizationTerms));
        }
        return Nd4j.toFlattened((char)'f', out);
    }

    public INDArray scoreExamples(org.nd4j.linalg.dataset.DataSet data, boolean addRegularizationTerms) {
        boolean hasMaskArray = data.hasMaskArrays();
        if (hasMaskArray) {
            this.setLayerMaskArrays(data.getFeaturesMaskArray(), data.getLabelsMaskArray());
        }
        this.feedForward(data.getFeatureMatrix(), false);
        this.setLabels(data.getLabels());
        if (!(this.getOutputLayer() instanceof IOutputLayer)) {
            throw new UnsupportedOperationException("Cannot calculate score with respect to labels without an OutputLayer");
        }
        IOutputLayer ol = (IOutputLayer)this.getOutputLayer();
        ol.setLabels(data.getLabels());
        double l1 = addRegularizationTerms ? this.calcL1(true) : 0.0;
        double l2 = addRegularizationTerms ? this.calcL2(true) : 0.0;
        INDArray out = ol.computeScoreForExamples(l1, l2);
        if (hasMaskArray) {
            this.clearLayerMaskArrays();
        }
        return out;
    }

    @Override
    public void fit() {
        this.fit(this.input, this.labels);
    }

    @Override
    public void update(INDArray gradient, String paramType) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public double score() {
        return this.score;
    }

    public void setScore(double score) {
        this.score = score;
    }

    @Override
    public void computeGradientAndScore() {
        List<INDArray> activations;
        if (this.layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) {
            activations = this.rnnActivateUsingStoredState(this.getInput(), true, true);
            if (this.trainingListeners.size() > 0) {
                for (TrainingListener tl : this.trainingListeners) {
                    tl.onForwardPass((Model)this, activations);
                }
            }
            this.truncatedBPTTGradient();
        } else {
            activations = this.feedForwardToLayer(this.layers.length - 2, true);
            if (this.trainingListeners.size() > 0) {
                for (TrainingListener tl : this.trainingListeners) {
                    tl.onForwardPass((Model)this, activations);
                }
            }
            INDArray actSecondLastLayer = activations.get(activations.size() - 1);
            if (this.layerWiseConfigurations.getInputPreProcess(this.layers.length - 1) != null) {
                actSecondLastLayer = this.layerWiseConfigurations.getInputPreProcess(this.layers.length - 1).preProcess(actSecondLastLayer, this.getInputMiniBatchSize());
            }
            this.getOutputLayer().setInput(actSecondLastLayer);
            this.backprop();
        }
        if (!(this.getOutputLayer() instanceof IOutputLayer)) {
            throw new DL4JException("Cannot calculate gradient and score with respect to labels: final layer is not an IOutputLayer");
        }
        this.score = ((IOutputLayer)this.getOutputLayer()).computeScore(this.calcL1(true), this.calcL2(true), true);
        if (this.trainingListeners.size() > 0) {
            MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
            Object object = null;
            try {
                for (TrainingListener tl : this.trainingListeners) {
                    tl.onBackwardPass(this);
                }
            }
            catch (Throwable throwable) {
                object = throwable;
                throw throwable;
            }
            finally {
                if (workspace != null) {
                    if (object != null) {
                        try {
                            workspace.close();
                        }
                        catch (Throwable throwable) {
                            ((Throwable)object).addSuppressed(throwable);
                        }
                    } else {
                        workspace.close();
                    }
                }
            }
        }
    }

    @Override
    public void accumulateScore(double accum) {
    }

    @Override
    public void clear() {
        for (Layer layer : this.layers) {
            layer.clear();
        }
        this.input = null;
        this.labels = null;
        this.solver = null;
    }

    @Override
    @Deprecated
    public void merge(Layer layer, int batchSize) {
        throw new UnsupportedOperationException();
    }

    @Deprecated
    public void merge(MultiLayerNetwork network, int batchSize) {
        if (network.layers.length != this.layers.length) {
            throw new IllegalArgumentException("Unable to merge networks that are not of equal length");
        }
        for (int i = 0; i < this.getnLayers(); ++i) {
            Layer n = this.layers[i];
            Layer otherNetwork = network.layers[i];
            n.merge(otherNetwork, batchSize);
        }
        this.getOutputLayer().merge(network.getOutputLayer(), batchSize);
    }

    @Override
    public void setInput(INDArray input) {
        this.input = input;
        if (this.layers == null) {
            log.info("setInput: {}", (Object)Nd4j.getMemoryManager().getCurrentWorkspace());
            this.initializeLayers(this.getInput());
        }
        if (input != null) {
            if (input.length() == 0) {
                throw new IllegalArgumentException("Invalid input: length 0 (shape: " + Arrays.toString(input.shape()) + ")");
            }
            this.setInputMiniBatchSize(input.size(0));
        }
    }

    public Layer getOutputLayer() {
        return this.getLayers()[this.getLayers().length - 1];
    }

    public void setParameters(INDArray params) {
        this.setParams(params);
    }

    @Override
    public void applyLearningRateScoreDecay() {
        for (Layer layer : this.layers) {
            if (layer.conf().getLearningRateByParam().isEmpty()) continue;
            for (Map.Entry<String, Double> lrPair : layer.conf().getLearningRateByParam().entrySet()) {
                layer.conf().setLearningRateByParam(lrPair.getKey(), lrPair.getValue() * (layer.conf().getLrPolicyDecayRate() + Nd4j.EPS_THRESHOLD));
            }
        }
    }

    public NeuralNetConfiguration getDefaultConfiguration() {
        return this.defaultConfiguration;
    }

    public INDArray getLabels() {
        return this.labels;
    }

    public INDArray getInput() {
        return this.input;
    }

    public void setLabels(INDArray labels) {
        this.labels = labels;
    }

    public int getnLayers() {
        return this.layerWiseConfigurations.getConfs().size();
    }

    public synchronized Layer[] getLayers() {
        return this.layers;
    }

    public Layer getLayer(int i) {
        return this.layers[i];
    }

    public Layer getLayer(String name) {
        return this.layerMap.get(name);
    }

    public List<String> getLayerNames() {
        return new ArrayList<String>(this.layerMap.keySet());
    }

    public void setLayers(Layer[] layers) {
        this.layers = layers;
    }

    public INDArray getMask() {
        return this.mask;
    }

    public void setMask(INDArray mask) {
        this.mask = mask;
    }

    @Override
    public INDArray getMaskArray() {
        return this.mask;
    }

    @Override
    public boolean isPretrainLayer() {
        return false;
    }

    @Override
    public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
        if (maskArray == null) {
            for (int i = 0; i < this.layers.length; ++i) {
                this.layers[i].feedForwardMaskArray(null, null, minibatchSize);
            }
        } else {
            for (int i = 0; i < this.layers.length; ++i) {
                Pair<INDArray, MaskState> p;
                InputPreProcessor preProcessor = this.getLayerWiseConfigurations().getInputPreProcess(i);
                if (preProcessor != null) {
                    p = preProcessor.feedForwardMaskArray(maskArray, currentMaskState, minibatchSize);
                    if (p != null) {
                        maskArray = p.getFirst();
                        currentMaskState = p.getSecond();
                    } else {
                        maskArray = null;
                        currentMaskState = null;
                    }
                }
                if ((p = this.layers[i].feedForwardMaskArray(maskArray, currentMaskState, minibatchSize)) != null) {
                    maskArray = p.getFirst();
                    currentMaskState = p.getSecond();
                    continue;
                }
                maskArray = null;
                currentMaskState = null;
            }
        }
        return new Pair<INDArray, MaskState>(maskArray, currentMaskState);
    }

    @Override
    public Gradient error(INDArray errorSignal) {
        throw new UnsupportedOperationException();
    }

    @Override
    public Layer.Type type() {
        return Layer.Type.MULTILAYER;
    }

    @Override
    public INDArray derivativeActivation(INDArray input) {
        throw new UnsupportedOperationException();
    }

    @Override
    public Gradient calcGradient(Gradient layerError, INDArray activation) {
        throw new UnsupportedOperationException();
    }

    @Override
    public INDArray preOutput(INDArray x) {
        INDArray lastLayerActivation = x;
        for (int i = 0; i < this.layers.length - 1; ++i) {
            if (this.getLayerWiseConfigurations().getInputPreProcess(i) != null) {
                lastLayerActivation = this.getLayerWiseConfigurations().getInputPreProcess(i).preProcess(lastLayerActivation, this.getInputMiniBatchSize());
            }
            lastLayerActivation = this.layers[i].activate(lastLayerActivation);
        }
        if (this.getLayerWiseConfigurations().getInputPreProcess(this.layers.length - 1) != null) {
            lastLayerActivation = this.getLayerWiseConfigurations().getInputPreProcess(this.layers.length - 1).preProcess(lastLayerActivation, this.getInputMiniBatchSize());
        }
        return this.layers[this.layers.length - 1].preOutput(lastLayerActivation);
    }

    @Override
    public INDArray preOutput(INDArray x, Layer.TrainingMode training) {
        return this.preOutput(x, training == Layer.TrainingMode.TRAIN);
    }

    @Override
    public INDArray activate(Layer.TrainingMode training) {
        return this.activate(training == Layer.TrainingMode.TRAIN);
    }

    @Override
    public INDArray activate(INDArray input, Layer.TrainingMode training) {
        return this.activate(input, training == Layer.TrainingMode.TRAIN);
    }

    @Override
    public Layer transpose() {
        throw new UnsupportedOperationException();
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon) {
        if (this.getOutputLayer() instanceof IOutputLayer) {
            throw new UnsupportedOperationException("Cannot calculate gradients based on epsilon with OutputLayer");
        }
        return this.calcBackpropGradients(epsilon, false);
    }

    @Override
    public void setIndex(int index) {
        this.layerIndex = index;
    }

    @Override
    public int getIndex() {
        return this.layerIndex;
    }

    @Override
    public double calcL2(boolean backpropParamsOnly) {
        double l2 = 0.0;
        for (int i = 0; i < this.layers.length; ++i) {
            l2 += this.layers[i].calcL2(backpropParamsOnly);
        }
        return l2;
    }

    @Override
    public double calcL1(boolean backpropParamsOnly) {
        double l1 = 0.0;
        for (int i = 0; i < this.layers.length; ++i) {
            l1 += this.layers[i].calcL1(backpropParamsOnly);
        }
        return l1;
    }

    @Override
    public void update(Gradient gradient) {
        if (gradient.gradient().length() != this.numParams(true)) {
            throw new IllegalArgumentException("Invalid input: expect gradients array of length " + this.numParams(true));
        }
        for (Map.Entry<String, INDArray> entry : gradient.gradientForVariable().entrySet()) {
            String key = entry.getKey();
            INDArray val = entry.getValue();
            int idx = key.indexOf(95);
            if (idx == -1) {
                throw new IllegalStateException("Invalid param key: not have layer separator: \"" + key + "\"");
            }
            Integer layerId = Integer.parseInt(key.substring(0, idx));
            String paramType = key.substring(idx + 1);
            this.gradient.gradientForVariable().put(key, val);
            this.layers[layerId].update(val, paramType);
        }
        this.setBackpropGradientsViewArray(gradient.gradient());
    }

    @Override
    public INDArray preOutput(INDArray x, boolean training) {
        throw new UnsupportedOperationException();
    }

    @Override
    public INDArray activate(boolean training) {
        throw new UnsupportedOperationException();
    }

    @Override
    public INDArray activate(INDArray input, boolean training) {
        throw new UnsupportedOperationException();
    }

    @Override
    public void setInputMiniBatchSize(int size) {
        if (this.layers != null) {
            for (Layer l : this.layers) {
                l.setInputMiniBatchSize(size);
            }
        }
    }

    @Override
    public int getInputMiniBatchSize() {
        return this.input.size(0);
    }

    @Override
    public void setMaskArray(INDArray maskArray) {
        throw new UnsupportedOperationException();
    }

    public INDArray rnnTimeStep(INDArray input) {
        this.setInputMiniBatchSize(input.size(0));
        this.input = input;
        boolean inputIs2d = input.rank() == 2;
        for (int i = 0; i < this.layers.length; ++i) {
            if (this.getLayerWiseConfigurations().getInputPreProcess(i) != null) {
                input = this.getLayerWiseConfigurations().getInputPreProcess(i).preProcess(input, this.getInputMiniBatchSize());
            }
            input = this.layers[i] instanceof RecurrentLayer ? ((RecurrentLayer)this.layers[i]).rnnTimeStep(input) : (this.layers[i] instanceof MultiLayerNetwork ? ((MultiLayerNetwork)this.layers[i]).rnnTimeStep(input) : this.layers[i].activate(input, false));
        }
        if (inputIs2d && input.rank() == 3 && this.layers[this.layers.length - 1].type() == Layer.Type.RECURRENT) {
            return input.tensorAlongDimension(0, new int[]{1, 0});
        }
        this.input = null;
        return input;
    }

    public Map<String, INDArray> rnnGetPreviousState(int layer) {
        if (layer < 0 || layer >= this.layers.length) {
            throw new IllegalArgumentException("Invalid layer number");
        }
        if (!(this.layers[layer] instanceof RecurrentLayer)) {
            throw new IllegalArgumentException("Layer is not an RNN layer");
        }
        return ((RecurrentLayer)this.layers[layer]).rnnGetPreviousState();
    }

    public void rnnSetPreviousState(int layer, Map<String, INDArray> state) {
        if (layer < 0 || layer >= this.layers.length) {
            throw new IllegalArgumentException("Invalid layer number");
        }
        if (!(this.layers[layer] instanceof RecurrentLayer)) {
            throw new IllegalArgumentException("Layer is not an RNN layer");
        }
        RecurrentLayer r = (RecurrentLayer)this.layers[layer];
        r.rnnSetPreviousState(state);
    }

    public void rnnClearPreviousState() {
        if (this.layers == null) {
            return;
        }
        for (int i = 0; i < this.layers.length; ++i) {
            if (this.layers[i] instanceof RecurrentLayer) {
                ((RecurrentLayer)this.layers[i]).rnnClearPreviousState();
                continue;
            }
            if (!(this.layers[i] instanceof MultiLayerNetwork)) continue;
            ((MultiLayerNetwork)this.layers[i]).rnnClearPreviousState();
        }
    }

    public List<INDArray> rnnActivateUsingStoredState(INDArray input, boolean training, boolean storeLastForTBPTT) {
        INDArray currInput = input;
        ArrayList<INDArray> activations = new ArrayList<INDArray>();
        activations.add(currInput);
        for (int i = 0; i < this.layers.length; ++i) {
            if (this.getLayerWiseConfigurations().getInputPreProcess(i) != null) {
                currInput = this.getLayerWiseConfigurations().getInputPreProcess(i).preProcess(currInput, input.size(0));
            }
            if (this.layers[i] instanceof RecurrentLayer) {
                currInput = ((RecurrentLayer)this.layers[i]).rnnActivateUsingStoredState(currInput, training, storeLastForTBPTT);
            } else if (this.layers[i] instanceof MultiLayerNetwork) {
                List<INDArray> temp = ((MultiLayerNetwork)this.layers[i]).rnnActivateUsingStoredState(currInput, training, storeLastForTBPTT);
                currInput = temp.get(temp.size() - 1);
            } else {
                currInput = this.layers[i].activate(currInput, training);
            }
            activations.add(currInput);
        }
        return activations;
    }

    public synchronized Updater getUpdater() {
        if (this.solver == null) {
            this.solver = new Solver.Builder().configure(this.conf()).listeners(this.getListeners()).model(this).build();
            this.solver.getOptimizer().setUpdater(UpdaterCreator.getUpdater(this));
        }
        return this.solver.getOptimizer().getUpdater();
    }

    public void setUpdater(Updater updater) {
        if (this.solver == null) {
            this.solver = new Solver.Builder().configure(this.conf()).listeners(this.getListeners()).model(this).build();
        }
        this.solver.getOptimizer().setUpdater(updater);
    }

    public void setLayerMaskArrays(INDArray featuresMaskArray, INDArray labelsMaskArray) {
        if (featuresMaskArray != null) {
            this.feedForwardMaskArray(featuresMaskArray, MaskState.Active, featuresMaskArray.size(0));
        }
        if (labelsMaskArray != null) {
            if (!(this.getOutputLayer() instanceof IOutputLayer)) {
                return;
            }
            this.layers[this.layers.length - 1].setMaskArray(labelsMaskArray);
        }
    }

    public void clearLayerMaskArrays() {
        for (Layer layer : this.layers) {
            layer.setMaskArray(null);
        }
    }

    public Evaluation evaluate(DataSetIterator iterator) {
        return this.evaluate(iterator, null);
    }

    public RegressionEvaluation evaluateRegression(DataSetIterator iterator) {
        return ((RegressionEvaluation[])this.doEvaluation(iterator, new RegressionEvaluation[]{new RegressionEvaluation(iterator.totalOutcomes())}))[0];
    }

    public ROC evaluateROC(DataSetIterator iterator, int rocThresholdSteps) {
        return ((ROC[])this.doEvaluation(iterator, new ROC[]{new ROC(rocThresholdSteps)}))[0];
    }

    public ROCMultiClass evaluateROCMultiClass(DataSetIterator iterator, int rocThresholdSteps) {
        return ((ROCMultiClass[])this.doEvaluation(iterator, new ROCMultiClass[]{new ROCMultiClass(rocThresholdSteps)}))[0];
    }

    @Override
    public <T extends IEvaluation> T[] doEvaluation(DataSetIterator iterator, T ... evaluations) {
        org.nd4j.linalg.dataset.DataSet next;
        DummyWorkspace workspace;
        if (!iterator.hasNext() && iterator.resetSupported()) {
            iterator.reset();
        }
        DataSetIterator iter = iterator.asyncSupported() ? new AsyncDataSetIterator(iterator, 2, true) : iterator;
        WorkspaceMode cMode = this.layerWiseConfigurations.getTrainingWorkspaceMode();
        this.layerWiseConfigurations.setTrainingWorkspaceMode(this.layerWiseConfigurations.getInferenceWorkspaceMode());
        Object object = workspace = this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationExternal, workspaceExternal);
        while (iter.hasNext() && (next = (org.nd4j.linalg.dataset.DataSet)iter.next()).getFeatureMatrix() != null && next.getLabels() != null) {
            try (MemoryWorkspace wsB = workspace.notifyScopeEntered();){
                INDArray out;
                INDArray features = next.getFeatures();
                INDArray labels = next.getLabels();
                INDArray lMask = next.getLabelsMaskArray();
                if (next.hasMaskArrays()) {
                    INDArray fMask = next.getFeaturesMaskArray();
                    out = this.silentOutput(features, false, fMask, lMask);
                } else {
                    out = this.silentOutput(features, false);
                }
                try (MemoryWorkspace wsO = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
                    for (T evaluation : evaluations) {
                        evaluation.eval(labels, out, lMask);
                    }
                }
            }
            this.clearLayerMaskArrays();
        }
        if (iterator.asyncSupported()) {
            ((AsyncDataSetIterator)iter).shutdown();
        }
        this.layerWiseConfigurations.setTrainingWorkspaceMode(cMode);
        return evaluations;
    }

    public Evaluation evaluate(DataSetIterator iterator, List<String> labelsList) {
        return this.evaluate(iterator, labelsList, 1);
    }

    @Override
    public INDArray updaterState() {
        return this.getUpdater() != null ? this.getUpdater().getStateViewArray() : null;
    }

    @Override
    public void fit(MultiDataSet dataSet) {
        if (dataSet.getFeatures().length == 1 && dataSet.getLabels().length == 1) {
            INDArray features = null;
            INDArray labels = null;
            INDArray fMask = null;
            INDArray lMask = null;
            if (dataSet.getFeaturesMaskArrays() != null) {
                fMask = dataSet.getFeaturesMaskArrays()[0];
            }
            if (dataSet.getFeaturesMaskArrays() != null) {
                lMask = dataSet.getLabelsMaskArrays()[0];
            }
            features = dataSet.getFeatures()[0];
            labels = dataSet.getLabels()[0];
            org.nd4j.linalg.dataset.DataSet ds = new org.nd4j.linalg.dataset.DataSet(features, labels, fMask, lMask);
            this.fit((DataSet)ds);
        }
        throw new DL4JInvalidInputException("MultiLayerNetwork can't handle MultiDataSet. Please consider use of ComputationGraph");
    }

    @Override
    public void fit(MultiDataSetIterator iterator) {
        this.fit(new MultiDataSetWrapperIterator(iterator));
    }

    @Override
    public <T extends IEvaluation> T[] doEvaluation(MultiDataSetIterator iterator, T[] evaluations) {
        return this.doEvaluation(new MultiDataSetWrapperIterator(iterator), (IEvaluation[])evaluations);
    }

    public Evaluation evaluate(DataSetIterator iterator, List<String> labelsList, int topN) {
        if (this.layers == null || !(this.getOutputLayer() instanceof IOutputLayer)) {
            throw new IllegalStateException("Cannot evaluate network with no output layer");
        }
        if (labelsList == null) {
            labelsList = iterator.getLabels();
        }
        Evaluation e = new Evaluation((List<String>)labelsList, topN);
        this.doEvaluation(iterator, new Evaluation[]{e});
        return e;
    }

    private void update(Task task) {
        if (!this.initDone) {
            this.initDone = true;
            Heartbeat heartbeat = Heartbeat.getInstance();
            task = ModelSerializer.taskByModel(this);
            Environment env = EnvironmentUtils.buildEnvironment();
            heartbeat.reportEvent(Event.STANDALONE, env, task);
        }
    }

    public String summary() {
        String ret = "\n";
        ret = ret + StringUtils.repeat((String)"=", (int)140);
        ret = ret + "\n";
        ret = ret + String.format("%-40s%-15s%-15s%-30s\n", "LayerName (LayerType)", "nIn,nOut", "TotalParams", "ParamsShape");
        ret = ret + StringUtils.repeat((String)"=", (int)140);
        ret = ret + "\n";
        int frozenParams = 0;
        for (Layer currentLayer : this.layers) {
            String name = String.valueOf(currentLayer.getIndex());
            String paramShape = "-";
            String in = "-";
            String out = "-";
            String[] classNameArr = currentLayer.getClass().getName().split("\\.");
            String className = classNameArr[classNameArr.length - 1];
            String paramCount = String.valueOf(currentLayer.numParams());
            if (currentLayer.numParams() > 0) {
                paramShape = "";
                in = String.valueOf(((FeedForwardLayer)currentLayer.conf().getLayer()).getNIn());
                out = String.valueOf(((FeedForwardLayer)currentLayer.conf().getLayer()).getNOut());
                Set<String> paraNames = currentLayer.conf().getLearningRateByParam().keySet();
                for (String aP : paraNames) {
                    String paramS = ArrayUtils.toString((Object)currentLayer.paramTable().get(aP).shape());
                    paramShape = paramShape + aP + ":" + paramS + ", ";
                }
                paramShape = paramShape.subSequence(0, paramShape.lastIndexOf(",")).toString();
            }
            if (currentLayer instanceof FrozenLayer) {
                frozenParams += currentLayer.numParams();
                classNameArr = ((FrozenLayer)currentLayer).getInsideLayer().getClass().getName().split("\\.");
                className = "Frozen " + classNameArr[classNameArr.length - 1];
            }
            ret = ret + String.format("%-40s%-15s%-15s%-30s", name + " (" + className + ")", in + "," + out, paramCount, paramShape);
            ret = ret + "\n";
        }
        ret = ret + StringUtils.repeat((String)"-", (int)140);
        ret = ret + String.format("\n%30s %d", "Total Parameters: ", this.params().length());
        ret = ret + String.format("\n%30s %d", "Trainable Parameters: ", this.params().length() - frozenParams);
        ret = ret + String.format("\n%30s %d", "Frozen Parameters: ", frozenParams);
        ret = ret + "\n";
        ret = ret + StringUtils.repeat((String)"=", (int)140);
        ret = ret + "\n";
        return ret;
    }

    protected void clearLayersStates() {
        for (int f = 0; f < this.layers.length; ++f) {
            this.layers[f].setInput(null);
            this.layers[f].setMaskArray(null);
            this.layers[f].clear();
        }
    }

    public boolean equals(Object obj) {
        if (obj == null) {
            return false;
        }
        if (obj instanceof MultiLayerNetwork) {
            MultiLayerNetwork network = (MultiLayerNetwork)obj;
            boolean paramsEquals = network.params().equals(this.params());
            boolean confEquals = this.getLayerWiseConfigurations().equals(network.getLayerWiseConfigurations());
            boolean updaterEquals = this.getUpdater().equals(network.getUpdater());
            return paramsEquals && confEquals && updaterEquals;
        }
        return false;
    }

    public void setInitDone(boolean initDone) {
        this.initDone = initDone;
    }

    public INDArray getFlattenedGradients() {
        return this.flattenedGradients;
    }
}

