/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.updater;

import com.google.common.base.Function;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.math3.util.FastMath;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.impl.accum.Norm2;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.AbsValueGreaterThan;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.ops.transforms.Transforms;

public abstract class BaseUpdater
implements Updater {
    protected Map<String, GradientUpdater> updaterForVariable = new HashMap<String, GradientUpdater>();

    @Override
    public void update(Layer layer, Gradient gradient, int iteration) {
        this.preApply(layer, gradient, iteration);
        for (Map.Entry<String, INDArray> gradientPair : gradient.gradientForVariable().entrySet()) {
            String paramName = gradientPair.getKey();
            INDArray paramVal = gradientPair.getValue();
            if (layer.conf().isUseSchedules()) {
                this.checkSchedules(layer, iteration, paramName);
            }
            GradientUpdater updater = this.init(paramName, paramVal, layer);
            INDArray gradient2 = updater.getGradient(paramVal, iteration);
            this.postApply(layer, gradient2, paramName);
            gradient.setGradientFor(paramName, gradient2);
        }
    }

    public void postApply(Layer layer, INDArray gradient, String param) {
        NeuralNetConfiguration conf = layer.conf();
        INDArray params = layer.getParam(param);
        if (conf.isUseRegularization() && conf.getLayer().getL2() > 0.0 && !param.equals("b")) {
            gradient.addi(params.mul((Number)conf.getLayer().getL2()));
        }
        if (conf.isUseRegularization() && conf.getLayer().getL1() > 0.0 && !param.equals("b")) {
            gradient.addi(Transforms.sign((INDArray)params).muli((Number)conf.getLayer().getL1()));
        }
        if (conf.isMiniBatch()) {
            gradient.divi((Number)layer.getInputMiniBatchSize());
        }
        if (conf.isConstrainGradientToUnitNorm()) {
            gradient.divi(gradient.norm2(new int[]{Integer.MAX_VALUE}));
        }
    }

    public void checkSchedules(Layer layer, int iteration, String param) {
        NeuralNetConfiguration conf = layer.conf();
        if (conf.getLayer().getLearningRateAfter().containsKey(iteration)) {
            conf.getLayer().setLearningRate(conf.getLayer().getLearningRateAfter().get(iteration));
            if (this.updaterForVariable.get(param) != null) {
                this.updaterForVariable.get(param).update(new Object[]{conf.getLayer().getLearningRateAfter().get(iteration)});
            }
        }
        if (conf.getLayer().getMomentumAfter().containsKey(iteration)) {
            conf.getLayer().setMomentum(conf.getLayer().getMomentumAfter().get(iteration));
            if (this.updaterForVariable.get(param) != null) {
                this.updaterForVariable.get(param).update(new Object[]{conf.getLayer().getLearningRate(), conf.getLayer().getMomentumAfter().get(iteration)});
            }
        }
    }

    public void preApply(Layer layer, Gradient gradient, int iteration) {
        GradientNormalization normalization = layer.conf().getLayer().getGradientNormalization();
        if (normalization == null || normalization == GradientNormalization.None) {
            return;
        }
        final double threshold = layer.conf().getLayer().getGradientNormalizationThreshold();
        switch (normalization) {
            case RenormalizeL2PerLayer: {
                double sumSquares = 0.0;
                for (INDArray g : gradient.gradientForVariable().values()) {
                    double l2 = g.norm2Number().doubleValue();
                    sumSquares += l2 * l2;
                }
                double layerL2 = FastMath.sqrt((double)sumSquares);
                for (INDArray g : gradient.gradientForVariable().values()) {
                    g.divi((Number)layerL2);
                }
                break;
            }
            case RenormalizeL2PerParamType: {
                for (INDArray g : gradient.gradientForVariable().values()) {
                    double l2 = Nd4j.getExecutioner().execAndReturn((Accumulation)new Norm2(g)).getFinalResult().doubleValue();
                    g.divi((Number)l2);
                }
                break;
            }
            case ClipElementWiseAbsoluteValue: {
                AbsValueGreaterThan absValueCondition = new AbsValueGreaterThan((Number)threshold);
                Function<Number, Number> clipFn = new Function<Number, Number>(){

                    public Number apply(Number number) {
                        return number.doubleValue() > threshold ? threshold : -threshold;
                    }
                };
                for (INDArray g : gradient.gradientForVariable().values()) {
                    BooleanIndexing.applyWhere((INDArray)g, (Condition)absValueCondition, (Function)clipFn);
                }
                break;
            }
            case ClipL2PerLayer: {
                double sumSquares2 = 0.0;
                for (INDArray g : gradient.gradientForVariable().values()) {
                    double l2 = Nd4j.getExecutioner().execAndReturn((Accumulation)new Norm2(g)).getFinalResult().doubleValue();
                    sumSquares2 += l2 * l2;
                }
                double layerL22 = FastMath.sqrt((double)sumSquares2);
                if (!(layerL22 > threshold)) break;
                double scalingFactor = threshold / layerL22;
                for (INDArray g : gradient.gradientForVariable().values()) {
                    g.muli((Number)scalingFactor);
                }
                break;
            }
            case ClipL2PerParamType: {
                for (INDArray g : gradient.gradientForVariable().values()) {
                    double l2 = g.norm2Number().doubleValue();
                    if (!(l2 > threshold)) continue;
                    double scalingFactor = l2 / threshold;
                    g.divi((Number)scalingFactor);
                }
                break;
            }
            default: {
                throw new RuntimeException("Unknown (or not implemented) gradient normalization strategy: " + (Object)((Object)normalization));
            }
        }
    }

    public abstract void init();

    public abstract GradientUpdater init(String var1, INDArray var2, Layer var3);
}

