/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.onnxruntime.util;

import org.bytedeco.javacpp.BoolPointer;
import org.bytedeco.javacpp.BooleanPointer;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerScope;
import org.bytedeco.javacpp.ShortPointer;
import org.bytedeco.javacpp.indexer.BooleanIndexer;
import org.bytedeco.javacpp.indexer.ByteIndexer;
import org.bytedeco.javacpp.indexer.DoubleIndexer;
import org.bytedeco.javacpp.indexer.FloatIndexer;
import org.bytedeco.javacpp.indexer.Indexer;
import org.bytedeco.javacpp.indexer.IntIndexer;
import org.bytedeco.javacpp.indexer.LongIndexer;
import org.bytedeco.javacpp.indexer.ShortIndexer;
import org.bytedeco.onnxruntime.MemoryInfo;
import org.bytedeco.onnxruntime.OrtMemoryInfo;
import org.bytedeco.onnxruntime.Value;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;

public class ONNXUtils {
    public static void validateType(DataType expected, INDArray array) {
        if (!array.dataType().equals((Object)expected)) {
            throw new RuntimeException("INDArray data type (" + array.dataType() + ") does not match required ONNX data type (" + expected + ")");
        }
    }

    public static DataType dataTypeForOnnxType(int dataType) {
        if (dataType == dataType) {
            return DataType.FLOAT;
        }
        if (dataType == 3) {
            return DataType.INT8;
        }
        if (dataType == 11) {
            return DataType.DOUBLE;
        }
        if (dataType == 9) {
            return DataType.BOOL;
        }
        if (dataType == 2) {
            return DataType.UINT8;
        }
        if (dataType == 4) {
            return DataType.UINT16;
        }
        if (dataType == 5) {
            return DataType.INT16;
        }
        if (dataType == 6) {
            return DataType.INT32;
        }
        if (dataType == 7) {
            return DataType.INT64;
        }
        if (dataType == 10) {
            return DataType.FLOAT16;
        }
        if (dataType == 12) {
            return DataType.UINT32;
        }
        if (dataType == 13) {
            return DataType.UINT64;
        }
        if (dataType == 16) {
            return DataType.BFLOAT16;
        }
        throw new IllegalArgumentException("Illegal data type " + dataType);
    }

    public static int onnxTypeForDataType(DataType dataType) {
        if (dataType == DataType.FLOAT) {
            return 1;
        }
        if (dataType == DataType.INT8) {
            return 3;
        }
        if (dataType == DataType.DOUBLE) {
            return 11;
        }
        if (dataType == DataType.BOOL) {
            return 9;
        }
        if (dataType == DataType.UINT8) {
            return 2;
        }
        if (dataType == DataType.UINT16) {
            return 4;
        }
        if (dataType == DataType.INT16) {
            return 5;
        }
        if (dataType == DataType.INT32) {
            return 6;
        }
        if (dataType == DataType.INT64) {
            return 7;
        }
        if (dataType == DataType.FLOAT16) {
            return 10;
        }
        if (dataType == DataType.UINT32) {
            return 12;
        }
        if (dataType == DataType.UINT64) {
            return 13;
        }
        if (dataType == DataType.BFLOAT16) {
            return 16;
        }
        throw new IllegalArgumentException("Illegal data type " + dataType);
    }

    public static INDArray getArray(Value value) {
        long[] shapeConvert;
        DataType dataType = ONNXUtils.dataTypeForOnnxType(value.GetTypeInfo().GetONNXType());
        LongPointer shape = value.GetTensorTypeAndShapeInfo().GetShape();
        if (shape != null) {
            shapeConvert = new long[(int)value.GetTensorTypeAndShapeInfo().GetDimensionsCount()];
            shape.get(shapeConvert);
        } else {
            shapeConvert = new long[]{1L};
        }
        DataBuffer getBuffer = ONNXUtils.getDataBuffer(value);
        Preconditions.checkState((boolean)dataType.equals((Object)getBuffer.dataType()), (String)"Data type must be equivalent as specified by the onnx metadata.");
        return Nd4j.create((DataBuffer)getBuffer, (long[])shapeConvert, (long[])Nd4j.getStrides((long[])shapeConvert), (long)0L);
    }

    public static int getOnnxLogLevelFromLogger(Logger logger) {
        if (logger.isTraceEnabled() || logger.isDebugEnabled()) {
            return 0;
        }
        if (logger.isInfoEnabled()) {
            return 1;
        }
        if (logger.isWarnEnabled()) {
            return 2;
        }
        if (logger.isErrorEnabled()) {
            return 3;
        }
        return 1;
    }

    public static Value getTensor(INDArray ndArray, MemoryInfo memoryInfo) {
        Pointer inputTensorValuesPtr;
        Pointer inputTensorValues = inputTensorValuesPtr = ndArray.data().pointer();
        long sizeInBytes = ndArray.length() * (long)ndArray.data().getElementSize();
        LongPointer dims = new LongPointer(ndArray.shape());
        Value ret = Value.CreateTensor((OrtMemoryInfo)memoryInfo.asOrtMemoryInfo(), (Pointer)inputTensorValues, (long)sizeInBytes, (LongPointer)dims, (long)ndArray.rank(), (int)ONNXUtils.onnxTypeForDataType(ndArray.dataType()));
        return ret;
    }

