/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx;

import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import onnx.Onnx;

public class TypeConverter {
    public static void verifyType(Onnx.TypeProto typeProto, OrderedTensorType type) {
        Onnx.TensorShapeProto shape = typeProto.getTensorType().getShape();
        if (shape != null) {
            if (shape.getDimCount() != type.rank()) {
                throw new IllegalArgumentException("Onnx shape of does not match Vespa shape");
            }
            for (int onnxIndex = 0; onnxIndex < type.dimensions().size(); ++onnxIndex) {
                int vespaIndex = type.dimensionMap(onnxIndex);
                Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(onnxIndex);
                TensorType.Dimension vespaDimension = (TensorType.Dimension)type.type().dimensions().get(vespaIndex);
                if (onnxDimension.getDimValue() == vespaDimension.size().orElse(-1L).longValue()) continue;
                throw new IllegalArgumentException("Onnx dimensions of does not match Vespa dimensions");
            }
        }
    }

    public static OrderedTensorType fromOnnxType(Onnx.TypeProto type) {
        return TypeConverter.fromOnnxType(type, "d");
    }

    public static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) {
        Onnx.TensorShapeProto shape = type.getTensorType().getShape();
        OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
        for (int i = 0; i < shape.getDimCount(); ++i) {
            String dimensionName = dimensionPrefix + i;
            Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i);
            if (onnxDimension.getDimValue() >= 0L) {
                builder.add(TensorType.Dimension.indexed((String)dimensionName, (long)onnxDimension.getDimValue()));
                continue;
            }
            builder.add(TensorType.Dimension.indexed((String)dimensionName));
        }
        return builder.build();
    }
}

