/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.tensorflow.conversion.graphrunner;

import com.github.os72.protobuf351.ByteString;
import com.github.os72.protobuf351.InvalidProtocolBufferException;
import com.github.os72.protobuf351.Message;
import com.github.os72.protobuf351.MessageOrBuilder;
import com.github.os72.protobuf351.util.JsonFormat;
import java.io.Closeable;
import java.io.File;
import java.net.URI;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import org.apache.commons.io.IOUtils;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.bytedeco.javacpp.tensorflow;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.tensorflow.conversion.TensorflowConversion;
import org.nd4j.tensorflow.conversion.graphrunner.SavedModelConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.ConfigProto;
import org.tensorflow.framework.GPUOptions;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

public class GraphRunner
implements Closeable {
    private static final Logger log = LoggerFactory.getLogger(GraphRunner.class);
    private SavedModelConfig savedModelConfig;
    private tensorflow.TF_Graph graph;
    private TensorflowConversion conversion = TensorflowConversion.getInstance();
    private tensorflow.TF_Session session;
    private tensorflow.TF_SessionOptions options;
    private tensorflow.TF_Status status;
    private List<String> inputOrder;
    private List<String> outputOrder;
    private ConfigProto protoBufConfigProto;

    public GraphRunner(List<String> inputNames, List<String> outputNames, tensorflow.TF_Graph graph, GraphDef graphDef) {
        this(inputNames, outputNames, graph, graphDef, null);
    }

    public GraphRunner(List<String> inputNames, List<String> outputNames, tensorflow.TF_Graph graph, GraphDef graphDef, ConfigProto configProto) {
        this.graph = graph;
        this.protoBufConfigProto = configProto;
        this.inputOrder = inputNames;
        this.outputOrder = outputNames;
        this.initSessionAndStatusIfNeeded(graphDef);
    }

    public GraphRunner(byte[] graphToUse, List<String> inputNames, List<String> outputNames) {
        this(graphToUse, inputNames, outputNames, GraphRunner.getAlignedWithNd4j());
    }

    public GraphRunner(String filePath, List<String> inputNames, List<String> outputNames) {
        this(filePath, inputNames, outputNames, GraphRunner.getAlignedWithNd4j());
    }

    public GraphRunner(String filePath, List<String> inputNames, List<String> outputNames, ConfigProto sessionOptionsConfiguration) {
        byte[] graphToUse = null;
        try {
            this.inputOrder = inputNames;
            this.outputOrder = outputNames;
            this.protoBufConfigProto = sessionOptionsConfiguration;
            this.initOptionsIfNeeded();
            graphToUse = IOUtils.toByteArray((URI)new File(filePath).toURI());
            this.graph = this.conversion.loadGraph(graphToUse, this.status);
        }
        catch (Exception e) {
            throw new IllegalArgumentException("Unable to parse protobuf", e);
        }
        this.initSessionAndStatusIfNeeded(graphToUse);
    }

    public GraphRunner(byte[] graphToUse, List<String> inputNames, List<String> outputNames, ConfigProto sessionOptionsConfiguration) {
        try {
            this.inputOrder = inputNames;
            this.outputOrder = outputNames;
            this.protoBufConfigProto = sessionOptionsConfiguration;
            this.initOptionsIfNeeded();
            this.graph = this.conversion.loadGraph(graphToUse, this.status);
        }
        catch (Exception e) {
            throw new IllegalArgumentException("Unable to parse protobuf", e);
        }
        this.initSessionAndStatusIfNeeded(graphToUse);
    }

    public GraphRunner(List<String> inputNames, List<String> outputNames, SavedModelConfig savedModelConfig) {
        this(inputNames, outputNames, savedModelConfig, GraphRunner.getAlignedWithNd4j());
    }

    public GraphRunner(List<String> inputNames, List<String> outputNames, SavedModelConfig savedModelConfig, ConfigProto sessionOptionsConfiguration) {
        try {
            this.savedModelConfig = savedModelConfig;
            this.protoBufConfigProto = sessionOptionsConfiguration;
            this.inputOrder = inputNames;
            this.outputOrder = outputNames;
            this.initOptionsIfNeeded();
            LinkedHashMap<String, String> inputsMap = new LinkedHashMap<String, String>();
            LinkedHashMap<String, String> outputsMap = new LinkedHashMap<String, String>();
            this.graph = tensorflow.TF_NewGraph();
            this.session = this.conversion.loadSavedModel(savedModelConfig, this.options, null, this.graph, inputsMap, outputsMap, this.status);
            this.inputOrder = new ArrayList(inputsMap.keySet());
            this.outputOrder = new ArrayList(outputsMap.keySet());
            savedModelConfig.setSavedModelInputOrder(new ArrayList<String>(inputsMap.values()));
            savedModelConfig.setSaveModelOutputOrder(new ArrayList<String>(outputsMap.values()));
        }
        catch (Exception e) {
            throw new IllegalArgumentException("Unable to parse protobuf", e);
        }
    }

    public GraphRunner(List<String> inputNames, tensorflow.TF_Graph graph, GraphDef graphDef) {
        this(inputNames, null, graph, graphDef, null);
    }

    public GraphRunner(List<String> inputNames, tensorflow.TF_Graph graph, GraphDef graphDef, ConfigProto configProto) {
        this(inputNames, null, graph, graphDef, configProto);
    }

    public GraphRunner(byte[] graphToUse, List<String> inputNames) {
        this(graphToUse, inputNames, GraphRunner.getAlignedWithNd4j());
    }

    public GraphRunner(String filePath, List<String> inputNames) {
        this(filePath, inputNames, GraphRunner.getAlignedWithNd4j());
    }

    public GraphRunner(String filePath, List<String> inputNames, ConfigProto sessionOptionsConfiguration) {
        this(filePath, inputNames, null, sessionOptionsConfiguration);
    }

    public GraphRunner(byte[] graphToUse, List<String> inputNames, ConfigProto sessionOptionsConfiguration) {
        this(graphToUse, inputNames, null, sessionOptionsConfiguration);
    }

    public GraphRunner(SavedModelConfig savedModelConfig) {
        this(savedModelConfig, GraphRunner.getAlignedWithNd4j());
    }

    public GraphRunner(SavedModelConfig savedModelConfig, ConfigProto sessionOptionsConfiguration) {
        try {
            this.savedModelConfig = savedModelConfig;
            this.protoBufConfigProto = sessionOptionsConfiguration;
            this.initOptionsIfNeeded();
            LinkedHashMap<String, String> inputsMap = new LinkedHashMap<String, String>();
            LinkedHashMap<String, String> outputsMap = new LinkedHashMap<String, String>();
            this.graph = tensorflow.TF_NewGraph();
            this.session = this.conversion.loadSavedModel(savedModelConfig, this.options, null, this.graph, inputsMap, outputsMap, this.status);
            this.inputOrder = new ArrayList(inputsMap.keySet());
            this.outputOrder = new ArrayList(outputsMap.keySet());
            savedModelConfig.setSavedModelInputOrder(new ArrayList<String>(inputsMap.values()));
            savedModelConfig.setSaveModelOutputOrder(new ArrayList<String>(outputsMap.values()));
        }
        catch (Exception e) {
            throw new IllegalArgumentException("Unable to parse protobuf", e);
        }
    }

    public Map<String, INDArray> run(Map<String, INDArray> inputs) {
        if (this.graph == null) {
            throw new IllegalStateException("Graph not initialized.");
        }
        if (inputs.size() != this.inputOrder.size()) {
            throw new IllegalArgumentException("Number of inputs specified do not match number of arrays specified.");
        }
        if (this.savedModelConfig != null) {
            LinkedHashMap<String, INDArray> outputArrays = new LinkedHashMap<String, INDArray>();
            HashMap<String, tensorflow.TF_Operation> opsByName = new HashMap<String, tensorflow.TF_Operation>();
            tensorflow.TF_Output inputOut = new tensorflow.TF_Output((long)this.savedModelConfig.getSavedModelInputOrder().size());
            tensorflow.TF_Tensor[] inputTensors = new tensorflow.TF_Tensor[this.savedModelConfig.getSavedModelInputOrder().size()];
            for (int i = 0; i < this.savedModelConfig.getSavedModelInputOrder().size(); ++i) {
                tensorflow.TF_Tensor tf_tensor;
                String[] name = this.savedModelConfig.getSavedModelInputOrder().get(i).split(":");
                tensorflow.TF_Operation inputOp = tensorflow.TF_GraphOperationByName((tensorflow.TF_Graph)this.graph, (String)name[0]);
                opsByName.put(this.savedModelConfig.getSavedModelInputOrder().get(i), inputOp);
                inputOut.position((long)i).oper(inputOp).index(name.length > 1 ? Integer.parseInt(name[1]) : 0);
                inputTensors[i] = tf_tensor = this.conversion.tensorFromNDArray(inputs.get(this.inputOrder != null && !this.inputOrder.isEmpty() ? this.inputOrder.get(i) : this.savedModelConfig.getSavedModelInputOrder().get(i)));
            }
            inputOut.position(0L);
            tensorflow.TF_Output outputOut = new tensorflow.TF_Output((long)this.savedModelConfig.getSaveModelOutputOrder().size());
            for (int i = 0; i < this.savedModelConfig.getSaveModelOutputOrder().size(); ++i) {
                String[] name = this.savedModelConfig.getSaveModelOutputOrder().get(i).split(":");
                tensorflow.TF_Operation outputOp = tensorflow.TF_GraphOperationByName((tensorflow.TF_Graph)this.graph, (String)name[0]);
                opsByName.put(this.savedModelConfig.getSaveModelOutputOrder().get(i), outputOp);
                outputOut.position((long)i).oper(outputOp).index(name.length > 1 ? Integer.parseInt(name[1]) : 0);
            }
            outputOut.position(0L);
            PointerPointer inputTensorsPointer = new PointerPointer((Pointer[])inputTensors);
            PointerPointer outputTensorsPointer = new PointerPointer((long)this.savedModelConfig.getSaveModelOutputOrder().size());
            tensorflow.TF_SessionRun((tensorflow.TF_Session)this.session, null, (tensorflow.TF_Output)inputOut, (PointerPointer)inputTensorsPointer, (int)inputTensors.length, (tensorflow.TF_Output)outputOut, (PointerPointer)outputTensorsPointer, (int)this.savedModelConfig.getSaveModelOutputOrder().size(), null, (int)0, null, (tensorflow.TF_Status)this.status);
            if (tensorflow.TF_GetCode((tensorflow.TF_Status)this.status) != 0) {
                throw new IllegalStateException("ERROR: Unable to run session " + tensorflow.TF_Message((tensorflow.TF_Status)this.status).getString());
            }
            for (int i = 0; i < this.outputOrder.size(); ++i) {
                INDArray to = this.conversion.ndArrayFromTensor(new tensorflow.TF_Tensor(outputTensorsPointer.get((long)i)));
                outputArrays.put(this.outputOrder != null && !this.outputOrder.isEmpty() ? this.outputOrder.get(i) : this.savedModelConfig.getSaveModelOutputOrder().get(i), to);
            }
            return outputArrays;
        }
        LinkedHashMap<String, INDArray> outputArrays = new LinkedHashMap<String, INDArray>();
        HashMap<String, tensorflow.TF_Operation> opsByName = new HashMap<String, tensorflow.TF_Operation>();
        tensorflow.TF_Output inputOut = new tensorflow.TF_Output((long)this.inputOrder.size());
        tensorflow.TF_Tensor[] inputTensors = new tensorflow.TF_Tensor[this.inputOrder.size()];
        for (int i = 0; i < this.inputOrder.size(); ++i) {
            tensorflow.TF_Tensor tf_tensor;
            String[] name = this.inputOrder.get(i).split(":");
            tensorflow.TF_Operation inputOp = tensorflow.TF_GraphOperationByName((tensorflow.TF_Graph)this.graph, (String)name[0]);
            opsByName.put(this.inputOrder.get(i), inputOp);
            inputOut.position((long)i).oper(inputOp).index(name.length > 1 ? Integer.parseInt(name[1]) : 0);
            inputTensors[i] = tf_tensor = this.conversion.tensorFromNDArray(inputs.get(this.inputOrder.get(i)));
        }
        inputOut.position(0L);
        tensorflow.TF_Output outputOut = new tensorflow.TF_Output((long)this.outputOrder.size());
        for (int i = 0; i < this.outputOrder.size(); ++i) {
            String[] name = this.outputOrder.get(i).split(":");
            tensorflow.TF_Operation outputOp = tensorflow.TF_GraphOperationByName((tensorflow.TF_Graph)this.graph, (String)name[0]);
            if (outputOp == null) {
                throw new IllegalArgumentException("Illegal input found " + this.inputOrder.get(i) + " - no op found! Mis specified name perhaps?");
            }
            opsByName.put(this.outputOrder.get(i), outputOp);
            outputOut.position((long)i).oper(outputOp).index(name.length > 1 ? Integer.parseInt(name[1]) : 0);
        }
        outputOut.position(0L);
        PointerPointer inputTensorsPointer = new PointerPointer((Pointer[])inputTensors);
        PointerPointer outputTensorsPointer = new PointerPointer((long)this.outputOrder.size());
        tensorflow.TF_SessionRun((tensorflow.TF_Session)this.session, null, (tensorflow.TF_Output)inputOut, (PointerPointer)inputTensorsPointer, (int)inputTensors.length, (tensorflow.TF_Output)outputOut, (PointerPointer)outputTensorsPointer, (int)this.outputOrder.size(), null, (int)0, null, (tensorflow.TF_Status)this.status);
        if (tensorflow.TF_GetCode((tensorflow.TF_Status)this.status) != 0) {
            throw new IllegalStateException("ERROR: Unable to run session " + tensorflow.TF_Message((tensorflow.TF_Status)this.status).getString());
        }
        for (int i = 0; i < this.outputOrder.size(); ++i) {
            INDArray to = this.conversion.ndArrayFromTensor(new tensorflow.TF_Tensor(outputTensorsPointer.get((long)i)));
            outputArrays.put(this.outputOrder.get(i), to);
        }
        return outputArrays;
    }

    private void initOptionsIfNeeded() {
        if (this.status == null) {
            this.status = tensorflow.TF_NewStatus();
        }
        if (this.options == null) {
            this.options = tensorflow.TF_NewSessionOptions();
            if (this.protoBufConfigProto != null) {
                BytePointer bytePointer = new BytePointer(this.protoBufConfigProto.toByteArray());
                tensorflow.TF_SetConfig((tensorflow.TF_SessionOptions)this.options, (Pointer)bytePointer, (long)bytePointer.getStringBytes().length, (tensorflow.TF_Status)this.status);
                if (tensorflow.TF_GetCode((tensorflow.TF_Status)this.status) != 0) {
                    throw new IllegalStateException("ERROR: Unable to set value configuration:" + tensorflow.TF_Message((tensorflow.TF_Status)this.status).getString());
                }
            }
        }
    }

    private void initSessionAndStatusIfNeeded(GraphDef graphDef1) {
        int i;
        LinkedHashSet<String> seenAsInput = new LinkedHashSet<String>();
        for (i = 0; i < graphDef1.getNodeCount(); ++i) {
            NodeDef node = graphDef1.getNode(i);
            for (int input = 0; input < node.getInputCount(); ++input) {
                seenAsInput.add(node.getInput(input));
            }
        }
        if (this.outputOrder == null) {
            this.outputOrder = new ArrayList<String>();
            log.trace("Attempting to automatically resolve tensorflow output names..");
            for (i = 0; i < graphDef1.getNodeCount(); ++i) {
                if (seenAsInput.contains(graphDef1.getNode(i).getName()) || graphDef1.getNode(i).getOp().equals("Placeholder")) continue;
                this.outputOrder.add(graphDef1.getNode(i).getName());
            }
            if (this.outputOrder.size() > 1) {
                HashSet<String> remove = new HashSet<String>();
                for (String name : this.outputOrder) {
                    if (!name.contains("/")) continue;
                    remove.add(name);
                }
                this.outputOrder.removeAll(remove);
            }
        }
        if (this.session == null) {
            this.initOptionsIfNeeded();
            this.session = tensorflow.TF_NewSession((tensorflow.TF_Graph)this.graph, (tensorflow.TF_SessionOptions)this.options, (tensorflow.TF_Status)this.status);
            if (tensorflow.TF_GetCode((tensorflow.TF_Status)this.status) != 0) {
                throw new IllegalStateException("ERROR: Unable to open session " + tensorflow.TF_Message((tensorflow.TF_Status)this.status).getString());
            }
        }
    }

    private void initSessionAndStatusIfNeeded(byte[] graphToUse) {
        try {
            GraphDef graphDef1 = GraphDef.parseFrom((byte[])graphToUse);
            this.initSessionAndStatusIfNeeded(graphDef1);
        }
        catch (InvalidProtocolBufferException e) {
            e.printStackTrace();
        }
    }

    public static ConfigProto getAlignedWithNd4j() {
        ConfigProto configProto = ConfigProto.getDefaultInstance();
        ConfigProto.Builder builder1 = configProto.toBuilder().addDeviceFilters(TensorflowConversion.defaultDeviceForThread());
        try {
            if (Nd4j.getBackend().getClass().getName().toLowerCase().contains("jcu")) {
                builder1.setGpuOptions(GPUOptions.newBuilder().setAllowGrowth(true).setPerProcessGpuMemoryFraction(0.5).build());
            }
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        return builder1.build();
    }

    public static ConfigProto fromJson(String json) {
        ConfigProto.Builder builder = ConfigProto.newBuilder();
        try {
            JsonFormat.parser().merge(json, (Message.Builder)builder);
            ConfigProto build = builder.build();
            ByteString serialized = build.toByteString();
            byte[] binaryString = serialized.toByteArray();
            ConfigProto configProto = ConfigProto.parseFrom((byte[])binaryString);
            return configProto;
        }
        catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }

    public String sessionOptionsToJson() {
        try {
            return JsonFormat.printer().print((MessageOrBuilder)this.protoBufConfigProto);
        }
        catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }

    @Override
    public void close() {
        if (this.session != null && this.status != null) {
            tensorflow.TF_CloseSession((tensorflow.TF_Session)this.session, (tensorflow.TF_Status)this.status);
            tensorflow.TF_DeleteSession((tensorflow.TF_Session)this.session, (tensorflow.TF_Status)this.status);
        }
        if (this.status != null && tensorflow.TF_GetCode((tensorflow.TF_Status)this.status) != 0) {
            throw new IllegalStateException("ERROR: Unable to delete session " + tensorflow.TF_Message((tensorflow.TF_Status)this.status).getString());
        }
        if (this.status != null) {
            tensorflow.TF_DeleteStatus((tensorflow.TF_Status)this.status);
        }
    }

    public List<String> getInputOrder() {
        return this.inputOrder;
    }

    public List<String> getOutputOrder() {
        return this.outputOrder;
    }

    public void setInputOrder(List<String> inputOrder) {
        this.inputOrder = inputOrder;
    }

    public void setOutputOrder(List<String> outputOrder) {
        this.outputOrder = outputOrder;
    }

    public ConfigProto getProtoBufConfigProto() {
        return this.protoBufConfigProto;
    }
}

