/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.recurrent.GravesLSTM;
import org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel;
import org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.shade.jackson.core.JsonFactory;
import org.nd4j.shade.jackson.core.type.TypeReference;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KerasModel {
    private static final Logger log = LoggerFactory.getLogger(KerasModel.class);
    public static final String MODEL_FIELD_CLASS_NAME = "class_name";
    public static final String MODEL_CLASS_NAME_SEQUENTIAL = "Sequential";
    public static final String MODEL_CLASS_NAME_MODEL = "Model";
    public static final String MODEL_FIELD_CONFIG = "config";
    public static final String MODEL_CONFIG_FIELD_LAYERS = "layers";
    public static final String MODEL_CONFIG_FIELD_INPUT_LAYERS = "input_layers";
    public static final String MODEL_CONFIG_FIELD_OUTPUT_LAYERS = "output_layers";
    public static final String TRAINING_CONFIG_FIELD_LOSS = "loss";
    public static final int DO_NOT_USE_TRUNCATED_BPTT = -123456789;
    public static final String PARAM_NAME_GAMMA = "gamma";
    public static final String PARAM_NAME_BETA = "beta";
    public static final String PARAM_NAME_RUNNING_MEAN = "running_mean";
    public static final String PARAM_NAME_RUNNING_STD = "running_std";
    public static final String PARAM_NAME_W = "W";
    public static final String PARAM_NAME_U = "U";
    public static final String PARAM_NAME_B = "b";
    public static final String PARAM_NAME_W_C = "W_c";
    public static final String PARAM_NAME_W_F = "W_f";
    public static final String PARAM_NAME_W_I = "W_i";
    public static final String PARAM_NAME_W_O = "W_o";
    public static final String PARAM_NAME_U_C = "U_c";
    public static final String PARAM_NAME_U_F = "U_f";
    public static final String PARAM_NAME_U_I = "U_i";
    public static final String PARAM_NAME_U_O = "U_o";
    public static final String PARAM_NAME_B_C = "b_c";
    public static final String PARAM_NAME_B_F = "b_f";
    public static final String PARAM_NAME_B_I = "b_i";
    public static final String PARAM_NAME_B_O = "b_o";
    protected String className;
    protected List<String> layerNamesOrdered;
    protected Map<String, KerasLayer> layers;
    protected ArrayList<String> inputLayerNames;
    protected ArrayList<String> outputLayerNames;
    protected Map<String, Set<String>> inputToOutput;
    protected Map<String, Set<String>> outputToInput;
    protected int truncatedBPTT = -123456789;
    protected Map<String, Map<String, INDArray>> weights = null;
    protected boolean train;

    public KerasModel(ModelBuilder modelBuilder) throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException {
        this(modelBuilder.modelJson, modelBuilder.modelYaml, modelBuilder.trainingJson, modelBuilder.weights, modelBuilder.train);
    }

    public KerasModel(String modelJson, String modelYaml, String trainingJson, Map<String, Map<String, INDArray>> weights, boolean train) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        Map<String, Object> classNameAndLayerLists;
        if (modelJson != null) {
            classNameAndLayerLists = KerasModel.parseJsonString(modelJson);
        } else if (modelYaml != null) {
            classNameAndLayerLists = KerasModel.parseYamlString(modelYaml);
        } else {
            throw new InvalidKerasConfigurationException("Requires model configuration as either JSON or YAML string.");
        }
        this.className = (String)KerasModel.checkAndGetModelField(classNameAndLayerLists, MODEL_FIELD_CLASS_NAME);
        if (!this.className.equals(MODEL_CLASS_NAME_MODEL)) {
            throw new InvalidKerasConfigurationException("Expected model class name Model (found " + this.className + ")");
        }
        this.train = train;
        Map layerLists = (Map)KerasModel.checkAndGetModelField(classNameAndLayerLists, MODEL_FIELD_CONFIG);
        this.helperPrepareLayers((List)KerasModel.checkAndGetModelField(layerLists, MODEL_CONFIG_FIELD_LAYERS));
        this.inputLayerNames = new ArrayList();
        for (Object inputLayerNameObj : (List)KerasModel.checkAndGetModelField(layerLists, MODEL_CONFIG_FIELD_INPUT_LAYERS)) {
            this.inputLayerNames.add((String)((List)inputLayerNameObj).get(0));
        }
        this.outputLayerNames = new ArrayList();
        for (Object outputLayerNameObj : (List)KerasModel.checkAndGetModelField(layerLists, MODEL_CONFIG_FIELD_OUTPUT_LAYERS)) {
            this.outputLayerNames.add((String)((List)outputLayerNameObj).get(0));
        }
        this.helperPrepareGraph();
        if (trainingJson != null) {
            this.helperImportTrainingConfiguration(trainingJson);
        }
        this.weights = weights;
    }

    protected void helperPrepareLayers(List<Object> layerConfigs) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this.layers = new HashMap<String, KerasLayer>();
        this.layerNamesOrdered = new ArrayList<String>();
        for (Object layerConfig : layerConfigs) {
            KerasLayer layer = new KerasLayer((Map)layerConfig, this.train);
            this.layerNamesOrdered.add(layer.getName());
            this.layers.put(layer.getName(), layer);
        }
    }

    protected void helperPrepareGraph() {
        this.outputToInput = new HashMap<String, Set<String>>();
        this.inputToOutput = new HashMap<String, Set<String>>();
        for (String childName : this.layerNamesOrdered) {
            if (!this.outputToInput.containsKey(childName)) {
                this.outputToInput.put(childName, new HashSet());
            }
            for (String parentName : this.layers.get(childName).getInboundLayerNames()) {
                this.outputToInput.get(childName).add(parentName);
                if (!this.inputToOutput.containsKey(parentName)) {
                    this.inputToOutput.put(parentName, new HashSet());
                }
                this.inputToOutput.get(parentName).add(childName);
            }
            if (this.inputToOutput.containsKey(childName)) continue;
            this.inputToOutput.put(childName, new HashSet());
        }
    }

    protected void helperImportTrainingConfiguration(String trainingConfigJson) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        Map<String, Object> trainingConfig = KerasModel.parseJsonString(trainingConfigJson);
        HashMap<String, KerasLayer> lossLayers = new HashMap<String, KerasLayer>();
        Object kerasLossObj = KerasModel.checkAndGetTrainingField(trainingConfig, TRAINING_CONFIG_FIELD_LOSS);
        if (kerasLossObj instanceof String) {
            String kerasLoss = (String)kerasLossObj;
            for (String outputLayerName : this.outputLayerNames) {
                lossLayers.put(outputLayerName, KerasLayer.createLossLayer(outputLayerName + "_loss", kerasLoss));
            }
            this.outputLayerNames.clear();
        } else if (kerasLossObj instanceof Map) {
            Map kerasLossMap = (Map)kerasLossObj;
            for (String outputLayerName : kerasLossMap.keySet()) {
                this.outputLayerNames.remove(outputLayerName);
                Object kerasLoss = kerasLossMap.get(outputLayerName);
                if (kerasLoss instanceof String) {
                    lossLayers.put(outputLayerName, KerasLayer.createLossLayer(outputLayerName + "_loss", (String)kerasLoss));
                    continue;
                }
                throw new InvalidKerasConfigurationException("Unknown Keras loss " + kerasLoss.toString());
            }
        }
        for (String outputLayerName : lossLayers.keySet()) {
            KerasLayer lossLayer = (KerasLayer)lossLayers.get(outputLayerName);
            this.layers.put(lossLayer.getName(), lossLayer);
            String lossLayerName = lossLayer.getName();
            this.outputLayerNames.add(lossLayerName);
            this.layerNamesOrdered.add(lossLayerName);
            if (!this.inputToOutput.containsKey(outputLayerName)) {
                this.inputToOutput.put(outputLayerName, new HashSet());
            }
            this.inputToOutput.get(outputLayerName).add(lossLayerName);
            if (!this.outputToInput.containsKey(lossLayerName)) {
                this.outputToInput.put(lossLayerName, new HashSet());
            }
            this.outputToInput.get(lossLayerName).add(outputLayerName);
        }
    }

    protected KerasModel() {
    }

    public ComputationGraphConfiguration getComputationGraphConfiguration() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        if (!this.className.equals(MODEL_CLASS_NAME_MODEL) && !this.className.equals(MODEL_CLASS_NAME_SEQUENTIAL)) {
            throw new InvalidKerasConfigurationException("Keras model class name " + this.className + " incompatible with ComputationGraph");
        }
        NeuralNetConfiguration.Builder modelBuilder = new NeuralNetConfiguration.Builder();
        String[] inputLayerArray = new String[this.inputLayerNames.size()];
        this.inputLayerNames.toArray(inputLayerArray);
        ArrayList<InputType> inputTypes = new ArrayList<InputType>();
        for (String inputLayerName : this.inputLayerNames) {
            inputTypes.add(this.inferInputType(inputLayerName));
        }
        InputType[] inputTypeArray = new InputType[inputTypes.size()];
        inputTypes.toArray(inputTypeArray);
        String[] outputLayerArray = new String[this.outputLayerNames.size()];
        this.outputLayerNames.toArray(outputLayerArray);
        ComputationGraphConfiguration.GraphBuilder graphBuilder = modelBuilder.graphBuilder().addInputs(inputLayerArray).setInputTypes(inputTypeArray).setOutputs(outputLayerArray);
        for (String layerName : this.layerNamesOrdered) {
            KerasLayer layer = this.layers.get(layerName);
            if (!layer.isDl4jLayer()) continue;
            List<String> inboundLayerNames = this.inferInboundLayerNames(layerName);
            String[] inboundLayerArray = new String[inboundLayerNames.size()];
            inboundLayerNames.toArray(inboundLayerArray);
            graphBuilder.addLayer(layerName, layer.getDl4jLayer(), inboundLayerArray);
        }
        if (this.truncatedBPTT <= 0) {
            graphBuilder.backpropType(BackpropType.Standard);
        } else if (this.truncatedBPTT > 0) {
            graphBuilder.backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(this.truncatedBPTT).tBPTTBackwardLength(this.truncatedBPTT);
        }
        return graphBuilder.build();
    }

    public ComputationGraph getComputationGraph() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        return this.getComputationGraph(true);
    }

    public ComputationGraph getComputationGraph(boolean importWeights) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        ComputationGraph model = new ComputationGraph(this.getComputationGraphConfiguration());
        model.init();
        if (importWeights) {
            model = (ComputationGraph)KerasModel.copyWeightsToModel((Model)model, this.weights, this.layers);
        }
        return model;
    }

    protected InputType inferInputType(String inputLayerName) throws UnsupportedOperationException, UnsupportedKerasConfigurationException {
        if (!this.inputLayerNames.contains(inputLayerName)) {
            throw new UnsupportedOperationException("Cannot infer input type for non-input layer " + inputLayerName);
        }
        int[] inputShape = this.layers.get(inputLayerName).getInputShape();
        InputType inputType = null;
        ArrayList layerNameQueue = new ArrayList(this.inputToOutput.get(inputLayerName));
        while (inputType == null && !layerNameQueue.isEmpty()) {
            KerasLayer nextLayer = this.layers.get(layerNameQueue.remove(0));
            if (nextLayer.isDl4jLayer()) {
                Layer dl4jLayer = nextLayer.getDl4jLayer();
                if (dl4jLayer instanceof BaseRecurrentLayer) {
                    if (inputShape.length != 2) {
                        throw new UnsupportedKerasConfigurationException("Input to Recurrent layer must have rank 2 (found " + inputShape.length + ")");
                    }
                    inputType = InputType.recurrent((int)inputShape[1]);
                    this.truncatedBPTT = inputShape[0];
                } else if (dl4jLayer instanceof ConvolutionLayer || dl4jLayer instanceof SubsamplingLayer) {
                    if (inputShape.length != 3) {
                        throw new UnsupportedKerasConfigurationException("Input to Convolutional layer must have rank 3 (found " + inputShape.length + ")");
                    }
                    inputType = InputType.convolutional((int)inputShape[0], (int)inputShape[1], (int)inputShape[2]);
                } else {
                    if (inputShape.length != 1) {
                        throw new UnsupportedKerasConfigurationException("Input to FeedForward layer must have rank 1 (found " + inputShape.length + ")");
                    }
                    inputType = InputType.feedForward((int)inputShape[0]);
                }
            }
            layerNameQueue.addAll(this.inputToOutput.get(nextLayer.getName()));
        }
        if (inputType == null) {
            throw new UnsupportedKerasConfigurationException("Could not infer InputType for input layer " + inputLayerName);
        }
        return inputType;
    }

    protected List<String> inferInboundLayerNames(String layerName) {
        ArrayList<String> inboundLayerNames = new ArrayList<String>((Collection)this.outputToInput.get(layerName));
        for (int i = 0; i < inboundLayerNames.size(); ++i) {
            KerasLayer nextLayer = this.layers.get(inboundLayerNames.get(i));
            if (nextLayer.isValidInboundLayer()) continue;
            String nextLayerName = inboundLayerNames.remove(i);
            inboundLayerNames.addAll(i--, (Collection<String>)this.outputToInput.get(nextLayerName));
        }
        return inboundLayerNames;
    }

    protected static Object checkAndGetModelField(Map<String, Object> map, String key) throws InvalidKerasConfigurationException {
        if (!map.containsKey(key)) {
            throw new InvalidKerasConfigurationException("Field " + key + " missing from model config");
        }
        return map.get(key);
    }

    protected static Object checkAndGetTrainingField(Map<String, Object> map, String key) throws InvalidKerasConfigurationException {
        if (!map.containsKey(key)) {
            throw new InvalidKerasConfigurationException("Field " + key + " missing from training config");
        }
        return map.get(key);
    }

    protected static Map<String, Object> parseJsonString(String json) throws IOException {
        ObjectMapper mapper = new ObjectMapper();
        TypeReference<HashMap<String, Object>> typeRef = new TypeReference<HashMap<String, Object>>(){};
        return (Map)mapper.readValue(json, (TypeReference)typeRef);
    }

    protected static Map<String, Object> parseYamlString(String json) throws IOException {
        ObjectMapper mapper = new ObjectMapper((JsonFactory)new YAMLFactory());
        TypeReference<HashMap<String, Object>> typeRef = new TypeReference<HashMap<String, Object>>(){};
        return (Map)mapper.readValue(json, (TypeReference)typeRef);
    }

    protected static Model copyWeightsToModel(Model model, Map<String, Map<String, INDArray>> weights, Map<String, KerasLayer> kerasLayers) throws InvalidKerasConfigurationException {
        for (String layerName : weights.keySet()) {
            KerasLayer kerasLayer = kerasLayers.get(layerName);
            org.deeplearning4j.nn.api.Layer layer = null;
            layer = model instanceof MultiLayerNetwork ? ((MultiLayerNetwork)model).getLayer(layerName) : ((ComputationGraph)model).getLayer(layerName);
            for (String kerasParamName : weights.get(layerName).keySet()) {
                String dl4JParamName = KerasModel.mapParameterName(kerasParamName);
                INDArray kerasParamValue = weights.get(layerName).get(kerasParamName);
                INDArray dl4jParamValue = null;
                if (layer instanceof org.deeplearning4j.nn.layers.convolution.ConvolutionLayer) {
                    if (dl4JParamName.equals(PARAM_NAME_W)) {
                        switch (kerasLayer.getDimOrder()) {
                            case TENSORFLOW: {
                                kerasParamValue = kerasParamValue.permute(new int[]{3, 2, 0, 1});
                                break;
                            }
                            case THEANO: {
                                break;
                            }
                            case NONE: {
                                break;
                            }
                            case UNKNOWN: {
                                throw new InvalidKerasConfigurationException("Unknown keras backend " + (Object)((Object)kerasLayer.getDimOrder()));
                            }
                        }
                    }
                    dl4jParamValue = kerasParamValue;
                } else if (layer instanceof GravesLSTM) {
                    int nOut;
                    if (kerasParamName.startsWith(PARAM_NAME_W)) {
                        dl4jParamValue = layer.getParam(dl4JParamName);
                        int nIn = ((BaseRecurrentLayer)layer.conf().getLayer()).getNIn();
                        int nOut2 = ((BaseRecurrentLayer)layer.conf().getLayer()).getNOut();
                        switch (kerasParamName) {
                            case "W_c": {
                                dl4jParamValue.put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)nIn), NDArrayIndex.interval((int)0, (int)nOut2)}, kerasParamValue);
                                break;
                            }
                            case "W_f": {
                                dl4jParamValue.put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)nIn), NDArrayIndex.interval((int)nOut2, (int)(2 * nOut2))}, kerasParamValue);
                                break;
                            }
                            case "W_o": {
                                dl4jParamValue.put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)nIn), NDArrayIndex.interval((int)(2 * nOut2), (int)(3 * nOut2))}, kerasParamValue);
                                break;
                            }
                            case "W_i": {
                                dl4jParamValue.put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)nIn), NDArrayIndex.interval((int)(3 * nOut2), (int)(4 * nOut2))}, kerasParamValue);
                            }
                        }
                    } else if (kerasParamName.startsWith(PARAM_NAME_U)) {
                        dl4jParamValue = layer.getParam(dl4JParamName);
                        nOut = ((BaseRecurrentLayer)layer.conf().getLayer()).getNOut();
                        switch (kerasParamName) {
                            case "U_c": {
                                dl4jParamValue.put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)nOut), NDArrayIndex.interval((int)0, (int)nOut)}, kerasParamValue);
                                break;
                            }
                            case "U_f": {
                                dl4jParamValue.put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)nOut), NDArrayIndex.interval((int)nOut, (int)(2 * nOut))}, kerasParamValue);
                                break;
                            }
                            case "U_o": {
                                dl4jParamValue.put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)nOut), NDArrayIndex.interval((int)(2 * nOut), (int)(3 * nOut))}, kerasParamValue);
                                break;
                            }
                            case "U_i": {
                                dl4jParamValue.put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)nOut), NDArrayIndex.interval((int)(3 * nOut), (int)(4 * nOut))}, kerasParamValue);
                            }
                        }
                        dl4jParamValue.put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)nOut), NDArrayIndex.interval((int)(4 * nOut), (int)(4 * nOut + 3))}, Nd4j.zeros((int)nOut, (int)3));
                    } else if (kerasParamName.startsWith(PARAM_NAME_B)) {
                        dl4jParamValue = layer.getParam(dl4JParamName);
                        nOut = ((BaseRecurrentLayer)layer.conf().getLayer()).getNOut();
                        switch (kerasParamName) {
                            case "b_c": {
                                dl4jParamValue.put(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)0, (int)nOut)}, kerasParamValue);
                                break;
                            }
                            case "b_f": {
                                dl4jParamValue.put(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)nOut, (int)(2 * nOut))}, kerasParamValue);
                                break;
                            }
                            case "b_o": {
                                dl4jParamValue.put(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)(2 * nOut), (int)(3 * nOut))}, kerasParamValue);
                                break;
                            }
                            case "b_i": {
                                dl4jParamValue.put(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)(3 * nOut), (int)(4 * nOut))}, kerasParamValue);
                            }
                        }
                    }
                }
                if (!layer.paramTable().keySet().contains(dl4JParamName)) {
                    throw new InvalidKerasConfigurationException("Layer " + layerName + ": Keras param " + kerasParamName + " maps to unknown param " + dl4JParamName);
                }
                layer.setParam(dl4JParamName, dl4jParamValue);
            }
        }
        return model;
    }

    private static String mapParameterName(String kerasParamName) {
        String paramName = null;
        switch (kerasParamName) {
            case "gamma": {
                paramName = PARAM_NAME_GAMMA;
                break;
            }
            case "beta": {
                paramName = PARAM_NAME_BETA;
                break;
            }
            case "running_mean": {
                paramName = "mean";
                break;
            }
            case "running_std": {
                paramName = "var";
                break;
            }
            case "W": {
                paramName = PARAM_NAME_W;
                break;
            }
            case "b": {
                paramName = PARAM_NAME_B;
                break;
            }
            case "W_c": 
            case "W_f": 
            case "W_i": 
            case "W_o": {
                paramName = PARAM_NAME_W;
                break;
            }
            case "U_c": 
            case "U_f": 
            case "U_i": 
            case "U_o": {
                paramName = "RW";
                break;
            }
            case "b_c": 
            case "b_f": 
            case "b_i": 
            case "b_o": {
                paramName = PARAM_NAME_B;
            }
        }
        return paramName;
    }

    static class ModelBuilder
    implements Cloneable {
        protected String modelJson;
        protected String modelYaml;
        protected String trainingJson = null;
        protected Map<String, Map<String, INDArray>> weights = null;
        protected boolean train = false;

        public ModelBuilder modelJson(String modelJson) {
            this.modelJson = modelJson;
            this.modelYaml = null;
            return this;
        }

        public ModelBuilder modelYaml(String modelYaml) {
            this.modelYaml = modelYaml;
            this.modelJson = null;
            return this;
        }

        public ModelBuilder trainingJson(String trainingJson) {
            this.trainingJson = trainingJson;
            return this;
        }

        public ModelBuilder weights(Map<String, Map<String, INDArray>> weights) {
            this.weights = weights;
            return this;
        }

        public ModelBuilder train(boolean train) {
            this.train = train;
            return this;
        }

        public static ModelBuilder builder() {
            return new ModelBuilder();
        }

        public KerasModel buildModel() throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
            return new KerasModel(this);
        }

        public KerasSequentialModel buildSequential() throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
            return new KerasSequentialModel(this);
        }

        public String getModelJson() {
            return this.modelJson;
        }

        public String getModelYaml() {
            return this.modelYaml;
        }

        public String getTrainingJson() {
            return this.trainingJson;
        }

        public Map<String, Map<String, INDArray>> getWeights() {
            return this.weights;
        }

        public boolean isTrain() {
            return this.train;
        }

        public void setModelJson(String modelJson) {
            this.modelJson = modelJson;
        }

        public void setModelYaml(String modelYaml) {
            this.modelYaml = modelYaml;
        }

        public void setTrainingJson(String trainingJson) {
            this.trainingJson = trainingJson;
        }

        public void setWeights(Map<String, Map<String, INDArray>> weights) {
            this.weights = weights;
        }

        public void setTrain(boolean train) {
            this.train = train;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof ModelBuilder)) {
                return false;
            }
            ModelBuilder other = (ModelBuilder)o;
            if (!other.canEqual(this)) {
                return false;
            }
            String this$modelJson = this.getModelJson();
            String other$modelJson = other.getModelJson();
            if (this$modelJson == null ? other$modelJson != null : !this$modelJson.equals(other$modelJson)) {
                return false;
            }
            String this$modelYaml = this.getModelYaml();
            String other$modelYaml = other.getModelYaml();
            if (this$modelYaml == null ? other$modelYaml != null : !this$modelYaml.equals(other$modelYaml)) {
                return false;
            }
            String this$trainingJson = this.getTrainingJson();
            String other$trainingJson = other.getTrainingJson();
            if (this$trainingJson == null ? other$trainingJson != null : !this$trainingJson.equals(other$trainingJson)) {
                return false;
            }
            Map<String, Map<String, INDArray>> this$weights = this.getWeights();
            Map<String, Map<String, INDArray>> other$weights = other.getWeights();
            if (this$weights == null ? other$weights != null : !((Object)this$weights).equals(other$weights)) {
                return false;
            }
            return this.isTrain() == other.isTrain();
        }

        protected boolean canEqual(Object other) {
            return other instanceof ModelBuilder;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            String $modelJson = this.getModelJson();
            result = result * 59 + ($modelJson == null ? 43 : $modelJson.hashCode());
            String $modelYaml = this.getModelYaml();
            result = result * 59 + ($modelYaml == null ? 43 : $modelYaml.hashCode());
            String $trainingJson = this.getTrainingJson();
            result = result * 59 + ($trainingJson == null ? 43 : $trainingJson.hashCode());
            Map<String, Map<String, INDArray>> $weights = this.getWeights();
            result = result * 59 + ($weights == null ? 43 : ((Object)$weights).hashCode());
            result = result * 59 + (this.isTrain() ? 79 : 97);
            return result;
        }

        public String toString() {
            return "KerasModel.ModelBuilder(modelJson=" + this.getModelJson() + ", modelYaml=" + this.getModelYaml() + ", trainingJson=" + this.getTrainingJson() + ", weights=" + this.getWeights() + ", train=" + this.isTrain() + ")";
        }
    }
}

