package org.deeplearning4j.nn.updater.graph;

import java.util.Arrays;
import java.util.HashMap;
import org.deeplearning4j.nn.api.Trainable;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
import org.deeplearning4j.nn.updater.BaseMultiLayerUpdater;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/nn/updater/graph/ComputationGraphUpdater.class */
public class ComputationGraphUpdater extends BaseMultiLayerUpdater<ComputationGraph> {
    protected Trainable[] orderedLayers;

    public ComputationGraphUpdater(ComputationGraph computationGraph) {
        this(computationGraph, null);
    }

    public ComputationGraphUpdater(ComputationGraph computationGraph, INDArray iNDArray) {
        super(computationGraph, iNDArray);
        this.layersByName = new HashMap();
        for (Trainable trainable : getOrderedLayers()) {
            this.layersByName.put(trainable.getConfig().getLayerName(), trainable);
        }
    }

    @Override // org.deeplearning4j.nn.updater.BaseMultiLayerUpdater
    protected Trainable[] getOrderedLayers() {
        if (this.orderedLayers != null) {
            return this.orderedLayers;
        }
        GraphVertex[] vertices = ((ComputationGraph) this.network).getVertices();
        int[] iArr = ((ComputationGraph) this.network).topologicalSortOrder();
        Trainable[] trainableArr = new Trainable[((ComputationGraph) this.network).getVertices().length];
        int i = 0;
        for (int i2 : iArr) {
            GraphVertex graphVertex = vertices[i2];
            if (graphVertex.numParams() != 0) {
                int i3 = i;
                i++;
                trainableArr[i3] = graphVertex;
            }
        }
        if (i != trainableArr.length) {
            trainableArr = (Trainable[]) Arrays.copyOfRange(trainableArr, 0, i);
        }
        this.orderedLayers = trainableArr;
        return this.orderedLayers;
    }

    @Override // org.deeplearning4j.nn.updater.BaseMultiLayerUpdater
    public INDArray getFlattenedGradientsView() {
        if (((ComputationGraph) this.network).getFlattenedGradients() == null) {
            ((ComputationGraph) this.network).initGradientsView();
        }
        return ((ComputationGraph) this.network).getFlattenedGradients();
    }

    @Override // org.deeplearning4j.nn.updater.BaseMultiLayerUpdater
    protected INDArray getParams() {
        return ((ComputationGraph) this.network).params();
    }

    @Override // org.deeplearning4j.nn.updater.BaseMultiLayerUpdater
    protected boolean isMiniBatch() {
        return ((ComputationGraph) this.network).conf().isMiniBatch();
    }
}
