/*
 * Decompiled with CFR 0.152.
 */
package org.eclipse.deeplearning4j.omnihub;

import java.io.File;
import java.io.IOException;
import java.util.Collections;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.omnihub.OmnihubConfig;
import org.eclipse.deeplearning4j.omnihub.Framework;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.samediff.frameworkimport.onnx.importer.OnnxFrameworkImporter;
import org.nd4j.samediff.frameworkimport.tensorflow.importer.TensorflowFrameworkImporter;

public class BootstrapFromLocal {
    public static void main(String ... args) {
        File localOmnihubHome = OmnihubConfig.getOmnihubHome();
        File[] frameworks = localOmnihubHome.listFiles();
        OnnxFrameworkImporter onnxFrameworkImporter = new OnnxFrameworkImporter();
        TensorflowFrameworkImporter tensorflowFrameworkImporter = new TensorflowFrameworkImporter();
        for (File frameworkFile : frameworks) {
            File[] inputFiles;
            Framework framework = Framework.valueOf(frameworkFile.getName().toUpperCase());
            if (!Framework.isInput(framework)) continue;
            for (File inputFile : inputFiles = frameworkFile.listFiles()) {
                try {
                    BootstrapFromLocal.extracted(localOmnihubHome, onnxFrameworkImporter, tensorflowFrameworkImporter, framework, inputFile);
                }
                catch (Exception e) {
                    System.err.println("Failed to import model at path " + inputFile.getAbsolutePath());
                    e.printStackTrace();
                }
            }
        }
    }

    private static void extracted(File localOmnihubHome, OnnxFrameworkImporter onnxFrameworkImporter, TensorflowFrameworkImporter tensorflowFrameworkImporter, Framework framework, File inputFile) throws Exception {
        String inputFileNameMinusFormat = FilenameUtils.getBaseName((String)inputFile.getName());
        String format = FilenameUtils.getExtension((String)inputFile.getName());
        Framework outputFramework = Framework.outputFrameworkFor(framework);
        File saveModelDir = new File(localOmnihubHome, outputFramework.name().toLowerCase());
        if (!saveModelDir.exists()) {
            saveModelDir.mkdirs();
        }
        switch (outputFramework) {
            case SAMEDIFF: {
                BootstrapFromLocal.importTfOnnxSameDiff(onnxFrameworkImporter, tensorflowFrameworkImporter, framework, inputFile, inputFileNameMinusFormat, format, saveModelDir);
                break;
            }
            case DL4J: {
                File saveModel2 = new File(saveModelDir, inputFileNameMinusFormat + ".zip");
                if (!format.equals("h5")) break;
                BootstrapFromLocal.importKerasDl4j(inputFile, saveModel2);
            }
        }
    }

    private static void importTfOnnxSameDiff(OnnxFrameworkImporter onnxFrameworkImporter, TensorflowFrameworkImporter tensorflowFrameworkImporter, Framework framework, File inputFile, String inputFileNameMinusFormat, String format, File saveModelDir) throws IOException {
        SameDiff sameDiff = null;
        switch (framework) {
            case ONNX: 
            case PYTORCH: {
                if (!format.equals("onnx")) break;
                sameDiff = onnxFrameworkImporter.runImport(inputFile.getAbsolutePath(), Collections.emptyMap(), true);
                break;
            }
            case TENSORFLOW: {
                if (!format.equals("pb")) break;
                sameDiff = tensorflowFrameworkImporter.runImport(inputFile.getAbsolutePath(), Collections.emptyMap(), true);
            }
        }
        File saveModel = new File(saveModelDir, inputFileNameMinusFormat + ".fb");
        if (sameDiff != null) {
            sameDiff.asFlatFile(saveModel, true);
        } else {
            System.err.println("Skipping model " + inputFile.getAbsolutePath());
        }
    }

    private static void importKerasDl4j(File inputFile, File saveModel2) {
        block5: {
            try {
                ComputationGraph computationGraph = KerasModelImport.importKerasModelAndWeights((String)inputFile.getAbsolutePath(), (boolean)true);
                computationGraph.save(saveModel2, true);
            }
            catch (Exception e) {
                if (e instanceof InvalidKerasConfigurationException) {
                    e.printStackTrace();
                    break block5;
                }
                MultiLayerNetwork multiLayerNetwork = KerasModelImport.importKerasSequentialModelAndWeights((String)inputFile.getAbsolutePath(), (boolean)true);
                multiLayerNetwork.save(saveModel2, true);
            }
        }
    }
}

