/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.models.evaluation;

import ai.vespa.models.evaluation.OnnxModel;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.TypeContext;
import java.util.ArrayList;
import java.util.Deque;
import java.util.HashMap;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

class OnnxExpressionNode
extends CompositeNode {
    private final OnnxModel model;
    private final String onnxOutputName;
    private final TensorType expectedType;
    private final String outputAs;
    private final List<String> modelInputs = new ArrayList<String>();
    private final List<ExpressionNode> inputRefs = new ArrayList<ExpressionNode>();

    OnnxExpressionNode(OnnxModel model, String onnxOutputName, TensorType expectedType, String outputAs) {
        this.model = model;
        this.onnxOutputName = onnxOutputName;
        this.expectedType = expectedType;
        this.outputAs = outputAs;
        for (OnnxModel.InputSpec input : model.inputSpecs) {
            this.modelInputs.add(input.onnxName);
            Optional<Reference> optRef = OnnxExpressionNode.parseOnnxInput(input.source);
            if (optRef.isEmpty()) {
                throw new IllegalArgumentException("Bad input source for ONNX model " + model.name() + ": '" + String.valueOf(input) + "'");
            }
            Reference ref = optRef.get();
            this.inputRefs.add((ExpressionNode)new ReferenceNode(ref));
        }
    }

    static Optional<Reference> parseOnnxInput(String input) {
        Optional optRef = Reference.simple((String)input);
        if (optRef.isPresent()) {
            return optRef;
        }
        try {
            Reference ref = Reference.fromIdentifier((String)input);
            return Optional.of(ref);
        }
        catch (Exception exception) {
            return Optional.empty();
        }
    }

    public List<ExpressionNode> children() {
        return List.copyOf(this.inputRefs);
    }

    public CompositeNode setChildren(List<ExpressionNode> children) {
        if (this.inputRefs.size() != children.size()) {
            throw new IllegalArgumentException("bad setChildren");
        }
        this.inputRefs.clear();
        this.inputRefs.addAll(children);
        return this;
    }

    public Value evaluate(Context context) {
        HashMap<String, Tensor> inputs = new HashMap<String, Tensor>();
        for (int i = 0; i < this.modelInputs.size(); ++i) {
            Value inputValue = this.inputRefs.get(i).evaluate(context);
            inputs.put(this.modelInputs.get(i), inputValue.asTensor());
        }
        return new TensorValue(this.model.unmappedEvaluate(inputs, this.onnxOutputName));
    }

    public TensorType type(TypeContext<Reference> context) {
        return this.expectedType;
    }

    public int hashCode() {
        return Objects.hash("OnnxExpressionNode", this.model.name(), this.onnxOutputName);
    }

    public StringBuilder toString(StringBuilder b, SerializationContext context, Deque<String> path, CompositeNode parent) {
        b.append("onnx_expression_node(").append(this.model.name()).append(")");
        if (this.outputAs != null && !this.outputAs.equals("")) {
            b.append(".").append(this.outputAs);
        }
        return b;
    }
}

