/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.updater;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.Trainable;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.updater.LayerUpdater;
import org.deeplearning4j.nn.updater.UpdaterBlock;
import org.deeplearning4j.nn.updater.UpdaterUtils;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.config.IUpdater;

public abstract class BaseMultiLayerUpdater<T extends Model>
implements Updater {
    protected final T network;
    protected Map<String, Trainable> layersByName;
    protected final List<UpdaterBlock> updaterBlocks;
    protected INDArray updaterStateViewArray;
    protected boolean initializedMinibatchDivision;
    protected List<INDArray> gradientsForMinibatchDivision;

    public BaseMultiLayerUpdater(T network) {
        this(network, null);
    }

    public BaseMultiLayerUpdater(T network, INDArray updaterState) {
        this.network = network;
        Trainable[] layers = this.getOrderedLayers();
        int updaterStateSize = 0;
        Trainable lastLayer = null;
        String lastVariable = null;
        UpdaterBlock currentBlock = null;
        this.updaterBlocks = new ArrayList<UpdaterBlock>();
        INDArray paramsView = network.params();
        INDArray gradientView = this.getFlattenedGradientsView();
        int paramsViewSoFar = 0;
        int currentUpdaterOffset = 0;
        for (int i = 0; i < layers.length; ++i) {
            Map<String, INDArray> layerParamTable = layers[i].paramTable(false);
            if (layerParamTable == null) continue;
            ArrayList<String> variables = new ArrayList<String>(layerParamTable.keySet());
            for (int j = 0; j < variables.size(); ++j) {
                String var = (String)variables.get(j);
                long paramSizeThisVariable = layerParamTable.get(var).length();
                IUpdater u = layers[i].getConfig().getUpdaterByParam(var);
                Preconditions.checkNotNull((Object)u, (String)"Updater for parameter %s, layer \"%s\" was null", (Object)var, (Object)layers[i].getConfig().getLayerName());
                int updaterStateSizeThisVariable = (int)u.stateSize(paramSizeThisVariable);
                INDArray gradientViewSubset = null;
                INDArray paramsViewSubset = null;
                if (paramSizeThisVariable > 0L) {
                    paramsViewSubset = paramsView.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)paramsViewSoFar, (long)((long)paramsViewSoFar + paramSizeThisVariable))});
                    gradientViewSubset = gradientView.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)paramsViewSoFar, (long)((long)paramsViewSoFar + paramSizeThisVariable))});
                }
                if (currentBlock == null || !UpdaterUtils.updaterConfigurationsEquals(lastLayer, lastVariable, layers[i], var)) {
                    ArrayList<UpdaterBlock.ParamState> list = new ArrayList<UpdaterBlock.ParamState>();
                    list.add(new UpdaterBlock.ParamState(layers[i], var, paramsViewSoFar, (int)((long)paramsViewSoFar + paramSizeThisVariable), paramsViewSubset, gradientViewSubset));
                    currentBlock = new UpdaterBlock(paramsViewSoFar, (int)((long)paramsViewSoFar + paramSizeThisVariable), currentUpdaterOffset, currentUpdaterOffset + updaterStateSizeThisVariable, list);
                    this.updaterBlocks.add(currentBlock);
                } else {
                    currentBlock.setParamOffsetEnd((int)((long)currentBlock.getParamOffsetEnd() + paramSizeThisVariable));
                    currentBlock.setUpdaterViewOffsetEnd(currentBlock.getUpdaterViewOffsetEnd() + updaterStateSizeThisVariable);
                    currentBlock.getLayersAndVariablesInBlock().add(new UpdaterBlock.ParamState(layers[i], var, paramsViewSoFar, (int)((long)paramsViewSoFar + paramSizeThisVariable), paramsViewSubset, gradientViewSubset));
                }
                lastLayer = layers[i];
                lastVariable = (String)variables.get(j);
                updaterStateSize += updaterStateSizeThisVariable;
                paramsViewSoFar = (int)((long)paramsViewSoFar + paramSizeThisVariable);
                currentUpdaterOffset += updaterStateSizeThisVariable;
            }
        }
        boolean updaterRequiresInit = false;
        if (updaterState != null) {
            this.updaterStateViewArray = updaterState;
            updaterRequiresInit = false;
        } else if (updaterStateSize > 0) {
            this.updaterStateViewArray = Nd4j.createUninitialized((DataType)network.params().dataType(), (long[])new long[]{1L, updaterStateSize}, (char)Nd4j.order().charValue());
            updaterRequiresInit = true;
        }
        int updaterViewSoFar = 0;
        paramsViewSoFar = 0;
        for (int i = 0; i < this.updaterBlocks.size(); ++i) {
            UpdaterBlock ub = this.updaterBlocks.get(i);
            int viewStateSize = ub.getUpdaterViewOffsetEnd() - ub.getUpdaterViewOffsetStart();
            int gradSize = ub.getParamOffsetEnd() - ub.getParamOffsetStart();
            if (viewStateSize > 0) {
                INDArray updaterViewSubset = this.updaterStateViewArray.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((int)updaterViewSoFar, (int)(updaterViewSoFar + viewStateSize))});
                ub.setUpdaterView(updaterViewSubset);
                ub.setUpdaterViewRequiresInitialization(updaterRequiresInit);
            }
            if (gradSize > 0) {
                INDArray gradientViewSubset = gradientView.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((int)paramsViewSoFar, (int)(paramsViewSoFar + gradSize))});
                ub.setGradientView(gradientViewSubset);
            }
            ub.init();
            updaterViewSoFar += viewStateSize;
            paramsViewSoFar += gradSize;
        }
    }

    protected abstract Trainable[] getOrderedLayers();

    protected abstract INDArray getFlattenedGradientsView();

    protected abstract INDArray getParams();

    protected abstract boolean isMiniBatch();

    public void setStateViewArray(INDArray viewArray) {
        if (this.updaterStateViewArray == null) {
            if (viewArray == null) {
                return;
            }
            throw new IllegalStateException("Attempting to set updater state view array with null value");
        }
        if (this.updaterStateViewArray.length() != viewArray.length()) {
            throw new IllegalStateException("Invalid input: view arrays differ in length. Expected length " + this.updaterStateViewArray.length() + ", got length " + viewArray.length());
        }
        this.updaterStateViewArray.assign(viewArray);
    }

    @Override
    public void setStateViewArray(Trainable layer, INDArray viewArray, boolean initialize) {
        this.setStateViewArray(viewArray);
    }

    @Override
    public INDArray getStateViewArray() {
        return this.updaterStateViewArray;
    }

    public synchronized INDArray getStateViewArrayCopy() {
        Nd4j.getExecutioner().commit();
        return this.updaterStateViewArray.dup();
    }

    @Override
    public void update(Trainable layer, Gradient gradient, int iteration, int epoch, int batchSize, LayerWorkspaceMgr workspaceMgr) {
        this.update(gradient, iteration, epoch, batchSize, workspaceMgr);
    }

    public synchronized void update(Gradient gradient, int iteration, int epoch, int batchSize, LayerWorkspaceMgr workspaceMgr) {
        boolean isExternal = gradient.gradient() != this.getFlattenedGradientsView();
        HashMap<String, Gradient> layerGradients = new HashMap<String, Gradient>();
        Trainable[] layers = this.getOrderedLayers();
        if (layers.length == 1 && this.isSingleLayerUpdater()) {
            layerGradients.put(layers[0].getConfig().getLayerName(), gradient);
        } else {
            for (Map.Entry<String, INDArray> entry : gradient.gradientForVariable().entrySet()) {
                String key = entry.getKey();
                int idx = key.lastIndexOf(95);
                if (idx == -1) {
                    throw new IllegalStateException("Invalid key: Gradient key does not have layer separator: \"" + key + "\"");
                }
                String layerName = key.substring(0, idx);
                Gradient g = (Gradient)layerGradients.get(layerName);
                if (g == null) {
                    g = new DefaultGradient();
                    layerGradients.put(layerName, g);
                }
                String newKey = key.substring(idx + 1);
                g.setGradientFor(newKey, entry.getValue());
            }
        }
        if (this.isMiniBatch()) {
            this.divideByMinibatch(isExternal, gradient, batchSize);
        }
        for (Map.Entry<String, Object> entry : layerGradients.entrySet()) {
            String layerName = entry.getKey();
            Trainable layer = this.layersByName.get(layerName);
            this.preApply(layer, (Gradient)layerGradients.get(layerName), iteration);
        }
        if (this.getClass() != LayerUpdater.class) {
            workspaceMgr.assertNotOpen(ArrayType.UPDATER_WORKING_MEM, "Updater working memory");
        }
        for (UpdaterBlock updaterBlock : this.updaterBlocks) {
            if (updaterBlock.skipDueToPretrainConfig(this instanceof LayerUpdater)) continue;
            MemoryWorkspace ws = workspaceMgr.notifyScopeEntered(ArrayType.UPDATER_WORKING_MEM);
            Throwable throwable = null;
            try {
                if (isExternal) {
                    updaterBlock.updateExternalGradient(iteration, epoch, gradient.gradient(), this.getParams());
                    continue;
                }
                updaterBlock.update(iteration, epoch);
            }
            catch (Throwable throwable2) {
                throwable = throwable2;
                throw throwable2;
            }
            finally {
                if (ws == null) continue;
                if (throwable != null) {
                    try {
                        ws.close();
                    }
                    catch (Throwable throwable3) {
                        throwable.addSuppressed(throwable3);
                    }
                    continue;
                }
                ws.close();
            }
        }
    }

    protected void divideByMinibatch(boolean isExternal, Gradient gradient, int batchSize) {
        if (!this.initializedMinibatchDivision) {
            this.gradientsForMinibatchDivision = this.getMinibatchDivisionSubsets(this.getFlattenedGradientsView());
            this.initializedMinibatchDivision = true;
        }
        List<INDArray> toDivide = isExternal ? this.getMinibatchDivisionSubsets(gradient.gradient()) : this.gradientsForMinibatchDivision;
        for (INDArray arr : toDivide) {
            arr.divi((Number)batchSize);
        }
    }

    protected List<INDArray> getMinibatchDivisionSubsets(INDArray from) {
        ArrayList<INDArray> out = new ArrayList<INDArray>();
        long paramsSoFar = 0L;
        long currentStart = 0L;
        long currentEnd = 0L;
        for (Trainable t : this.getOrderedLayers()) {
            Set<String> layerParams = t.paramTable(false).keySet();
            Map<String, INDArray> paramTable = t.paramTable(false);
            for (String s : layerParams) {
                if (t.updaterDivideByMinibatch(s)) {
                    long l = paramTable.get(s).length();
                    currentEnd += l;
                } else {
                    if (currentEnd > currentStart) {
                        INDArray subset = from.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)currentStart, (long)currentEnd)});
                        out.add(subset);
                    }
                    currentEnd = currentStart = paramsSoFar + paramTable.get(s).length();
                }
                paramsSoFar += paramTable.get(s).length();
            }
        }
        if (currentEnd > currentStart && currentStart < from.length()) {
            INDArray subset = from.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)currentStart, (long)currentEnd)});
            out.add(subset);
        }
        return out;
    }

    protected boolean isSingleLayerUpdater() {
        return false;
    }

    public void preApply(Trainable layer, Gradient gradient, int iteration) {
        if (layer.getConfig() == null || layer.numParams() == 0L) {
            return;
        }
        GradientNormalization normalization = layer.getConfig().getGradientNormalization();
        if (normalization == null || normalization == GradientNormalization.None) {
            return;
        }
        double threshold = layer.getConfig().getGradientNormalizationThreshold();
        INDArray layerGradientView = layer.getGradientsViewArray();
        switch (normalization) {
            case RenormalizeL2PerLayer: {
                if (layerGradientView == null) break;
                double l2 = layerGradientView.norm2Number().doubleValue();
                if (l2 == 0.0) {
                    l2 = 1.0E-5;
                }
                layerGradientView.divi((Number)l2);
                break;
            }
            case RenormalizeL2PerParamType: {
                for (INDArray g : gradient.gradientForVariable().values()) {
                    double l2 = Nd4j.getExecutioner().execAndReturn((ReduceOp)new Norm2(g, new int[0])).getFinalResult().doubleValue();
                    if (l2 == 0.0) {
                        l2 = 1.0E-5;
                    }
                    g.divi((Number)l2);
                }
                break;
            }
            case ClipElementWiseAbsoluteValue: {
                if (layerGradientView == null) break;
                DynamicCustomOp op = DynamicCustomOp.builder((String)"clipbyvalue").addInputs(new INDArray[]{layerGradientView}).callInplace(true).addFloatingPointArguments(new Double[]{-threshold, threshold}).build();
                Nd4j.getExecutioner().exec((CustomOp)op);
                break;
            }
            case ClipL2PerLayer: {
                double layerL2;
                if (layerGradientView == null || !((layerL2 = layerGradientView.norm2Number().doubleValue()) > threshold)) break;
                double scalingFactor = threshold / layerL2;
                layerGradientView.muli((Number)scalingFactor);
                break;
            }
            case ClipL2PerParamType: {
                for (INDArray g : gradient.gradientForVariable().values()) {
                    double l2 = g.norm2Number().doubleValue();
                    if (!(l2 > threshold)) continue;
                    double scalingFactor = l2 / threshold;
                    g.divi((Number)scalingFactor);
                }
                break;
            }
            default: {
                throw new RuntimeException("Unknown (or not implemented) gradient normalization strategy: " + (Object)((Object)normalization));
            }
        }
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        BaseMultiLayerUpdater that = (BaseMultiLayerUpdater)o;
        return this.updaterStateViewArray != null ? this.updaterStateViewArray.equals(that.updaterStateViewArray) : that.updaterStateViewArray == null;
    }

    public int hashCode() {
        int result = this.layersByName != null ? this.layersByName.hashCode() : 0;
        result = 31 * result + (this.updaterBlocks != null ? this.updaterBlocks.hashCode() : 0);
        result = 31 * result + (this.updaterStateViewArray != null ? this.updaterStateViewArray.hashCode() : 0);
        return result;
    }

    public T getNetwork() {
        return this.network;
    }

    public Map<String, Trainable> getLayersByName() {
        return this.layersByName;
    }

    public List<UpdaterBlock> getUpdaterBlocks() {
        return this.updaterBlocks;
    }

    public INDArray getUpdaterStateViewArray() {
        return this.updaterStateViewArray;
    }

    public boolean isInitializedMinibatchDivision() {
        return this.initializedMinibatchDivision;
    }

    public List<INDArray> getGradientsForMinibatchDivision() {
        return this.gradientsForMinibatchDivision;
    }
}

