/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.serde;

import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.BaseOutputLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.learning.config.AdaDelta;
import org.nd4j.linalg.learning.config.AdaGrad;
import org.nd4j.linalg.learning.config.AdaMax;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.config.Nadam;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.learning.regularization.L1Regularization;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.learning.regularization.WeightDecay;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
import org.nd4j.linalg.lossfunctions.impl.LossL2;
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
import org.nd4j.linalg.lossfunctions.impl.LossMSE;
import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood;
import org.nd4j.shade.jackson.core.JsonParser;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.DeserializationContext;
import org.nd4j.shade.jackson.databind.JsonDeserializer;
import org.nd4j.shade.jackson.databind.JsonMappingException;
import org.nd4j.shade.jackson.databind.deser.ResolvableDeserializer;
import org.nd4j.shade.jackson.databind.deser.std.StdDeserializer;
import org.nd4j.shade.jackson.databind.node.ObjectNode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class BaseNetConfigDeserializer<T>
extends StdDeserializer<T>
implements ResolvableDeserializer {
    private static final Logger log = LoggerFactory.getLogger(BaseNetConfigDeserializer.class);
    protected final JsonDeserializer<?> defaultDeserializer;
    private static Map<String, Class<? extends IActivation>> activationMap;

    public BaseNetConfigDeserializer(JsonDeserializer<?> defaultDeserializer, Class<T> deserializedType) {
        super(deserializedType);
        this.defaultDeserializer = defaultDeserializer;
    }

    public abstract T deserialize(JsonParser var1, DeserializationContext var2) throws IOException, JsonProcessingException;

    protected boolean requiresIUpdaterFromLegacy(Layer[] layers) {
        for (Layer l : layers) {
            BaseLayer bl;
            if (!(l instanceof BaseLayer) || (bl = (BaseLayer)l).getIUpdater() != null || bl.initializer().numParams(bl) <= 0L) continue;
            return true;
        }
        return false;
    }

    protected boolean requiresDropoutFromLegacy(Layer[] layers) {
        for (Layer l : layers) {
            if (l.getIDropout() == null) continue;
            return false;
        }
        return true;
    }

    protected boolean requiresRegularizationFromLegacy(Layer[] layers) {
        for (Layer l : layers) {
            if (!(l instanceof BaseLayer) || ((BaseLayer)l).getRegularization() != null) continue;
            return true;
        }
        return false;
    }

    protected boolean requiresWeightInitFromLegacy(Layer[] layers) {
        for (Layer l : layers) {
            if (!(l instanceof BaseLayer) || ((BaseLayer)l).getWeightInitFn() != null) continue;
            return true;
        }
        return false;
    }

    protected boolean requiresActivationFromLegacy(Layer[] layers) {
        for (Layer l : layers) {
            if (!(l instanceof BaseLayer) || ((BaseLayer)l).getActivationFn() != null) continue;
            return true;
        }
        return false;
    }

    protected boolean requiresLegacyLossHandling(Layer[] layers) {
        for (Layer l : layers) {
            if (!(l instanceof BaseOutputLayer) || ((BaseOutputLayer)l).getLossFn() != null) continue;
            return true;
        }
        return false;
    }

    protected void handleUpdaterBackwardCompatibility(BaseLayer layer, ObjectNode on) {
        String updaterName;
        if (on != null && on.has("updater") && (updaterName = on.get("updater").asText()) != null) {
            Updater u = Updater.valueOf(updaterName);
            IUpdater iu = u.getIUpdaterWithDefaultConfig();
            double lr = on.get("learningRate").asDouble();
            double eps = on.has("epsilon") ? on.get("epsilon").asDouble() : Double.NaN;
            double rho = on.get("rho").asDouble();
            switch (u) {
                case SGD: {
                    ((Sgd)iu).setLearningRate(lr);
                    break;
                }
                case ADAM: {
                    if (Double.isNaN(eps)) {
                        eps = 1.0E-8;
                    }
                    ((Adam)iu).setLearningRate(lr);
                    ((Adam)iu).setBeta1(on.get("adamMeanDecay").asDouble());
                    ((Adam)iu).setBeta2(on.get("adamVarDecay").asDouble());
                    ((Adam)iu).setEpsilon(eps);
                    break;
                }
                case ADAMAX: {
                    if (Double.isNaN(eps)) {
                        eps = 1.0E-8;
                    }
                    ((AdaMax)iu).setLearningRate(lr);
                    ((AdaMax)iu).setBeta1(on.get("adamMeanDecay").asDouble());
                    ((AdaMax)iu).setBeta2(on.get("adamVarDecay").asDouble());
                    ((AdaMax)iu).setEpsilon(eps);
                    break;
                }
                case ADADELTA: {
                    if (Double.isNaN(eps)) {
                        eps = 1.0E-6;
                    }
                    ((AdaDelta)iu).setRho(rho);
                    ((AdaDelta)iu).setEpsilon(eps);
                    break;
                }
                case NESTEROVS: {
                    ((Nesterovs)iu).setLearningRate(lr);
                    ((Nesterovs)iu).setMomentum(on.get("momentum").asDouble());
                    break;
                }
                case NADAM: {
                    if (Double.isNaN(eps)) {
                        eps = 1.0E-8;
                    }
                    ((Nadam)iu).setLearningRate(lr);
                    ((Nadam)iu).setBeta1(on.get("adamMeanDecay").asDouble());
                    ((Nadam)iu).setBeta2(on.get("adamVarDecay").asDouble());
                    ((Nadam)iu).setEpsilon(eps);
                    break;
                }
                case ADAGRAD: {
                    if (Double.isNaN(eps)) {
                        eps = 1.0E-6;
                    }
                    ((AdaGrad)iu).setLearningRate(lr);
                    ((AdaGrad)iu).setEpsilon(eps);
                    break;
                }
                case RMSPROP: {
                    if (Double.isNaN(eps)) {
                        eps = 1.0E-8;
                    }
                    ((RmsProp)iu).setLearningRate(lr);
                    ((RmsProp)iu).setEpsilon(eps);
                    ((RmsProp)iu).setRmsDecay(on.get("rmsDecay").asDouble());
                    break;
                }
            }
            layer.setIUpdater(iu);
        }
    }

    protected void handleL1L2BackwardCompatibility(BaseLayer baseLayer, ObjectNode on) {
        if (on != null && (on.has("l1") || on.has("l2"))) {
            double l2Bias;
            double l1Bias;
            double l2;
            double l1;
            baseLayer.setRegularization(new ArrayList<Regularization>());
            baseLayer.setRegularizationBias(new ArrayList<Regularization>());
            if (on.has("l1") && (l1 = on.get("l1").doubleValue()) > 0.0) {
                baseLayer.getRegularization().add((Regularization)new L1Regularization(l1));
            }
            if (on.has("l2") && (l2 = on.get("l2").doubleValue()) > 0.0) {
                baseLayer.getRegularization().add((Regularization)new WeightDecay(l2, false));
            }
            if (on.has("l1Bias") && (l1Bias = on.get("l1Bias").doubleValue()) > 0.0) {
                baseLayer.getRegularizationBias().add((Regularization)new L1Regularization(l1Bias));
            }
            if (on.has("l2Bias") && (l2Bias = on.get("l2Bias").doubleValue()) > 0.0) {
                baseLayer.getRegularizationBias().add((Regularization)new WeightDecay(l2Bias, false));
            }
        }
    }

    protected void handleWeightInitBackwardCompatibility(BaseLayer baseLayer, ObjectNode on) {
        if (on != null && on.has("weightInit") && on.has("weightInit")) {
            String wi = on.get("weightInit").asText();
            try {
                WeightInit w = WeightInit.valueOf(wi);
                Distribution d = null;
                if (w == WeightInit.DISTRIBUTION && on.has("dist")) {
                    String dist = on.get("dist").toString();
                    d = (Distribution)NeuralNetConfiguration.mapper().readValue(dist, Distribution.class);
                }
                IWeightInit iwi = w.getWeightInitFunction(d);
                baseLayer.setWeightInitFn(iwi);
            }
            catch (Throwable t) {
                log.warn("Failed to infer weight initialization from legacy JSON format", t);
            }
        }
    }

    protected void handleActivationBackwardCompatibility(BaseLayer baseLayer, ObjectNode on) {
        if (baseLayer.getActivationFn() == null && on.has("activationFunction")) {
            String afn = on.get("activationFunction").asText();
            IActivation a = null;
            try {
                a = BaseNetConfigDeserializer.getMap().get(afn.toLowerCase()).getDeclaredConstructor(new Class[0]).newInstance(new Object[0]);
            }
            catch (IllegalAccessException | InstantiationException | NoSuchMethodException | InvocationTargetException instantiationException) {
                log.error(instantiationException.getMessage());
            }
            baseLayer.setActivationFn(a);
        }
    }

    protected void handleLossBackwardCompatibility(BaseOutputLayer baseLayer, ObjectNode on) {
        if (baseLayer.getLossFn() == null && on.has("activationFunction")) {
            String lfn = on.get("lossFunction").asText();
            LossMCXENT loss = null;
            switch (lfn) {
                case "MCXENT": {
                    loss = new LossMCXENT();
                    break;
                }
                case "MSE": {
                    loss = new LossMSE();
                    break;
                }
                case "NEGATIVELOGLIKELIHOOD": {
                    loss = new LossNegativeLogLikelihood();
                    break;
                }
                case "SQUARED_LOSS": {
                    loss = new LossL2();
                    break;
                }
                case "XENT": {
                    loss = new LossBinaryXENT();
                }
            }
            baseLayer.setLossFn((ILossFunction)loss);
        }
    }

    private static synchronized Map<String, Class<? extends IActivation>> getMap() {
        if (activationMap == null) {
            activationMap = new HashMap<String, Class<? extends IActivation>>();
            for (Activation a : Activation.values()) {
                activationMap.put(a.toString().toLowerCase(), a.getActivationFunction().getClass());
            }
        }
        return activationMap;
    }

    public void resolve(DeserializationContext ctxt) throws JsonMappingException {
        ((ResolvableDeserializer)this.defaultDeserializer).resolve(ctxt);
    }
}

