/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.modelintegration.evaluator;

import ai.onnxruntime.NodeInfo;
import ai.onnxruntime.OnnxJavaType;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.TensorInfo;
import ai.onnxruntime.ValueInfo;
import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

class TensorConverter {
    TensorConverter() {
    }

    static Map<String, OnnxTensor> toOnnxTensors(Map<String, Tensor> tensorMap, OrtEnvironment env, OrtSession session) throws OrtException {
        HashMap<String, OnnxTensor> result = new HashMap<String, OnnxTensor>();
        for (String name : tensorMap.keySet()) {
            Tensor vespaTensor = tensorMap.get(name);
            name = TensorConverter.toOnnxName(name, session.getInputInfo().keySet());
            TensorInfo onnxTensorInfo = TensorConverter.toTensorInfo(((NodeInfo)session.getInputInfo().get(name)).getInfo());
            OnnxTensor onnxTensor = TensorConverter.toOnnxTensor(vespaTensor, onnxTensorInfo, env);
            result.put(name, onnxTensor);
        }
        return result;
    }

    static OnnxTensor toOnnxTensor(Tensor vespaTensor, TensorInfo onnxTensorInfo, OrtEnvironment environment) throws OrtException {
        if (!(vespaTensor instanceof IndexedTensor)) {
            throw new IllegalArgumentException("OnnxEvaluator currently only supports tensors with indexed dimensions");
        }
        IndexedTensor tensor = (IndexedTensor)vespaTensor;
        ByteBuffer buffer = ByteBuffer.allocateDirect((int)tensor.size() * onnxTensorInfo.type.size).order(ByteOrder.nativeOrder());
        if (onnxTensorInfo.type == OnnxJavaType.FLOAT) {
            int i = 0;
            while ((long)i < tensor.size()) {
                buffer.putFloat(tensor.getFloat((long)i));
                ++i;
            }
            return OnnxTensor.createTensor((OrtEnvironment)environment, (FloatBuffer)buffer.rewind().asFloatBuffer(), (long[])tensor.shape());
        }
        if (onnxTensorInfo.type == OnnxJavaType.DOUBLE) {
            int i = 0;
            while ((long)i < tensor.size()) {
                buffer.putDouble(tensor.get((long)i));
                ++i;
            }
            return OnnxTensor.createTensor((OrtEnvironment)environment, (DoubleBuffer)buffer.rewind().asDoubleBuffer(), (long[])tensor.shape());
        }
        if (onnxTensorInfo.type == OnnxJavaType.INT8) {
            int i = 0;
            while ((long)i < tensor.size()) {
                buffer.put((byte)tensor.get((long)i));
                ++i;
            }
            return OnnxTensor.createTensor((OrtEnvironment)environment, (ByteBuffer)buffer.rewind(), (long[])tensor.shape());
        }
        if (onnxTensorInfo.type == OnnxJavaType.INT16) {
            int i = 0;
            while ((long)i < tensor.size()) {
                buffer.putShort((short)tensor.get((long)i));
                ++i;
            }
            return OnnxTensor.createTensor((OrtEnvironment)environment, (ShortBuffer)buffer.rewind().asShortBuffer(), (long[])tensor.shape());
        }
        if (onnxTensorInfo.type == OnnxJavaType.INT32) {
            int i = 0;
            while ((long)i < tensor.size()) {
                buffer.putInt((int)tensor.get((long)i));
                ++i;
            }
            return OnnxTensor.createTensor((OrtEnvironment)environment, (IntBuffer)buffer.rewind().asIntBuffer(), (long[])tensor.shape());
        }
        if (onnxTensorInfo.type == OnnxJavaType.INT64) {
            int i = 0;
            while ((long)i < tensor.size()) {
                buffer.putLong((long)tensor.get((long)i));
                ++i;
            }
            return OnnxTensor.createTensor((OrtEnvironment)environment, (LongBuffer)buffer.rewind().asLongBuffer(), (long[])tensor.shape());
        }
        throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensorInfo.type);
    }

    static Tensor toVespaTensor(OnnxValue onnxValue) {
        if (!(onnxValue instanceof OnnxTensor)) {
            throw new IllegalArgumentException("ONNX value is not a tensor: maps and sequences are not yet supported");
        }
        OnnxTensor onnxTensor = (OnnxTensor)onnxValue;
        TensorInfo tensorInfo = onnxTensor.getInfo();
        TensorType type = TensorConverter.toVespaType((ValueInfo)onnxTensor.getInfo());
        DimensionSizes sizes = TensorConverter.sizesFromType(type);
        IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of((TensorType)type, (DimensionSizes)sizes);
        if (tensorInfo.type == OnnxJavaType.FLOAT) {
            FloatBuffer buffer = onnxTensor.getFloatBuffer();
            for (long i = 0L; i < sizes.totalSize(); ++i) {
                builder.cellByDirectIndex(i, buffer.get());
            }
        } else if (tensorInfo.type == OnnxJavaType.DOUBLE) {
            DoubleBuffer buffer = onnxTensor.getDoubleBuffer();
            for (long i = 0L; i < sizes.totalSize(); ++i) {
                builder.cellByDirectIndex(i, buffer.get());
            }
        } else if (tensorInfo.type == OnnxJavaType.INT8) {
            ByteBuffer buffer = onnxTensor.getByteBuffer();
            for (long i = 0L; i < sizes.totalSize(); ++i) {
                builder.cellByDirectIndex(i, (float)buffer.get());
            }
        } else if (tensorInfo.type == OnnxJavaType.INT16) {
            ShortBuffer buffer = onnxTensor.getShortBuffer();
            for (long i = 0L; i < sizes.totalSize(); ++i) {
                builder.cellByDirectIndex(i, (float)buffer.get());
            }
        } else if (tensorInfo.type == OnnxJavaType.INT32) {
            IntBuffer buffer = onnxTensor.getIntBuffer();
            for (long i = 0L; i < sizes.totalSize(); ++i) {
                builder.cellByDirectIndex(i, (float)buffer.get());
            }
        } else if (tensorInfo.type == OnnxJavaType.INT64) {
            LongBuffer buffer = onnxTensor.getLongBuffer();
            for (long i = 0L; i < sizes.totalSize(); ++i) {
                builder.cellByDirectIndex(i, (float)buffer.get());
            }
        } else {
            throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensor.getInfo().type);
        }
        return builder.build();
    }

    private static DimensionSizes sizesFromType(TensorType type) {
        DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size());
        for (int i = 0; i < type.dimensions().size(); ++i) {
            builder.set(i, ((Long)((TensorType.Dimension)type.dimensions().get(i)).size().get()).longValue());
        }
        return builder.build();
    }

    static Map<String, TensorType> toVespaTypes(Map<String, NodeInfo> infoMap) {
        return infoMap.entrySet().stream().collect(Collectors.toMap(e -> TensorConverter.asValidName((String)e.getKey()), e -> TensorConverter.toVespaType(((NodeInfo)e.getValue()).getInfo())));
    }

    static String asValidName(String name) {
        return OnnxImporter.asValidIdentifier(name);
    }

    static String toOnnxName(String name, Set<String> onnxNames) {
        if (onnxNames.contains(name)) {
            return name;
        }
        for (String onnxName : onnxNames) {
            if (!TensorConverter.asValidName(onnxName).equals(name)) continue;
            return onnxName;
        }
        throw new IllegalArgumentException("ONNX model has no input with name " + name);
    }

    static TensorType toVespaType(ValueInfo valueInfo) {
        TensorInfo tensorInfo = TensorConverter.toTensorInfo(valueInfo);
        TensorType.Builder builder = new TensorType.Builder(TensorConverter.toVespaValueType(tensorInfo.onnxType));
        long[] shape = tensorInfo.getShape();
        for (int i = 0; i < shape.length; ++i) {
            long dimSize = shape[i];
            String dimName = "d" + i;
            if (dimSize > 0L) {
                builder.indexed(dimName, dimSize);
                continue;
            }
            builder.indexed(dimName);
        }
        return builder.build();
    }

    private static TensorType.Value toVespaValueType(TensorInfo.OnnxTensorType onnxType) {
        switch (onnxType) {
            case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: {
                return TensorType.Value.INT8;
            }
            case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: {
                return TensorType.Value.BFLOAT16;
            }
            case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: {
                return TensorType.Value.FLOAT;
            }
            case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: {
                return TensorType.Value.DOUBLE;
            }
        }
        return TensorType.Value.DOUBLE;
    }

    private static TensorInfo toTensorInfo(ValueInfo valueInfo) {
        if (!(valueInfo instanceof TensorInfo)) {
            throw new IllegalArgumentException("ONNX value is not a tensor: maps and sequences are not yet supported");
        }
        return (TensorInfo)valueInfo;
    }
}

