/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.triton;

import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.triton.TritonOnnxClient;
import com.yahoo.jdisc.ResourceReference;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;

class TritonOnnxEvaluator
implements OnnxEvaluator {
    private static final Logger log = Logger.getLogger(TritonOnnxEvaluator.class.getName());
    private final String modelName;
    private final ResourceReference modelReference;
    private final TritonOnnxClient tritonClient;
    private final boolean isExplicitControl;
    private TritonOnnxClient.ModelMetadata modelMetadata;

    TritonOnnxEvaluator(String modelName, ResourceReference modelReference, TritonOnnxClient tritonClient, boolean isExplicitControl) {
        this.modelName = modelName;
        this.modelReference = modelReference;
        this.tritonClient = tritonClient;
        this.isExplicitControl = isExplicitControl;
        this.modelMetadata = tritonClient.getModelMetadata(modelName);
    }

    @Override
    public Tensor evaluate(Map<String, Tensor> inputs, String output) {
        return this.evaluate(inputs).get(output);
    }

    @Override
    public Map<String, Tensor> evaluate(Map<String, Tensor> inputs) {
        try {
            return this.tritonClient.evaluate(this.modelName, this.modelMetadata, inputs);
        }
        catch (TritonOnnxClient.TritonException e) {
            if (!this.isExplicitControl) {
                throw e;
            }
            log.log(Level.WARNING, "Failed to evaluate ONNX model " + this.modelName + " in Trion, will retry after reload.", e);
            this.tritonClient.loadUntilModelReady(this.modelName);
            this.modelMetadata = this.tritonClient.getModelMetadata(this.modelName);
            return this.tritonClient.evaluate(this.modelName, this.modelMetadata, inputs);
        }
    }

    @Override
    public Map<String, OnnxEvaluator.IdAndType> getInputs() {
        HashMap<String, OnnxEvaluator.IdAndType> result = new HashMap<String, OnnxEvaluator.IdAndType>();
        this.modelMetadata.inputs.forEach((name, type) -> result.put((String)name, new OnnxEvaluator.IdAndType((String)name, (TensorType)type)));
        return result;
    }

    @Override
    public Map<String, OnnxEvaluator.IdAndType> getOutputs() {
        HashMap<String, OnnxEvaluator.IdAndType> result = new HashMap<String, OnnxEvaluator.IdAndType>();
        this.modelMetadata.outputs.forEach((name, type) -> result.put((String)name, new OnnxEvaluator.IdAndType((String)name, (TensorType)type)));
        return result;
    }

    @Override
    public Map<String, TensorType> getInputInfo() {
        return this.modelMetadata.inputs;
    }

    @Override
    public Map<String, TensorType> getOutputInfo() {
        return this.modelMetadata.outputs;
    }

    @Override
    public void close() {
        this.modelReference.close();
    }
}

