/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.tensorflow.conversion;

import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.transform.GraphTransformUtil;
import org.nd4j.autodiff.samediff.transform.OpPredicate;
import org.nd4j.autodiff.samediff.transform.SubGraph;
import org.nd4j.autodiff.samediff.transform.SubGraphPredicate;
import org.nd4j.autodiff.samediff.transform.SubGraphProcessor;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.imports.tensorflow.TFImportOverride;
import org.nd4j.imports.tensorflow.TFOpImportFilter;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.exception.ND4JIllegalStateException;

public class ProtoBufToFlatBufConversion {
    public static void convert(String inFile, String outFile) throws IOException, ND4JIllegalStateException {
        SameDiff tg = TFGraphMapper.importGraph((File)new File(inFile));
        tg.asFlatFile(new File(outFile));
    }

    public static void convertBERT(String inFile, String outFile) throws IOException, ND4JIllegalStateException {
        int minibatchSize = 4;
        HashMap<String, TFImportOverride> m = new HashMap<String, TFImportOverride>();
        m.put("IteratorGetNext", (inputs, controlDepInputs, nodeDef, initWith, attributesForNode, graph) -> Arrays.asList(initWith.placeHolder("IteratorGetNext", DataType.INT, new long[]{minibatchSize, 128L}), initWith.placeHolder("IteratorGetNext:1", DataType.INT, new long[]{minibatchSize, 128L}), initWith.placeHolder("IteratorGetNext:4", DataType.INT, new long[]{minibatchSize, 128L})));
        TFOpImportFilter filter = (nodeDef, initWith, attributesForNode, graph) -> "IteratorV2".equals(nodeDef.getName());
        SameDiff sd = TFGraphMapper.importGraph((File)new File(inFile), m, (TFOpImportFilter)filter);
        SubGraphPredicate p = SubGraphPredicate.withRoot((OpPredicate)OpPredicate.nameMatches((String)".*/dropout/mul")).withInputCount(2).withInputSubgraph(0, (OpPredicate)SubGraphPredicate.withRoot((OpPredicate)OpPredicate.nameMatches((String)".*/dropout/div"))).withInputSubgraph(1, (OpPredicate)SubGraphPredicate.withRoot((OpPredicate)OpPredicate.nameMatches((String)".*/dropout/Floor")).withInputSubgraph(0, (OpPredicate)SubGraphPredicate.withRoot((OpPredicate)OpPredicate.nameMatches((String)".*/dropout/add")).withInputSubgraph(1, (OpPredicate)SubGraphPredicate.withRoot((OpPredicate)OpPredicate.nameMatches((String)".*/dropout/random_uniform")).withInputSubgraph(0, (OpPredicate)SubGraphPredicate.withRoot((OpPredicate)OpPredicate.nameMatches((String)".*/dropout/random_uniform/mul")).withInputSubgraph(0, (OpPredicate)SubGraphPredicate.withRoot((OpPredicate)OpPredicate.nameMatches((String)".*/dropout/random_uniform/RandomUniform"))).withInputSubgraph(1, (OpPredicate)SubGraphPredicate.withRoot((OpPredicate)OpPredicate.nameMatches((String)".*/dropout/random_uniform/sub")))))));
        List subGraphs = GraphTransformUtil.getSubgraphsMatching((SameDiff)sd, (SubGraphPredicate)p);
        int subGraphCount = subGraphs.size();
        sd = GraphTransformUtil.replaceSubgraphsMatching((SameDiff)sd, (SubGraphPredicate)p, (SubGraphProcessor)new SubGraphProcessor(){

            public List<SDVariable> processSubgraph(SameDiff sd, SubGraph subGraph) {
                List inputs = subGraph.inputs();
                SDVariable newOut = null;
                for (SDVariable v : inputs) {
                    if (!v.getVarName().endsWith("/BiasAdd") && !v.getVarName().endsWith("/Softmax") && !v.getVarName().endsWith("/add_1") && !v.getVarName().endsWith("/Tanh")) continue;
                    newOut = v;
                    break;
                }
                if (newOut != null) {
                    return Collections.singletonList(newOut);
                }
                throw new RuntimeException("No pre-dropout input variable found");
            }
        });
        System.out.println("Exporting file " + outFile);
        sd.asFlatFile(new File(outFile));
    }

    public static void main(String[] args) throws IOException {
        if (args.length < 2) {
            System.err.println("Usage:\nmvn exec:java -Dexec.mainClass=\"org.nd4j.tensorflow.conversion.ProtoBufToFlatBufConversion\" -Dexec.args=\"<input_file.pb> <output_file.fb>\"\n");
        } else {
            ProtoBufToFlatBufConversion.convert(args[0], args[1]);
        }
    }
}