    public static DataBuffer getDataBuffer(Value tens) {
        if (tens.isNull()) {
            throw new IllegalArgumentException("Native underlying tensor value was null!");
        }
        try (PointerScope scope = new PointerScope(new Class[0]);){
            DataBuffer buffer = null;
            int type = tens.GetTensorTypeAndShapeInfo().GetElementType();
            long size = tens.GetTensorTypeAndShapeInfo().GetElementCount();
            switch (type) {
                case 1: {
                    FloatPointer pFloat = tens.GetTensorMutableDataFloat().capacity(size);
                    FloatIndexer floatIndexer = FloatIndexer.create((FloatPointer)pFloat);
                    buffer = Nd4j.createBuffer((Pointer)pFloat, (DataType)DataType.FLOAT, (long)size, (Indexer)floatIndexer);
                    break;
                }
                case 2: {
                    BytePointer pUint8 = tens.GetTensorMutableDataUByte().capacity(size);
                    ByteIndexer uint8Indexer = ByteIndexer.create((BytePointer)pUint8);
                    buffer = Nd4j.createBuffer((Pointer)pUint8, (DataType)DataType.UINT8, (long)size, (Indexer)uint8Indexer);
                    break;
                }
                case 3: {
                    BytePointer pInt8 = tens.GetTensorMutableDataByte().capacity(size);
                    ByteIndexer int8Indexer = ByteIndexer.create((BytePointer)pInt8);
                    buffer = Nd4j.createBuffer((Pointer)pInt8, (DataType)DataType.UINT8, (long)size, (Indexer)int8Indexer);
                    break;
                }
                case 4: {
                    ShortPointer pUint16 = tens.GetTensorMutableDataUShort().capacity(size);
                    ShortIndexer uint16Indexer = ShortIndexer.create((ShortPointer)pUint16);
                    buffer = Nd4j.createBuffer((Pointer)pUint16, (DataType)DataType.UINT16, (long)size, (Indexer)uint16Indexer);
                    break;
                }
                case 5: {
                    ShortPointer pInt16 = tens.GetTensorMutableDataShort().capacity(size);
                    ShortIndexer int16Indexer = ShortIndexer.create((ShortPointer)pInt16);
                    buffer = Nd4j.createBuffer((Pointer)pInt16, (DataType)DataType.INT16, (long)size, (Indexer)int16Indexer);
                    break;
                }
                case 6: {
                    IntPointer pInt32 = tens.GetTensorMutableDataInt().capacity(size);
                    IntIndexer int32Indexer = IntIndexer.create((IntPointer)pInt32);
                    buffer = Nd4j.createBuffer((Pointer)pInt32, (DataType)DataType.INT32, (long)size, (Indexer)int32Indexer);
                    break;
                }
                case 7: {
                    LongPointer pInt64 = tens.GetTensorMutableDataLong().capacity(size);
                    LongIndexer int64Indexer = LongIndexer.create((LongPointer)pInt64);
                    buffer = Nd4j.createBuffer((Pointer)pInt64, (DataType)DataType.INT64, (long)size, (Indexer)int64Indexer);
                    break;
                }
                case 8: {
                    BytePointer pString = tens.GetTensorMutableDataByte().capacity(size);
                    ByteIndexer stringIndexer = ByteIndexer.create((BytePointer)pString);
                    buffer = Nd4j.createBuffer((Pointer)pString, (DataType)DataType.INT8, (long)size, (Indexer)stringIndexer);
                    break;
                }
                case 9: {
                    BoolPointer pBool = tens.GetTensorMutableDataBool().capacity(size);
                    BooleanIndexer boolIndexer = BooleanIndexer.create((BooleanPointer)new BooleanPointer((Pointer)pBool));
                    buffer = Nd4j.createBuffer((Pointer)pBool, (DataType)DataType.BOOL, (long)size, (Indexer)boolIndexer);
                    break;
                }
                case 10: {
                    ShortPointer pFloat16 = tens.GetTensorMutableDataShort().capacity(size);
                    ShortIndexer float16Indexer = ShortIndexer.create((ShortPointer)pFloat16);
                    buffer = Nd4j.createBuffer((Pointer)pFloat16, (DataType)DataType.FLOAT16, (long)size, (Indexer)float16Indexer);
                    break;
                }
                case 11: {
                    DoublePointer pDouble = tens.GetTensorMutableDataDouble().capacity(size);
                    DoubleIndexer doubleIndexer = DoubleIndexer.create((DoublePointer)pDouble);
                    buffer = Nd4j.createBuffer((Pointer)pDouble, (DataType)DataType.DOUBLE, (long)size, (Indexer)doubleIndexer);
                    break;
                }
                case 12: {
                    IntPointer pUint32 = tens.GetTensorMutableDataUInt().capacity(size);
                    IntIndexer uint32Indexer = IntIndexer.create((IntPointer)pUint32);
                    buffer = Nd4j.createBuffer((Pointer)pUint32, (DataType)DataType.UINT32, (long)size, (Indexer)uint32Indexer);
                    break;
                }
                case 13: {
                    LongPointer pUint64 = tens.GetTensorMutableDataULong().capacity(size);
                    LongIndexer uint64Indexer = LongIndexer.create((LongPointer)pUint64);
                    buffer = Nd4j.createBuffer((Pointer)pUint64, (DataType)DataType.UINT64, (long)size, (Indexer)uint64Indexer);
                    break;
                }
                case 16: {
                    ShortPointer pBfloat16 = tens.GetTensorMutableDataShort().capacity(size);
                    ShortIndexer bfloat16Indexer = ShortIndexer.create((ShortPointer)pBfloat16);
                    buffer = Nd4j.createBuffer((Pointer)pBfloat16, (DataType)DataType.BFLOAT16, (long)size, (Indexer)bfloat16Indexer);
                    break;
                }
                default: {
                    throw new RuntimeException("Unsupported data type encountered");
                }
            }
            DataBuffer dataBuffer = buffer;
            return dataBuffer;
        }
    }
}

