/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.tvm.runner;

import java.io.Closeable;
import java.util.LinkedHashMap;
import java.util.Map;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.PointerScope;
import org.bytedeco.tvm.DLContext;
import org.bytedeco.tvm.DLTensor;
import org.bytedeco.tvm.Module;
import org.bytedeco.tvm.PackedFunc;
import org.bytedeco.tvm.TVMArgs;
import org.bytedeco.tvm.TVMArgsSetter;
import org.bytedeco.tvm.TVMRetValue;
import org.bytedeco.tvm.TVMValue;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.tvm.util.TVMUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TvmRunner
implements Closeable {
    private static final Logger log = LoggerFactory.getLogger(TvmRunner.class);
    private static DLContext ctx;
    private Module modFactory;
    private TVMValue values;
    private IntPointer codes;
    private TVMArgsSetter setter;
    private TVMRetValue rv;
    private Module gmod;
    private PackedFunc getNumInputs;
    private PackedFunc getNumOutputs;
    private PackedFunc setInput;
    private PackedFunc getOutput;
    private PackedFunc run;

    public TvmRunner(String modelUri) {
        if (ctx == null) {
            ctx = new DLContext().device_type(1).device_id(0);
            ctx.retainReference();
        }
        try (PointerScope scope = new PointerScope(new Class[0]);){
            this.modFactory = Module.LoadFromFile((String)modelUri);
            this.values = new TVMValue(2L);
            this.codes = new IntPointer(2L);
            this.setter = new TVMArgsSetter(this.values, this.codes);
            this.setter.apply(0L, ctx);
            this.rv = new TVMRetValue();
            this.modFactory.GetFunction("default").CallPacked(new TVMArgs(this.values, this.codes, 1), this.rv);
            this.gmod = this.rv.asModule();
            this.getNumInputs = this.gmod.GetFunction("get_num_inputs");
            this.getNumOutputs = this.gmod.GetFunction("get_num_outputs");
            this.setInput = this.gmod.GetFunction("set_input");
            this.getOutput = this.gmod.GetFunction("get_output");
            this.run = this.gmod.GetFunction("run");
            this.modFactory.retainReference();
            this.values.retainReference();
            this.codes.retainReference();
            this.setter.retainReference();
            this.rv.retainReference();
            this.gmod.retainReference();
            this.getNumInputs.retainReference();
            this.getNumOutputs.retainReference();
            this.setInput.retainReference();
            this.getOutput.retainReference();
            this.run.retainReference();
        }
    }

    @Override
    public void close() {
        if (this.run != null) {
            this.run.releaseReference();
        }
        if (this.getOutput != null) {
            this.getOutput.releaseReference();
        }
        if (this.setInput != null) {
            this.setInput.releaseReference();
        }
        if (this.getNumOutputs != null) {
            this.getNumOutputs.releaseReference();
        }
        if (this.getNumInputs != null) {
            this.getNumInputs.releaseReference();
        }
        if (this.gmod != null) {
            this.gmod.releaseReference();
        }
        if (this.rv != null) {
            this.rv.releaseReference();
        }
        if (this.setter != null) {
            this.setter.releaseReference();
        }
        if (this.codes != null) {
            this.codes.releaseReference();
        }
        if (this.values != null) {
            this.values.releaseReference();
        }
        if (this.modFactory != null) {
            this.modFactory.releaseReference();
        }
    }

    public Map<String, INDArray> exec(Map<String, INDArray> input) {
        try (PointerScope scope = new PointerScope(new Class[0]);){
            this.getNumInputs.CallPacked(new TVMArgs(this.values, this.codes, 0), this.rv);
            long numInputNodes = this.rv.asLong();
            this.getNumOutputs.CallPacked(new TVMArgs(this.values, this.codes, 0), this.rv);
            long numOutputNodes = this.rv.asLong();
            for (Map.Entry<String, INDArray> e : input.entrySet()) {
                String name = e.getKey();
                INDArray arr = e.getValue();
                DLTensor inputTensor = TVMUtils.getTensor(arr, ctx);
                Preconditions.checkState((inputTensor != null ? 1 : 0) != 0, (String)"Input must be a tensor.");
                this.setter.apply(0L, new BytePointer(name));
                this.setter.apply(1L, inputTensor);
                this.setInput.CallPacked(new TVMArgs(this.values, this.codes, 2), this.rv);
            }
            this.run.CallPacked(new TVMArgs(this.values, this.codes, 0), this.rv);
            LinkedHashMap<String, INDArray> ret = new LinkedHashMap<String, INDArray>();
            int i = 0;
            while ((long)i < numOutputNodes) {
                this.setter.apply(0L, (long)i);
                this.getOutput.CallPacked(new TVMArgs(this.values, this.codes, 1), this.rv);
                DLTensor outputTensor = this.rv.asDLTensor();
                ret.put(Integer.toString(i), TVMUtils.getArray(outputTensor));
                ++i;
            }
            LinkedHashMap<String, INDArray> linkedHashMap = ret;
            return linkedHashMap;
        }
    }

    public static TvmRunnerBuilder builder() {
        return new TvmRunnerBuilder();
    }

    public static class TvmRunnerBuilder {
        private String modelUri;

        TvmRunnerBuilder() {
        }

        public TvmRunnerBuilder modelUri(String modelUri) {
            this.modelUri = modelUri;
            return this;
        }

        public TvmRunner build() {
            return new TvmRunner(this.modelUri);
        }

        public String toString() {
            return "TvmRunner.TvmRunnerBuilder(modelUri=" + this.modelUri + ")";
        }
    }
}

