/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.updater.graph;

import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.updater.UpdaterCreator;
import org.deeplearning4j.nn.updater.aggregate.UpdaterAggregator;
import org.nd4j.linalg.api.ndarray.INDArray;

public class ComputationGraphUpdater
implements Serializable,
Cloneable {
    private final Updater[] layerUpdaters;
    private final Map<String, Integer> layerUpdatersMap;

    public ComputationGraphUpdater(ComputationGraph graph) {
        this.layerUpdaters = new Updater[graph.getNumLayers()];
        this.layerUpdatersMap = new HashMap<String, Integer>();
        int i = 0;
        for (Layer layer : graph.getLayers()) {
            Updater u;
            this.layerUpdaters[i] = u = UpdaterCreator.getUpdater(layer);
            this.layerUpdatersMap.put(layer.conf().getLayer().getLayerName(), i);
            ++i;
        }
    }

    private ComputationGraphUpdater(int size, Map<String, Integer> layerUpdatersMap) {
        this.layerUpdaters = new Updater[size];
        this.layerUpdatersMap = layerUpdatersMap;
    }

    private ComputationGraphUpdater(ComputationGraphUpdater updater) {
        this.layerUpdaters = new Updater[updater.layerUpdaters.length];
        for (int i = 0; i < this.layerUpdaters.length; ++i) {
            this.layerUpdaters[i] = updater.layerUpdaters[i].clone();
        }
        this.layerUpdatersMap = new HashMap<String, Integer>(updater.layerUpdatersMap);
    }

    public ComputationGraphUpdater clone() {
        return new ComputationGraphUpdater(this);
    }

    public void update(ComputationGraph graph, Gradient gradient, int iteration, int batchSize) {
        HashMap<String, Gradient> layerGradients = new HashMap<String, Gradient>();
        for (Map.Entry<String, INDArray> entry : gradient.gradientForVariable().entrySet()) {
            String key = entry.getKey();
            int idx = key.lastIndexOf("_");
            if (idx == -1) {
                throw new IllegalStateException("Invalid key: ComputationGraph Gradient key does not have layer separator: \"" + key + "\"");
            }
            String layerName = key.substring(0, idx);
            Gradient g = (Gradient)layerGradients.get(layerName);
            if (g == null) {
                g = new DefaultGradient();
                layerGradients.put(layerName, g);
            }
            String newKey = key.substring(idx + 1);
            g.setGradientFor(newKey, entry.getValue());
        }
        for (Map.Entry<String, Object> entry : layerGradients.entrySet()) {
            String layerName = entry.getKey();
            int updaterIdx = this.layerUpdatersMap.get(layerName);
            this.layerUpdaters[updaterIdx].update(graph.getLayer(layerName), (Gradient)entry.getValue(), iteration, batchSize);
            for (Map.Entry<String, INDArray> entry2 : ((Gradient)layerGradients.get(layerName)).gradientForVariable().entrySet()) {
                gradient.setGradientFor(entry.getKey() + "_" + entry2.getKey(), entry2.getValue());
            }
        }
    }

    public Aggregator getAggregator(boolean addThis) {
        Aggregator aggregator = new Aggregator();
        if (addThis) {
            aggregator.aggregate(this);
        }
        return aggregator;
    }

    public static class Aggregator
    implements Serializable {
        private UpdaterAggregator[] aggregators;
        private Map<String, Integer> layerNamesMap;

        public void aggregate(ComputationGraphUpdater updater) {
            if (this.aggregators == null) {
                this.aggregators = new UpdaterAggregator[updater.layerUpdaters.length];
                for (int i = 0; i < updater.layerUpdaters.length; ++i) {
                    this.aggregators[i] = updater.layerUpdaters[i].getAggregator(true);
                }
                this.layerNamesMap = new HashMap<String, Integer>(updater.layerUpdatersMap);
            } else {
                if (updater.layerUpdaters == null) {
                    return;
                }
                for (int i = 0; i < this.aggregators.length; ++i) {
                    this.aggregators[i].aggregate(updater.layerUpdaters[i]);
                }
            }
        }

        public void merge(Aggregator aggregator) {
            if (this.aggregators == null) {
                this.aggregators = aggregator.aggregators;
            } else if (aggregator.aggregators != null) {
                for (int i = 0; i < this.aggregators.length; ++i) {
                    this.aggregators[i].merge(aggregator.aggregators[i]);
                }
            }
        }

        public ComputationGraphUpdater getUpdater() {
            ComputationGraphUpdater updater = new ComputationGraphUpdater(this.aggregators.length, this.layerNamesMap);
            for (int i = 0; i < this.aggregators.length; ++i) {
                ((ComputationGraphUpdater)updater).layerUpdaters[i] = this.aggregators[i].getUpdater();
            }
            return updater;
        }
    }
}

