/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.onnxruntime.runner;

import java.io.Closeable;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.UUID;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.CharPointer;
import org.bytedeco.javacpp.Loader;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.bytedeco.onnxruntime.AllocatorWithDefaultOptions;
import org.bytedeco.onnxruntime.Env;
import org.bytedeco.onnxruntime.MemoryInfo;
import org.bytedeco.onnxruntime.RunOptions;
import org.bytedeco.onnxruntime.Session;
import org.bytedeco.onnxruntime.SessionOptions;
import org.bytedeco.onnxruntime.TypeInfo;
import org.bytedeco.onnxruntime.Value;
import org.bytedeco.onnxruntime.ValueVector;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.onnxruntime.util.ONNXUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class OnnxRuntimeRunner
implements Closeable {
    private static final Logger log = LoggerFactory.getLogger(OnnxRuntimeRunner.class);
    private Session session;
    private RunOptions runOptions;
    private MemoryInfo memoryInfo;
    private AllocatorWithDefaultOptions allocator;
    private SessionOptions sessionOptions;
    private static Env env;
    private Pointer bp;

    public OnnxRuntimeRunner(String modelUri) {
        if (env == null) {
            env = new Env(ONNXUtils.getOnnxLogLevelFromLogger(log), new BytePointer("nd4j-serving-onnx-session-" + UUID.randomUUID().toString()));
            env.retainReference();
        }
        this.sessionOptions = new SessionOptions();
        this.sessionOptions.SetGraphOptimizationLevel(2);
        this.sessionOptions.SetIntraOpNumThreads(1);
        this.sessionOptions.retainReference();
        this.allocator = new AllocatorWithDefaultOptions();
        this.allocator.retainReference();
        this.bp = Loader.getPlatform().toLowerCase().startsWith("windows") ? new CharPointer(modelUri) : new BytePointer(modelUri);
        this.runOptions = new RunOptions();
        this.memoryInfo = MemoryInfo.CreateCpu((int)1, (int)0);
        this.session = new Session(env, this.bp, this.sessionOptions);
        this.session.retainReference();
    }

    @Override
    public void close() {
        if (this.session != null) {
            this.session.close();
        }
        this.sessionOptions.releaseReference();
        this.allocator.releaseReference();
        this.runOptions.releaseReference();
    }

    public Map<String, INDArray> exec(Map<String, INDArray> input) {
        long numInputNodes = this.session.GetInputCount();
        long numOutputNodes = this.session.GetOutputCount();
        PointerPointer inputNodeNames = new PointerPointer(numInputNodes);
        PointerPointer outputNodeNames = new PointerPointer(numOutputNodes);
        Value inputVal = new Value(numInputNodes);
        int i = 0;
        while ((long)i < numInputNodes) {
            BytePointer inputName = this.session.GetInputName((long)i, this.allocator.asOrtAllocator());
            inputNodeNames.put((long)i, (Pointer)inputName);
            INDArray arr = input.get(inputName.getString());
            Value inputTensor = ONNXUtils.getTensor(arr, this.memoryInfo);
            Preconditions.checkState((boolean)inputTensor.IsTensor(), (String)"Input must be a tensor.");
            inputVal.position((long)i).put(inputTensor);
            ++i;
        }
        inputVal.position(0L);
        i = 0;
        while ((long)i < numOutputNodes) {
            BytePointer outputName = this.session.GetOutputName((long)i, this.allocator.asOrtAllocator());
            outputNodeNames.put((long)i, (Pointer)outputName);
            ++i;
        }
        ValueVector outputVector = this.session.Run(this.runOptions, inputNodeNames, inputVal, numInputNodes, outputNodeNames, numOutputNodes);
        outputVector.retainReference();
        LinkedHashMap<String, INDArray> ret = new LinkedHashMap<String, INDArray>();
        int i2 = 0;
        while ((long)i2 < numOutputNodes) {
            Value outValue = outputVector.get((long)i2);
            outValue.retainReference();
            TypeInfo typeInfo = this.session.GetOutputTypeInfo((long)i2);
            DataBuffer buffer = ONNXUtils.getDataBuffer(outValue);
            LongPointer longPointer = outValue.GetTensorTypeAndShapeInfo().GetShape();
            if (longPointer != null) {
                long[] shape = new long[(int)longPointer.capacity()];
                longPointer.get(shape);
                ret.put(((BytePointer)outputNodeNames.get(BytePointer.class, (long)i2)).getString(), Nd4j.create((DataBuffer)buffer).reshape(shape));
            } else {
                ret.put(((BytePointer)outputNodeNames.get(BytePointer.class, (long)i2)).getString(), Nd4j.create((DataBuffer)buffer));
            }
            ++i2;
        }
        return ret;
    }

    public static OnnxRuntimeRunnerBuilder builder() {
        return new OnnxRuntimeRunnerBuilder();
    }

    public static class OnnxRuntimeRunnerBuilder {
        private String modelUri;

        OnnxRuntimeRunnerBuilder() {
        }

        public OnnxRuntimeRunnerBuilder modelUri(String modelUri) {
            this.modelUri = modelUri;
            return this;
        }

        public OnnxRuntimeRunner build() {
            return new OnnxRuntimeRunner(this.modelUri);
        }

        public String toString() {
            return "OnnxRuntimeRunner.OnnxRuntimeRunnerBuilder(modelUri=" + this.modelUri + ")";
        }
    }
}

