/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.graph;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.berkeley.Triple;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.util.ComputationGraphUtil;
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.deeplearning4j.nn.graph.vertex.impl.InputVertex;
import org.deeplearning4j.nn.layers.BaseOutputLayer;
import org.deeplearning4j.nn.layers.BasePretrainNetwork;
import org.deeplearning4j.nn.layers.recurrent.BaseRecurrentLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.optimize.Solver;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.heartbeat.Heartbeat;
import org.nd4j.linalg.heartbeat.reports.Environment;
import org.nd4j.linalg.heartbeat.reports.Event;
import org.nd4j.linalg.heartbeat.reports.Task;
import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils;
import org.nd4j.linalg.heartbeat.utils.TaskUtils;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ComputationGraph
implements Serializable,
Model {
    private static final Logger log = LoggerFactory.getLogger(ComputationGraph.class);
    protected ComputationGraphConfiguration configuration;
    protected boolean initCalled = false;
    protected transient Solver solver;
    protected INDArray flattenedParams;
    protected transient INDArray flattenedGradients;
    protected Gradient gradient;
    protected double score;
    private boolean initDone = false;
    protected GraphVertex[] vertices;
    protected Map<String, GraphVertex> verticesMap;
    protected int[] topologicalOrder;
    protected Layer[] layers;
    private int numInputArrays;
    private int numOutputArrays;
    private INDArray[] inputs;
    private INDArray[] labels;
    private INDArray[] inputMaskArrays;
    private INDArray[] labelMaskArrays;
    private NeuralNetConfiguration defaultConfiguration;
    private Collection<IterationListener> listeners = new ArrayList<IterationListener>();

    public ComputationGraph(ComputationGraphConfiguration configuration) {
        this.configuration = configuration;
        this.numInputArrays = configuration.getNetworkInputs().size();
        this.numOutputArrays = configuration.getNetworkOutputs().size();
        this.inputs = new INDArray[this.numInputArrays];
        this.labels = new INDArray[this.numOutputArrays];
        this.defaultConfiguration = configuration.getDefaultConfiguration();
    }

    public ComputationGraphConfiguration getConfiguration() {
        return this.configuration;
    }

    public int getNumLayers() {
        return this.layers != null ? this.layers.length : 0;
    }

    public Layer getLayer(int idx) {
        return this.layers[idx];
    }

    public Layer[] getLayers() {
        return this.layers;
    }

    public Layer getLayer(String name) {
        return this.verticesMap.get(name).getLayer();
    }

    public GraphVertex[] getVertices() {
        return this.vertices;
    }

    public GraphVertex getVertex(String name) {
        return this.verticesMap.get(name);
    }

    public int getNumInputArrays() {
        return this.numInputArrays;
    }

    public int getNumOutputArrays() {
        return this.numOutputArrays;
    }

    public void setInput(int inputNum, INDArray input) {
        this.inputs[inputNum] = input;
    }

    public void setInputs(INDArray ... inputs) {
        if (inputs != null && inputs.length != this.numInputArrays) {
            throw new IllegalArgumentException("Invalid input array: network has " + this.numInputArrays + " inputs, but array is of length " + inputs.length);
        }
        this.inputs = inputs;
    }

    public INDArray getInput(int inputNum) {
        if (this.inputs == null) {
            return null;
        }
        return this.inputs[inputNum];
    }

    public INDArray[] getInputs() {
        return this.inputs;
    }

    public INDArray[] getInputMaskArrays() {
        return this.inputMaskArrays;
    }

    public INDArray[] getLabelMaskArrays() {
        return this.labelMaskArrays;
    }

    public void setLabel(int labelNum, INDArray label) {
        this.labels[labelNum] = label;
    }

    public void setLabels(INDArray[] labels) {
        if (labels != null && labels.length != this.numOutputArrays) {
            throw new IllegalArgumentException("Invalid output array: network has " + this.numOutputArrays + " outputs, but array is of length " + labels.length);
        }
        this.labels = labels;
    }

    public void init() {
        String vertexName;
        int i;
        if (this.initCalled) {
            return;
        }
        this.topologicalOrder = this.topologicalSortOrder();
        Map<String, org.deeplearning4j.nn.conf.graph.GraphVertex> configVertexMap = this.configuration.getVertices();
        List<String> networkInputNames = this.configuration.getNetworkInputs();
        Map<String, List<String>> vertexInputs = this.configuration.getVertexInputs();
        this.vertices = new GraphVertex[networkInputNames.size() + this.configuration.getVertices().size()];
        HashMap<String, Integer> allNamesReverse = new HashMap<String, Integer>();
        int vertexNumber = 0;
        for (String name : networkInputNames) {
            InputVertex gv = new InputVertex(this, name, vertexNumber, null);
            allNamesReverse.put(name, vertexNumber);
            this.vertices[vertexNumber++] = gv;
        }
        int numParams = 0;
        int[] numParamsForVertex = new int[this.topologicalOrder.length];
        for (i = 0; i < this.configuration.getNetworkInputs().size(); ++i) {
            numParamsForVertex[i] = 0;
        }
        for (Map.Entry<String, org.deeplearning4j.nn.conf.graph.GraphVertex> nodeEntry : configVertexMap.entrySet()) {
            org.deeplearning4j.nn.conf.graph.GraphVertex n = nodeEntry.getValue();
            numParamsForVertex[i] = n.numParams(true);
            numParams += numParamsForVertex[i];
            ++i;
        }
        this.flattenedParams = Nd4j.create((int)1, (int)numParams);
        INDArray[] paramsViewForVertex = new INDArray[this.topologicalOrder.length];
        int paramOffsetSoFar = 0;
        i = 0;
        for (Object vertexIdx : (org.deeplearning4j.nn.conf.graph.GraphVertex)this.topologicalOrder) {
            int nParamsThisVertex = numParamsForVertex[vertexIdx];
            if (nParamsThisVertex != 0) {
                paramsViewForVertex[vertexIdx] = this.flattenedParams.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)paramOffsetSoFar, (int)(paramOffsetSoFar + nParamsThisVertex))});
            }
            ++i;
            paramOffsetSoFar += nParamsThisVertex;
        }
        int numLayers = 0;
        ArrayList<Layer> tempLayerList = new ArrayList<Layer>();
        for (Map.Entry<String, org.deeplearning4j.nn.conf.graph.GraphVertex> nodeEntry : configVertexMap.entrySet()) {
            String name;
            org.deeplearning4j.nn.conf.graph.GraphVertex n = nodeEntry.getValue();
            GraphVertex gv = n.instantiate(this, name = nodeEntry.getKey(), vertexNumber, paramsViewForVertex[vertexNumber]);
            if (gv.hasLayer()) {
                ++numLayers;
                tempLayerList.add(gv.getLayer());
            }
            allNamesReverse.put(name, vertexNumber);
            this.vertices[vertexNumber++] = gv;
        }
        this.layers = tempLayerList.toArray(new Layer[numLayers]);
        this.verticesMap = new HashMap<String, GraphVertex>();
        for (GraphVertex gv : this.vertices) {
            this.verticesMap.put(gv.getVertexName(), gv);
        }
        HashMap<String, ArrayList<String>> verticesOutputTo = new HashMap<String, ArrayList<String>>();
        for (GraphVertex gv : this.vertices) {
            vertexName = gv.getVertexName();
            List<String> vertexInputNames = vertexInputs.get(vertexName);
            if (vertexInputNames == null) continue;
            for (String s : vertexInputNames) {
                ArrayList<String> list = (ArrayList<String>)verticesOutputTo.get(s);
                if (list == null) {
                    list = new ArrayList<String>();
                    verticesOutputTo.put(s, list);
                }
                list.add(vertexName);
            }
        }
        for (GraphVertex gv : this.vertices) {
            vertexName = gv.getVertexName();
            int vertexIndex = gv.getVertexIndex();
            List<String> vertexInputNames = vertexInputs.get(vertexName);
            if (vertexInputNames == null) continue;
            VertexIndices[] inputIndices = new VertexIndices[vertexInputNames.size()];
            for (int j = 0; j < vertexInputNames.size(); ++j) {
                String inName = vertexInputNames.get(j);
                int inputVertexIndex = (Integer)allNamesReverse.get(inName);
                GraphVertex inputVertex = this.vertices[inputVertexIndex];
                List inputVertexOutputsTo = (List)verticesOutputTo.get(inName);
                int outputNumberOfInput = inputVertexOutputsTo.indexOf(vertexName);
                if (outputNumberOfInput == -1) {
                    throw new IllegalStateException("Could not find vertex " + vertexIndex + " in the list of outputs " + "for vertex " + inputVertex + "; error in graph structure?");
                }
                inputIndices[j] = new VertexIndices(inputVertexIndex, outputNumberOfInput);
            }
            gv.setInputVertices(inputIndices);
        }
        for (GraphVertex gv : this.vertices) {
            vertexName = gv.getVertexName();
            List thisVertexOutputsTo = (List)verticesOutputTo.get(vertexName);
            if (thisVertexOutputsTo == null || thisVertexOutputsTo.size() == 0) continue;
            VertexIndices[] outputIndices = new VertexIndices[thisVertexOutputsTo.size()];
            int j = 0;
            for (String s : thisVertexOutputsTo) {
                List<String> nextVertexInputNames = vertexInputs.get(s);
                int outputVertexInputNumber = nextVertexInputNames.indexOf(vertexName);
                int outputVertexIndex = (Integer)allNamesReverse.get(s);
                outputIndices[j++] = new VertexIndices(outputVertexIndex, outputVertexInputNumber);
            }
            gv.setOutputVertices(outputIndices);
        }
        this.initCalled = true;
    }

    protected void initGradientsView() {
        int i;
        if (!this.initCalled) {
            this.init();
        }
        int numParams = 0;
        int[] numParamsForVertex = new int[this.topologicalOrder.length];
        for (i = 0; i < this.configuration.getNetworkInputs().size(); ++i) {
            numParamsForVertex[i] = 0;
        }
        Map<String, org.deeplearning4j.nn.conf.graph.GraphVertex> configVertexMap = this.configuration.getVertices();
        for (Map.Entry<String, org.deeplearning4j.nn.conf.graph.GraphVertex> nodeEntry : configVertexMap.entrySet()) {
            org.deeplearning4j.nn.conf.graph.GraphVertex n = nodeEntry.getValue();
            numParamsForVertex[i] = n.numParams(true);
            numParams += numParamsForVertex[i];
            ++i;
        }
        this.flattenedGradients = Nd4j.create((int)1, (int)numParams);
        int paramOffsetSoFar = 0;
        i = 0;
        for (int vertexIdx : this.topologicalOrder) {
            int nParamsThisVertex = numParamsForVertex[vertexIdx];
            if (nParamsThisVertex != 0) {
                INDArray gradientView = this.flattenedGradients.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)paramOffsetSoFar, (int)(paramOffsetSoFar + nParamsThisVertex))});
                this.vertices[vertexIdx].setBackpropGradientsViewArray(gradientView);
            }
            ++i;
            paramOffsetSoFar += nParamsThisVertex;
        }
    }

    public void pretrain(DataSetIterator iter) {
        if (this.numInputArrays != 1 || this.numOutputArrays != 1) {
            throw new UnsupportedOperationException("Cannot train ComputationGraph network with  multiple inputs or outputs using a DataSetIterator");
        }
        this.pretrain(ComputationGraphUtil.toMultiDataSetIterator(iter));
    }

    public void pretrain(MultiDataSetIterator iter) {
        for (int i = 0; i < this.topologicalOrder.length; ++i) {
            if (!this.vertices[i].hasLayer() || this.vertices[i].getLayer() instanceof BaseOutputLayer) continue;
            LinkedList<Integer> partialTopoSort = new LinkedList<Integer>();
            HashSet<Integer> seenSoFar = new HashSet<Integer>();
            partialTopoSort.add(this.topologicalOrder[i]);
            seenSoFar.add(this.topologicalOrder[i]);
            for (int j = i - 1; j >= 0; --j) {
                VertexIndices[] outputsTo = this.vertices[this.topologicalOrder[j]].getOutputVertices();
                boolean needed = false;
                for (VertexIndices vi : outputsTo) {
                    if (!seenSoFar.contains(vi.getVertexIndex())) continue;
                    needed = true;
                    break;
                }
                if (!needed) continue;
                partialTopoSort.addFirst(this.topologicalOrder[j]);
                seenSoFar.add(this.topologicalOrder[j]);
            }
            int[] fwdPassOrder = new int[partialTopoSort.size()];
            int k = 0;
            for (Integer g : partialTopoSort) {
                fwdPassOrder[k++] = g;
            }
            GraphVertex gv = this.vertices[fwdPassOrder[fwdPassOrder.length - 1]];
            Layer layer = gv.getLayer();
            if (!(layer instanceof BasePretrainNetwork)) {
                throw new IllegalStateException("Cannot pretrain network with layer that is not pretrainable");
            }
            log.info("Pretraining on layer \"{}\"", (Object)this.vertices[i].getVertexName());
            BasePretrainNetwork toPretrain = (BasePretrainNetwork)layer;
            if (this.listeners != null) {
                toPretrain.setListeners(this.listeners);
            }
            while (iter.hasNext()) {
                MultiDataSet multiDataSet = (MultiDataSet)iter.next();
                this.setInputs(multiDataSet.getFeatures());
                for (int j = 0; j < fwdPassOrder.length - 1; ++j) {
                    int vIdx;
                    GraphVertex current = this.vertices[fwdPassOrder[j]];
                    if (current.isInputVertex()) {
                        VertexIndices[] inputsTo = current.getOutputVertices();
                        INDArray input = this.inputs[current.getVertexIndex()];
                        for (VertexIndices v : inputsTo) {
                            vIdx = v.getVertexIndex();
                            int vIdxInputNum = v.getVertexEdgeNumber();
                            this.vertices[vIdx].setInput(vIdxInputNum, input.dup());
                        }
                        continue;
                    }
                    INDArray out = current.doForward(true);
                    VertexIndices[] outputsTo = current.getOutputVertices();
                    if (outputsTo == null) continue;
                    for (VertexIndices v : outputsTo) {
                        vIdx = v.getVertexIndex();
                        int inputNum = v.getVertexEdgeNumber();
                        this.vertices[vIdx].setInput(inputNum, out);
                    }
                }
                toPretrain.fit(gv.getInputs()[0]);
            }
            iter.reset();
        }
    }

    public void fit(DataSet dataSet) {
        if (this.numInputArrays != 1 || this.numOutputArrays != 1) {
            throw new UnsupportedOperationException("Cannot train ComputationGraph network with  multiple inputs or outputs using a DataSet");
        }
        boolean hasMaskArrays = dataSet.hasMaskArrays();
        if (hasMaskArrays) {
            INDArray[] iNDArrayArray;
            INDArray[] fMask;
            INDArray[] iNDArrayArray2;
            if (dataSet.getFeaturesMaskArray() != null) {
                INDArray[] iNDArrayArray3 = new INDArray[1];
                iNDArrayArray2 = iNDArrayArray3;
                iNDArrayArray3[0] = dataSet.getFeaturesMaskArray();
            } else {
                iNDArrayArray2 = fMask = null;
            }
            if (dataSet.getLabelsMaskArray() != null) {
                INDArray[] iNDArrayArray4 = new INDArray[1];
                iNDArrayArray = iNDArrayArray4;
                iNDArrayArray4[0] = dataSet.getLabelsMaskArray();
            } else {
                iNDArrayArray = null;
            }
            INDArray[] lMask = iNDArrayArray;
            this.setLayerMaskArrays(fMask, lMask);
        }
        this.fit(new INDArray[]{dataSet.getFeatureMatrix()}, new INDArray[]{dataSet.getLabels()});
        if (hasMaskArrays) {
            this.clearLayerMaskArrays();
        }
    }

    public void fit(DataSetIterator dataSetIterator) {
        if (this.numInputArrays != 1 || this.numOutputArrays != 1) {
            throw new UnsupportedOperationException("Cannot train ComputationGraph network with  multiple inputs or outputs using a DataSetIterator");
        }
        if (this.configuration.isPretrain()) {
            this.pretrain(dataSetIterator);
        }
        if (this.configuration.isBackprop()) {
            DataSet next;
            this.update(TaskUtils.buildTask((org.nd4j.linalg.dataset.api.iterator.DataSetIterator)dataSetIterator));
            while (dataSetIterator.hasNext() && (next = (DataSet)dataSetIterator.next()).getFeatureMatrix() != null && next.getLabels() != null) {
                boolean hasMaskArrays = next.hasMaskArrays();
                if (hasMaskArrays) {
                    INDArray[] iNDArrayArray;
                    INDArray[] fMask;
                    INDArray[] iNDArrayArray2;
                    if (next.getFeaturesMaskArray() != null) {
                        INDArray[] iNDArrayArray3 = new INDArray[1];
                        iNDArrayArray2 = iNDArrayArray3;
                        iNDArrayArray3[0] = next.getFeaturesMaskArray();
                    } else {
                        iNDArrayArray2 = fMask = null;
                    }
                    if (next.getLabelsMaskArray() != null) {
                        INDArray[] iNDArrayArray4 = new INDArray[1];
                        iNDArrayArray = iNDArrayArray4;
                        iNDArrayArray4[0] = next.getLabelsMaskArray();
                    } else {
                        iNDArrayArray = null;
                    }
                    INDArray[] lMask = iNDArrayArray;
                    this.setLayerMaskArrays(fMask, lMask);
                }
                if (this.configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
                    INDArray[] iNDArrayArray;
                    INDArray[] iNDArrayArray5;
                    INDArray[] iNDArrayArray6 = new INDArray[]{next.getFeatures()};
                    INDArray[] iNDArrayArray7 = new INDArray[]{next.getLabels()};
                    if (hasMaskArrays) {
                        INDArray[] iNDArrayArray8 = new INDArray[1];
                        iNDArrayArray5 = iNDArrayArray8;
                        iNDArrayArray8[0] = next.getFeaturesMaskArray();
                    } else {
                        iNDArrayArray5 = null;
                    }
                    if (hasMaskArrays) {
                        INDArray[] iNDArrayArray9 = new INDArray[1];
                        iNDArrayArray = iNDArrayArray9;
                        iNDArrayArray9[0] = next.getLabelsMaskArray();
                    } else {
                        iNDArrayArray = null;
                    }
                    this.doTruncatedBPTT(iNDArrayArray6, iNDArrayArray7, iNDArrayArray5, iNDArrayArray);
                } else {
                    this.setInput(0, next.getFeatureMatrix());
                    this.setLabel(0, next.getLabels());
                    if (this.solver == null) {
                        this.solver = new Solver.Builder().configure(this.defaultConfiguration).listeners(this.listeners).model(this).build();
                    }
                    this.solver.optimize();
                }
                if (!hasMaskArrays) continue;
                this.clearLayerMaskArrays();
            }
        }
    }

    public void fit(MultiDataSet multiDataSet) {
        if (multiDataSet.hasMaskArrays()) {
            this.setLayerMaskArrays(multiDataSet.getFeaturesMaskArrays(), multiDataSet.getLabelsMaskArrays());
        }
        this.fit(multiDataSet.getFeatures(), multiDataSet.getLabels());
        if (multiDataSet.hasMaskArrays()) {
            this.clearLayerMaskArrays();
        }
    }

    public void fit(MultiDataSetIterator multiDataSetIterator) {
        if (this.configuration.isPretrain()) {
            this.pretrain(multiDataSetIterator);
        }
        if (this.configuration.isBackprop()) {
            MultiDataSet next;
            while (multiDataSetIterator.hasNext() && (next = (MultiDataSet)multiDataSetIterator.next()).getFeatures() != null && next.getLabels() != null) {
                if (this.configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
                    this.doTruncatedBPTT(next.getFeatures(), next.getLabels(), next.getFeaturesMaskArrays(), next.getLabelsMaskArrays());
                    continue;
                }
                boolean hasMaskArrays = next.hasMaskArrays();
                if (hasMaskArrays) {
                    this.setLayerMaskArrays(next.getFeaturesMaskArrays(), next.getLabelsMaskArrays());
                }
                this.setInputs(next.getFeatures());
                this.setLabels(next.getLabels());
                if (this.solver == null) {
                    this.solver = new Solver.Builder().configure(this.defaultConfiguration).listeners(this.listeners).model(this).build();
                }
                this.solver.optimize();
                if (!hasMaskArrays) continue;
                this.clearLayerMaskArrays();
            }
        }
    }

    public void fit(INDArray[] inputs, INDArray[] labels) {
        this.fit(inputs, labels, null, null);
    }

    public void fit(INDArray[] inputs, INDArray[] labels, INDArray[] featureMaskArrays, INDArray[] labelMaskArrays) {
        this.setInputs(inputs);
        this.setLabels(labels);
        this.setLayerMaskArrays(featureMaskArrays, labelMaskArrays);
        this.update(TaskUtils.buildTask((INDArray[])inputs, (INDArray[])labels));
        if (this.configuration.isPretrain()) {
            throw new UnsupportedOperationException("Pretraining: Not yet implemented");
        }
        if (this.configuration.isBackprop()) {
            if (this.configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
                this.doTruncatedBPTT(inputs, labels, null, null);
            } else {
                if (this.solver == null) {
                    this.solver = new Solver.Builder().configure(this.conf()).listeners(this.getListeners()).model(this).build();
                }
                this.solver.optimize();
            }
        }
    }

    public int[] topologicalSortOrder() {
        if (this.topologicalOrder != null) {
            return this.topologicalOrder;
        }
        Map<String, org.deeplearning4j.nn.conf.graph.GraphVertex> nodeMap = this.configuration.getVertices();
        List<String> networkInputNames = this.configuration.getNetworkInputs();
        int numVertices = networkInputNames.size() + this.configuration.getVertices().size();
        int[] out = new int[numVertices];
        int outCounter = 0;
        HashMap<Integer, String> vertexNamesMap = new HashMap<Integer, String>();
        HashMap<Object, Integer> vertexNamesMap2 = new HashMap<Object, Integer>();
        int i = 0;
        for (String string : this.configuration.getNetworkInputs()) {
            vertexNamesMap.put(i, string);
            vertexNamesMap2.put(string, i);
            ++i;
        }
        for (Map.Entry entry : nodeMap.entrySet()) {
            Iterator<Map.Entry<String, org.deeplearning4j.nn.conf.graph.GraphVertex>> name = (String)entry.getKey();
            vertexNamesMap.put(i, (String)((Object)name));
            vertexNamesMap2.put(name, i);
            ++i;
        }
        HashMap inputEdges = new HashMap();
        HashMap<Integer, HashSet<Integer>> hashMap = new HashMap<Integer, HashSet<Integer>>();
        for (String string : this.configuration.getNetworkInputs()) {
            int n = (Integer)vertexNamesMap2.get(string);
            inputEdges.put(n, null);
        }
        for (Map.Entry<String, org.deeplearning4j.nn.conf.graph.GraphVertex> entry : nodeMap.entrySet()) {
            String string = entry.getKey();
            int idx = (Integer)vertexNamesMap2.get(string);
            List<String> inputsToThisVertex = this.configuration.getVertexInputs().get(string);
            if (inputsToThisVertex == null || inputsToThisVertex.size() == 0) {
                inputEdges.put(idx, null);
                continue;
            }
            HashSet<Integer> inputSet = new HashSet<Integer>();
            for (String s : inputsToThisVertex) {
                Integer inputIdx = (Integer)vertexNamesMap2.get(s);
                if (inputIdx == null) {
                    System.out.println();
                }
                inputSet.add(inputIdx);
                HashSet<Integer> outputSetForInputIdx = (HashSet<Integer>)hashMap.get(inputIdx);
                if (outputSetForInputIdx == null) {
                    outputSetForInputIdx = new HashSet<Integer>();
                    hashMap.put(inputIdx, outputSetForInputIdx);
                }
                outputSetForInputIdx.add(idx);
            }
            inputEdges.put(idx, inputSet);
        }
        LinkedList<Object> noIncomingEdges = new LinkedList<Object>();
        for (Map.Entry entry : inputEdges.entrySet()) {
            Set inputsFrom = (Set)entry.getValue();
            if (inputsFrom != null && inputsFrom.size() != 0) continue;
            noIncomingEdges.add(entry.getKey());
        }
        while (noIncomingEdges.size() > 0) {
            int n = (Integer)noIncomingEdges.removeFirst();
            out[outCounter++] = n;
            Set set = (Set)hashMap.get(n);
            if (set == null) continue;
            for (Integer v : set) {
                Set set2 = (Set)inputEdges.get(v);
                set2.remove(n);
                if (set2.size() != 0) continue;
                noIncomingEdges.add(v);
            }
        }
        for (Map.Entry entry : inputEdges.entrySet()) {
            Set set = (Set)entry.getValue();
            if (set == null || set.size() <= 0) continue;
            throw new IllegalStateException("Invalid configuration: cycle detected in graph. Cannot calculate topological ordering with graph cycle (cycle includes vertex \"" + (String)vertexNamesMap.get(entry.getKey()) + "\")");
        }
        return out;
    }

    @Override
    public void computeGradientAndScore() {
        if (this.configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
            this.rnnActivateUsingStoredState(this.inputs, true, true);
            this.backprop(true);
        } else {
            this.feedForward(true, true);
            this.backprop(false);
        }
        double l1 = this.calcL1();
        double l2 = this.calcL2();
        this.score = 0.0;
        for (String s : this.configuration.getNetworkOutputs()) {
            GraphVertex gv = this.verticesMap.get(s);
            this.score += ((BaseOutputLayer)gv.getLayer()).computeScore(l1, l2, true);
            l1 = 0.0;
            l2 = 0.0;
        }
    }

    public Map<String, INDArray> feedForward(INDArray input, boolean train) {
        if (this.numInputArrays != 1) {
            throw new UnsupportedOperationException("Cannot feedForward with single input for graph network with " + this.numInputArrays + " expected inputs");
        }
        this.setInput(0, input);
        return this.feedForward(train);
    }

    public Map<String, INDArray> feedForward(INDArray[] input, boolean train) {
        if (this.numInputArrays != input.length) {
            throw new UnsupportedOperationException("Cannot feedForward with " + input.length + " inputs for graph network with " + this.numInputArrays + " expected inputs");
        }
        for (int i = 0; i < input.length; ++i) {
            this.setInput(i, input[i]);
        }
        return this.feedForward(train);
    }

    public Map<String, INDArray> feedForward() {
        return this.feedForward(false);
    }

    public Map<String, INDArray> feedForward(boolean train) {
        return this.feedForward(train, false);
    }

    private Map<String, INDArray> feedForward(boolean train, boolean excludeOutputLayers) {
        HashMap<String, INDArray> layerActivations = new HashMap<String, INDArray>();
        for (int i = 0; i < this.topologicalOrder.length; ++i) {
            VertexIndices[] outputsTo;
            int vIdx;
            GraphVertex current = this.vertices[this.topologicalOrder[i]];
            if (current.isInputVertex()) {
                VertexIndices[] inputsTo = current.getOutputVertices();
                INDArray input = this.inputs[current.getVertexIndex()];
                layerActivations.put(current.getVertexName(), input);
                for (VertexIndices v : inputsTo) {
                    vIdx = v.getVertexIndex();
                    int vIdxInputNum = v.getVertexEdgeNumber();
                    this.vertices[vIdx].setInput(vIdxInputNum, input.dup());
                }
                continue;
            }
            if (excludeOutputLayers && current.isOutputVertex() && current.hasLayer() && current.getLayer() instanceof BaseOutputLayer) continue;
            INDArray out = current.doForward(train);
            if (current.hasLayer()) {
                layerActivations.put(current.getVertexName(), out);
            }
            if ((outputsTo = current.getOutputVertices()) == null) continue;
            for (VertexIndices v : outputsTo) {
                vIdx = v.getVertexIndex();
                int inputNum = v.getVertexEdgeNumber();
                this.vertices[vIdx].setInput(inputNum, out);
            }
        }
        return layerActivations;
    }

    public INDArray[] output(INDArray ... input) {
        return this.output(false, input);
    }

    public INDArray[] output(boolean train, INDArray ... input) {
        this.setInputs(input);
        Map<String, INDArray> activations = this.feedForward(train);
        INDArray[] outputs = new INDArray[this.numOutputArrays];
        int i = 0;
        for (String s : this.configuration.getNetworkOutputs()) {
            outputs[i++] = activations.get(s);
        }
        return outputs;
    }

    protected void backprop(boolean truncatedBPTT) {
        if (this.flattenedGradients == null) {
            this.initGradientsView();
        }
        LinkedList<Triple> gradients = new LinkedList<Triple>();
        for (int i = this.topologicalOrder.length - 1; i >= 0; --i) {
            GraphVertex current = this.vertices[this.topologicalOrder[i]];
            if (current.isInputVertex()) continue;
            if (current.isOutputVertex()) {
                BaseOutputLayer outputLayer = (BaseOutputLayer)current.getLayer();
                int thisOutputNumber = this.configuration.getNetworkOutputs().indexOf(current.getVertexName());
                INDArray currLabels = this.labels[thisOutputNumber];
                outputLayer.setLabels(currLabels);
            }
            Pair<Gradient, INDArray[]> pair = current.doBackward(truncatedBPTT);
            INDArray[] epsilons = pair.getSecond();
            VertexIndices[] inputVertices = current.getInputVertices();
            if (inputVertices != null) {
                int j = 0;
                for (VertexIndices vertexIndices : inputVertices) {
                    GraphVertex gv = this.vertices[vertexIndices.getVertexIndex()];
                    int outputNumberOfInputVertex = vertexIndices.getVertexEdgeNumber();
                    gv.setError(outputNumberOfInputVertex, epsilons[j++]);
                }
            }
            if (pair.getFirst() == null) continue;
            Gradient g = pair.getFirst();
            Map<String, INDArray> map = g.gradientForVariable();
            LinkedList<Triple<String, INDArray, Character>> tempList = new LinkedList<Triple<String, INDArray, Character>>();
            for (Map.Entry<String, INDArray> entry : map.entrySet()) {
                String origName = entry.getKey();
                String newName = current.getVertexName() + "_" + origName;
                tempList.addFirst(new Triple<String, INDArray, Character>(newName, entry.getValue(), g.flatteningOrderForVariable(origName)));
            }
            for (Triple triple : tempList) {
                gradients.addFirst(triple);
            }
        }
        DefaultGradient gradient = new DefaultGradient(this.flattenedGradients);
        for (Triple t : gradients) {
            gradient.setGradientFor((String)t.getFirst(), (INDArray)t.getSecond(), (Character)t.getThird());
        }
        this.gradient = gradient;
    }

    public ComputationGraph clone() {
        ComputationGraph cg = new ComputationGraph(this.configuration.clone());
        cg.init();
        cg.setParams(this.params().dup());
        return cg;
    }

    public double calcL2() {
        double l2 = 0.0;
        for (Layer l : this.layers) {
            l2 += l.calcL2();
        }
        return l2;
    }

    public double calcL1() {
        double l1 = 0.0;
        for (Layer l : this.layers) {
            l1 += l.calcL1();
        }
        return l1;
    }

    public void setListeners(Collection<IterationListener> listeners) {
        this.listeners = listeners;
        if (this.layers == null) {
            this.init();
        }
        for (Layer l : this.layers) {
            l.setListeners(listeners);
        }
        if (this.solver != null) {
            this.solver.setListeners(listeners);
        }
    }

    public void setListeners(IterationListener ... listeners) {
        ArrayList<IterationListener> list = new ArrayList<IterationListener>();
        Collections.addAll(list, listeners);
        this.setListeners(list);
    }

    public Collection<IterationListener> getListeners() {
        return this.listeners;
    }

    public ComputationGraphUpdater getUpdater() {
        if (this.solver == null) {
            this.solver = new Solver.Builder().configure(this.conf()).listeners(this.getListeners()).model(this).build();
            this.solver.getOptimizer().setUpdaterComputationGraph(new ComputationGraphUpdater(this));
        }
        return this.solver.getOptimizer().getComputationGraphUpdater();
    }

    public void setUpdater(ComputationGraphUpdater updater) {
        if (this.solver == null) {
            this.solver = new Solver.Builder().configure(this.conf()).listeners(this.getListeners()).model(this).build();
        }
        this.solver.getOptimizer().setUpdaterComputationGraph(updater);
    }

    public Layer getOutputLayer(int outputLayerIdx) {
        if (outputLayerIdx >= this.numOutputArrays) {
            throw new IllegalArgumentException("Invalid index: cannot get output layer " + outputLayerIdx + ", total number of network outputs = " + this.numOutputArrays);
        }
        return this.getLayer(this.configuration.getNetworkOutputs().get(outputLayerIdx));
    }

    public INDArray params(boolean backwardOnly) {
        if (backwardOnly) {
            return this.flattenedParams;
        }
        ArrayList<INDArray> list = new ArrayList<INDArray>(this.layers.length);
        for (int i = 0; i < this.topologicalOrder.length; ++i) {
            Layer l;
            INDArray layerParams;
            if (!this.vertices[this.topologicalOrder[i]].hasLayer() || (layerParams = (l = this.vertices[this.topologicalOrder[i]].getLayer()).params()) == null) continue;
            list.add(layerParams);
        }
        return Nd4j.toFlattened((char)'f', list);
    }

    public double score(DataSet dataSet) {
        return this.score(dataSet, false);
    }

    public double score(DataSet dataSet, boolean training) {
        if (this.numInputArrays != 1 || this.numOutputArrays != 1) {
            throw new UnsupportedOperationException("Cannot score ComputationGraph network with  DataSet: network does not have 1 input and 1 output arrays");
        }
        return this.score(ComputationGraphUtil.toMultiDataSet(dataSet), training);
    }

    public double score(MultiDataSet dataSet) {
        return this.score(dataSet, false);
    }

    public double score(MultiDataSet dataSet, boolean training) {
        boolean hasMaskArrays = dataSet.hasMaskArrays();
        if (hasMaskArrays) {
            this.setLayerMaskArrays(dataSet.getFeaturesMaskArrays(), dataSet.getLabelsMaskArrays());
        }
        this.feedForward(dataSet.getFeatures(), training);
        INDArray[] labels = dataSet.getLabels();
        this.setLabels(labels);
        double l1 = this.calcL1();
        double l2 = this.calcL2();
        double score = 0.0;
        int i = 0;
        for (String s : this.configuration.getNetworkOutputs()) {
            Layer outLayer = this.verticesMap.get(s).getLayer();
            if (outLayer == null || !(outLayer instanceof BaseOutputLayer)) {
                log.warn("Cannot calculate score: vertex \"" + s + "\" is not an output layer");
                return 0.0;
            }
            BaseOutputLayer ol = (BaseOutputLayer)outLayer;
            ol.setLabels(labels[i++]);
            score += ol.computeScore(l1, l2, true);
            l1 = 0.0;
            l2 = 0.0;
        }
        if (hasMaskArrays) {
            this.clearLayerMaskArrays();
        }
        return score;
    }

    public INDArray scoreExamples(DataSet data, boolean addRegularizationTerms) {
        if (this.numInputArrays != 1 || this.numOutputArrays != 1) {
            throw new UnsupportedOperationException("Cannot score ComputationGraph network with  DataSet: network does not have 1 input and 1 output arrays");
        }
        return this.scoreExamples(ComputationGraphUtil.toMultiDataSet(data), addRegularizationTerms);
    }

    public INDArray scoreExamples(MultiDataSet data, boolean addRegularizationTerms) {
        boolean hasMaskArray = data.hasMaskArrays();
        if (hasMaskArray) {
            this.setLayerMaskArrays(data.getFeaturesMaskArrays(), data.getLabelsMaskArrays());
        }
        this.feedForward(data.getFeatures(), false);
        this.setLabels(data.getLabels());
        INDArray out = null;
        double l1 = addRegularizationTerms ? this.calcL1() : 0.0;
        double l2 = addRegularizationTerms ? this.calcL2() : 0.0;
        int i = 0;
        for (String s : this.configuration.getNetworkOutputs()) {
            Layer outLayer = this.verticesMap.get(s).getLayer();
            if (outLayer == null || !(outLayer instanceof BaseOutputLayer)) {
                throw new UnsupportedOperationException("Cannot calculate score: vertex \"" + s + "\" is not an output layer");
            }
            BaseOutputLayer ol = (BaseOutputLayer)outLayer;
            ol.setLabels(this.labels[i++]);
            INDArray scoreCurrLayer = ol.computeScoreForExamples(l1, l2);
            if (out == null) {
                out = scoreCurrLayer;
            } else {
                out.addi(scoreCurrLayer);
            }
            l1 = 0.0;
            l2 = 0.0;
        }
        if (hasMaskArray) {
            this.clearLayerMaskArrays();
        }
        return out;
    }

    @Override
    public void fit() {
        this.fit(this.inputs, this.labels, this.inputMaskArrays, this.labelMaskArrays);
    }

    @Override
    public void update(INDArray gradient, String paramType) {
        throw new UnsupportedOperationException("Not implemented");
    }

    private void update(Task task) {
        if (!this.initDone) {
            this.initDone = true;
            Heartbeat heartbeat = Heartbeat.getInstance();
            task = ModelSerializer.taskByModel(this);
            Environment env = EnvironmentUtils.buildEnvironment();
            heartbeat.reportEvent(Event.STANDALONE, env, task);
        }
    }

    @Override
    public double score() {
        return this.score;
    }

    public void setScore(double score) {
        this.score = score;
    }

    @Override
    public void accumulateScore(double accum) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public INDArray params() {
        return this.params(true);
    }

    @Override
    public int numParams() {
        return this.numParams(true);
    }

    @Override
    public int numParams(boolean backwards) {
        int nParams = 0;
        for (Layer layer : this.layers) {
            nParams += layer.numParams(backwards);
        }
        return nParams;
    }

    @Override
    public void setParams(INDArray params) {
        if (params == this.flattenedParams) {
            return;
        }
        if (this.flattenedParams != null && this.flattenedParams.length() == params.length()) {
            this.flattenedParams.assign(params);
            return;
        }
        int idx = 0;
        for (int i = 0; i < this.topologicalOrder.length; ++i) {
            int range;
            if (!this.vertices[this.topologicalOrder[i]].hasLayer()) continue;
            Layer layer = this.vertices[this.topologicalOrder[i]].getLayer();
            int n = range = layer instanceof BasePretrainNetwork ? ((BasePretrainNetwork)layer).numParamsBackprop() : layer.numParams();
            if (range <= 0) continue;
            INDArray get = params.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)idx, (int)(range + idx))});
            layer.setParams(get);
            idx += range;
        }
    }

    @Override
    public void setParamsViewArray(INDArray params) {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    public void setBackpropGradientsViewArray(INDArray gradients) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public void applyLearningRateScoreDecay() {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public void fit(INDArray data) {
        throw new UnsupportedOperationException("Cannot pretrain ComputationGraph with single INDArray");
    }

    @Override
    public void iterate(INDArray input) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public Gradient gradient() {
        return this.gradient;
    }

    @Override
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair<Gradient, Double>(this.gradient(), this.score());
    }

    @Override
    public int batchSize() {
        return this.inputs[0].size(0);
    }

    @Override
    public NeuralNetConfiguration conf() {
        return this.defaultConfiguration;
    }

    @Override
    public void setConf(NeuralNetConfiguration conf) {
        throw new UnsupportedOperationException();
    }

    @Override
    public INDArray input() {
        if (this.numInputArrays == 1) {
            return this.inputs != null ? this.inputs[0] : null;
        }
        throw new UnsupportedOperationException("Cannot return single input: ComputationGraph  has multiple inputs");
    }

    @Override
    public void validateInput() {
    }

    @Override
    public ConvexOptimizer getOptimizer() {
        return this.solver.getOptimizer();
    }

    @Override
    public INDArray getParam(String param) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public void initParams() {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public Map<String, INDArray> paramTable() {
        LinkedHashMap<String, INDArray> allParams = new LinkedHashMap<String, INDArray>();
        for (Layer layer : this.layers) {
            Map<String, INDArray> paramMap = layer.paramTable();
            for (Map.Entry<String, INDArray> entry : paramMap.entrySet()) {
                String newKey = layer.conf().getLayer().getLayerName() + "_" + entry.getKey();
                allParams.put(newKey, entry.getValue());
            }
        }
        return allParams;
    }

    @Override
    public void setParamTable(Map<String, INDArray> paramTable) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public void setParam(String key, INDArray val) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public void clear() {
        this.inputs = null;
        this.labels = null;
        this.inputMaskArrays = null;
        this.labelMaskArrays = null;
    }

    public INDArray[] rnnTimeStep(INDArray ... inputs) {
        boolean inputIs2d = true;
        for (INDArray i : inputs) {
            if (i.rank() == 2) continue;
            inputIs2d = false;
            break;
        }
        INDArray[] outputs = new INDArray[this.numOutputArrays];
        for (int currVertexIdx : this.topologicalOrder) {
            VertexIndices[] outputsTo;
            Layer l;
            int vIdx;
            GraphVertex current = this.vertices[currVertexIdx];
            if (current.isInputVertex()) {
                VertexIndices[] inputsTo = current.getOutputVertices();
                INDArray input = inputs[current.getVertexIndex()];
                for (VertexIndices v : inputsTo) {
                    vIdx = v.getVertexIndex();
                    int vIdxInputNum = v.getVertexEdgeNumber();
                    this.vertices[vIdx].setInput(vIdxInputNum, input.dup());
                }
                continue;
            }
            INDArray out = current.hasLayer() ? ((l = current.getLayer()) instanceof BaseRecurrentLayer ? ((BaseRecurrentLayer)l).rnnTimeStep(current.getInputs()[0]) : (l instanceof MultiLayerNetwork ? ((MultiLayerNetwork)l).rnnTimeStep(current.getInputs()[0]) : current.doForward(false))) : current.doForward(false);
            if (current.isOutputVertex()) {
                int idx = this.configuration.getNetworkOutputs().indexOf(current.getVertexName());
                outputs[idx] = out;
            }
            if ((outputsTo = current.getOutputVertices()) == null) continue;
            for (VertexIndices v : outputsTo) {
                vIdx = v.getVertexIndex();
                int inputNum = v.getVertexEdgeNumber();
                this.vertices[vIdx].setInput(inputNum, out);
            }
        }
        if (inputIs2d) {
            for (int i = 0; i < outputs.length; ++i) {
                if (outputs[i].rank() != 3 || outputs[i].size(2) != 1) continue;
                outputs[i] = outputs[i].tensorAlongDimension(0, new int[]{1, 0});
            }
        }
        return outputs;
    }

    public Map<String, INDArray> rnnGetPreviousState(int layer) {
        return this.rnnGetPreviousState(this.layers[layer].conf().getLayer().getLayerName());
    }

    public Map<String, INDArray> rnnGetPreviousState(String layerName) {
        Layer l = this.verticesMap.get(layerName).getLayer();
        if (l == null || !(l instanceof BaseRecurrentLayer)) {
            return null;
        }
        return ((BaseRecurrentLayer)l).rnnGetPreviousState();
    }

    public Map<String, Map<String, INDArray>> rnnGetPreviousStates() {
        HashMap<String, Map<String, INDArray>> states = new HashMap<String, Map<String, INDArray>>();
        for (Layer l : this.layers) {
            if (!(l instanceof BaseRecurrentLayer)) continue;
            states.put(l.conf().getLayer().getLayerName(), ((BaseRecurrentLayer)l).rnnGetPreviousState());
        }
        return states;
    }

    public void rnnSetPreviousState(int layer, Map<String, INDArray> state) {
        this.rnnSetPreviousState(this.layers[layer].conf().getLayer().getLayerName(), state);
    }

    public void rnnSetPreviousState(String layerName, Map<String, INDArray> state) {
        Layer l = this.verticesMap.get(layerName).getLayer();
        if (l == null || !(l instanceof BaseRecurrentLayer)) {
            throw new UnsupportedOperationException("Layer \"" + layerName + "\" is not a recurrent layer. Cannot set state");
        }
        ((BaseRecurrentLayer)l).rnnSetPreviousState(state);
    }

    public void rnnSetPreviousStates(Map<String, Map<String, INDArray>> previousStates) {
        for (Map.Entry<String, Map<String, INDArray>> entry : previousStates.entrySet()) {
            this.rnnSetPreviousState(entry.getKey(), entry.getValue());
        }
    }

    public void rnnClearPreviousState() {
        if (this.layers == null) {
            return;
        }
        for (Layer layer : this.layers) {
            if (layer instanceof BaseRecurrentLayer) {
                ((BaseRecurrentLayer)layer).rnnClearPreviousState();
                continue;
            }
            if (!(layer instanceof MultiLayerNetwork)) continue;
            ((MultiLayerNetwork)layer).rnnClearPreviousState();
        }
    }

    protected void doTruncatedBPTT(INDArray[] inputs, INDArray[] labels, INDArray[] featureMasks, INDArray[] labelMasks) {
        if (this.flattenedGradients == null) {
            this.initGradientsView();
        }
        int timeSeriesLength = -1;
        for (INDArray in : inputs) {
            if (in.rank() != 3) continue;
            if (timeSeriesLength == -1) {
                timeSeriesLength = in.size(2);
                continue;
            }
            if (timeSeriesLength == in.size(2)) continue;
            log.warn("Cannot do TBPTT with time series of different lengths");
            return;
        }
        for (INDArray out : labels) {
            if (out.rank() != 3) continue;
            if (timeSeriesLength == -1) {
                timeSeriesLength = out.size(2);
                continue;
            }
            if (timeSeriesLength == out.size(2)) continue;
            log.warn("Cannot do TBPTT with time series of different lengths");
            return;
        }
        int fwdLen = this.configuration.getTbpttFwdLength();
        if (fwdLen > timeSeriesLength) {
            log.warn("Cannot do TBPTT: Truncated BPTT forward length (" + fwdLen + ") > input time series length (" + timeSeriesLength + ")");
            return;
        }
        int nSubsets = timeSeriesLength / fwdLen;
        this.rnnClearPreviousState();
        INDArray[] newInputs = new INDArray[inputs.length];
        INDArray[] newLabels = new INDArray[labels.length];
        INDArray[] newFeatureMasks = featureMasks != null ? new INDArray[featureMasks.length] : null;
        INDArray[] newLabelMasks = labelMasks != null ? new INDArray[labelMasks.length] : null;
        for (int i = 0; i < nSubsets; ++i) {
            int j;
            int startTimeIdx = i * fwdLen;
            int endTimeIdx = startTimeIdx + fwdLen;
            for (j = 0; j < inputs.length; ++j) {
                newInputs[j] = inputs[j].rank() != 3 ? inputs[j] : inputs[j].get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval((int)startTimeIdx, (int)endTimeIdx)});
            }
            for (j = 0; j < labels.length; ++j) {
                newLabels[j] = labels[j].rank() != 3 ? labels[j] : labels[j].get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval((int)startTimeIdx, (int)endTimeIdx)});
            }
            if (featureMasks != null) {
                for (j = 0; j < featureMasks.length; ++j) {
                    if (featureMasks[j] == null) continue;
                    newFeatureMasks[j] = featureMasks[j].get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)startTimeIdx, (int)endTimeIdx)});
                }
            }
            if (labelMasks != null) {
                for (j = 0; j < labelMasks.length; ++j) {
                    if (labelMasks[j] == null) continue;
                    newLabelMasks[j] = labelMasks[j].get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)startTimeIdx, (int)endTimeIdx)});
                }
            }
            this.setInputs(newInputs);
            this.setLabels(newLabels);
            this.setLayerMaskArrays(newFeatureMasks, newLabelMasks);
            if (this.solver == null) {
                this.solver = new Solver.Builder().configure(this.conf()).listeners(this.getListeners()).model(this).build();
            }
            this.solver.optimize();
            this.rnnUpdateStateWithTBPTTState();
        }
        this.rnnClearPreviousState();
    }

    public Map<String, INDArray> rnnActivateUsingStoredState(INDArray[] inputs, boolean training, boolean storeLastForTBPTT) {
        HashMap<String, INDArray> layerActivations = new HashMap<String, INDArray>();
        for (int currVertexIdx : this.topologicalOrder) {
            INDArray out;
            int vIdx;
            GraphVertex current = this.vertices[currVertexIdx];
            if (current.isInputVertex()) {
                VertexIndices[] inputsTo = current.getOutputVertices();
                INDArray input = inputs[current.getVertexIndex()];
                layerActivations.put(current.getVertexName(), input);
                for (VertexIndices v : inputsTo) {
                    vIdx = v.getVertexIndex();
                    int vIdxInputNum = v.getVertexEdgeNumber();
                    this.vertices[vIdx].setInput(vIdxInputNum, input.dup());
                }
                continue;
            }
            if (current.hasLayer()) {
                Layer l = current.getLayer();
                if (l instanceof BaseRecurrentLayer) {
                    out = ((BaseRecurrentLayer)l).rnnActivateUsingStoredState(current.getInputs()[0], training, storeLastForTBPTT);
                } else if (l instanceof MultiLayerNetwork) {
                    List<INDArray> temp = ((MultiLayerNetwork)l).rnnActivateUsingStoredState(current.getInputs()[0], training, storeLastForTBPTT);
                    out = temp.get(temp.size() - 1);
                } else {
                    out = current.doForward(training);
                }
                layerActivations.put(current.getVertexName(), out);
            } else {
                out = current.doForward(training);
            }
            VertexIndices[] outputsTo = current.getOutputVertices();
            if (outputsTo == null) continue;
            for (VertexIndices v : outputsTo) {
                vIdx = v.getVertexIndex();
                int inputNum = v.getVertexEdgeNumber();
                this.vertices[vIdx].setInput(inputNum, out);
            }
        }
        return layerActivations;
    }

    public void setLayerMaskArrays(INDArray[] featureMaskArrays, INDArray[] labelMaskArrays) {
        int i;
        this.inputMaskArrays = featureMaskArrays;
        this.labelMaskArrays = labelMaskArrays;
        if (featureMaskArrays != null) {
            if (featureMaskArrays.length != this.numInputArrays) {
                throw new IllegalArgumentException("Invalid number of feature mask arrays");
            }
            for (i = 0; i < featureMaskArrays.length; ++i) {
                VertexIndices[] outputsFromThisInput;
                String inputName = this.configuration.getNetworkInputs().get(i);
                INDArray reshapedFeaturesMask = TimeSeriesUtils.reshapeTimeSeriesMaskToVector(featureMaskArrays[i]);
                LinkedList<String> stack = new LinkedList<String>();
                GraphVertex gv = this.verticesMap.get(inputName);
                for (VertexIndices v : outputsFromThisInput = gv.getOutputVertices()) {
                    stack.addLast(this.vertices[v.getVertexIndex()].getVertexName());
                }
                while (!stack.isEmpty()) {
                    String nextVertexName = (String)stack.removeLast();
                    GraphVertex nextVertex = this.verticesMap.get(nextVertexName);
                    if (nextVertex.hasLayer()) {
                        Layer l = nextVertex.getLayer();
                        if (l instanceof BaseRecurrentLayer) continue;
                        if (l.type() == Layer.Type.FEED_FORWARD || l.type() == Layer.Type.CONVOLUTIONAL) {
                            l.setMaskArray(reshapedFeaturesMask);
                        }
                    }
                    if ((outputsFromThisInput = nextVertex.getOutputVertices()) == null) continue;
                    for (VertexIndices v : outputsFromThisInput) {
                        stack.addLast(this.vertices[v.getVertexIndex()].getVertexName());
                    }
                }
            }
        }
        if (labelMaskArrays != null) {
            if (labelMaskArrays.length != this.numOutputArrays) {
                throw new IllegalArgumentException("Invalid number of label mask arrays");
            }
            for (i = 0; i < labelMaskArrays.length; ++i) {
                String outputName = this.configuration.getNetworkOutputs().get(i);
                GraphVertex v = this.verticesMap.get(outputName);
                Layer ol = v.getLayer();
                ol.setMaskArray(labelMaskArrays[i]);
            }
        }
    }

    public void clearLayerMaskArrays() {
        for (Layer layer : this.layers) {
            layer.setMaskArray(null);
        }
        this.inputMaskArrays = null;
        this.labelMaskArrays = null;
    }

    protected void rnnUpdateStateWithTBPTTState() {
        for (int i = 0; i < this.layers.length; ++i) {
            if (this.layers[i] instanceof BaseRecurrentLayer) {
                BaseRecurrentLayer l = (BaseRecurrentLayer)this.layers[i];
                l.rnnSetPreviousState(l.rnnGetTBPTTState());
                continue;
            }
            if (!(this.layers[i] instanceof MultiLayerNetwork)) continue;
            ((MultiLayerNetwork)this.layers[i]).updateRnnStateWithTBPTTState();
        }
    }

    public void setInitDone(boolean initDone) {
        this.initDone = initDone;
    }
}

