/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.onnxruntime.runner;

import java.io.Closeable;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import onnx.Onnx;
import org.apache.commons.io.FileUtils;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.CharPointer;
import org.bytedeco.javacpp.Loader;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.bytedeco.onnxruntime.AllocatorWithDefaultOptions;
import org.bytedeco.onnxruntime.Env;
import org.bytedeco.onnxruntime.MemoryInfo;
import org.bytedeco.onnxruntime.RunOptions;
import org.bytedeco.onnxruntime.Session;
import org.bytedeco.onnxruntime.SessionOptions;
import org.bytedeco.onnxruntime.Value;
import org.bytedeco.onnxruntime.ValueVector;
import org.nd4j.autodiff.samediff.config.SDValue;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.onnxruntime.runner.enums.ONNXType;
import org.nd4j.onnxruntime.util.ONNXUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class OnnxRuntimeRunner
implements Closeable {
    private static final Logger log = LoggerFactory.getLogger(OnnxRuntimeRunner.class);
    private Session session;
    private RunOptions runOptions;
    private MemoryInfo memoryInfo;
    private AllocatorWithDefaultOptions allocator;
    private SessionOptions sessionOptions;
    private static Env env;
    private Pointer bp;
    private Onnx.ModelProto modelProto;

    public OnnxRuntimeRunner(String modelUri) {
        if (env == null) {
            env = new Env(ONNXUtils.getOnnxLogLevelFromLogger(log), new BytePointer("nd4j-serving-onnx-session-" + UUID.randomUUID()));
            env.retainReference();
        }
        this.sessionOptions = new SessionOptions();
        this.sessionOptions.SetGraphOptimizationLevel(2);
        this.sessionOptions.SetIntraOpNumThreads(1);
        this.sessionOptions.SetLogSeverityLevel(0);
        this.sessionOptions.retainReference();
        this.allocator = new AllocatorWithDefaultOptions();
        this.allocator.retainReference();
        if (modelUri != null) {
            this.bp = Loader.getPlatform().toLowerCase().startsWith("windows") ? new CharPointer(modelUri) : new BytePointer(modelUri);
            this.session = new Session(env, this.bp, this.sessionOptions);
            this.session.retainReference();
            try {
                this.modelProto = Onnx.ModelProto.parseFrom((byte[])FileUtils.readFileToByteArray((File)new File(modelUri)));
            }
            catch (IOException e) {
                e.printStackTrace();
            }
        }
        this.runOptions = new RunOptions();
        this.memoryInfo = MemoryInfo.CreateCpu((int)1, (int)0);
    }

    @Override
    public void close() {
        if (this.session != null) {
            this.session.close();
        }
        this.sessionOptions.releaseReference();
        this.allocator.releaseReference();
        this.runOptions.releaseReference();
    }

    public Map<String, SDValue> execValues(Map<String, SDValue> input) {
        long numInputNodes = this.session.GetInputCount();
        long numOutputNodes = this.session.GetOutputCount();
        PointerPointer inputNodeNames = new PointerPointer(numInputNodes);
        PointerPointer outputNodeNames = new PointerPointer(numOutputNodes);
        Value inputVal = new Value(numInputNodes);
        int i = 0;
        while ((long)i < numInputNodes) {
            BytePointer inputName = this.session.GetInputName((long)i, this.allocator.asOrtAllocator());
            inputNodeNames.put((long)i, (Pointer)inputName);
            ONNXType typeForInput = ONNXUtils.getTypeForInput(this.session, i);
            List arr = input.get(inputName.getString()).getListValue();
            if (arr.size() == 1 && typeForInput == ONNXType.ONNX_TYPE_TENSOR) {
                INDArray arr2 = (INDArray)arr.get(0);
                Value inputTensor = ONNXUtils.getTensor(arr2, this.memoryInfo);
                Preconditions.checkState((boolean)inputTensor.IsTensor(), (String)"Input must be a tensor.");
                inputVal.position((long)i).put(inputTensor);
            } else {
                if (arr.size() == 0) {
                    throw new IllegalArgumentException("Onnx Runtime does not support empty sequences! Found at input name " + inputName.getString());
                }
                if (arr.size() > 1 || typeForInput == ONNXType.ONNX_TYPE_SEQUENCE) {
                    ValueVector inputTensor = ONNXUtils.getSequence(arr, this.memoryInfo);
                    inputVal.position((long)i).put(Value.CreateSequence((ValueVector)inputTensor));
                }
            }
            ++i;
        }
        inputVal.position(0L);
        i = 0;
        while ((long)i < numOutputNodes) {
            BytePointer outputName = this.session.GetOutputName((long)i, this.allocator.asOrtAllocator());
            outputNodeNames.put((long)i, (Pointer)outputName);
            ++i;
        }
        ValueVector outputVector = this.session.Run(this.runOptions, inputNodeNames, inputVal, numInputNodes, outputNodeNames, numOutputNodes);
        outputVector.retainReference();
        LinkedHashMap<String, SDValue> ret = new LinkedHashMap<String, SDValue>();
        int i2 = 0;
        while ((long)i2 < numOutputNodes) {
            Value outValue = outputVector.get((long)i2);
            outValue.retainReference();
            if (outValue.IsTensor()) {
                INDArray arr = ONNXUtils.getArray(outValue);
                ret.put(((BytePointer)outputNodeNames.get(BytePointer.class, (long)i2)).getString(), SDValue.create((INDArray)arr));
            } else {
                INDArray[] seq = ONNXUtils.ndarraysFromSequence(outValue, this.allocator.asOrtAllocator());
                ret.put(((BytePointer)outputNodeNames.get(BytePointer.class, (long)i2)).getString(), SDValue.create(Arrays.asList(seq)));
            }
            ++i2;
        }
        return ret;
    }

    public Map<String, INDArray> exec(Map<String, INDArray> input) {
        long numInputNodes = this.session.GetInputCount();
        long numOutputNodes = this.session.GetOutputCount();
        PointerPointer inputNodeNames = new PointerPointer(numInputNodes);
        PointerPointer outputNodeNames = new PointerPointer(numOutputNodes);
        Value inputVal = new Value(numInputNodes);
        int i = 0;
        while ((long)i < numInputNodes) {
            BytePointer inputName = this.session.GetInputName((long)i, this.allocator.asOrtAllocator());
            inputNodeNames.put((long)i, (Pointer)inputName);
            INDArray arr = input.get(inputName.getString());
            Value inputTensor = ONNXUtils.getTensor(arr, this.memoryInfo);
            Preconditions.checkState((boolean)inputTensor.IsTensor(), (String)"Input must be a tensor.");
            inputVal.position((long)i).put(inputTensor);
            ++i;
        }
        inputVal.position(0L);
        i = 0;
        while ((long)i < numOutputNodes) {
            BytePointer outputName = this.session.GetOutputName((long)i, this.allocator.asOrtAllocator());
            outputNodeNames.put((long)i, (Pointer)outputName);
            ++i;
        }
        ValueVector outputVector = this.session.Run(this.runOptions, inputNodeNames, inputVal, numInputNodes, outputNodeNames, numOutputNodes);
        outputVector.retainReference();
        LinkedHashMap<String, INDArray> ret = new LinkedHashMap<String, INDArray>();
        int i2 = 0;
        while ((long)i2 < numOutputNodes) {
            Value outValue = outputVector.get((long)i2);
            outValue.retainReference();
            ONNXType typeForOutput = ONNXUtils.getTypeForOutput(this.session, i2);
            switch (typeForOutput) {
                case ONNX_TYPE_SEQUENCE: {
                    long count = outValue.GetCount();
                    break;
                }
                case ONNX_TYPE_TENSOR: {
                    DataBuffer buffer = ONNXUtils.getDataBuffer(outValue);
                    LongPointer longPointer = outValue.GetTensorTypeAndShapeInfo().GetShape();
                    if (longPointer != null) {
                        long[] shape = new long[(int)longPointer.capacity()];
                        longPointer.get(shape);
                        ret.put(((BytePointer)outputNodeNames.get(BytePointer.class, (long)i2)).getString(), Nd4j.create((DataBuffer)buffer).reshape(shape));
                        break;
                    }
                    ret.put(((BytePointer)outputNodeNames.get(BytePointer.class, (long)i2)).getString(), Nd4j.create((DataBuffer)buffer));
                    break;
                }
                default: {
                    throw new IllegalStateException("Unable to get type " + typeForOutput + " only accepts tensors and sequences.");
                }
            }
            ++i2;
        }
        return ret;
    }

    public static OnnxRuntimeRunnerBuilder builder() {
        return new OnnxRuntimeRunnerBuilder();
    }

    public Session getSession() {
        return this.session;
    }

    public RunOptions getRunOptions() {
        return this.runOptions;
    }

    public MemoryInfo getMemoryInfo() {
        return this.memoryInfo;
    }

    public AllocatorWithDefaultOptions getAllocator() {
        return this.allocator;
    }

    public SessionOptions getSessionOptions() {
        return this.sessionOptions;
    }

    public Pointer getBp() {
        return this.bp;
    }

    public Onnx.ModelProto getModelProto() {
        return this.modelProto;
    }

    public static class OnnxRuntimeRunnerBuilder {
        private String modelUri;

        OnnxRuntimeRunnerBuilder() {
        }

        public OnnxRuntimeRunnerBuilder modelUri(String modelUri) {
            this.modelUri = modelUri;
            return this;
        }

        public OnnxRuntimeRunner build() {
            return new OnnxRuntimeRunner(this.modelUri);
        }

        public String toString() {
            return "OnnxRuntimeRunner.OnnxRuntimeRunnerBuilder(modelUri=" + this.modelUri + ")";
        }
    }
}

