/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.graph;

import java.io.File;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.bytedeco.javacpp.Pointer;
import org.deeplearning4j.eval.ROC;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.nn.api.FwdPassType;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.ModelAdapter;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.api.layers.RecurrentLayer;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.util.ComputationGraphUtil;
import org.deeplearning4j.nn.graph.util.GraphIndices;
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.deeplearning4j.nn.graph.vertex.impl.FrozenVertex;
import org.deeplearning4j.nn.graph.vertex.impl.InputVertex;
import org.deeplearning4j.nn.graph.vertex.impl.LayerVertex;
import org.deeplearning4j.nn.layers.FrozenLayer;
import org.deeplearning4j.nn.layers.FrozenLayerWithBackprop;
import org.deeplearning4j.nn.layers.recurrent.BidirectionalLayer;
import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.Solver;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator;
import org.deeplearning4j.util.Convolution1DUtils;
import org.deeplearning4j.util.ConvolutionUtils;
import org.deeplearning4j.util.CrashReportingUtil;
import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.util.NetworkUtils;
import org.deeplearning4j.util.OutputLayerUtil;
import org.nd4j.adapters.OutputAdapter;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.primitives.Triple;
import org.nd4j.common.util.OneTimeLogger;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.abstracts.DummyWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.AsyncMultiDataSetIterator;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.DataSetUtil;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.heartbeat.Heartbeat;
import org.nd4j.linalg.heartbeat.reports.Environment;
import org.nd4j.linalg.heartbeat.reports.Event;
import org.nd4j.linalg.heartbeat.reports.Task;
import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils;
import org.nd4j.linalg.heartbeat.utils.TaskUtils;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.linalg.workspace.ND4JWorkspaceException;
import org.nd4j.linalg.workspace.WorkspaceUtils;
import org.nd4j.linalg.workspace.WorkspacesCloseable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ComputationGraph
implements Serializable,
Model,
NeuralNetwork {
    private static final Logger log = LoggerFactory.getLogger(ComputationGraph.class);
    protected ComputationGraphConfiguration configuration;
    protected boolean initCalled = false;
    protected transient Solver solver;
    protected INDArray flattenedParams;
    protected transient INDArray flattenedGradients;
    protected Gradient gradient;
    protected double score;
    private boolean initDone = false;
    protected boolean clearTbpttState = true;
    protected transient Map<String, Pointer> helperWorkspaces = new HashMap<String, Pointer>();
    private final transient AtomicLong occupiedBy = new AtomicLong(-1L);
    protected static final String WS_LAYER_WORKING_MEM = "WS_LAYER_WORKING_MEM";
    protected static final String WS_ALL_LAYERS_ACT = "WS_ALL_LAYERS_ACT";
    protected static final String WS_RNN_LOOP_WORKING_MEM = "WS_RNN_LOOP_WORKING_MEM";
    protected static final String WS_OUTPUT_MEM = "WS_OUTPUT_MEM";
    protected final WorkspaceConfiguration WS_LAYER_WORKING_MEM_CONFIG;
    protected static final WorkspaceConfiguration WS_ALL_LAYERS_ACT_CONFIG = WorkspaceConfiguration.builder().initialSize(0L).overallocationLimit(0.05).policyLearning(LearningPolicy.FIRST_LOOP).policyReset(ResetPolicy.BLOCK_LEFT).policySpill(SpillPolicy.REALLOCATE).policyAllocation(AllocationPolicy.OVERALLOCATE).build();
    protected final WorkspaceConfiguration WS_LAYER_ACT_X_CONFIG;
    protected static final WorkspaceConfiguration WS_RNN_LOOP_WORKING_MEM_CONFIG = WorkspaceConfiguration.builder().initialSize(0L).overallocationLimit(0.05).policyReset(ResetPolicy.BLOCK_LEFT).policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE).policyLearning(LearningPolicy.FIRST_LOOP).build();
    protected transient ThreadLocal<Long> lastEtlTime = new ThreadLocal();
    protected GraphVertex[] vertices;
    protected Map<String, GraphVertex> verticesMap;
    protected int[] topologicalOrder;
    protected GraphIndices graphIndices;
    protected org.deeplearning4j.nn.api.Layer[] layers;
    private int numInputArrays;
    private int numOutputArrays;
    private transient INDArray[] inputs;
    private transient INDArray[] labels;
    private transient INDArray[] inputMaskArrays;
    private transient INDArray[] labelMaskArrays;
    private transient int[] outputLayerIdxs;
    private NeuralNetConfiguration defaultConfiguration;
    private Collection<TrainingListener> trainingListeners = new ArrayList<TrainingListener>();

    public ComputationGraph(ComputationGraphConfiguration configuration) {
        this.configuration = configuration;
        this.numInputArrays = configuration.getNetworkInputs().size();
        this.numOutputArrays = configuration.getNetworkOutputs().size();
        this.inputs = new INDArray[this.numInputArrays];
        this.labels = new INDArray[this.numOutputArrays];
        this.defaultConfiguration = configuration.getDefaultConfiguration();
        int numWorkingMem = 2 * configuration.getVertices().size();
        this.WS_LAYER_WORKING_MEM_CONFIG = WorkspaceConfiguration.builder().initialSize(0L).overallocationLimit(0.02).policyLearning(LearningPolicy.OVER_TIME).cyclesBeforeInitialization(numWorkingMem).policyReset(ResetPolicy.BLOCK_LEFT).policySpill(SpillPolicy.REALLOCATE).policyAllocation(AllocationPolicy.OVERALLOCATE).build();
        this.WS_LAYER_ACT_X_CONFIG = WorkspaceConfiguration.builder().initialSize(0L).overallocationLimit(0.02).policyLearning(LearningPolicy.OVER_TIME).cyclesBeforeInitialization(configuration.getVertices().size()).policyReset(ResetPolicy.BLOCK_LEFT).policySpill(SpillPolicy.REALLOCATE).policyAllocation(AllocationPolicy.OVERALLOCATE).build();
    }

    public void setLastEtlTime(long time) {
        this.lastEtlTime.set(time);
    }

    public long getLastEtlTime() {
        Long time = this.lastEtlTime.get();
        return time == null ? 0L : time;
    }

    public void setCacheMode(CacheMode mode) {
        if (mode == null) {
            mode = CacheMode.NONE;
        }
        for (org.deeplearning4j.nn.api.Layer layer : this.layers) {
            layer.setCacheMode(mode);
        }
    }

    public ComputationGraphConfiguration getConfiguration() {
        return this.configuration;
    }

    public int getNumLayers() {
        return this.layers != null ? this.layers.length : 0;
    }

    public org.deeplearning4j.nn.api.Layer getLayer(int idx) {
        return this.layers[idx];
    }

    public org.deeplearning4j.nn.api.Layer[] getLayers() {
        return this.layers;
    }

    public org.deeplearning4j.nn.api.Layer getLayer(String name) {
        Preconditions.checkState((boolean)this.verticesMap.containsKey(name), (String)"Layer with name %s does not exist in the network", (Object)name);
        return this.verticesMap.get(name).getLayer();
    }

    public GraphVertex[] getVertices() {
        return this.vertices;
    }

    public GraphVertex getVertex(String name) {
        return this.verticesMap.get(name);
    }

    public int getNumInputArrays() {
        return this.numInputArrays;
    }

    public int getNumOutputArrays() {
        return this.numOutputArrays;
    }

    public void setInput(int inputNum, INDArray input) {
        if (this.inputs == null) {
            this.inputs = new INDArray[this.numInputArrays];
        }
        this.inputs[inputNum] = input;
    }

    public void setInputs(INDArray ... inputs) {
        if (inputs != null && inputs.length != this.numInputArrays) {
            throw new IllegalArgumentException("Invalid input array: network has " + this.numInputArrays + " inputs, but array is of length " + inputs.length);
        }
        this.inputs = inputs;
    }

    public INDArray getInput(int inputNum) {
        if (this.inputs == null) {
            return null;
        }
        return this.inputs[inputNum];
    }

    public INDArray[] getInputs() {
        return this.inputs;
    }

    public INDArray[] getInputMaskArrays() {
        return this.inputMaskArrays;
    }

    public INDArray[] getLabelMaskArrays() {
        return this.labelMaskArrays;
    }

    public void setLabel(int labelNum, INDArray label) {
        this.labels[labelNum] = label;
    }

    public void setLabels(INDArray ... labels) {
        if (labels != null && labels.length != this.numOutputArrays) {
            throw new IllegalArgumentException("Invalid output array: network has " + this.numOutputArrays + " outputs, but array is of length " + labels.length);
        }
        this.labels = labels;
    }

    public void setGradientsAccumulator(GradientsAccumulator accumulator) {
        if (!this.initCalled) {
            this.init();
        }
        this.solver.getOptimizer().setGradientsAccumulator(accumulator);
    }

    @Override
    public void init() {
        this.init(null, false);
    }

    public void init(INDArray parameters, boolean cloneParametersArray) {
        String vertexName;
        boolean initializeParams;
        int i;
        if (this.initCalled) {
            return;
        }
        DataType netDtype = this.getConfiguration().getDataType();
        if (parameters != null && parameters.dataType() != netDtype) {
            Preconditions.checkState((parameters.rank() == 2 && parameters.size(0) == 1L ? 1 : 0) != 0, (String)"Invalid parameters array: should be rank 2 with shape [1,numParams]. Got %ndShape", (Object)parameters);
            if (cloneParametersArray) {
                try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
                    parameters = parameters.castTo(netDtype);
                }
            } else {
                throw new IllegalStateException("Error initializing network: Network datatype is set to " + netDtype + " but provided array has datatype " + parameters.dataType() + " with cloneParametersArray argument set to false. Cannot initialize net with specified datatype array if that array does not match network datatype");
            }
        }
        if (this.configuration.getTrainingWorkspaceMode() == null) {
            this.configuration.setTrainingWorkspaceMode(WorkspaceMode.NONE);
        }
        if (this.configuration.getInferenceWorkspaceMode() == null) {
            this.configuration.setInferenceWorkspaceMode(WorkspaceMode.NONE);
        }
        if (this.configuration.getCacheMode() == null) {
            this.configuration.setCacheMode(CacheMode.NONE);
        }
        OneTimeLogger.info((Logger)log, (String)"Starting ComputationGraph with WorkspaceModes set to [training: {}; inference: {}], cacheMode set to [{}]", (Object[])new Object[]{this.configuration.getTrainingWorkspaceMode(), this.configuration.getInferenceWorkspaceMode(), this.configuration.getCacheMode()});
        GraphIndices indices = this.calculateIndices();
        this.topologicalOrder = indices.getTopologicalSortOrder();
        Map<String, org.deeplearning4j.nn.conf.graph.GraphVertex> configVertexMap = this.configuration.getVertices();
        List<String> networkInputNames = this.configuration.getNetworkInputs();
        Map<String, List<String>> vertexInputs = this.configuration.getVertexInputs();
        this.vertices = new GraphVertex[networkInputNames.size() + this.configuration.getVertices().size()];
        HashMap<String, Integer> allNamesReverse = new HashMap<String, Integer>();
        int vertexNumber = 0;
        for (String name : networkInputNames) {
            InputVertex gv = new InputVertex(this, name, vertexNumber, null, netDtype);
            allNamesReverse.put(name, vertexNumber);
            this.vertices[vertexNumber++] = gv;
        }
        long numParams = 0L;
        long[] numParamsForVertex = new long[this.topologicalOrder.length];
        for (i = 0; i < this.configuration.getNetworkInputs().size(); ++i) {
            numParamsForVertex[i] = 0L;
        }
        while (i < this.topologicalOrder.length) {
            String name = indices.getIdxToName().get(i);
            org.deeplearning4j.nn.conf.graph.GraphVertex n = configVertexMap.get(name);
            n.setDataType(netDtype);
            numParamsForVertex[i] = n.numParams(true);
            if (numParamsForVertex[i] < 0L) {
                throw new DL4JInvalidConfigException("Layer " + name + " had parameters < 0 " + numParamsForVertex[i]);
            }
            numParams += numParamsForVertex[i];
            ++i;
        }
        if (parameters != null) {
            if (!parameters.isRowVectorOrScalar()) {
                throw new IllegalArgumentException("Invalid parameters: should be a row vector");
            }
            if (parameters.length() != numParams) {
                throw new IllegalArgumentException("Invalid parameters: expected length " + numParams + ", got length " + parameters.length());
            }
            this.flattenedParams = cloneParametersArray ? parameters.dup() : parameters;
            initializeParams = false;
        } else if (numParams > 0L) {
            this.flattenedParams = Nd4j.create((DataType)netDtype, (long[])new long[]{1L, numParams});
            initializeParams = true;
        } else {
            this.flattenedParams = null;
            initializeParams = false;
        }
        if (initializeParams) {
            Nd4j.getRandom().setSeed(this.conf().getSeed());
        }
        INDArray[] paramsViewForVertex = new INDArray[this.topologicalOrder.length];
        long paramOffsetSoFar = 0L;
        i = 0;
        for (int vertexIdx : this.topologicalOrder) {
            long nParamsThisVertex = numParamsForVertex[vertexIdx];
            if (nParamsThisVertex != 0L) {
                paramsViewForVertex[vertexIdx] = this.flattenedParams.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)paramOffsetSoFar, (long)(paramOffsetSoFar + nParamsThisVertex))});
            }
            ++i;
            paramOffsetSoFar += nParamsThisVertex;
        }
        int numLayers = 0;
        ArrayList<org.deeplearning4j.nn.api.Layer> tempLayerList = new ArrayList<org.deeplearning4j.nn.api.Layer>();
        this.defaultConfiguration.clearVariables();
        List<String> variables = this.defaultConfiguration.variables(false);
        for (i = this.configuration.getNetworkInputs().size(); i < this.topologicalOrder.length; ++i) {
            String name = indices.getIdxToName().get(i);
            org.deeplearning4j.nn.conf.graph.GraphVertex n = configVertexMap.get(name);
            GraphVertex gv = n.instantiate(this, name, vertexNumber, paramsViewForVertex[vertexNumber], initializeParams, netDtype);
            if (gv == null) {
                throw new IllegalStateException("Encountered null layer/vertex during initialization for layer \"" + name + "\": " + n.getClass().getSimpleName() + " initialization returned null layer/vertex?");
            }
            if (gv.hasLayer()) {
                ++numLayers;
                org.deeplearning4j.nn.api.Layer l = gv.getLayer();
                tempLayerList.add(l);
                List<String> layerVariables = l.conf().variables();
                if (layerVariables != null) {
                    for (String s : layerVariables) {
                        variables.add(gv.getVertexName() + "_" + s);
                    }
                }
            }
            allNamesReverse.put(name, vertexNumber);
            this.vertices[vertexNumber++] = gv;
        }
        this.layers = tempLayerList.toArray(new org.deeplearning4j.nn.api.Layer[numLayers]);
        this.verticesMap = new HashMap<String, GraphVertex>();
        for (GraphVertex gv : this.vertices) {
            this.verticesMap.put(gv.getVertexName(), gv);
        }
        HashMap<String, ArrayList<String>> verticesOutputTo = new HashMap<String, ArrayList<String>>();
        for (GraphVertex gv : this.vertices) {
            vertexName = gv.getVertexName();
            List<String> vertexInputNames = vertexInputs.get(vertexName);
            if (vertexInputNames == null) continue;
            for (String s : vertexInputNames) {
                ArrayList<String> list = (ArrayList<String>)verticesOutputTo.get(s);
                if (list == null) {
                    list = new ArrayList<String>();
                    verticesOutputTo.put(s, list);
                }
                list.add(vertexName);
            }
        }
        for (GraphVertex gv : this.vertices) {
            vertexName = gv.getVertexName();
            int vertexIndex = gv.getVertexIndex();
            List<String> vertexInputNames = vertexInputs.get(vertexName);
            if (vertexInputNames == null) continue;
            VertexIndices[] inputIndices = new VertexIndices[vertexInputNames.size()];
            for (int j = 0; j < vertexInputNames.size(); ++j) {
                String inName = vertexInputNames.get(j);
                int inputVertexIndex = (Integer)allNamesReverse.get(inName);
                int inputNumber = vertexInputs.get(vertexName).indexOf(inName);
                if (inputNumber == -1) {
                    throw new IllegalStateException("Could not find vertex " + vertexIndex + " in the list of inputs for vertex " + gv.getVertexName() + "; error in graph structure?");
                }
                inputIndices[j] = new VertexIndices(inputVertexIndex, inputNumber);
            }
            gv.setInputVertices(inputIndices);
        }
        for (GraphVertex gv : this.vertices) {
            vertexName = gv.getVertexName();
            List thisVertexOutputsTo = (List)verticesOutputTo.get(vertexName);
            if (thisVertexOutputsTo == null || thisVertexOutputsTo.isEmpty()) continue;
            VertexIndices[] outputIndices = new VertexIndices[thisVertexOutputsTo.size()];
            int j = 0;
            for (String s : new HashSet(thisVertexOutputsTo)) {
                List<String> nextVertexInputNames = vertexInputs.get(s);
                for (int k = 0; k < nextVertexInputNames.size(); ++k) {
                    if (!vertexName.equals(nextVertexInputNames.get(k))) continue;
                    int outputVertexIndex = (Integer)allNamesReverse.get(s);
                    outputIndices[j++] = new VertexIndices(outputVertexIndex, k);
                }
            }
            gv.setOutputVertices(outputIndices);
        }
        for (String s : this.configuration.getNetworkOutputs()) {
            GraphVertex gv = this.verticesMap.get(s);
            gv.setOutputVertex(true);
        }
        if (this.solver == null) {
            try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                this.solver = new Solver.Builder().configure(this.conf()).listeners(this.getListeners()).model(this).build();
                this.solver.initOptimizer();
            }
        }
        HashMap seenAsInputTo = new HashMap();
        for (Map.Entry<String, List<String>> entry : this.configuration.getVertexInputs().entrySet()) {
            for (String s : entry.getValue()) {
                if (!seenAsInputTo.containsKey(s)) {
                    seenAsInputTo.put(s, new ArrayList());
                }
                List seen = (List)seenAsInputTo.get(s);
                seen.add(entry.getKey());
            }
        }
        for (org.deeplearning4j.nn.api.Layer l : this.layers) {
            String layerName = l.conf().getLayer().getLayerName();
            List<String> inputs = this.configuration.getVertexInputs().get(layerName);
            String in = inputs.get(0);
            if (this.configuration.getNetworkInputs().contains(in)) continue;
            List seen = (List)seenAsInputTo.get(in);
            if (seen.size() == 1) {
                l.allowInputModification(true);
                continue;
            }
            int thisIdx = indices.getNameToIdx().get(layerName);
            int thisTopoPos = ArrayUtils.indexOf((int[])indices.getTopologicalSortOrder(), (int)thisIdx);
            int maxTopoPosition = -1;
            for (String s : seen) {
                int idx = indices.getNameToIdx().get(s);
                int topoPos = ArrayUtils.indexOf((int[])indices.getTopologicalSortOrder(), (int)idx);
                maxTopoPosition = Math.max(maxTopoPosition, topoPos);
            }
            if (thisTopoPos != maxTopoPosition) continue;
            l.allowInputModification(true);
        }
        this.synchronizeIterEpochCounts();
        this.initCalled = true;
    }

    public void initGradientsView() {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            int i;
            if (!this.initCalled) {
                this.init();
            }
            GraphIndices indices = this.calculateIndices();
            long numParams = 0L;
            long[] numParamsForVertex = new long[this.topologicalOrder.length];
            for (i = 0; i < this.configuration.getNetworkInputs().size(); ++i) {
                numParamsForVertex[i] = 0L;
            }
            Map<String, org.deeplearning4j.nn.conf.graph.GraphVertex> configVertexMap = this.configuration.getVertices();
            while (i < this.topologicalOrder.length) {
                String name = indices.getIdxToName().get(i);
                org.deeplearning4j.nn.conf.graph.GraphVertex n = configVertexMap.get(name);
                numParamsForVertex[i] = n.numParams(true);
                numParams += numParamsForVertex[i];
                ++i;
            }
            if (numParams > 0L) {
                this.flattenedGradients = Nd4j.create((DataType)this.flattenedParams.dataType(), (long[])new long[]{1L, numParams});
            }
            long paramOffsetSoFar = 0L;
            i = 0;
            for (int vertexIdx : this.topologicalOrder) {
                long nParamsThisVertex = numParamsForVertex[vertexIdx];
                if (nParamsThisVertex != 0L) {
                    INDArray gradientView = this.flattenedGradients.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)paramOffsetSoFar, (long)(paramOffsetSoFar + nParamsThisVertex))});
                    this.vertices[vertexIdx].setBackpropGradientsViewArray(gradientView);
                }
                ++i;
                paramOffsetSoFar += nParamsThisVertex;
            }
        }
    }

    protected int[] getOutputLayerIndices() {
        if (this.outputLayerIdxs == null) {
            this.outputLayerIdxs = new int[this.numOutputArrays];
            int i = 0;
            for (String s : this.configuration.getNetworkOutputs()) {
                this.outputLayerIdxs[i++] = this.verticesMap.get(s).getVertexIndex();
            }
        }
        return this.outputLayerIdxs;
    }

    public void pretrain(DataSetIterator iter) {
        this.pretrain(iter, 1);
    }

    public void pretrain(DataSetIterator iter, int numEpochs) {
        if (this.numInputArrays != 1) {
            throw new UnsupportedOperationException("Cannot train ComputationGraph network with  multiple inputs using a DataSetIterator");
        }
        this.pretrain(ComputationGraphUtil.toMultiDataSetIterator(iter), numEpochs);
    }

    public void pretrain(MultiDataSetIterator iter) {
        this.pretrain(iter, 1);
    }

    public void pretrain(MultiDataSetIterator iter, int numEpochs) {
        try {
            this.pretrainHelper(iter, numEpochs);
        }
        catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    private void pretrainHelper(MultiDataSetIterator iter, int numEpochs) {
        if (this.flattenedGradients == null) {
            this.initGradientsView();
        }
        for (int i = 0; i < this.topologicalOrder.length; ++i) {
            if (!this.vertices[i].hasLayer() || this.vertices[i].getLayer() instanceof IOutputLayer || !this.vertices[i].getLayer().isPretrainLayer()) continue;
            this.pretrainLayerHelper(this.vertices[i].getVertexName(), iter, numEpochs);
        }
    }

    public void pretrainLayer(String layerName, DataSetIterator dataSetIterator) {
        if (this.numInputArrays != 1) {
            throw new UnsupportedOperationException("Cannot train ComputationGraph network with  multiple inputs using a DataSetIterator");
        }
        this.pretrainLayer(layerName, ComputationGraphUtil.toMultiDataSetIterator(dataSetIterator));
    }

    public void pretrainLayer(String layerName, MultiDataSetIterator iter) {
        try {
            this.pretrainLayerHelper(layerName, iter, 1);
        }
        catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    private void pretrainLayerHelper(String layerName, MultiDataSetIterator iter, int numEpochs) {
        MultiDataSetIterator withAsync;
        if (this.flattenedGradients == null) {
            this.initGradientsView();
        }
        if (!this.verticesMap.containsKey(layerName)) {
            throw new IllegalStateException("Invalid vertex name: " + layerName + " - all vertex names: " + this.verticesMap.keySet());
        }
        if (!this.verticesMap.get(layerName).hasLayer()) {
            return;
        }
        GraphVertex toTrain = this.verticesMap.get(layerName);
        int idx = toTrain.getVertexIndex();
        LayerWorkspaceMgr workspaceMgr = this.configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? LayerWorkspaceMgr.noWorkspaces() : LayerWorkspaceMgr.builder().with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).with(ArrayType.UPDATER_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).build();
        workspaceMgr.setHelperWorkspacePointers(this.helperWorkspaces);
        if (!iter.hasNext() && iter.resetSupported()) {
            iter.reset();
        }
        Object object = withAsync = iter.asyncSupported() ? new AsyncMultiDataSetIterator(iter) : iter;
        while (withAsync.hasNext()) {
            MultiDataSet mds = (MultiDataSet)withAsync.next();
            MemoryWorkspace ws = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATIONS);
            Throwable throwable = null;
            try {
                VertexIndices[] inputsToLayer;
                Map<String, INDArray> activations = this.ffToLayerActivationsInWS(false, idx, new int[]{idx}, FwdPassType.STANDARD, false, mds.getFeatures(), mds.getFeaturesMaskArrays(), mds.getLabelsMaskArrays(), true);
                for (VertexIndices vi : inputsToLayer = toTrain.getInputVertices()) {
                    String inName = this.vertices[vi.getVertexIndex()].getVertexName();
                    INDArray act = activations.get(inName);
                    toTrain.setInput(vi.getVertexEdgeNumber(), act, workspaceMgr);
                }
                org.deeplearning4j.nn.api.Layer layer = toTrain.getLayer();
                layer.fit(layer.input(), workspaceMgr);
            }
            catch (Throwable throwable2) {
                throwable = throwable2;
                throw throwable2;
            }
            finally {
                if (ws == null) continue;
                if (throwable != null) {
                    try {
                        ws.close();
                    }
                    catch (Throwable throwable3) {
                        throwable.addSuppressed(throwable3);
                    }
                    continue;
                }
                ws.close();
            }
        }
    }

    @Override
    public void fit(DataSet dataSet) {
        if (this.numInputArrays != 1 || this.numOutputArrays != 1) {
            throw new UnsupportedOperationException("Cannot train ComputationGraph network with  multiple inputs or outputs using a DataSet");
        }
        boolean hasMaskArrays = dataSet.hasMaskArrays();
        if (hasMaskArrays) {
            INDArray[] iNDArrayArray;
            INDArray[] fMask;
            INDArray[] iNDArrayArray2;
            if (dataSet.getFeaturesMaskArray() != null) {
                INDArray[] iNDArrayArray3 = new INDArray[1];
                iNDArrayArray2 = iNDArrayArray3;
                iNDArrayArray3[0] = dataSet.getFeaturesMaskArray();
            } else {
                iNDArrayArray2 = fMask = null;
            }
            if (dataSet.getLabelsMaskArray() != null) {
                INDArray[] iNDArrayArray4 = new INDArray[1];
                iNDArrayArray = iNDArrayArray4;
                iNDArrayArray4[0] = dataSet.getLabelsMaskArray();
            } else {
                iNDArrayArray = null;
            }
            INDArray[] lMask = iNDArrayArray;
            this.fit(new INDArray[]{dataSet.getFeatures()}, new INDArray[]{dataSet.getLabels()}, fMask, lMask);
        } else {
            this.fit(new INDArray[]{dataSet.getFeatures()}, new INDArray[]{dataSet.getLabels()});
        }
        if (hasMaskArrays) {
            this.clearLayerMaskArrays();
        }
        this.clearLayersStates();
    }

    public void fit(@NonNull DataSetIterator iterator, int numEpochs) {
        if (iterator == null) {
            throw new NullPointerException("iterator is marked non-null but is null");
        }
        Preconditions.checkArgument((numEpochs > 0 ? 1 : 0) != 0, (String)"Number of epochs much be > 0. Got numEpochs = %s", (int)numEpochs);
        Preconditions.checkArgument((numEpochs == 1 || iterator.resetSupported() ? 1 : 0) != 0, (String)"Cannot perform multiple epochs training usingiterator thas does not support resetting (iterator.resetSupported() returned false)");
        for (int i = 0; i < numEpochs; ++i) {
            this.fit(iterator);
        }
    }

    @Override
    public void fit(@NonNull DataSetIterator iterator) {
        if (iterator == null) {
            throw new NullPointerException("iterator is marked non-null but is null");
        }
        this.fit((MultiDataSetIterator)new MultiDataSetIteratorAdapter(iterator));
    }

    @Override
    public void fit(MultiDataSet multiDataSet) {
        this.fit(multiDataSet.getFeatures(), multiDataSet.getLabels(), multiDataSet.getFeaturesMaskArrays(), multiDataSet.getLabelsMaskArrays());
        if (multiDataSet.hasMaskArrays()) {
            this.clearLayerMaskArrays();
        }
    }

    public void fit(@NonNull MultiDataSetIterator iterator, int numEpochs) {
        if (iterator == null) {
            throw new NullPointerException("iterator is marked non-null but is null");
        }
        Preconditions.checkArgument((numEpochs > 0 ? 1 : 0) != 0, (String)"Number of epochs much be > 0. Got numEpochs = %s", (int)numEpochs);
        Preconditions.checkArgument((numEpochs == 1 || iterator.resetSupported() ? 1 : 0) != 0, (String)"Cannot perform multiple epochs training usingiterator thas does not support resetting (iterator.resetSupported() returned false)");
        for (int i = 0; i < numEpochs; ++i) {
            this.fit(iterator);
        }
    }

    @Override
    public synchronized void fit(MultiDataSetIterator multi) {
        MultiDataSetIterator multiDataSetIterator;
        if (this.flattenedGradients == null) {
            this.initGradientsView();
        }
        if (!multi.hasNext() && multi.resetSupported()) {
            multi.reset();
        }
        for (TrainingListener tl : this.trainingListeners) {
            tl.onEpochStart(this);
        }
        boolean destructable = false;
        if (multi.asyncSupported()) {
            multiDataSetIterator = new AsyncMultiDataSetIterator(multi, Math.max(Nd4j.getAffinityManager().getNumberOfDevices() * 2, 2), true);
            destructable = true;
        } else {
            multiDataSetIterator = multi;
        }
        long time1 = System.currentTimeMillis();
        while (multiDataSetIterator.hasNext()) {
            MultiDataSet mds = (MultiDataSet)multiDataSetIterator.next();
            long time2 = System.currentTimeMillis();
            this.lastEtlTime.set(time2 - time1);
            this.fit(mds.getFeatures(), mds.getLabels(), mds.getFeaturesMaskArrays(), mds.getLabelsMaskArrays());
            time1 = System.currentTimeMillis();
        }
        if (destructable) {
            ((AsyncMultiDataSetIterator)multiDataSetIterator).shutdown();
        }
        for (TrainingListener tl : this.trainingListeners) {
            tl.onEpochEnd(this);
        }
        this.incrementEpochCount();
    }

    public void fit(INDArray[] inputs, INDArray[] labels) {
        this.fit(inputs, labels, null, null);
    }

    public void fit(INDArray[] inputs, INDArray[] labels, INDArray[] featureMaskArrays, INDArray[] labelMaskArrays) {
        try {
            this.fitHelper(inputs, labels, featureMaskArrays, labelMaskArrays);
        }
        catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    private synchronized void fitHelper(INDArray[] inputs, INDArray[] labels, INDArray[] featureMaskArrays, INDArray[] labelMaskArrays) {
        if (this.numParams() == 0L) {
            return;
        }
        if (this.flattenedGradients == null) {
            this.initGradientsView();
        }
        this.setInputs(inputs);
        this.setLabels(labels);
        this.setLayerMaskArrays(featureMaskArrays, labelMaskArrays);
        this.update(TaskUtils.buildTask((INDArray[])inputs, (INDArray[])labels));
        LayerWorkspaceMgr workspaceMgr = this.configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? LayerWorkspaceMgr.noWorkspaces() : LayerWorkspaceMgr.builder().with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).with(ArrayType.UPDATER_WORKING_MEM, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).build();
        workspaceMgr.setHelperWorkspacePointers(this.helperWorkspaces);
        if (this.configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
            this.doTruncatedBPTT(inputs, labels, featureMaskArrays, labelMaskArrays, workspaceMgr);
        } else {
            if (this.solver == null) {
                try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                    this.solver = new Solver.Builder().configure(this.conf()).listeners(this.getListeners()).model(this).build();
                }
            }
            this.solver.optimize(workspaceMgr);
        }
        if (featureMaskArrays != null || labelMaskArrays != null) {
            this.clearLayerMaskArrays();
        }
        this.clearLayersStates();
        this.synchronizeIterEpochCounts();
    }

    public int[] topologicalSortOrder() {
        return this.calculateIndices().getTopologicalSortOrder();
    }

    public GraphIndices calculateIndices() {
        if (this.graphIndices != null) {
            return this.graphIndices;
        }
        if (this.configuration.getTopologicalOrder() != null && this.configuration.getTopologicalOrderStr() != null) {
            int[] t = this.configuration.getTopologicalOrder();
            List<String> s = this.configuration.getTopologicalOrderStr();
            HashMap<String, Integer> m1 = new HashMap<String, Integer>();
            HashMap<Integer, String> m2 = new HashMap<Integer, String>();
            for (int i = 0; i < t.length; ++i) {
                m1.put(s.get(i), t[i]);
                m2.put(t[i], s.get(i));
            }
            this.graphIndices = GraphIndices.builder().topologicalSortOrder(t).nameToIdx(m1).idxToName(m2).build();
            return this.graphIndices;
        }
        Map<String, org.deeplearning4j.nn.conf.graph.GraphVertex> nodeMap = this.configuration.getVertices();
        List<String> networkInputNames = this.configuration.getNetworkInputs();
        int numVertices = networkInputNames.size() + this.configuration.getVertices().size();
        int[] out = new int[numVertices];
        int outCounter = 0;
        HashMap<Integer, String> vertexNamesMap = new HashMap<Integer, String>();
        HashMap<String, Integer> vertexNamesMap2 = new HashMap<String, Integer>();
        int i = 0;
        for (String string : this.configuration.getNetworkInputs()) {
            vertexNamesMap.put(i, string);
            vertexNamesMap2.put(string, i);
            ++i;
        }
        for (Map.Entry entry : nodeMap.entrySet()) {
            Iterator<Map.Entry<String, org.deeplearning4j.nn.conf.graph.GraphVertex>> name = (String)entry.getKey();
            vertexNamesMap.put(i, (String)((Object)name));
            vertexNamesMap2.put((String)((Object)name), i);
            ++i;
        }
        HashMap inputEdges = new HashMap();
        HashMap<Integer, HashSet<Integer>> hashMap = new HashMap<Integer, HashSet<Integer>>();
        for (String string : this.configuration.getNetworkInputs()) {
            int n = (Integer)vertexNamesMap2.get(string);
            inputEdges.put(n, null);
        }
        for (Map.Entry<String, org.deeplearning4j.nn.conf.graph.GraphVertex> entry : nodeMap.entrySet()) {
            String string = entry.getKey();
            int idx = (Integer)vertexNamesMap2.get(string);
            List<String> inputsToThisVertex = this.configuration.getVertexInputs().get(string);
            if (inputsToThisVertex == null || inputsToThisVertex.isEmpty()) {
                inputEdges.put(idx, null);
                continue;
            }
            HashSet<Integer> inputSet = new HashSet<Integer>();
            for (String s : inputsToThisVertex) {
                Integer inputIdx = (Integer)vertexNamesMap2.get(s);
                inputSet.add(inputIdx);
                HashSet<Integer> outputSetForInputIdx = (HashSet<Integer>)hashMap.get(inputIdx);
                if (outputSetForInputIdx == null) {
                    outputSetForInputIdx = new HashSet<Integer>();
                    hashMap.put(inputIdx, outputSetForInputIdx);
                }
                outputSetForInputIdx.add(idx);
            }
            inputEdges.put(idx, inputSet);
        }
        LinkedList<Object> noIncomingEdges = new LinkedList<Object>();
        for (Map.Entry entry : inputEdges.entrySet()) {
            Set inputsFrom = (Set)entry.getValue();
            if (inputsFrom != null && !inputsFrom.isEmpty()) continue;
            noIncomingEdges.add(entry.getKey());
        }
        while (!noIncomingEdges.isEmpty()) {
            int n = (Integer)noIncomingEdges.removeFirst();
            out[outCounter++] = n;
            Set set = (Set)hashMap.get(n);
            if (set == null) continue;
            for (Integer v : set) {
                Set set2 = (Set)inputEdges.get(v);
                set2.remove(n);
                if (!set2.isEmpty()) continue;
                noIncomingEdges.add(v);
            }
        }
        for (Map.Entry entry : inputEdges.entrySet()) {
            Set set = (Set)entry.getValue();
            if (set == null || set.isEmpty()) continue;
            throw new IllegalStateException("Invalid configuration: cycle detected in graph. Cannot calculate topological ordering with graph cycle (cycle includes vertex \"" + (String)vertexNamesMap.get(entry.getKey()) + "\")");
        }
        ArrayList<String> arrayList = new ArrayList<String>(out.length);
        for (int idx : out) {
            arrayList.add((String)vertexNamesMap.get(idx));
        }
        this.configuration.setTopologicalOrder(out);
        this.configuration.setTopologicalOrderStr(arrayList);
        this.graphIndices = GraphIndices.builder().topologicalSortOrder(out).nameToIdx(vertexNamesMap2).idxToName(vertexNamesMap).build();
        return this.graphIndices;
    }

    @Override
    public void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr) {
        this.computeGradientAndScore();
    }

    public void computeGradientAndScore() {
        block57: {
            this.synchronizeIterEpochCounts();
            LayerWorkspaceMgr workspaceMgr = this.configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? LayerWorkspaceMgr.noWorkspaces() : LayerWorkspaceMgr.builder().with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).with(ArrayType.UPDATER_WORKING_MEM, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).build();
            workspaceMgr.setHelperWorkspacePointers(this.helperWorkspaces);
            boolean tbptt = this.configuration.getBackpropType() == BackpropType.TruncatedBPTT;
            FwdPassType fwdType = tbptt ? FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE : FwdPassType.STANDARD;
            this.synchronizeIterEpochCounts();
            try (MemoryWorkspace wsAllActivations = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATIONS);){
                Map<String, INDArray> activations = this.ffToLayerActivationsInWS(true, -1, this.getOutputLayerIndices(), fwdType, tbptt, this.inputs, this.inputMaskArrays, this.labelMaskArrays, false);
                if (!this.trainingListeners.isEmpty()) {
                    try (MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                        for (TrainingListener trainingListener : this.trainingListeners) {
                            trainingListener.onForwardPass((Model)this, activations);
                        }
                    }
                }
                this.calcBackpropGradients(false, false, new INDArray[0]);
                workspaceMgr.assertCurrentWorkspace(ArrayType.ACTIVATIONS, null);
                double r = this.calcRegularizationScore(true);
                this.score = 0.0;
                int outNum = 0;
                for (String s : this.configuration.getNetworkOutputs()) {
                    org.deeplearning4j.nn.api.Layer vertexLayer;
                    LayerVertex lv;
                    GraphVertex gv = this.verticesMap.get(s);
                    if (gv instanceof LayerVertex && !(lv = (LayerVertex)gv).isSetLayerInput()) {
                        lv.applyPreprocessorAndSetInput(workspaceMgr);
                    }
                    if ((vertexLayer = gv.getLayer()) instanceof FrozenLayerWithBackprop) {
                        vertexLayer = ((FrozenLayerWithBackprop)vertexLayer).getInsideLayer();
                    }
                    vertexLayer.setMaskArray(this.labelMaskArrays == null ? null : this.labelMaskArrays[outNum]);
                    try (MemoryWorkspace ws = workspaceMgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM);){
                        this.score += ((IOutputLayer)vertexLayer).computeScore(r, true, workspaceMgr);
                    }
                    r = 0.0;
                    ++outNum;
                }
                if (this.trainingListeners.isEmpty()) break block57;
                try (MemoryWorkspace memoryWorkspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                    for (TrainingListener tl : this.trainingListeners) {
                        tl.onBackwardPass(this);
                    }
                }
            }
        }
        for (GraphVertex gv : this.vertices) {
            gv.clear();
        }
    }

    public Map<String, INDArray> feedForward(INDArray input, int layerTillIndex, boolean train) {
        if (this.numInputArrays != 1) {
            throw new UnsupportedOperationException("Cannot feedForward with single input for graph network with " + this.numInputArrays + " expected inputs");
        }
        this.setInput(0, input);
        return this.feedForward(train, layerTillIndex);
    }

    public Map<String, INDArray> feedForward(INDArray[] input, int layerTillIndex, boolean train, boolean clearInputs) {
        this.setInputs(input);
        try {
            return this.ffToLayerActivationsDetached(train, FwdPassType.STANDARD, false, layerTillIndex, null, input, this.inputMaskArrays, this.labelMaskArrays, clearInputs);
        }
        catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    public Map<String, INDArray> feedForward(INDArray[] input, int layerTillIndex, boolean train) {
        this.setInputs(input);
        return this.feedForward(train, layerTillIndex);
    }

    public Map<String, INDArray> feedForward(boolean train, int layerTillIndex) {
        int graphVertexIndexOfLayer = this.layers[layerTillIndex].getIndex();
        try {
            return this.ffToLayerActivationsDetached(train, FwdPassType.STANDARD, false, graphVertexIndexOfLayer, null, this.inputs, this.inputMaskArrays, this.labelMaskArrays, true);
        }
        catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    public Map<String, INDArray> feedForward(INDArray input, boolean train) {
        if (this.numInputArrays != 1) {
            throw new UnsupportedOperationException("Cannot feedForward with single input for graph network with " + this.numInputArrays + " expected inputs");
        }
        this.setInput(0, input);
        return this.feedForward(train);
    }

    public Map<String, INDArray> feedForward(INDArray[] input, boolean train) {
        return this.feedForward(input, train, true);
    }

    public Map<String, INDArray> feedForward(INDArray[] input, boolean train, boolean clearInputs) {
        this.setInputs(input);
        try {
            return this.ffToLayerActivationsDetached(train, FwdPassType.STANDARD, false, this.vertices.length - 1, null, input, this.inputMaskArrays, this.labelMaskArrays, clearInputs);
        }
        catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    public Map<String, INDArray> feedForward() {
        return this.feedForward(false);
    }

    public Map<String, INDArray> feedForward(boolean train) {
        try {
            return this.ffToLayerActivationsDetached(train, FwdPassType.STANDARD, false, this.vertices.length - 1, null, this.inputs, this.inputMaskArrays, this.labelMaskArrays, true);
        }
        catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    public Map<String, INDArray> feedForward(boolean train, boolean excludeOutputLayers, boolean includeNonLayerVertexActivations) {
        int[] exclude = null;
        if (excludeOutputLayers) {
            exclude = this.getOutputLayerIndices();
        }
        Map<String, INDArray> m = this.ffToLayerActivationsDetached(train, FwdPassType.STANDARD, false, this.vertices.length - 1, exclude, this.inputs, this.inputMaskArrays, this.labelMaskArrays, true);
        if (includeNonLayerVertexActivations) {
            return m;
        }
        HashMap<String, INDArray> out = new HashMap<String, INDArray>();
        for (Map.Entry<String, INDArray> e : m.entrySet()) {
            GraphVertex v = this.verticesMap.get(e.getKey());
            if (!(v instanceof LayerVertex) && !(v instanceof InputVertex)) continue;
            out.put(e.getKey(), e.getValue());
        }
        return out;
    }

    public INDArray[] output(INDArray ... input) {
        return this.output(false, input, this.inputMaskArrays, this.labelMaskArrays);
    }

    public INDArray outputSingle(INDArray ... input) {
        return this.outputSingle(false, input);
    }

    public INDArray[] output(boolean train, INDArray ... input) {
        return this.output(train, (MemoryWorkspace)null, input);
    }

    public INDArray[] output(boolean train, MemoryWorkspace outputWorkspace, INDArray ... input) {
        return this.output(train, input, this.inputMaskArrays, this.labelMaskArrays, outputWorkspace);
    }

    public INDArray[] output(boolean train, @NonNull INDArray[] input, INDArray[] inputMasks) {
        if (input == null) {
            throw new NullPointerException("input is marked non-null but is null");
        }
        return this.output(train, input, inputMasks, (INDArray[])null);
    }

    public INDArray[] output(boolean train, @NonNull INDArray[] input, INDArray[] inputMasks, INDArray[] labelMasks) {
        if (input == null) {
            throw new NullPointerException("input is marked non-null but is null");
        }
        return this.output(train, input, inputMasks, labelMasks, (MemoryWorkspace)null);
    }

    public synchronized <T> T output(@NonNull INDArray[] inputs, INDArray[] inputMasks, INDArray[] labelMasks, @NonNull OutputAdapter<T> outputAdapter) {
        if (inputs == null) {
            throw new NullPointerException("inputs is marked non-null but is null");
        }
        if (outputAdapter == null) {
            throw new NullPointerException("outputAdapter is marked non-null but is null");
        }
        try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(WS_ALL_LAYERS_ACT_CONFIG, WS_OUTPUT_MEM);){
            if (outputAdapter instanceof ModelAdapter) {
                Object t = ((ModelAdapter)outputAdapter).apply(this, inputs, inputMasks, labelMasks);
                return t;
            }
            Object object = outputAdapter.apply(this.output(false, inputs, inputMasks, labelMasks, ws));
            return (T)object;
        }
    }

    public synchronized INDArray[] output(boolean train, @NonNull INDArray[] input, INDArray[] inputMasks, INDArray[] labelMasks, MemoryWorkspace outputWorkspace) {
        if (input == null) {
            throw new NullPointerException("input is marked non-null but is null");
        }
        try {
            this.setLayerMaskArrays(inputMasks, labelMasks);
            INDArray[] out = this.outputOfLayersDetached(train, FwdPassType.STANDARD, this.getOutputLayerIndices(), input, inputMasks, labelMasks, true, false, outputWorkspace);
            this.clearLayerMaskArrays();
            this.clearLayersStates();
            return out;
        }
        catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    public INDArray outputSingle(boolean train, INDArray ... input) {
        return this.outputSingle(train, true, input);
    }

    public INDArray outputSingle(boolean train, boolean clearInputs, INDArray ... input) {
        if (this.numOutputArrays != 1) {
            throw new IllegalStateException("Cannot use outputSingle with ComputationGraph that does not have exactly 1 output. nOutputs: " + this.numOutputArrays);
        }
        return this.output(train, clearInputs, input)[0];
    }

    public synchronized INDArray[] output(boolean train, boolean clearInputs, INDArray ... input) {
        boolean detachedInputs = !clearInputs;
        try {
            return this.outputOfLayersDetached(train, FwdPassType.STANDARD, this.getOutputLayerIndices(), input, null, null, clearInputs, detachedInputs, null);
        }
        catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    public INDArray[] output(DataSetIterator iterator) {
        return this.output((MultiDataSetIterator)new MultiDataSetIteratorAdapter(iterator));
    }

    public INDArray[] output(MultiDataSetIterator iterator) {
        ArrayList<INDArray[]> outputs = new ArrayList<INDArray[]>();
        while (iterator.hasNext()) {
            MultiDataSet next = (MultiDataSet)iterator.next();
            INDArray[] out = this.output(false, next.getFeatures(), next.getFeaturesMaskArrays(), next.getLabelsMaskArrays());
            outputs.add(out);
        }
        INDArray[][] arr = (INDArray[][])outputs.toArray((T[])new INDArray[outputs.size()][0]);
        return (INDArray[])DataSetUtil.mergeFeatures((INDArray[][])arr, (INDArray[][])null).getFirst();
    }

    public INDArray outputSingle(DataSetIterator iterator) {
        Preconditions.checkArgument((this.numOutputArrays == 1 ? 1 : 0) != 0, (String)"Cannot use this method with nets that have more than 1 output array. This network has %s outputs", (int)this.numOutputArrays);
        return this.output(iterator)[0];
    }

    public INDArray outputSingle(MultiDataSetIterator iterator) {
        Preconditions.checkArgument((this.numOutputArrays == 1 ? 1 : 0) != 0, (String)"Cannot use this method with nets that have more than 1 output array. This network has %s outputs", (int)this.numOutputArrays);
        return this.output(iterator)[0];
    }

    public INDArray[] output(List<String> layers, boolean train, INDArray[] features, INDArray[] featureMasks) {
        Preconditions.checkState((layers != null && layers.size() > 0 ? 1 : 0) != 0, (String)"Layers must not be null: got later names %s", layers);
        int[] layerNums = new int[layers.size()];
        for (int i = 0; i < layers.size(); ++i) {
            String n = layers.get(i);
            Preconditions.checkState((boolean)this.verticesMap.containsKey(n), (String)"Layer with name %s not found in network", (Object)n);
            layerNums[i] = this.verticesMap.get(n).getVertexIndex();
        }
        INDArray[] out = this.outputOfLayersDetached(train, FwdPassType.STANDARD, layerNums, features, featureMasks, null, true, false, null);
        return out;
    }

    protected void validateArrayWorkspaces(LayerWorkspaceMgr mgr, INDArray array, ArrayType arrayType, String vertexName, boolean isInputVertex, String op) {
        try {
            mgr.validateArrayLocation(arrayType, array, false, isInputVertex);
        }
        catch (ND4JWorkspaceException e) {
            GraphVertex v = this.verticesMap.get(vertexName);
            String clazz = v instanceof LayerVertex ? v.getLayer().getClass().getSimpleName() : v.getClass().getSimpleName();
            throw new IllegalStateException(op + ": array (" + (Object)((Object)arrayType) + ") workspace validation failed (vertex " + vertexName + " - class: " + clazz + ") - array is defined in incorrect workspace", e);
        }
    }

    protected synchronized Map<String, INDArray> ffToLayerActivationsDetached(boolean train, @NonNull FwdPassType fwdPassType, boolean storeLastForTBPTT, int layerIndex, int[] excludeIdxs, @NonNull INDArray[] features, INDArray[] fMask, INDArray[] lMask, boolean clearLayers) {
        LayerWorkspaceMgr workspaceMgr;
        WorkspaceMode wsm;
        if (fwdPassType == null) {
            throw new NullPointerException("fwdPassType is marked non-null but is null");
        }
        if (features == null) {
            throw new NullPointerException("features is marked non-null but is null");
        }
        if (layerIndex < 0 || layerIndex >= this.topologicalOrder.length) {
            throw new IllegalArgumentException("Invalid layer index - index must be >= 0 and < " + this.topologicalOrder.length + ", got index " + layerIndex);
        }
        this.setInputs(features);
        this.setLayerMaskArrays(fMask, lMask);
        WorkspaceUtils.assertNoWorkspacesOpen((String)"Expected no workspace active before call to ffToLayerActivationsDetached", (boolean)true);
        WorkspaceMode workspaceMode = wsm = train ? this.configuration.getTrainingWorkspaceMode() : this.configuration.getInferenceWorkspaceMode();
        if (wsm == WorkspaceMode.NONE) {
            workspaceMgr = LayerWorkspaceMgr.noWorkspaces();
        } else {
            workspaceMgr = LayerWorkspaceMgr.builder().noWorkspaceFor(ArrayType.ACTIVATIONS).with(ArrayType.INPUT, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).build();
            if (features[0].isAttached()) {
                workspaceMgr.setNoLeverageOverride(features[0].data().getParentWorkspace().getId());
            }
        }
        workspaceMgr.setHelperWorkspacePointers(this.helperWorkspaces);
        HashMap<String, INDArray> activations = new HashMap<String, INDArray>();
        for (int i = 0; i < features.length; ++i) {
            activations.put(this.configuration.getNetworkInputs().get(i), features[i]);
        }
        boolean traceLog = log.isTraceEnabled();
        for (int i = 0; i <= layerIndex; ++i) {
            GraphVertex current = this.vertices[this.topologicalOrder[i]];
            String vName = current.getVertexName();
            int vIdx = current.getVertexIndex();
            if (excludeIdxs != null && ArrayUtils.contains((int[])excludeIdxs, (int)vIdx)) continue;
            if (traceLog) {
                log.trace("About forward pass: {} (\"{}\") - {}", new Object[]{i, vName, current.getClass().getSimpleName()});
            }
            try (MemoryWorkspace wsFFWorking = workspaceMgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM);){
                INDArray out;
                VertexIndices[] inputsTo = current.getOutputVertices();
                if (current.isInputVertex()) {
                    out = this.inputs[vIdx];
                } else {
                    if (fwdPassType == FwdPassType.STANDARD) {
                        out = current.doForward(train, workspaceMgr);
                    } else if (fwdPassType == FwdPassType.RNN_TIMESTEP) {
                        if (current.hasLayer()) {
                            INDArray input = current.getInputs()[0];
                            org.deeplearning4j.nn.api.Layer l = current.getLayer();
                            if (l instanceof RecurrentLayer) {
                                out = ((RecurrentLayer)l).rnnTimeStep(this.reshapeTimeStepInput(input), workspaceMgr);
                            } else if (l instanceof BaseWrapperLayer && ((BaseWrapperLayer)l).getUnderlying() instanceof RecurrentLayer) {
                                RecurrentLayer rl = (RecurrentLayer)((BaseWrapperLayer)l).getUnderlying();
                                out = rl.rnnTimeStep(this.reshapeTimeStepInput(input), workspaceMgr);
                            } else {
                                out = l instanceof MultiLayerNetwork ? ((MultiLayerNetwork)l).rnnTimeStep(this.reshapeTimeStepInput(input)) : current.doForward(train, workspaceMgr);
                            }
                        } else {
                            out = current.doForward(train, workspaceMgr);
                        }
                    } else if (fwdPassType == FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE) {
                        if (current.hasLayer()) {
                            org.deeplearning4j.nn.api.Layer l = current.getLayer();
                            if (l instanceof RecurrentLayer) {
                                out = ((RecurrentLayer)l).rnnActivateUsingStoredState(current.getInputs()[0], train, storeLastForTBPTT, workspaceMgr);
                            } else if (l instanceof BaseWrapperLayer && ((BaseWrapperLayer)l).getUnderlying() instanceof RecurrentLayer) {
                                RecurrentLayer rl = (RecurrentLayer)((BaseWrapperLayer)l).getUnderlying();
                                out = rl.rnnActivateUsingStoredState(current.getInputs()[0], train, storeLastForTBPTT, workspaceMgr);
                            } else if (l instanceof MultiLayerNetwork) {
                                List<INDArray> temp = ((MultiLayerNetwork)l).rnnActivateUsingStoredState(current.getInputs()[0], train, storeLastForTBPTT);
                                out = temp.get(temp.size() - 1);
                            } else {
                                out = current.doForward(train, workspaceMgr);
                            }
                        } else {
                            out = current.doForward(train, workspaceMgr);
                        }
                    } else {
                        throw new IllegalArgumentException("Unsupported forward pass type for this method: " + (Object)((Object)fwdPassType));
                    }
                    this.validateArrayWorkspaces(workspaceMgr, out, ArrayType.ACTIVATIONS, vName, false, "Feed forward (inference)");
                }
                activations.put(current.getVertexName(), out);
                if (inputsTo != null) {
                    for (VertexIndices v : inputsTo) {
                        int inputToIndex = v.getVertexIndex();
                        int vIdxEdge = v.getVertexEdgeNumber();
                        this.vertices[inputToIndex].setInput(vIdxEdge, out, workspaceMgr);
                    }
                }
                if (clearLayers) {
                    current.clear();
                }
            }
            if (!traceLog) continue;
            log.trace("Completed forward pass: {} (\"{}\") - {}", new Object[]{i, vName, current.getClass().getSimpleName()});
        }
        return activations;
    }

    protected synchronized Map<String, INDArray> ffToLayerActivationsInWS(boolean train, int layerIndex, int[] excludeIdxs, FwdPassType fwdPassType, boolean storeLastForTBPTT, INDArray[] input, INDArray[] fMask, INDArray[] lMask, boolean clearInputs) {
        LayerWorkspaceMgr workspaceMgr;
        WorkspaceMode wsm;
        if (layerIndex != -1 && (layerIndex < 0 || layerIndex >= this.topologicalOrder.length)) {
            throw new IllegalArgumentException("Invalid input index - index must be >= 0 and < " + this.topologicalOrder.length + ", got index " + layerIndex);
        }
        this.setInputs(input);
        this.setLayerMaskArrays(fMask, lMask);
        WorkspaceMode workspaceMode = wsm = train ? this.configuration.getTrainingWorkspaceMode() : this.configuration.getInferenceWorkspaceMode();
        if (wsm == WorkspaceMode.NONE) {
            WorkspaceUtils.assertNoWorkspacesOpen((String)"Expected no workspace active in ffToLayerActivationsDetached", (boolean)true);
            workspaceMgr = LayerWorkspaceMgr.noWorkspaces();
        } else {
            WorkspaceUtils.assertOpenAndActive((String)WS_ALL_LAYERS_ACT, (String)"ffToLayerActivationsInWs method requires workspace WS_ALL_LAYERS_ACT to be open");
            workspaceMgr = LayerWorkspaceMgr.builder().with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).build();
            if (input[0].isAttached()) {
                workspaceMgr.setNoLeverageOverride(input[0].data().getParentWorkspace().getId());
            }
            if (this.configuration.getCacheMode() != CacheMode.NONE) {
                workspaceMgr.setWorkspace(ArrayType.FF_CACHE, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG);
            }
        }
        workspaceMgr.setHelperWorkspacePointers(this.helperWorkspaces);
        boolean traceLog = log.isTraceEnabled();
        HashMap<String, INDArray> activations = new HashMap<String, INDArray>();
        int stopIndex = layerIndex > 0 ? ArrayUtils.indexOf((int[])this.topologicalOrder, (int)layerIndex) : this.topologicalOrder.length - 1;
        for (int i = 0; i <= stopIndex; ++i) {
            GraphVertex current = this.vertices[this.topologicalOrder[i]];
            String vName = current.getVertexName();
            int vIdx = current.getVertexIndex();
            if (traceLog) {
                log.trace("About forward pass: {} (\"{}\") - {}", new Object[]{i, vName, current.getClass().getSimpleName()});
            }
            if (excludeIdxs != null && ArrayUtils.contains((int[])excludeIdxs, (int)vIdx)) continue;
            try (MemoryWorkspace wsFFWorking = workspaceMgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM);){
                INDArray out;
                VertexIndices[] inputsTo = current.getOutputVertices();
                if (current.isInputVertex()) {
                    out = this.inputs[vIdx];
                } else {
                    if (fwdPassType == FwdPassType.STANDARD) {
                        out = current.doForward(train, workspaceMgr);
                    } else if (fwdPassType == FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE) {
                        if (current.hasLayer()) {
                            org.deeplearning4j.nn.api.Layer l = current.getLayer();
                            if (l instanceof RecurrentLayer) {
                                out = ((RecurrentLayer)l).rnnActivateUsingStoredState(current.getInputs()[0], train, storeLastForTBPTT, workspaceMgr);
                            } else if (l instanceof BaseWrapperLayer && ((BaseWrapperLayer)l).getUnderlying() instanceof RecurrentLayer) {
                                RecurrentLayer rl = (RecurrentLayer)((BaseWrapperLayer)l).getUnderlying();
                                out = rl.rnnActivateUsingStoredState(current.getInputs()[0], train, storeLastForTBPTT, workspaceMgr);
                            } else if (l instanceof MultiLayerNetwork) {
                                List<INDArray> temp = ((MultiLayerNetwork)l).rnnActivateUsingStoredState(current.getInputs()[0], train, storeLastForTBPTT);
                                out = temp.get(temp.size() - 1);
                            } else {
                                out = current.doForward(train, workspaceMgr);
                            }
                        } else {
                            out = current.doForward(train, workspaceMgr);
                        }
                    } else {
                        throw new IllegalStateException("FwdPassType not supported for this method: " + (Object)((Object)fwdPassType));
                    }
                    this.validateArrayWorkspaces(workspaceMgr, out, ArrayType.ACTIVATIONS, vName, false, "Feed forward (inference)");
                }
                activations.put(current.getVertexName(), out);
                if (inputsTo != null) {
                    for (VertexIndices v : inputsTo) {
                        int inputToIndex = v.getVertexIndex();
                        int vIdxEdge = v.getVertexEdgeNumber();
                        this.vertices[inputToIndex].setInput(vIdxEdge, out, workspaceMgr);
                    }
                }
                if (clearInputs) {
                    current.clear();
                }
            }
            if (!traceLog) continue;
            log.trace("Completed forward pass: {} (\"{}\") - {}", new Object[]{i, vName, current.getClass().getSimpleName()});
        }
        return activations;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected INDArray[] outputOfLayersDetached(boolean train, @NonNull FwdPassType fwdPassType, @NonNull int[] layerIndexes, @NonNull INDArray[] features, INDArray[] fMask, INDArray[] lMasks, boolean clearLayerInputs, boolean detachedInputs, MemoryWorkspace outputWorkspace) {
        if (fwdPassType == null) {
            throw new NullPointerException("fwdPassType is marked non-null but is null");
        }
        if (layerIndexes == null) {
            throw new NullPointerException("layerIndexes is marked non-null but is null");
        }
        if (features == null) {
            throw new NullPointerException("features is marked non-null but is null");
        }
        if (features.length != this.numInputArrays) {
            throw new IllegalArgumentException("Invalid number of input arrays: network has " + this.numInputArrays + " inputs, got " + features.length + " input arrays");
        }
        for (int i = 0; i < layerIndexes.length; ++i) {
            if (layerIndexes[i] >= 0 && layerIndexes[i] < this.topologicalOrder.length) continue;
            throw new IllegalArgumentException("Invalid input index - index must be >= 0 and < " + this.topologicalOrder.length + ", got index " + layerIndexes[i]);
        }
        this.setInputs(features);
        this.setLayerMaskArrays(fMask, lMasks);
        MemoryWorkspace outputPrevious = null;
        if (outputWorkspace == null || outputWorkspace instanceof DummyWorkspace) {
            WorkspaceUtils.assertNoWorkspacesOpen((String)"Expected no workspace active before call to outputOfLayersDetached");
        } else {
            Preconditions.checkState((boolean)outputWorkspace.isScopeActive(), (String)("Workspace \"" + outputWorkspace.getId() + "\" was provided for the network/layer outputs. When provided, this workspace must be opened before calling the output method; furthermore, closing the workspace is the responsibility of the user"));
            outputPrevious = outputWorkspace.getParentWorkspace();
        }
        int[] vertexOutputsFullyConsumedByStep = new int[this.topologicalOrder.length];
        for (GraphVertex gv : this.vertices) {
            int idx = gv.getVertexIndex();
            int maxStepOfOutputTo = -1;
            VertexIndices[] outputsTo = gv.getOutputVertices();
            if (outputsTo != null) {
                for (VertexIndices vi : outputsTo) {
                    int posInTopoSort = ArrayUtils.indexOf((int[])this.topologicalOrder, (int)vi.getVertexIndex());
                    if (posInTopoSort == -1) {
                        throw new IllegalStateException("Did not find vertex " + vi.getVertexIndex() + " in topological sort array");
                    }
                    maxStepOfOutputTo = Math.max(maxStepOfOutputTo, posInTopoSort);
                }
            } else {
                maxStepOfOutputTo = this.topologicalOrder.length - 1;
            }
            vertexOutputsFullyConsumedByStep[idx] = maxStepOfOutputTo;
        }
        INDArray[] outputs = new INDArray[layerIndexes.length];
        int stopIndex = -1;
        for (int i = 0; i < layerIndexes.length; ++i) {
            stopIndex = Math.max(stopIndex, ArrayUtils.indexOf((int[])this.topologicalOrder, (int)layerIndexes[i]));
        }
        ArrayList<LayerWorkspaceMgr> allWorkspaceManagers = new ArrayList<LayerWorkspaceMgr>();
        ArrayList<LayerWorkspaceMgr> freeWorkspaceManagers = new ArrayList<LayerWorkspaceMgr>();
        IdentityHashMap<MemoryWorkspace, LayerWorkspaceMgr> openActivationsWorkspaces = new IdentityHashMap<MemoryWorkspace, LayerWorkspaceMgr>();
        WorkspaceMode wsm = train ? this.configuration.getTrainingWorkspaceMode() : this.configuration.getInferenceWorkspaceMode();
        boolean noWS = wsm == WorkspaceMode.NONE;
        LayerWorkspaceMgr allNone = noWS ? LayerWorkspaceMgr.noWorkspaces(this.helperWorkspaces) : null;
        List[] closeAtEndIteraton = new List[this.topologicalOrder.length];
        MemoryWorkspace initialWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
        Throwable t = null;
        try {
            for (int i = 0; i <= stopIndex; ++i) {
                LayerWorkspaceMgr workspaceMgr;
                GraphVertex current = this.vertices[this.topologicalOrder[i]];
                GraphVertex prev = i > 0 ? this.vertices[this.topologicalOrder[i - 1]] : null;
                String vName = current.getVertexName();
                int vIdx = current.getVertexIndex();
                if (noWS) {
                    workspaceMgr = allNone;
                } else if (freeWorkspaceManagers.size() > 0) {
                    workspaceMgr = (LayerWorkspaceMgr)((Object)freeWorkspaceManagers.remove(freeWorkspaceManagers.size() - 1));
                } else {
                    String wsName = "WS_LAYER_ACT_" + allWorkspaceManagers.size();
                    workspaceMgr = LayerWorkspaceMgr.builder().with(ArrayType.INPUT, wsName, this.WS_LAYER_ACT_X_CONFIG).with(ArrayType.ACTIVATIONS, wsName, this.WS_LAYER_ACT_X_CONFIG).with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).build();
                    if (detachedInputs) {
                        workspaceMgr.setScopedOutFor(ArrayType.INPUT);
                        workspaceMgr.setScopedOutFor(ArrayType.ACTIVATIONS);
                    } else if (features[0].isAttached()) {
                        workspaceMgr.setNoLeverageOverride(features[0].data().getParentWorkspace().getId());
                    }
                    allWorkspaceManagers.add(workspaceMgr);
                }
                workspaceMgr.setHelperWorkspacePointers(this.helperWorkspaces);
                boolean isRequiredOutput = false;
                String origWSAct = null;
                WorkspaceConfiguration origWSActConf = null;
                if (ArrayUtils.contains((int[])layerIndexes, (int)vIdx)) {
                    isRequiredOutput = true;
                    if (outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)) {
                        origWSAct = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS);
                        origWSActConf = workspaceMgr.getConfiguration(ArrayType.ACTIVATIONS);
                        workspaceMgr.setWorkspace(ArrayType.ACTIVATIONS, outputWorkspace.getId(), outputWorkspace.getWorkspaceConfiguration());
                    } else if (!workspaceMgr.isScopedOut(ArrayType.ACTIVATIONS)) {
                        origWSAct = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS);
                        origWSActConf = workspaceMgr.getConfiguration(ArrayType.ACTIVATIONS);
                        workspaceMgr.setScopedOutFor(ArrayType.ACTIVATIONS);
                    }
                }
                MemoryWorkspace wsActivations = null;
                if (outputWorkspace == null || outputWorkspace instanceof DummyWorkspace || !isRequiredOutput) {
                    wsActivations = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATIONS);
                    openActivationsWorkspaces.put(wsActivations, workspaceMgr);
                }
                if (wsActivations != null) {
                    wsActivations.setPreviousWorkspace(initialWorkspace);
                }
                int closeableAt = vertexOutputsFullyConsumedByStep[vIdx];
                if (outputWorkspace == null || outputWorkspace instanceof DummyWorkspace || wsActivations != null && !outputWorkspace.getId().equals(wsActivations.getId())) {
                    if (closeAtEndIteraton[closeableAt] == null) {
                        closeAtEndIteraton[closeableAt] = new ArrayList();
                    }
                    closeAtEndIteraton[closeableAt].add(wsActivations);
                }
                try (MemoryWorkspace wsFFWorking = workspaceMgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM);){
                    VertexIndices[] inputsTo = current.getOutputVertices();
                    INDArray out = null;
                    if (current.isInputVertex()) {
                        out = features[vIdx];
                    } else {
                        if (fwdPassType == FwdPassType.STANDARD) {
                            int inputVertex;
                            Enum preLayerFormat;
                            if (i > 0 && current.hasLayer() && prev.hasLayer() && ConvolutionUtils.layerHasConvolutionLayout(prev.getLayer().conf().getLayer()) && ConvolutionUtils.layerHasConvolutionLayout(current.getLayer().conf().getLayer())) {
                                CNN2DFormat currLayerFormat;
                                preLayerFormat = ConvolutionUtils.getFormatForLayer(prev.getLayer().conf().getLayer());
                                if (preLayerFormat != (currLayerFormat = ConvolutionUtils.getFormatForLayer(current.getLayer().conf().getLayer()))) {
                                    int inputIdx = -1;
                                    for (inputVertex = 0; inputVertex < current.getInputVertices().length; ++inputVertex) {
                                        if (current.getInputVertices()[inputVertex].getVertexIndex() != prev.getVertexIndex()) continue;
                                        inputIdx = inputVertex;
                                    }
                                    if (preLayerFormat == CNN2DFormat.NCHW) {
                                        current.setInput(inputIdx, current.getInputs()[inputIdx].permute(new int[]{0, 3, 1, 2}), workspaceMgr);
                                    } else if (preLayerFormat == CNN2DFormat.NHWC) {
                                        current.setInput(inputIdx, current.getInputs()[inputIdx].permute(new int[]{0, 2, 3, 1}), workspaceMgr);
                                    } else {
                                        throw new IllegalStateException("No CNN2DDataFormat type found for previous layer!");
                                    }
                                    out = current.doForward(train, workspaceMgr);
                                } else {
                                    out = current.doForward(train, workspaceMgr);
                                }
                            } else if (i > 0 && current.hasLayer() && prev.hasLayer() && Convolution1DUtils.hasRnnDataFormat(prev.getLayer().conf().getLayer()) && Convolution1DUtils.hasRnnDataFormat(current.getLayer().conf().getLayer())) {
                                preLayerFormat = Convolution1DUtils.getRnnFormatFromLayer(prev.getLayer().conf().getLayer());
                                RNNFormat currLayerFormat = Convolution1DUtils.getRnnFormatFromLayer(current.getLayer().conf().getLayer());
                                int inputIdx = -1;
                                for (inputVertex = 0; inputVertex < current.getInputVertices().length; ++inputVertex) {
                                    if (current.getInputVertices()[inputVertex].getVertexIndex() != prev.getVertexIndex()) continue;
                                    inputIdx = inputVertex;
                                }
                                if (preLayerFormat != currLayerFormat) {
                                    current.setInput(inputIdx, current.getInputs()[inputIdx].permute(new int[]{0, 2, 1}), workspaceMgr);
                                }
                                out = current.doForward(train, workspaceMgr);
                            } else {
                                out = current.doForward(train, workspaceMgr);
                            }
                        } else if (fwdPassType == FwdPassType.RNN_TIMESTEP) {
                            if (current.hasLayer()) {
                                INDArray input = current.getInputs()[0];
                                org.deeplearning4j.nn.api.Layer l = current.getLayer();
                                if (l instanceof RecurrentLayer) {
                                    out = ((RecurrentLayer)l).rnnTimeStep(this.reshapeTimeStepInput(input), workspaceMgr);
                                } else if (l instanceof BaseWrapperLayer && ((BaseWrapperLayer)l).getUnderlying() instanceof RecurrentLayer) {
                                    RecurrentLayer rl = (RecurrentLayer)((BaseWrapperLayer)l).getUnderlying();
                                    out = rl.rnnTimeStep(this.reshapeTimeStepInput(input), workspaceMgr);
                                } else {
                                    out = l instanceof MultiLayerNetwork ? ((MultiLayerNetwork)l).rnnTimeStep(this.reshapeTimeStepInput(input)) : current.doForward(train, workspaceMgr);
                                }
                            } else {
                                out = current.doForward(train, workspaceMgr);
                            }
                        } else {
                            throw new IllegalArgumentException("Unsupported forward pass type for this method: " + (Object)((Object)fwdPassType));
                        }
                        this.validateArrayWorkspaces(workspaceMgr, out, ArrayType.ACTIVATIONS, vName, false, "Feed forward (inference)");
                    }
                    if (inputsTo != null) {
                        for (VertexIndices v : inputsTo) {
                            int inputToIndex = v.getVertexIndex();
                            int vIdxEdge = v.getVertexEdgeNumber();
                            this.vertices[inputToIndex].setInput(vIdxEdge, out, workspaceMgr);
                        }
                    }
                    if (clearLayerInputs) {
                        current.clear();
                    }
                    if (isRequiredOutput) {
                        outputs[ArrayUtils.indexOf((int[])layerIndexes, (int)vIdx)] = out;
                        if (origWSAct != null) {
                            workspaceMgr.setWorkspace(ArrayType.ACTIVATIONS, origWSAct, origWSActConf);
                        }
                    }
                }
                if (closeAtEndIteraton[i] == null) continue;
                for (MemoryWorkspace wsAct : closeAtEndIteraton[i]) {
                    wsAct.close();
                    LayerWorkspaceMgr canNowReuse = (LayerWorkspaceMgr)((Object)openActivationsWorkspaces.remove(wsAct));
                    freeWorkspaceManagers.add(canNowReuse);
                }
            }
        }
        catch (Throwable t2) {
            t = t2;
        }
        finally {
            for (MemoryWorkspace ws : openActivationsWorkspaces.keySet()) {
                while (ws.isScopeActive()) {
                    try {
                        ws.close();
                    }
                    catch (Throwable t2) {
                        if (t == null) continue;
                        log.error("Encountered second exception while trying to close workspace after initial exception");
                        log.error("Original exception:", t);
                        throw t2;
                    }
                }
            }
            Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
            if (t != null) {
                if (t instanceof RuntimeException) {
                    throw (RuntimeException)t;
                }
                throw new RuntimeException("Error during neural network forward pass", t);
            }
            if (outputWorkspace == null || outputWorkspace instanceof DummyWorkspace) {
                WorkspaceUtils.assertNoWorkspacesOpen((String)"Expected no workspace active at the end of outputOfLayerDetached");
            } else {
                Preconditions.checkState((boolean)outputWorkspace.isScopeActive(), (String)"Expected output workspace to still be openat end of outputOfLayerDetached, but ");
                outputWorkspace.setPreviousWorkspace(outputPrevious);
            }
        }
        return outputs;
    }

    private INDArray reshapeTimeStepInput(INDArray input) {
        if (input.rank() == 2) {
            long[] inShape = input.shape();
            input = input.reshape(new long[]{inShape[0], inShape[1], 1L});
        }
        return input;
    }

    public Gradient backpropGradient(INDArray ... epsilons) {
        if (epsilons == null || epsilons.length != this.numOutputArrays) {
            throw new IllegalArgumentException("Invalid input: must have epsilons length equal to number of output arrays");
        }
        try {
            this.calcBackpropGradients(true, this.configuration.getBackpropType() == BackpropType.TruncatedBPTT, epsilons);
            return this.gradient;
        }
        catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected void calcBackpropGradients(boolean clearLayers, boolean truncatedBPTT, INDArray ... externalEpsilons) {
        if (this.flattenedGradients == null) {
            this.initGradientsView();
        }
        if (externalEpsilons == null || externalEpsilons.length == 0 && this.configuration.getTrainingWorkspaceMode() != WorkspaceMode.NONE) {
            WorkspaceUtils.assertOpenAndActive((String)WS_ALL_LAYERS_ACT, (String)"Expected workspace WS_ALL_LAYERS_ACT to be active and open in calcBackpropGradients when workspace mode is not set to NONE");
        }
        if (externalEpsilons != null && externalEpsilons.length > 0) {
            List<String> outputLayers = this.configuration.getNetworkOutputs();
            for (String s : outputLayers) {
                GraphVertex gv = this.getVertex(s);
                if (!(gv instanceof LayerVertex) || !(((LayerVertex)gv).getLayer() instanceof IOutputLayer)) continue;
                throw new IllegalStateException("Cannot perform backprop with external errors in conjunction with an output layer: output layers cannot use external errors for backprop. Layer name: " + s);
            }
        }
        int[] vertexActGradsFullyConsumedByStep = new int[this.topologicalOrder.length];
        for (GraphVertex gv : this.vertices) {
            int idx = gv.getVertexIndex();
            int minStepOfInputFrom = Integer.MAX_VALUE;
            VertexIndices[] inputsFrom = gv.getInputVertices();
            if (inputsFrom != null) {
                for (VertexIndices vi : inputsFrom) {
                    int posInTopoSort = ArrayUtils.indexOf((int[])this.topologicalOrder, (int)vi.getVertexIndex());
                    if (posInTopoSort == -1) {
                        throw new IllegalStateException("Did not find vertex " + vi.getVertexIndex() + " in topological sort array");
                    }
                    minStepOfInputFrom = Math.min(minStepOfInputFrom, posInTopoSort);
                }
            }
            vertexActGradsFullyConsumedByStep[idx] = minStepOfInputFrom == Integer.MAX_VALUE ? 0 : minStepOfInputFrom;
        }
        boolean noWS = this.configuration.getInferenceWorkspaceMode() == WorkspaceMode.NONE;
        LayerWorkspaceMgr allNone = noWS ? LayerWorkspaceMgr.noWorkspaces(this.helperWorkspaces) : null;
        ArrayList<LayerWorkspaceMgr> allWorkspaceManagers = new ArrayList<LayerWorkspaceMgr>();
        ArrayList<LayerWorkspaceMgr> freeWorkspaceManagers = new ArrayList<LayerWorkspaceMgr>();
        IdentityHashMap<MemoryWorkspace, LayerWorkspaceMgr> openActivationsWorkspaces = new IdentityHashMap<MemoryWorkspace, LayerWorkspaceMgr>();
        List[] closeAtEndIteraton = new List[this.topologicalOrder.length];
        LinkedList<Triple> gradients = new LinkedList<Triple>();
        boolean[] setVertexEpsilon = new boolean[this.topologicalOrder.length];
        MemoryWorkspace initialWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
        boolean traceLog = log.isTraceEnabled();
        Throwable t = null;
        try {
            for (int i = this.topologicalOrder.length - 1; i >= 0; --i) {
                INDArray[] epsilons;
                Pair<Gradient, INDArray[]> pair;
                LayerWorkspaceMgr workspaceMgr;
                boolean hitFrozen = false;
                GraphVertex current = this.vertices[this.topologicalOrder[i]];
                int vIdx = current.getVertexIndex();
                String vertexName = current.getVertexName();
                if (traceLog) {
                    log.trace("About backprop: {} (\"{}\") - {}", new Object[]{i, vertexName, current.getClass().getSimpleName()});
                }
                if (current.hasLayer() && current.getLayer() instanceof FrozenLayer || current instanceof FrozenVertex) {
                    hitFrozen = true;
                }
                if (current.isInputVertex() || hitFrozen) {
                    if (closeAtEndIteraton[i] != null) {
                        for (MemoryWorkspace wsAct : closeAtEndIteraton[i]) {
                            wsAct.close();
                            LayerWorkspaceMgr canNowReuse = (LayerWorkspaceMgr)((Object)openActivationsWorkspaces.remove(wsAct));
                            freeWorkspaceManagers.add(canNowReuse);
                        }
                    }
                    closeAtEndIteraton[i] = null;
                    continue;
                }
                if (noWS) {
                    workspaceMgr = allNone;
                } else if (freeWorkspaceManagers.size() > 0) {
                    workspaceMgr = (LayerWorkspaceMgr)((Object)freeWorkspaceManagers.remove(freeWorkspaceManagers.size() - 1));
                } else {
                    String wsName = "WS_LAYER_ACT_" + allWorkspaceManagers.size();
                    workspaceMgr = LayerWorkspaceMgr.builder().with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.ACTIVATION_GRAD, wsName, this.WS_LAYER_ACT_X_CONFIG).with(ArrayType.ACTIVATIONS, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).build();
                    allWorkspaceManagers.add(workspaceMgr);
                }
                workspaceMgr.setHelperWorkspacePointers(this.helperWorkspaces);
                if (current.isOutputVertex()) {
                    int thisOutputNumber = this.configuration.getNetworkOutputs().indexOf(current.getVertexName());
                    org.deeplearning4j.nn.api.Layer currentLayer = current.getLayer();
                    if (currentLayer instanceof FrozenLayerWithBackprop) {
                        currentLayer = ((FrozenLayerWithBackprop)currentLayer).getInsideLayer();
                    }
                    if (currentLayer instanceof IOutputLayer) {
                        IOutputLayer outputLayer = (IOutputLayer)currentLayer;
                        INDArray currLabels = this.labels[thisOutputNumber];
                        outputLayer.setLabels(currLabels);
                    } else {
                        if ((externalEpsilons == null || externalEpsilons.length == 0) && this.labels[thisOutputNumber] != null) {
                            throw new DL4JException("Layer \"" + current.getVertexName() + "\" of type " + current.getLayer().getClass().getSimpleName() + " is set as network output (but isn't an IOutputLayer). Only IOutputLayer layers can be fit via backprop with a labels array. ");
                        }
                        current.setEpsilon(externalEpsilons[thisOutputNumber]);
                        setVertexEpsilon[this.topologicalOrder[i]] = true;
                    }
                }
                MemoryWorkspace wsActivationGrads = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATION_GRAD);
                openActivationsWorkspaces.put(wsActivationGrads, workspaceMgr);
                wsActivationGrads.setPreviousWorkspace(initialWorkspace);
                int closeableAt = vertexActGradsFullyConsumedByStep[vIdx];
                if (closeableAt >= 0) {
                    if (closeAtEndIteraton[closeableAt] == null) {
                        closeAtEndIteraton[closeableAt] = new ArrayList();
                    }
                    closeAtEndIteraton[closeableAt].add(wsActivationGrads);
                }
                try (MemoryWorkspace wsWorkingMem = workspaceMgr.notifyScopeEntered(ArrayType.BP_WORKING_MEM);){
                    pair = current.doBackward(truncatedBPTT, workspaceMgr);
                    for (INDArray iNDArray : epsilons = (INDArray[])pair.getSecond()) {
                        if (iNDArray == null) continue;
                        this.validateArrayWorkspaces(workspaceMgr, iNDArray, ArrayType.ACTIVATION_GRAD, vertexName, false, "Backprop");
                    }
                }
                VertexIndices[] inputVertices = current.getInputVertices();
                if (inputVertices != null) {
                    int j = 0;
                    for (VertexIndices vertexIndices : inputVertices) {
                        GraphVertex gv = this.vertices[vertexIndices.getVertexIndex()];
                        if (setVertexEpsilon[gv.getVertexIndex()]) {
                            INDArray currentEps = gv.getEpsilon();
                            if (currentEps == null) {
                                gv.setEpsilon(currentEps);
                            } else {
                                gv.setEpsilon(currentEps.addi(epsilons[j++]));
                            }
                        } else {
                            gv.setEpsilon(epsilons[j++]);
                        }
                        setVertexEpsilon[gv.getVertexIndex()] = true;
                    }
                }
                if (pair.getFirst() != null) {
                    Gradient g = (Gradient)pair.getFirst();
                    Map<String, INDArray> map = g.gradientForVariable();
                    LinkedList<Triple> tempList = new LinkedList<Triple>();
                    for (Map.Entry<String, INDArray> entry : map.entrySet()) {
                        String origName = entry.getKey();
                        String newName = current.getVertexName() + "_" + origName;
                        tempList.addFirst(new Triple((Object)newName, (Object)entry.getValue(), (Object)g.flatteningOrderForVariable(origName)));
                    }
                    for (Triple triple : tempList) {
                        gradients.addFirst(triple);
                    }
                }
                if (closeAtEndIteraton[i] != null) {
                    for (MemoryWorkspace memoryWorkspace : closeAtEndIteraton[i]) {
                        memoryWorkspace.close();
                        LayerWorkspaceMgr canNowReuse = (LayerWorkspaceMgr)((Object)openActivationsWorkspaces.remove(memoryWorkspace));
                        freeWorkspaceManagers.add(canNowReuse);
                    }
                    closeAtEndIteraton[i] = null;
                }
                if (!traceLog) continue;
                log.trace("Completed backprop: {} (\"{}\") - {}", new Object[]{i, vertexName, current.getClass().getSimpleName()});
            }
        }
        catch (Throwable t2) {
            t = t2;
        }
        finally {
            for (MemoryWorkspace ws : openActivationsWorkspaces.keySet()) {
                try {
                    ws.close();
                }
                catch (Throwable t2) {
                    if (t == null) continue;
                    log.error("Encountered second exception while trying to close workspace after initial exception");
                    log.error("Original exception:", t);
                    throw t2;
                }
            }
            Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
            if (t != null) {
                if (t instanceof RuntimeException) {
                    throw (RuntimeException)t;
                }
                throw new RuntimeException("Error during neural network backpropagation calculation", t);
            }
        }
        DefaultGradient gradient = new DefaultGradient(this.flattenedGradients);
        for (Triple tr : gradients) {
            gradient.setGradientFor((String)tr.getFirst(), (INDArray)tr.getSecond(), (Character)tr.getThird());
        }
        this.gradient = gradient;
        if (truncatedBPTT && this.clearTbpttState) {
            this.rnnClearPreviousState();
        }
        if (clearLayers) {
            for (GraphVertex gv : this.vertices) {
                gv.clear();
            }
        }
    }

    public ComputationGraph clone() {
        ComputationGraphUpdater u;
        INDArray updaterState;
        ComputationGraph cg = new ComputationGraph(this.configuration.clone());
        cg.init(this.params().dup(), false);
        if (this.solver != null && (updaterState = (u = this.getUpdater()).getStateViewArray()) != null) {
            cg.getUpdater().setStateViewArray(updaterState.dup());
        }
        cg.trainingListeners = this.trainingListeners;
        for (int i = 0; i < this.topologicalOrder.length; ++i) {
            String layerName;
            if (!this.vertices[this.topologicalOrder[i]].hasLayer() || !(this.getLayer(layerName = this.vertices[this.topologicalOrder[i]].getVertexName()) instanceof FrozenLayer)) continue;
            cg.getVertex(layerName).setLayerAsFrozen();
        }
        return cg;
    }

    public double calcRegularizationScore(boolean backpropParamsOnly) {
        double scoreSum = 0.0;
        for (int i = 0; i < this.layers.length; ++i) {
            scoreSum += this.layers[i].calcRegularizationScore(backpropParamsOnly);
        }
        return scoreSum;
    }

    @Override
    public void setListeners(Collection<TrainingListener> listeners) {
        if (this.layers == null) {
            this.init();
        }
        for (org.deeplearning4j.nn.api.Layer l : this.layers) {
            l.setListeners(listeners);
        }
        if (this.solver != null) {
            this.solver.setListeners(listeners);
        }
        this.trainingListeners.clear();
        if (listeners != null) {
            this.trainingListeners.addAll(listeners);
        }
    }

    @Override
    public void setListeners(TrainingListener ... listeners) {
        ArrayList<TrainingListener> list = new ArrayList<TrainingListener>();
        if (listeners != null && listeners.length > 0) {
            for (TrainingListener i : listeners) {
                if (i == null) continue;
                list.add(i);
            }
        }
        this.setListeners(list);
    }

    @Override
    public void addListeners(TrainingListener ... listeners) {
        if (this.trainingListeners == null) {
            this.setListeners(listeners);
            return;
        }
        ArrayList<TrainingListener> newListeners = new ArrayList<TrainingListener>(this.trainingListeners);
        Collections.addAll(newListeners, listeners);
        this.setListeners(newListeners);
        if (this.solver != null) {
            this.solver.setListeners(this.trainingListeners);
        }
    }

    public Collection<TrainingListener> getListeners() {
        return this.trainingListeners;
    }

    public ComputationGraphUpdater getUpdater() {
        return this.getUpdater(true);
    }

    public ComputationGraphUpdater getUpdater(boolean initializeIfAbsent) {
        if (this.solver == null && initializeIfAbsent) {
            this.solver = new Solver.Builder().configure(this.conf()).listeners(this.getListeners()).model(this).build();
            this.solver.getOptimizer().setUpdaterComputationGraph(new ComputationGraphUpdater(this));
        }
        if (this.solver != null) {
            return this.solver.getOptimizer().getComputationGraphUpdater(initializeIfAbsent);
        }
        return null;
    }

    public void setUpdater(ComputationGraphUpdater updater) {
        if (this.solver == null) {
            this.solver = new Solver.Builder().configure(this.conf()).listeners(this.getListeners()).model(this).build();
        }
        this.solver.getOptimizer().setUpdaterComputationGraph(updater);
    }

    public org.deeplearning4j.nn.api.Layer getOutputLayer(int outputLayerIdx) {
        if (outputLayerIdx >= this.numOutputArrays) {
            throw new IllegalArgumentException("Invalid index: cannot get output layer " + outputLayerIdx + ", total number of network outputs = " + this.numOutputArrays);
        }
        return this.getLayer(this.configuration.getNetworkOutputs().get(outputLayerIdx));
    }

    @Deprecated
    public INDArray params(boolean backwardOnly) {
        return this.params();
    }

    public double score(DataSet dataSet) {
        return this.score(dataSet, false);
    }

    public double score(DataSet dataSet, boolean training) {
        if (this.numInputArrays != 1 || this.numOutputArrays != 1) {
            throw new UnsupportedOperationException("Cannot score ComputationGraph network with  DataSet: network does not have 1 input and 1 output arrays");
        }
        return this.score(ComputationGraphUtil.toMultiDataSet(dataSet), training);
    }

    public double score(MultiDataSet dataSet) {
        return this.score(dataSet, false);
    }

    public double score(MultiDataSet dataSet, boolean training) {
        try {
            return this.scoreHelper(dataSet, training);
        }
        catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    private double scoreHelper(MultiDataSet dataSet, boolean training) {
        WorkspaceMode wsm = training ? this.configuration.getTrainingWorkspaceMode() : this.configuration.getInferenceWorkspaceMode();
        LayerWorkspaceMgr mgr = wsm == WorkspaceMode.NONE ? LayerWorkspaceMgr.noWorkspaces() : LayerWorkspaceMgr.builder().noWorkspaceFor(ArrayType.ACTIVATIONS).noWorkspaceFor(ArrayType.INPUT).with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).build();
        mgr.setHelperWorkspacePointers(this.helperWorkspaces);
        boolean hasMaskArrays = dataSet.hasMaskArrays();
        if (hasMaskArrays) {
            this.setLayerMaskArrays(dataSet.getFeaturesMaskArrays(), dataSet.getLabelsMaskArrays());
        }
        double score = 0.0;
        this.setInputs(dataSet.getFeatures());
        this.ffToLayerActivationsDetached(training, FwdPassType.STANDARD, false, this.vertices.length - 1, this.getOutputLayerIndices(), dataSet.getFeatures(), dataSet.getFeaturesMaskArrays(), dataSet.getLabelsMaskArrays(), false);
        try (WorkspacesCloseable ws = mgr.notifyScopeEntered(new ArrayType[]{ArrayType.ACTIVATIONS, ArrayType.FF_WORKING_MEM, ArrayType.RNN_FF_LOOP_WORKING_MEM});){
            INDArray[] labels = dataSet.getLabels();
            this.setLabels(labels);
            double r = this.calcRegularizationScore(true);
            int i = 0;
            for (String s : this.configuration.getNetworkOutputs()) {
                GraphVertex gv = this.verticesMap.get(s);
                org.deeplearning4j.nn.api.Layer outLayer = gv.getLayer();
                if (outLayer == null || !(outLayer instanceof IOutputLayer)) {
                    log.warn("Cannot calculate score: vertex \"" + s + "\" is not an output layer");
                    double d = 0.0;
                    return d;
                }
                IOutputLayer ol = (IOutputLayer)outLayer;
                ol.setLabels(labels[i++]);
                score += ((LayerVertex)gv).computeScore(r, training, mgr);
                r = 0.0;
            }
        }
        this.clearLayersStates();
        return score;
    }

    public INDArray scoreExamples(DataSet data, boolean addRegularizationTerms) {
        if (this.numInputArrays != 1 || this.numOutputArrays != 1) {
            throw new UnsupportedOperationException("Cannot score ComputationGraph network with  DataSet: network does not have 1 input and 1 output arrays");
        }
        return this.scoreExamples(ComputationGraphUtil.toMultiDataSet(data), addRegularizationTerms);
    }

    public INDArray scoreExamples(MultiDataSet dataSet, boolean addRegularizationTerms) {
        try {
            return this.scoreExamplesHelper(dataSet, addRegularizationTerms);
        }
        catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    private INDArray scoreExamplesHelper(MultiDataSet dataSet, boolean addRegularizationTerms) {
        LayerWorkspaceMgr mgr = this.configuration.getInferenceWorkspaceMode() == WorkspaceMode.NONE ? LayerWorkspaceMgr.noWorkspaces() : LayerWorkspaceMgr.builder().with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).build();
        mgr.setHelperWorkspacePointers(this.helperWorkspaces);
        boolean hasMaskArrays = dataSet.hasMaskArrays();
        if (hasMaskArrays) {
            this.setLayerMaskArrays(dataSet.getFeaturesMaskArrays(), dataSet.getLabelsMaskArrays());
        }
        INDArray out = null;
        this.setInputs(dataSet.getFeatures());
        try (MemoryWorkspace ws = mgr.notifyScopeEntered(ArrayType.ACTIVATIONS);){
            this.ffToLayerActivationsInWS(false, this.vertices.length - 1, this.getOutputLayerIndices(), FwdPassType.STANDARD, false, dataSet.getFeatures(), dataSet.getFeaturesMaskArrays(), dataSet.getLabelsMaskArrays(), false);
            INDArray[] labels = dataSet.getLabels();
            this.setLabels(labels);
            double r = addRegularizationTerms ? this.calcRegularizationScore(true) : 0.0;
            int i = 0;
            for (String s : this.configuration.getNetworkOutputs()) {
                INDArray scoreCurrLayer;
                GraphVertex gv = this.verticesMap.get(s);
                org.deeplearning4j.nn.api.Layer outLayer = gv.getLayer();
                if (outLayer == null || !(outLayer instanceof IOutputLayer)) {
                    throw new UnsupportedOperationException("Cannot calculate score: vertex \"" + s + "\" is not an output layer");
                }
                IOutputLayer ol = (IOutputLayer)outLayer;
                ol.setLabels(labels[i++]);
                try (MemoryWorkspace wsFF = mgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM);){
                    scoreCurrLayer = ((LayerVertex)gv).computeScoreForExamples(r, mgr);
                }
                if (out == null) {
                    out = scoreCurrLayer.detach();
                } else {
                    out.addi(scoreCurrLayer);
                }
                r = 0.0;
            }
        }
        if (dataSet.hasMaskArrays()) {
            this.clearLayerMaskArrays();
        }
        this.clearLayersStates();
        return out;
    }

    @Override
    public void fit() {
        this.fit(this.inputs, this.labels, this.inputMaskArrays, this.labelMaskArrays);
    }

    @Override
    public void update(INDArray gradient, String paramType) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public void update(Gradient gradient) {
        if (gradient.gradient().length() != this.numParams(true)) {
            throw new IllegalArgumentException("Invalid input: expect gradients array of length " + this.numParams(true));
        }
        for (Map.Entry<String, INDArray> entry : gradient.gradientForVariable().entrySet()) {
            String key = entry.getKey();
            INDArray val = entry.getValue();
            int idx = key.lastIndexOf(95);
            if (idx == -1) {
                throw new IllegalStateException("Invalid param key: not have layer separator: \"" + key + "\"");
            }
            String layerName = key.substring(0, idx);
            String paramType = key.split("_")[1];
            this.gradient.gradientForVariable().put(key, val);
            this.getLayer(layerName).update(val, paramType);
        }
        this.setBackpropGradientsViewArray(gradient.gradient());
    }

    private void update(Task task) {
        if (!this.initDone) {
            this.initDone = true;
            Heartbeat heartbeat = Heartbeat.getInstance();
            task = ModelSerializer.taskByModel(this);
            Environment env = EnvironmentUtils.buildEnvironment();
            heartbeat.reportEvent(Event.STANDALONE, env, task);
        }
    }

    @Override
    public double score() {
        return this.score;
    }

    public void setScore(double score) {
        this.score = score;
    }

    @Override
    public INDArray params() {
        return this.flattenedParams;
    }

    @Override
    public INDArray updaterState() {
        return this.getUpdater() != null ? this.getUpdater().getUpdaterStateViewArray() : null;
    }

    @Override
    public long numParams() {
        return this.numParams(true);
    }

    @Override
    public long numParams(boolean backwards) {
        int nParams = 0;
        for (org.deeplearning4j.nn.api.Layer layer : this.layers) {
            nParams = (int)((long)nParams + layer.numParams(backwards));
        }
        return nParams;
    }

    @Override
    public void setParams(INDArray params) {
        if (params == this.flattenedParams) {
            return;
        }
        if (this.flattenedParams != null && this.flattenedParams.length() == params.length()) {
            this.flattenedParams.assign(params);
            return;
        }
        int idx = 0;
        for (int i = 0; i < this.topologicalOrder.length; ++i) {
            org.deeplearning4j.nn.api.Layer layer;
            long range;
            if (!this.vertices[this.topologicalOrder[i]].hasLayer() || (range = (layer = this.vertices[this.topologicalOrder[i]].getLayer()).numParams()) <= 0L) continue;
            INDArray get = params.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)idx, (long)(range + (long)idx))});
            layer.setParams(get);
            idx = (int)((long)idx + range);
        }
    }

    @Override
    public void setParamsViewArray(INDArray gradient) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public INDArray getGradientsViewArray() {
        return this.flattenedGradients;
    }

    @Override
    public void setBackpropGradientsViewArray(INDArray gradient) {
        int paramsSoFar = 0;
        for (int i = 0; i < this.topologicalOrder.length; ++i) {
            org.deeplearning4j.nn.api.Layer layer;
            long range;
            if (!this.vertices[this.topologicalOrder[i]].hasLayer() || (range = (layer = this.vertices[this.topologicalOrder[i]].getLayer()).numParams()) <= 0L) continue;
            layer.setBackpropGradientsViewArray(gradient.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)paramsSoFar, (long)((long)paramsSoFar + range))}));
            paramsSoFar = (int)((long)paramsSoFar + range);
        }
    }

    @Override
    public void fit(INDArray data, LayerWorkspaceMgr workspaceMgr) {
        throw new UnsupportedOperationException("Cannot pretrain ComputationGraph with single INDArray");
    }

    @Override
    public Gradient gradient() {
        return this.gradient;
    }

    @Override
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair((Object)this.gradient(), (Object)this.score());
    }

    @Override
    public int batchSize() {
        return this.labels == null || this.labels[0] == null ? (int)this.inputs[0].size(0) : (int)this.labels[0].size(0);
    }

    @Override
    public NeuralNetConfiguration conf() {
        return this.defaultConfiguration;
    }

    @Override
    public void setConf(NeuralNetConfiguration conf) {
        throw new UnsupportedOperationException();
    }

    @Override
    public INDArray input() {
        if (this.numInputArrays == 1) {
            return this.inputs != null ? this.inputs[0] : null;
        }
        throw new UnsupportedOperationException("Cannot return single input: ComputationGraph  has multiple inputs");
    }

    @Override
    public ConvexOptimizer getOptimizer() {
        return this.solver.getOptimizer();
    }

    @Override
    public INDArray getParam(String paramName) {
        int idx = paramName.lastIndexOf(95);
        if (idx == -1) {
            throw new IllegalStateException("Invalid param key: not have layer separator: \"" + paramName + "\"");
        }
        String layerName = paramName.substring(0, idx);
        String paramType = paramName.substring(idx + 1);
        return this.getLayer(layerName).getParam(paramType);
    }

    @Override
    public Map<String, INDArray> paramTable() {
        return this.paramTable(false);
    }

    @Override
    public Map<String, INDArray> paramTable(boolean backpropParamsOnly) {
        LinkedHashMap<String, INDArray> allParams = new LinkedHashMap<String, INDArray>();
        for (GraphVertex gv : this.vertices) {
            Map<String, INDArray> paramMap = gv.paramTable(backpropParamsOnly);
            for (Map.Entry<String, INDArray> entry : paramMap.entrySet()) {
                String newKey = gv.getVertexName() + "_" + entry.getKey();
                allParams.put(newKey, entry.getValue());
            }
        }
        return allParams;
    }

    @Override
    public void setParamTable(@NonNull Map<String, INDArray> paramTable) {
        INDArray arrNew;
        INDArray arrCurrent;
        if (paramTable == null) {
            throw new NullPointerException("paramTable is marked non-null but is null");
        }
        Map<String, INDArray> m = this.paramTable();
        Preconditions.checkArgument((boolean)paramTable.keySet().equals(m.keySet()), (String)"Cannot set param table: parameter set keys are not equal");
        Map<String, INDArray> current = this.paramTable();
        for (String s : current.keySet()) {
            arrCurrent = current.get(s);
            arrNew = paramTable.get(s);
            long[] shapeCurrent = arrCurrent.shape();
            long[] shapeNew = arrNew.shape();
            Preconditions.checkState((boolean)Arrays.equals(shapeCurrent, shapeNew), (String)"Cannot set parameters: shape array for parameter \"%s\" does not match existing shape: parameter shape = %s, new param shape = %s", (Object)s, (Object)shapeCurrent, (Object)arrNew);
        }
        for (String s : current.keySet()) {
            arrCurrent = current.get(s);
            arrNew = paramTable.get(s);
            arrCurrent.assign(arrNew);
        }
    }

    @Override
    public void setParam(String key, INDArray val) {
        int idx = key.lastIndexOf(95);
        if (idx == -1) {
            throw new IllegalStateException("Invalid param key: not have layer separator: \"" + key + "\"");
        }
        String layerName = key.substring(0, idx);
        String paramType = key.substring(idx + 1);
        this.getLayer(layerName).setParam(paramType, val);
    }

    @Override
    public void clear() {
        this.inputs = null;
        this.labels = null;
        this.inputMaskArrays = null;
        this.labelMaskArrays = null;
    }

    @Override
    public void applyConstraints(int iteration, int epoch) {
        for (org.deeplearning4j.nn.api.Layer l : this.layers) {
            l.applyConstraints(iteration, epoch);
        }
    }

    public INDArray[] rnnTimeStep(INDArray ... inputs) {
        return this.rnnTimeStepHelper(null, inputs);
    }

    public INDArray[] rnnTimeStep(MemoryWorkspace outputWorkspace, INDArray ... inputs) {
        try {
            return this.rnnTimeStepHelper(outputWorkspace, inputs);
        }
        catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    private INDArray[] rnnTimeStepHelper(MemoryWorkspace outputWs, INDArray ... inputs) {
        boolean inputIs2d = true;
        for (INDArray i : inputs) {
            if (i.rank() == 2) continue;
            inputIs2d = false;
            break;
        }
        INDArray[] outputs = this.outputOfLayersDetached(false, FwdPassType.RNN_TIMESTEP, this.getOutputLayerIndices(), inputs, null, null, true, false, outputWs);
        if (inputIs2d) {
            for (int i = 0; i < outputs.length; ++i) {
                if (outputs[i].rank() != 3 || outputs[i].size(2) != 1L) continue;
                outputs[i] = outputs[i].tensorAlongDimension(0L, new int[]{1, 0});
            }
        }
        this.inputs = null;
        return outputs;
    }

    public Map<String, INDArray> rnnGetPreviousState(int layer) {
        return this.rnnGetPreviousState(this.layers[layer].conf().getLayer().getLayerName());
    }

    public Map<String, INDArray> rnnGetPreviousState(String layerName) {
        org.deeplearning4j.nn.api.Layer l = this.verticesMap.get(layerName).getLayer();
        if (l instanceof BaseWrapperLayer) {
            l = ((BaseWrapperLayer)l).getUnderlying();
        }
        if (l == null || !(l instanceof RecurrentLayer)) {
            return null;
        }
        return ((RecurrentLayer)l).rnnGetPreviousState();
    }

    public Map<String, Map<String, INDArray>> rnnGetPreviousStates() {
        HashMap<String, Map<String, INDArray>> states = new HashMap<String, Map<String, INDArray>>();
        for (org.deeplearning4j.nn.api.Layer l : this.layers) {
            if (l instanceof BaseWrapperLayer) {
                l = ((BaseWrapperLayer)l).getUnderlying();
            }
            if (!(l instanceof RecurrentLayer)) continue;
            states.put(l.conf().getLayer().getLayerName(), ((RecurrentLayer)l).rnnGetPreviousState());
        }
        return states;
    }

    public void rnnSetPreviousState(int layer, Map<String, INDArray> state) {
        this.rnnSetPreviousState(this.layers[layer].conf().getLayer().getLayerName(), state);
    }

    public void rnnSetPreviousState(String layerName, Map<String, INDArray> state) {
        org.deeplearning4j.nn.api.Layer l = this.verticesMap.get(layerName).getLayer();
        if (l instanceof BaseWrapperLayer) {
            l = ((BaseWrapperLayer)l).getUnderlying();
        }
        if (l == null || !(l instanceof RecurrentLayer)) {
            throw new UnsupportedOperationException("Layer \"" + layerName + "\" is not a recurrent layer. Cannot set state");
        }
        ((RecurrentLayer)l).rnnSetPreviousState(state);
    }

    public void rnnSetPreviousStates(Map<String, Map<String, INDArray>> previousStates) {
        for (Map.Entry<String, Map<String, INDArray>> entry : previousStates.entrySet()) {
            this.rnnSetPreviousState(entry.getKey(), entry.getValue());
        }
    }

    public void rnnClearPreviousState() {
        if (this.layers == null) {
            return;
        }
        for (org.deeplearning4j.nn.api.Layer layer : this.layers) {
            if (layer instanceof RecurrentLayer) {
                ((RecurrentLayer)layer).rnnClearPreviousState();
                continue;
            }
            if (!(layer instanceof MultiLayerNetwork)) continue;
            ((MultiLayerNetwork)layer).rnnClearPreviousState();
        }
    }

    protected void doTruncatedBPTT(INDArray[] inputs, INDArray[] labels, INDArray[] featureMasks, INDArray[] labelMasks, LayerWorkspaceMgr workspaceMgr) {
        if (this.flattenedGradients == null) {
            this.initGradientsView();
        }
        long timeSeriesLength = -1L;
        for (INDArray in : inputs) {
            if (in.rank() != 3) continue;
            if (timeSeriesLength == -1L) {
                timeSeriesLength = in.size(2);
                continue;
            }
            if (timeSeriesLength == in.size(2)) continue;
            log.warn("Cannot do TBPTT with time series of different lengths");
            return;
        }
        for (INDArray out : labels) {
            if (out.rank() != 3) continue;
            if (timeSeriesLength == -1L) {
                timeSeriesLength = out.size(2);
                continue;
            }
            if (timeSeriesLength == out.size(2)) continue;
            log.warn("Cannot do TBPTT with time series of different lengths");
            return;
        }
        long fwdLen = this.configuration.getTbpttFwdLength();
        long nSubsets = timeSeriesLength / fwdLen;
        if (timeSeriesLength % fwdLen != 0L) {
            ++nSubsets;
        }
        this.rnnClearPreviousState();
        int i = 0;
        while ((long)i < nSubsets) {
            long startTimeIdx = (long)i * fwdLen;
            long endTimeIdx = startTimeIdx + fwdLen;
            if (endTimeIdx > timeSeriesLength) {
                endTimeIdx = timeSeriesLength;
            }
            if (startTimeIdx > Integer.MAX_VALUE) {
                throw new ND4JArraySizeException();
            }
            List<INDArray[]> list = this.getSubsetsForTbptt((int)startTimeIdx, endTimeIdx, inputs, labels, featureMasks, labelMasks);
            this.setInputs(list.get(0));
            this.setLabels(list.get(1));
            this.setLayerMaskArrays(list.get(2), list.get(3));
            if (this.solver == null) {
                try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                    this.solver = new Solver.Builder().configure(this.conf()).listeners(this.getListeners()).model(this).build();
                }
            }
            this.solver.optimize(workspaceMgr);
            this.rnnUpdateStateWithTBPTTState();
            ++i;
        }
        if (this.clearTbpttState) {
            this.rnnClearPreviousState();
        }
        this.clearLayerMaskArrays();
    }

    private List<INDArray[]> getSubsetsForTbptt(int startTimeIdx, long endTimeIdx, INDArray[] inputs, INDArray[] labels, INDArray[] featureMasks, INDArray[] labelMasks) {
        int j;
        INDArray[] newInputs = new INDArray[inputs.length];
        INDArray[] newLabels = new INDArray[labels.length];
        INDArray[] newFeatureMasks = featureMasks != null ? new INDArray[featureMasks.length] : null;
        INDArray[] newLabelMasks = labelMasks != null ? new INDArray[labelMasks.length] : null;
        for (j = 0; j < inputs.length; ++j) {
            newInputs[j] = inputs[j].rank() != 3 ? inputs[j] : inputs[j].get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval((long)startTimeIdx, (long)endTimeIdx)});
        }
        for (j = 0; j < labels.length; ++j) {
            newLabels[j] = labels[j].rank() != 3 ? labels[j] : labels[j].get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval((long)startTimeIdx, (long)endTimeIdx)});
        }
        if (featureMasks != null) {
            for (j = 0; j < featureMasks.length; ++j) {
                if (featureMasks[j] == null) continue;
                newFeatureMasks[j] = featureMasks[j].get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((long)startTimeIdx, (long)endTimeIdx)});
            }
        }
        if (labelMasks != null) {
            for (j = 0; j < labelMasks.length; ++j) {
                if (labelMasks[j] == null) continue;
                newLabelMasks[j] = labelMasks[j].get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((long)startTimeIdx, (long)endTimeIdx)});
            }
        }
        return Arrays.asList(newInputs, newLabels, newFeatureMasks, newLabelMasks);
    }

    public Map<String, INDArray> rnnActivateUsingStoredState(INDArray[] inputs, boolean training, boolean storeLastForTBPTT) {
        return this.ffToLayerActivationsDetached(training, FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE, storeLastForTBPTT, this.vertices.length - 1, null, inputs, this.inputMaskArrays, this.labelMaskArrays, true);
    }

    public void setLayerMaskArrays(INDArray[] featureMaskArrays, INDArray[] labelMaskArrays) {
        this.clearLayerMaskArrays();
        this.inputMaskArrays = featureMaskArrays;
        this.labelMaskArrays = labelMaskArrays;
        if (featureMaskArrays != null) {
            if (featureMaskArrays.length != this.numInputArrays) {
                throw new IllegalArgumentException("Invalid number of feature mask arrays");
            }
            long minibatchSize = -1L;
            for (INDArray i : featureMaskArrays) {
                if (i == null) continue;
                minibatchSize = i.size(0);
            }
            HashMap<Integer, Object> map = new HashMap<Integer, Object>();
            for (int i = 0; i < this.topologicalOrder.length; ++i) {
                GraphVertex current = this.vertices[this.topologicalOrder[i]];
                if (current.isInputVertex()) {
                    INDArray fMask = featureMaskArrays[current.getVertexIndex()];
                    map.put(current.getVertexIndex(), new Pair((Object)fMask, (Object)MaskState.Active));
                    continue;
                }
                VertexIndices[] inputVertices = current.getInputVertices();
                INDArray[] inputMasks = null;
                MaskState maskState = null;
                for (int j = 0; j < inputVertices.length; ++j) {
                    Pair p = (Pair)map.get(inputVertices[j].getVertexIndex());
                    if (p == null) continue;
                    if (inputMasks == null) {
                        inputMasks = new INDArray[inputVertices.length];
                    }
                    inputMasks[j] = (INDArray)p.getFirst();
                    if (maskState != null && maskState != MaskState.Passthrough) continue;
                    maskState = (MaskState)((Object)p.getSecond());
                }
                if (minibatchSize > Integer.MAX_VALUE) {
                    throw new ND4JArraySizeException();
                }
                Pair<INDArray, MaskState> outPair = current.feedForwardMaskArrays(inputMasks, maskState, (int)minibatchSize);
                map.put(this.topologicalOrder[i], outPair);
            }
        }
        if (labelMaskArrays != null) {
            if (labelMaskArrays.length != this.numOutputArrays) {
                throw new IllegalArgumentException("Invalid number of label mask arrays");
            }
            for (int i = 0; i < labelMaskArrays.length; ++i) {
                if (labelMaskArrays[i] == null) continue;
                String outputName = this.configuration.getNetworkOutputs().get(i);
                GraphVertex v = this.verticesMap.get(outputName);
                org.deeplearning4j.nn.api.Layer ol = v.getLayer();
                ol.setMaskArray(labelMaskArrays[i]);
            }
        }
    }

    public void clearLayerMaskArrays() {
        for (org.deeplearning4j.nn.api.Layer layer : this.layers) {
            layer.setMaskArray(null);
        }
        this.inputMaskArrays = null;
        this.labelMaskArrays = null;
    }

    protected void rnnUpdateStateWithTBPTTState() {
        for (int i = 0; i < this.layers.length; ++i) {
            if (this.layers[i] instanceof RecurrentLayer) {
                RecurrentLayer l = (RecurrentLayer)this.layers[i];
                l.rnnSetPreviousState(l.rnnGetTBPTTState());
                continue;
            }
            if (!(this.layers[i] instanceof MultiLayerNetwork)) continue;
            ((MultiLayerNetwork)this.layers[i]).updateRnnStateWithTBPTTState();
        }
    }

    public <T extends Evaluation> T evaluate(DataSetIterator iterator) {
        return this.evaluate(iterator, (List<String>)null);
    }

    public <T extends Evaluation> T evaluate(MultiDataSetIterator iterator) {
        return this.evaluate(iterator, (List<String>)null);
    }

    public <T extends Evaluation> T evaluate(DataSetIterator iterator, List<String> labelsList) {
        return this.evaluate(iterator, labelsList, 1);
    }

    public <T extends Evaluation> T evaluate(MultiDataSetIterator iterator, List<String> labelsList) {
        return this.evaluate(iterator, labelsList, 1);
    }

    public <T extends Evaluation> T evaluate(DataSetIterator iterator, List<String> labelsList, int topN) {
        if (labelsList == null) {
            labelsList = iterator.getLabels();
        }
        org.deeplearning4j.nn.api.Layer outputLayer = this.getOutputLayer(0);
        if (this.getConfiguration().isValidateOutputLayerConfig()) {
            OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), Evaluation.class);
        }
        return (T)((org.deeplearning4j.eval.Evaluation[])this.doEvaluation(iterator, new org.deeplearning4j.eval.Evaluation[]{new org.deeplearning4j.eval.Evaluation((List<String>)labelsList, topN)}))[0];
    }

    public <T extends Evaluation> T evaluate(MultiDataSetIterator iterator, List<String> labelsList, int topN) {
        org.deeplearning4j.nn.api.Layer outputLayer = this.getOutputLayer(0);
        if (this.getConfiguration().isValidateOutputLayerConfig()) {
            OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), Evaluation.class);
        }
        return (T)((org.deeplearning4j.eval.Evaluation[])this.doEvaluation(iterator, new org.deeplearning4j.eval.Evaluation[]{new org.deeplearning4j.eval.Evaluation(labelsList, topN)}))[0];
    }

    public <T extends org.nd4j.evaluation.regression.RegressionEvaluation> T evaluateRegression(DataSetIterator iterator) {
        return this.evaluateRegression(iterator, null);
    }

    public <T extends org.nd4j.evaluation.regression.RegressionEvaluation> T evaluateRegression(MultiDataSetIterator iterator) {
        return this.evaluateRegression(iterator, null);
    }

    public <T extends org.nd4j.evaluation.regression.RegressionEvaluation> T evaluateRegression(DataSetIterator iterator, List<String> columnNames) {
        return (T)((RegressionEvaluation[])this.doEvaluation(iterator, new RegressionEvaluation[]{new RegressionEvaluation(columnNames)}))[0];
    }

    public <T extends org.nd4j.evaluation.regression.RegressionEvaluation> T evaluateRegression(MultiDataSetIterator iterator, List<String> columnNames) {
        return (T)((RegressionEvaluation[])this.doEvaluation(iterator, new RegressionEvaluation[]{new RegressionEvaluation(columnNames)}))[0];
    }

    @Deprecated
    public <T extends org.nd4j.evaluation.classification.ROC> T evaluateROC(DataSetIterator iterator) {
        return this.evaluateROC(iterator, 0);
    }

    public <T extends org.nd4j.evaluation.classification.ROC> T evaluateROC(DataSetIterator iterator, int rocThresholdSteps) {
        org.deeplearning4j.nn.api.Layer outputLayer = this.getOutputLayer(0);
        if (this.getConfiguration().isValidateOutputLayerConfig()) {
            OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), org.nd4j.evaluation.classification.ROC.class);
        }
        return (T)((ROC[])this.doEvaluation(iterator, new ROC[]{new ROC(rocThresholdSteps)}))[0];
    }

    @Deprecated
    public <T extends org.nd4j.evaluation.classification.ROC> T evaluateROC(MultiDataSetIterator iterator) {
        return this.evaluateROC(iterator, 0);
    }

    public <T extends org.nd4j.evaluation.classification.ROC> T evaluateROC(MultiDataSetIterator iterator, int rocThresholdSteps) {
        org.deeplearning4j.nn.api.Layer outputLayer = this.getOutputLayer(0);
        if (this.getConfiguration().isValidateOutputLayerConfig()) {
            OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), org.nd4j.evaluation.classification.ROC.class);
        }
        return (T)((ROC[])this.doEvaluation(iterator, new ROC[]{new ROC(rocThresholdSteps)}))[0];
    }

    @Deprecated
    public <T extends ROCMultiClass> T evaluateROCMultiClass(DataSetIterator iterator) {
        return this.evaluateROCMultiClass(iterator, 0);
    }

    public <T extends ROCMultiClass> T evaluateROCMultiClass(DataSetIterator iterator, int rocThresholdSteps) {
        org.deeplearning4j.nn.api.Layer outputLayer = this.getOutputLayer(0);
        if (this.getConfiguration().isValidateOutputLayerConfig()) {
            OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), ROCMultiClass.class);
        }
        return (T)((org.deeplearning4j.eval.ROCMultiClass[])this.doEvaluation(iterator, new org.deeplearning4j.eval.ROCMultiClass[]{new org.deeplearning4j.eval.ROCMultiClass(rocThresholdSteps)}))[0];
    }

    public <T extends ROCMultiClass> T evaluateROCMultiClass(MultiDataSetIterator iterator, int rocThresholdSteps) {
        org.deeplearning4j.nn.api.Layer outputLayer = this.getOutputLayer(0);
        if (this.getConfiguration().isValidateOutputLayerConfig()) {
            OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), ROCMultiClass.class);
        }
        return (T)((org.deeplearning4j.eval.ROCMultiClass[])this.doEvaluation(iterator, new org.deeplearning4j.eval.ROCMultiClass[]{new org.deeplearning4j.eval.ROCMultiClass(rocThresholdSteps)}))[0];
    }

    @Override
    public <T extends IEvaluation> T[] doEvaluation(DataSetIterator iterator, T ... evaluations) {
        return this.doEvaluation((MultiDataSetIterator)new MultiDataSetIteratorAdapter(iterator), (IEvaluation[])evaluations);
    }

    @Override
    public <T extends IEvaluation> T[] doEvaluation(MultiDataSetIterator iterator, T ... evaluations) {
        try {
            return this.doEvaluationHelper(iterator, (IEvaluation[])evaluations);
        }
        catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    public <T extends IEvaluation> Map<Integer, T[]> evaluate(DataSetIterator iterator, Map<Integer, T[]> evaluations) {
        return this.evaluate((MultiDataSetIterator)new MultiDataSetIteratorAdapter(iterator), evaluations);
    }

    public <T extends IEvaluation> Map<Integer, T[]> evaluate(MultiDataSetIterator iterator, Map<Integer, T[]> evaluations) {
        try {
            return this.doEvaluationHelper(iterator, evaluations);
        }
        catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    @SafeVarargs
    private final <T extends IEvaluation> T[] doEvaluationHelper(MultiDataSetIterator iterator, T ... evaluations) {
        Map<Integer, IEvaluation[]> map = Collections.singletonMap(0, (IEvaluation[])evaluations);
        return (IEvaluation[])this.doEvaluationHelper(iterator, map).get(0);
    }

    private <T extends IEvaluation> Map<Integer, T[]> doEvaluationHelper(MultiDataSetIterator iterator, Map<Integer, T[]> evaluations) {
        if (this.layers == null || !(this.getOutputLayer(0) instanceof IOutputLayer)) {
            throw new IllegalStateException("Cannot evaluate network with no output layer");
        }
        WorkspaceUtils.assertNoWorkspacesOpen((String)"Expected no external workspaces open at start of evaluation (doEvaluationHelper)");
        if (iterator.resetSupported() && !iterator.hasNext()) {
            iterator.reset();
        }
        MultiDataSetIterator iter = iterator.asyncSupported() ? new AsyncMultiDataSetIterator(iterator, 2, true) : iterator;
        WorkspaceMode cMode = this.configuration.getTrainingWorkspaceMode();
        this.configuration.setTrainingWorkspaceMode(this.configuration.getInferenceWorkspaceMode());
        boolean useRnnSegments = this.configuration.getBackpropType() == BackpropType.TruncatedBPTT;
        Object outputWs = this.getConfiguration().getInferenceWorkspaceMode() == WorkspaceMode.ENABLED ? Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(WS_ALL_LAYERS_ACT_CONFIG, WS_OUTPUT_MEM) : new DummyWorkspace();
        while (iter.hasNext()) {
            MultiDataSet next = (MultiDataSet)iter.next();
            if (next.getFeatures() == null || next.getLabels() == null) continue;
            if (!useRnnSegments) {
                INDArray[] features = next.getFeatures();
                INDArray[] featuresMasks = next.getFeaturesMaskArrays();
                INDArray[] labels = next.getLabels();
                INDArray[] labelMasks = next.getLabelsMaskArrays();
                List meta = next.getExampleMetaData();
                try (MemoryWorkspace ws = outputWs.notifyScopeEntered();){
                    INDArray[] out = this.outputOfLayersDetached(false, FwdPassType.STANDARD, this.getOutputLayerIndices(), features, featuresMasks, labelMasks, true, false, ws);
                    for (Integer i : evaluations.keySet()) {
                        Preconditions.checkState((i >= 0 && i < labels.length ? 1 : 0) != 0, (String)"Invalid output index: evaluation/output indices must be between 0 and numOutputs-1 (%s), got index %s", (int)this.numOutputArrays, (int)i);
                        IEvaluation[] evalsThisOutput = (IEvaluation[])evaluations.get(i);
                        if (evalsThisOutput == null) continue;
                        Preconditions.checkState((i >= 0 && i < this.getNumOutputArrays() ? 1 : 0) != 0, (String)"Invalid output index: indices for outputs must be between 0 and %s inclusive - found index %s", (int)this.numOutputArrays, (int)i);
                        INDArray currOut = out[i];
                        INDArray currLabel = labels[i];
                        MemoryWorkspace wsO = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
                        Throwable throwable = null;
                        try {
                            for (IEvaluation evaluation : evalsThisOutput) {
                                evaluation.eval(currLabel, currOut, next.getLabelsMaskArray(i.intValue()), meta);
                            }
                        }
                        catch (Throwable throwable2) {
                            throwable = throwable2;
                            throw throwable2;
                        }
                        finally {
                            if (wsO == null) continue;
                            if (throwable != null) {
                                try {
                                    wsO.close();
                                }
                                catch (Throwable throwable3) {
                                    throwable.addSuppressed(throwable3);
                                }
                                continue;
                            }
                            wsO.close();
                        }
                    }
                }
            }
            this.rnnClearPreviousState();
            int fwdLen = this.configuration.getTbpttFwdLength();
            long tsLength = -1L;
            long nF = next.getFeatures().length;
            int i = 0;
            while ((long)i < nF) {
                if (next.getFeatures(i).rank() == 3) {
                    tsLength = next.getFeatures(i).size(2);
                }
                ++i;
            }
            if (tsLength < 0L) {
                throw new IllegalStateException("Invalid configuration: detected TBPTT backprop type without time series features");
            }
            long nSubsets = tsLength / (long)fwdLen;
            if (tsLength % (long)fwdLen != 0L) {
                ++nSubsets;
            }
            int i2 = 0;
            while ((long)i2 < nSubsets) {
                int startTimeIdx = i2 * fwdLen;
                long endTimeIdx = Math.min((long)(startTimeIdx + fwdLen), tsLength);
                List<INDArray[]> subset = this.getSubsetsForTbptt(startTimeIdx, endTimeIdx, next.getFeatures(), next.getLabels(), next.getFeaturesMaskArrays(), next.getLabelsMaskArrays());
                this.setLayerMaskArrays(subset.get(2), subset.get(3));
                try (MemoryWorkspace ws = outputWs.notifyScopeEntered();){
                    INDArray[] outSub = this.rnnTimeStep(ws, subset.get(0));
                    for (Integer idx : evaluations.keySet()) {
                        IEvaluation[] evalsThisOutput = (IEvaluation[])evaluations.get(idx);
                        if (evalsThisOutput == null) continue;
                        INDArray labelSub = subset.get(1) == null ? null : subset.get(1)[idx];
                        INDArray maskSub = subset.get(3) == null ? null : subset.get(3)[idx];
                        INDArray currOut = outSub[idx];
                        MemoryWorkspace wsO = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
                        Throwable throwable = null;
                        try {
                            for (IEvaluation evaluation : evalsThisOutput) {
                                evaluation.eval(labelSub, currOut, maskSub);
                            }
                        }
                        catch (Throwable throwable4) {
                            throwable = throwable4;
                            throw throwable4;
                        }
                        finally {
                            if (wsO == null) continue;
                            if (throwable != null) {
                                try {
                                    wsO.close();
                                }
                                catch (Throwable throwable5) {
                                    throwable.addSuppressed(throwable5);
                                }
                                continue;
                            }
                            wsO.close();
                        }
                    }
                }
                ++i2;
            }
            this.rnnClearPreviousState();
            this.clearLayersStates();
        }
        if (iterator.asyncSupported()) {
            ((AsyncMultiDataSetIterator)iter).shutdown();
        }
        this.configuration.setTrainingWorkspaceMode(cMode);
        return evaluations;
    }

    public String summary() {
        return this.summary(null);
    }

    public String summary(InputType ... inputTypes) {
        StringBuilder ret = new StringBuilder();
        ret.append("\n");
        int frozenParams = 0;
        HashMap<String, InputType> vertexOutputs = new HashMap<String, InputType>();
        int currLayerIdx = -1;
        ArrayList<String[]> lines = new ArrayList<String[]>();
        if (inputTypes == null) {
            lines.add(new String[]{"VertexName (VertexType)", "nIn,nOut", "TotalParams", "ParamsShape", "Vertex Inputs"});
        } else {
            lines.add(new String[]{"VertexName (VertexType)", "nIn,nOut", "TotalParams", "ParamsShape", "Vertex Inputs", "InputShape", "OutputShape"});
        }
        int[] maxLength = new int[inputTypes == null || inputTypes.length == 0 ? 5 : 7];
        String[] header = (String[])lines.get(0);
        for (int i = 0; i < header.length; ++i) {
            maxLength[i] = header[i].length();
        }
        if (this.topologicalOrder == null) {
            GraphIndices indices = this.calculateIndices();
            this.topologicalOrder = indices.getTopologicalSortOrder();
        }
        for (int currVertexIdx : this.topologicalOrder) {
            GraphVertex currentVertex = this.vertices[currVertexIdx];
            String currentVertexName = currentVertex.getVertexName();
            String[] classNameArr = currentVertex.getClass().toString().split("\\.");
            String className = classNameArr[classNameArr.length - 1];
            String connections = "-";
            String inShape = "-";
            String outShape = "-";
            String paramCount = "-";
            String in = "-";
            String out = "-";
            String paramShape = "-";
            if (currentVertex.isInputVertex()) {
                if (inputTypes != null) {
                    vertexOutputs.put(currentVertexName, inputTypes[this.configuration.getNetworkInputs().indexOf(currentVertexName)]);
                }
            } else {
                VertexIndices[] inputVertices;
                connections = this.configuration.getVertexInputs().get(currentVertexName).toString();
                ArrayList<InputType> inputTypeList = new ArrayList<InputType>();
                if (currentVertex.hasLayer()) {
                    org.deeplearning4j.nn.api.Layer currentLayer = ((LayerVertex)currentVertex).getLayer();
                    classNameArr = currentLayer.getClass().getName().split("\\.");
                    className = classNameArr[classNameArr.length - 1];
                    paramCount = String.format("%,d", currentLayer.numParams());
                    if (currentLayer.numParams() > 0L) {
                        paramShape = "";
                        if (currentLayer instanceof BidirectionalLayer) {
                            BidirectionalLayer bi = (BidirectionalLayer)currentLayer;
                            in = String.valueOf(((Bidirectional)bi.conf().getLayer()).getNIn());
                            out = String.valueOf(((Bidirectional)bi.conf().getLayer()).getNOut());
                        } else {
                            try {
                                in = String.valueOf(((FeedForwardLayer)currentLayer.conf().getLayer()).getNIn());
                                out = String.valueOf(((FeedForwardLayer)currentLayer.conf().getLayer()).getNOut());
                            }
                            catch (Exception bi) {
                                // empty catch block
                            }
                        }
                        List<String> paraNames = currentLayer.conf().variables();
                        for (String aP : paraNames) {
                            String paramS = ArrayUtils.toString((Object)currentLayer.paramTable().get(aP).shape());
                            paramShape = paramShape + aP + ":" + paramS + ", ";
                        }
                        paramShape = paramShape.subSequence(0, paramShape.lastIndexOf(",")).toString();
                    }
                    if (currentLayer instanceof FrozenLayer) {
                        frozenParams = (int)((long)frozenParams + currentLayer.numParams());
                        classNameArr = ((FrozenLayer)currentLayer).getInsideLayer().getClass().getName().split("\\.");
                        className = "Frozen " + classNameArr[classNameArr.length - 1];
                    }
                    if (inputTypes != null) {
                        String inputVertexName = this.vertices[currentVertex.getInputVertices()[0].getVertexIndex()].getVertexName();
                        InputType currentInType = (InputType)vertexOutputs.get(inputVertexName);
                        inShape = currentInType.toString();
                        inputTypeList.add(currentInType);
                        InputPreProcessor layerVertexPreProcesor = ((org.deeplearning4j.nn.conf.graph.LayerVertex)this.configuration.getVertices().get(currentVertexName)).getPreProcessor();
                        if (layerVertexPreProcesor != null) {
                            inShape = inShape + "-->" + layerVertexPreProcesor.getOutputType(currentInType);
                        }
                    }
                    ++currLayerIdx;
                } else if (inputTypes != null && (inputVertices = currentVertex.getInputVertices()) != null) {
                    for (int i = 0; i < inputVertices.length; ++i) {
                        GraphVertex thisInputVertex = this.vertices[inputVertices[i].getVertexIndex()];
                        inputTypeList.add((InputType)vertexOutputs.get(thisInputVertex.getVertexName()));
                    }
                }
                if (inputTypes != null) {
                    InputType currentVertexOutputType = this.configuration.getVertices().get(currentVertexName).getOutputType(currLayerIdx, inputTypeList.toArray(new InputType[inputTypeList.size()]));
                    outShape = currentVertexOutputType.toString();
                    vertexOutputs.put(currentVertexName, currentVertexOutputType);
                }
            }
            String[] line = inputTypes == null ? new String[]{currentVertexName + " (" + className + ")", in + "," + out, paramCount, paramShape, connections} : new String[]{currentVertexName + " (" + className + ")", in + "," + out, paramCount, paramShape, connections, inShape, outShape};
            for (int i = 0; i < line.length; ++i) {
                maxLength[i] = Math.max(maxLength[i], line[i] == null ? 0 : line[i].length());
            }
            lines.add(line);
        }
        StringBuilder sbFormat = new StringBuilder();
        int totalLength = 0;
        int pos = 0;
        for (int length : maxLength) {
            int currLength = pos++ == maxLength.length - 1 ? length : length + 3;
            sbFormat.append("%-").append(currLength).append("s");
            totalLength += currLength;
        }
        sbFormat.append("\n");
        String format = sbFormat.toString();
        ret.append(StringUtils.repeat((String)"=", (int)totalLength)).append("\n");
        boolean first = true;
        for (String[] line : lines) {
            String formatted = String.format(format, line);
            ret.append(formatted);
            if (!first) continue;
            ret.append(StringUtils.repeat((String)"=", (int)totalLength)).append("\n");
            first = false;
        }
        ret.append(StringUtils.repeat((String)"-", (int)totalLength)).append(String.format("\n%30s %,d", "Total Parameters: ", this.params().length())).append(String.format("\n%30s %,d", "Trainable Parameters: ", this.params().length() - (long)frozenParams)).append(String.format("\n%30s %,d", "Frozen Parameters: ", frozenParams)).append("\n").append(StringUtils.repeat((String)"=", (int)totalLength)).append("\n");
        return ret.toString();
    }

    public String memoryInfo(int minibatch, InputType ... inputTypes) {
        return CrashReportingUtil.generateMemoryStatus(this, minibatch, inputTypes);
    }

    public void clearLayersStates() {
        for (org.deeplearning4j.nn.api.Layer layer : this.layers) {
            layer.clear();
            layer.clearNoiseWeightParams();
        }
        for (Serializable serializable : this.vertices) {
            serializable.clearVertex();
        }
    }

    public void incrementEpochCount() {
        this.configuration.setEpochCount(this.configuration.getEpochCount() + 1);
        this.synchronizeIterEpochCounts();
    }

    protected void synchronizeIterEpochCounts() {
        int currIter = this.getConfiguration().getIterationCount();
        int currEpoch = this.getConfiguration().getEpochCount();
        for (org.deeplearning4j.nn.api.Layer l : this.layers) {
            l.setIterationCount(currIter);
            l.setEpochCount(currEpoch);
        }
    }

    public int getIterationCount() {
        return this.configuration.getIterationCount();
    }

    public int getEpochCount() {
        return this.configuration.getEpochCount();
    }

    public void save(File f) throws IOException {
        this.save(f, true);
    }

    public void save(File f, boolean saveUpdater) throws IOException {
        ModelSerializer.writeModel((Model)this, f, saveUpdater);
    }

    public static ComputationGraph load(File f, boolean loadUpdater) throws IOException {
        return ModelSerializer.restoreComputationGraph(f, loadUpdater);
    }

    public ComputationGraph convertDataType(@NonNull DataType dataType) {
        if (dataType == null) {
            throw new NullPointerException("dataType is marked non-null but is null");
        }
        Preconditions.checkState((boolean)dataType.isFPType(), (String)"Invalid DataType: %s. Can only convert network to a floating point type", (Object)dataType);
        if (dataType == this.params().dataType()) {
            return this;
        }
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            INDArray newParams = this.params().castTo(dataType);
            String jsonConfig = this.getConfiguration().toJson();
            ComputationGraphConfiguration newConf = ComputationGraphConfiguration.fromJson(jsonConfig);
            newConf.setDataType(dataType);
            ComputationGraph newNet = new ComputationGraph(newConf);
            newNet.init(newParams, false);
            ComputationGraphUpdater u = this.getUpdater(false);
            if (u != null && u.getStateViewArray() != null) {
                INDArray oldUpdaterState = u.getStateViewArray();
                newNet.getUpdater(true).getStateViewArray().assign(oldUpdaterState);
            }
            ComputationGraph computationGraph = newNet;
            return computationGraph;
        }
    }

    public void setLearningRate(double newLr) {
        NetworkUtils.setLearningRate(this, newLr);
    }

    public void setLearningRate(ISchedule newLr) {
        NetworkUtils.setLearningRate(this, newLr);
    }

    public void setLearningRate(String layerName, double newLr) {
        NetworkUtils.setLearningRate(this, layerName, newLr);
    }

    public void setLearningRate(String layerName, ISchedule newLr) {
        NetworkUtils.setLearningRate(this, layerName, newLr);
    }

    public Double getLearningRate(String layerName) {
        return NetworkUtils.getLearningRate(this, layerName);
    }

    public long layerSize(int layer) {
        if (layer < 0 || layer > this.layers.length) {
            throw new IllegalArgumentException("Invalid layer index: " + layer + ". Layer index must be between 0 and " + (this.layers.length - 1) + " inclusive");
        }
        return this.layerSize(this.layers[layer].conf().getLayer().getLayerName());
    }

    public long layerInputSize(int layer) {
        if (layer < 0 || layer > this.layers.length) {
            throw new IllegalArgumentException("Invalid layer index: " + layer + ". Layer index must be between 0 and " + (this.layers.length - 1) + " inclusive");
        }
        return this.layerInputSize(this.layers[layer].conf().getLayer().getLayerName());
    }

    public long layerSize(String layerName) {
        org.deeplearning4j.nn.api.Layer l = this.getLayer(layerName);
        if (l == null) {
            throw new IllegalArgumentException("No layer with name \"" + layerName + "\" exists");
        }
        Layer conf = l.conf().getLayer();
        if (conf == null || !(conf instanceof FeedForwardLayer)) {
            return 0L;
        }
        FeedForwardLayer ffl = (FeedForwardLayer)conf;
        return ffl.getNOut();
    }

    public long layerInputSize(String layerName) {
        org.deeplearning4j.nn.api.Layer l = this.getLayer(layerName);
        if (l == null) {
            throw new IllegalArgumentException("No layer with name \"" + layerName + "\" exists");
        }
        Layer conf = l.conf().getLayer();
        if (conf == null || !(conf instanceof FeedForwardLayer)) {
            return 0L;
        }
        FeedForwardLayer ffl = (FeedForwardLayer)conf;
        return ffl.getNIn();
    }

    public boolean equals(Object obj) {
        if (obj == null) {
            return false;
        }
        if (obj instanceof ComputationGraph) {
            ComputationGraph network = (ComputationGraph)obj;
            boolean paramsEquals = network.params().equals(this.params());
            boolean confEquals = this.getConfiguration().equals(network.getConfiguration());
            boolean updaterEquals = this.getUpdater().equals(network.getUpdater());
            return paramsEquals && confEquals && updaterEquals;
        }
        return false;
    }

    private void writeObject(ObjectOutputStream oos) throws IOException {
        ModelSerializer.writeModel((Model)this, oos, true);
    }

    private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException {
        ComputationGraph cg = ModelSerializer.restoreComputationGraph(ois, true);
        this.defaultConfiguration = cg.defaultConfiguration.clone();
        this.configuration = cg.configuration.clone();
        this.init();
        this.flattenedParams.assign(cg.flattenedParams);
        if (cg.getUpdater() != null && cg.getUpdater(false).getStateViewArray() != null) {
            this.getUpdater(true).getStateViewArray().assign(cg.getUpdater(false).getStateViewArray());
        }
    }

    @Override
    public void close() {
        INDArray state;
        ComputationGraphUpdater u;
        if (this.flattenedParams.closeable()) {
            this.flattenedParams.close();
        }
        if (this.flattenedGradients != null && this.flattenedGradients.closeable()) {
            this.flattenedGradients.close();
        }
        if ((u = this.getUpdater(false)) != null && u.getStateViewArray() != null && (state = u.getStateViewArray()).closeable()) {
            state.close();
        }
        Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
        System.gc();
    }

    public INDArray getFlattenedGradients() {
        return this.flattenedGradients;
    }

    public void setInitDone(boolean initDone) {
        this.initDone = initDone;
    }

    public Map<String, Pointer> getHelperWorkspaces() {
        return this.helperWorkspaces;
    }
}

