/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.tvm.util;

import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.ShortPointer;
import org.bytedeco.javacpp.indexer.Bfloat16Indexer;
import org.bytedeco.javacpp.indexer.ByteIndexer;
import org.bytedeco.javacpp.indexer.DoubleIndexer;
import org.bytedeco.javacpp.indexer.FloatIndexer;
import org.bytedeco.javacpp.indexer.HalfIndexer;
import org.bytedeco.javacpp.indexer.Indexer;
import org.bytedeco.javacpp.indexer.IntIndexer;
import org.bytedeco.javacpp.indexer.LongIndexer;
import org.bytedeco.javacpp.indexer.ShortIndexer;
import org.bytedeco.javacpp.indexer.UByteIndexer;
import org.bytedeco.javacpp.indexer.UIntIndexer;
import org.bytedeco.javacpp.indexer.UShortIndexer;
import org.bytedeco.tvm.DLDataType;
import org.bytedeco.tvm.DLDevice;
import org.bytedeco.tvm.DLTensor;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class TVMUtils {
    public static DataType dataTypeForTvmType(DLDataType dataType) {
        if (dataType.code() == 0 && dataType.bits() == 8) {
            return DataType.INT8;
        }
        if (dataType.code() == 0 && dataType.bits() == 16) {
            return DataType.INT16;
        }
        if (dataType.code() == 0 && dataType.bits() == 32) {
            return DataType.INT32;
        }
        if (dataType.code() == 0 && dataType.bits() == 64) {
            return DataType.INT64;
        }
        if (dataType.code() == 1 && dataType.bits() == 8) {
            return DataType.UINT8;
        }
        if (dataType.code() == 1 && dataType.bits() == 16) {
            return DataType.UINT16;
        }
        if (dataType.code() == 1 && dataType.bits() == 32) {
            return DataType.UINT32;
        }
        if (dataType.code() == 1 && dataType.bits() == 64) {
            return DataType.UINT64;
        }
        if (dataType.code() == 2 && dataType.bits() == 16) {
            return DataType.FLOAT16;
        }
        if (dataType.code() == 2 && dataType.bits() == 32) {
            return DataType.FLOAT;
        }
        if (dataType.code() == 2 && dataType.bits() == 64) {
            return DataType.DOUBLE;
        }
        if (dataType.code() == 4 && dataType.bits() == 16) {
            return DataType.BFLOAT16;
        }
        throw new IllegalArgumentException("Illegal data type code " + dataType.code() + " with bits " + dataType.bits());
    }

    public static DLDataType tvmTypeForDataType(DataType dataType) {
        if (dataType == DataType.INT8) {
            return new DLDataType().code((byte)0).bits((byte)8).lanes((short)1);
        }
        if (dataType == DataType.INT16) {
            return new DLDataType().code((byte)0).bits((byte)16).lanes((short)1);
        }
        if (dataType == DataType.INT32) {
            return new DLDataType().code((byte)0).bits((byte)32).lanes((short)1);
        }
        if (dataType == DataType.INT64) {
            return new DLDataType().code((byte)0).bits((byte)64).lanes((short)1);
        }
        if (dataType == DataType.UINT8) {
            return new DLDataType().code((byte)1).bits((byte)8).lanes((short)1);
        }
        if (dataType == DataType.UINT16) {
            return new DLDataType().code((byte)1).bits((byte)16).lanes((short)1);
        }
        if (dataType == DataType.UINT32) {
            return new DLDataType().code((byte)1).bits((byte)32).lanes((short)1);
        }
        if (dataType == DataType.UINT64) {
            return new DLDataType().code((byte)1).bits((byte)64).lanes((short)1);
        }
        if (dataType == DataType.FLOAT16) {
            return new DLDataType().code((byte)2).bits((byte)16).lanes((short)1);
        }
        if (dataType == DataType.FLOAT) {
            return new DLDataType().code((byte)2).bits((byte)32).lanes((short)1);
        }
        if (dataType == DataType.DOUBLE) {
            return new DLDataType().code((byte)2).bits((byte)64).lanes((short)1);
        }
        if (dataType == DataType.BFLOAT16) {
            return new DLDataType().code((byte)4).bits((byte)16).lanes((short)1);
        }
        throw new IllegalArgumentException("Illegal data type " + dataType);
    }

    public static INDArray getArray(DLTensor value) {
        long[] strideConvert;
        long[] shapeConvert;
        DataType dataType = TVMUtils.dataTypeForTvmType(value.dtype());
        LongPointer shape = value.shape();
        LongPointer stride = value.strides();
        if (shape != null) {
            shapeConvert = new long[value.ndim()];
            shape.get(shapeConvert);
        } else {
            shapeConvert = new long[]{1L};
        }
        if (stride != null) {
            strideConvert = new long[value.ndim()];
            stride.get(strideConvert);
        } else {
            strideConvert = Nd4j.getStrides((long[])shapeConvert);
        }
        long size = 1L;
        for (int i = 0; i < shapeConvert.length; ++i) {
            size *= shapeConvert[i];
        }
        DataBuffer getBuffer = TVMUtils.getDataBuffer(value, size *= (long)(value.dtype().bits() / 8));
        Preconditions.checkState((boolean)dataType.equals((Object)getBuffer.dataType()), (String)"Data type must be equivalent as specified by the tvm metadata.");
        return Nd4j.create((DataBuffer)getBuffer, (long[])shapeConvert, (long[])strideConvert, (long)0L);
    }

    public static DLTensor getTensor(INDArray ndArray, DLDevice ctx) {
        DLTensor ret = new DLTensor();
        ret.data(ndArray.data().pointer());
        ret.device(ctx);
        ret.ndim(ndArray.rank());
        ret.dtype(TVMUtils.tvmTypeForDataType(ndArray.dataType()));
        ret.shape(new LongPointer(ndArray.shape()));
        ret.strides(new LongPointer(ndArray.stride()));
        ret.byte_offset(ndArray.offset());
        return ret;
    }

    public static DataBuffer getDataBuffer(DLTensor tens, long size) {
        DataBuffer buffer = null;
        DataType type = TVMUtils.dataTypeForTvmType(tens.dtype());
        switch (type) {
            case BYTE: {
                BytePointer pInt8 = new BytePointer(tens.data()).capacity(size);
                ByteIndexer int8Indexer = ByteIndexer.create((BytePointer)pInt8);
                buffer = Nd4j.createBuffer((Pointer)pInt8, (DataType)type, (long)size, (Indexer)int8Indexer);
                break;
            }
            case SHORT: {
                ShortPointer pInt16 = new ShortPointer(tens.data()).capacity(size);
                ShortIndexer int16Indexer = ShortIndexer.create((ShortPointer)pInt16);
                buffer = Nd4j.createBuffer((Pointer)pInt16, (DataType)type, (long)size, (Indexer)int16Indexer);
                break;
            }
            case INT: {
                IntPointer pInt32 = new IntPointer(tens.data()).capacity(size);
                IntIndexer int32Indexer = IntIndexer.create((IntPointer)pInt32);
                buffer = Nd4j.createBuffer((Pointer)pInt32, (DataType)type, (long)size, (Indexer)int32Indexer);
                break;
            }
            case LONG: {
                LongPointer pInt64 = new LongPointer(tens.data()).capacity(size);
                LongIndexer int64Indexer = LongIndexer.create((LongPointer)pInt64);
                buffer = Nd4j.createBuffer((Pointer)pInt64, (DataType)type, (long)size, (Indexer)int64Indexer);
                break;
            }
            case UBYTE: {
                BytePointer pUint8 = new BytePointer(tens.data()).capacity(size);
                UByteIndexer uint8Indexer = UByteIndexer.create((BytePointer)pUint8);
                buffer = Nd4j.createBuffer((Pointer)pUint8, (DataType)type, (long)size, (Indexer)uint8Indexer);
                break;
            }
            case UINT16: {
                ShortPointer pUint16 = new ShortPointer(tens.data()).capacity(size);
                UShortIndexer uint16Indexer = UShortIndexer.create((ShortPointer)pUint16);
                buffer = Nd4j.createBuffer((Pointer)pUint16, (DataType)type, (long)size, (Indexer)uint16Indexer);
                break;
            }
            case UINT32: {
                IntPointer pUint32 = new IntPointer(tens.data()).capacity(size);
                UIntIndexer uint32Indexer = UIntIndexer.create((IntPointer)pUint32);
                buffer = Nd4j.createBuffer((Pointer)pUint32, (DataType)type, (long)size, (Indexer)uint32Indexer);
                break;
            }
            case UINT64: {
                LongPointer pUint64 = new LongPointer(tens.data()).capacity(size);
                LongIndexer uint64Indexer = LongIndexer.create((LongPointer)pUint64);
                buffer = Nd4j.createBuffer((Pointer)pUint64, (DataType)type, (long)size, (Indexer)uint64Indexer);
                break;
            }
            case HALF: {
                ShortPointer pFloat16 = new ShortPointer(tens.data()).capacity(size);
                HalfIndexer float16Indexer = HalfIndexer.create((ShortPointer)pFloat16);
                buffer = Nd4j.createBuffer((Pointer)pFloat16, (DataType)type, (long)size, (Indexer)float16Indexer);
                break;
            }
            case FLOAT: {
                FloatPointer pFloat = new FloatPointer(tens.data()).capacity(size);
                FloatIndexer floatIndexer = FloatIndexer.create((FloatPointer)pFloat);
                buffer = Nd4j.createBuffer((Pointer)pFloat, (DataType)type, (long)size, (Indexer)floatIndexer);
                break;
            }
            case DOUBLE: {
                DoublePointer pDouble = new DoublePointer(tens.data()).capacity(size);
                DoubleIndexer doubleIndexer = DoubleIndexer.create((DoublePointer)pDouble);
                buffer = Nd4j.createBuffer((Pointer)pDouble, (DataType)type, (long)size, (Indexer)doubleIndexer);
                break;
            }
            case BFLOAT16: {
                ShortPointer pBfloat16 = new ShortPointer(tens.data()).capacity(size);
                Bfloat16Indexer bfloat16Indexer = Bfloat16Indexer.create((ShortPointer)pBfloat16);
                buffer = Nd4j.createBuffer((Pointer)pBfloat16, (DataType)type, (long)size, (Indexer)bfloat16Indexer);
                break;
            }
            default: {
                throw new RuntimeException("Unsupported data type encountered");
            }
        }
        return buffer;
    }
}

