/*
 * Decompiled with CFR 0.152.
 */
package ai.onnxruntime;

import ai.onnxruntime.OnnxJavaType;
import ai.onnxruntime.TensorInfo;
import java.lang.reflect.Array;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.logging.Logger;

public final class OrtUtil {
    private static final Logger logger = Logger.getLogger(OrtUtil.class.getName());

    private OrtUtil() {
    }

    public static int[] transformShape(long[] shape) {
        if (shape.length == 0 || shape.length > 8) {
            throw new IllegalArgumentException("Arrays with less than 1 and greater than 8 dimensions are not supported.");
        }
        int[] newShape = new int[shape.length];
        for (int i = 0; i < shape.length; ++i) {
            long curDim = shape[i];
            if (curDim < 0L || curDim > Integer.MAX_VALUE) {
                throw new IllegalArgumentException("Invalid shape for a Java array, expected non-negative entries smaller than Integer.MAX_VALUE. Found " + Arrays.toString(shape));
            }
            newShape[i] = (int)curDim;
        }
        return newShape;
    }

    public static long[] transformShape(int[] shape) {
        if (shape.length == 0 || shape.length > 8) {
            throw new IllegalArgumentException("Arrays with less than 1 and greater than 8 dimensions are not supported.");
        }
        long[] newShape = new long[shape.length];
        for (int i = 0; i < shape.length; ++i) {
            long curDim = shape[i];
            if (curDim < 1L) {
                throw new IllegalArgumentException("Invalid shape for a Java array, expected positive entries smaller than Integer.MAX_VALUE. Found " + Arrays.toString(shape));
            }
            newShape[i] = curDim;
        }
        return newShape;
    }

    public static Object newBooleanArray(long[] shape) {
        int[] intShape = OrtUtil.transformShape(shape);
        return Array.newInstance(Boolean.TYPE, intShape);
    }

    public static Object newByteArray(long[] shape) {
        int[] intShape = OrtUtil.transformShape(shape);
        return Array.newInstance(Byte.TYPE, intShape);
    }

    public static Object newShortArray(long[] shape) {
        int[] intShape = OrtUtil.transformShape(shape);
        return Array.newInstance(Short.TYPE, intShape);
    }

    public static Object newIntArray(long[] shape) {
        int[] intShape = OrtUtil.transformShape(shape);
        return Array.newInstance(Integer.TYPE, intShape);
    }

    public static Object newLongArray(long[] shape) {
        int[] intShape = OrtUtil.transformShape(shape);
        return Array.newInstance(Long.TYPE, intShape);
    }

    public static Object newFloatArray(long[] shape) {
        int[] intShape = OrtUtil.transformShape(shape);
        return Array.newInstance(Float.TYPE, intShape);
    }

    public static Object newDoubleArray(long[] shape) {
        int[] intShape = OrtUtil.transformShape(shape);
        return Array.newInstance(Double.TYPE, intShape);
    }

    public static Object newStringArray(long[] shape) {
        int[] intShape = OrtUtil.transformShape(shape);
        return Array.newInstance(String.class, intShape);
    }

    public static Object reshape(boolean[] input, long[] shape) {
        Object output = OrtUtil.newBooleanArray(shape);
        OrtUtil.reshape(input, output, 0);
        return output;
    }

    public static Object reshape(byte[] input, long[] shape) {
        Object output = OrtUtil.newByteArray(shape);
        OrtUtil.reshape(input, output, 0);
        return output;
    }

    public static Object reshape(short[] input, long[] shape) {
        Object output = OrtUtil.newShortArray(shape);
        OrtUtil.reshape(input, output, 0);
        return output;
    }

    public static Object reshape(int[] input, long[] shape) {
        Object output = OrtUtil.newIntArray(shape);
        OrtUtil.reshape(input, output, 0);
        return output;
    }

    public static Object reshape(long[] input, long[] shape) {
        Object output = OrtUtil.newLongArray(shape);
        OrtUtil.reshape(input, output, 0);
        return output;
    }

    public static Object reshape(float[] input, long[] shape) {
        Object output = OrtUtil.newFloatArray(shape);
        OrtUtil.reshape(input, output, 0);
        return output;
    }

    public static Object reshape(double[] input, long[] shape) {
        Object output = OrtUtil.newDoubleArray(shape);
        OrtUtil.reshape(input, output, 0);
        return output;
    }

    public static Object reshape(String[] input, long[] shape) {
        Object output = OrtUtil.newStringArray(shape);
        OrtUtil.reshape(input, output, 0);
        return output;
    }

    private static int reshape(Object input, Object output, int position) {
        if (output.getClass().isArray()) {
            Object[] outputArray;
            for (Object outputElement : outputArray = (Object[])output) {
                Class<?> outputElementClass = outputElement.getClass();
                if (outputElementClass.isArray()) {
                    Class<?> componentType = outputElementClass.getComponentType();
                    if (componentType.isPrimitive() || componentType == String.class) {
                        int length = Array.getLength(outputElement);
                        System.arraycopy(input, position, outputElement, 0, length);
                        position += length;
                        continue;
                    }
                    position = OrtUtil.reshape(input, outputElement, position);
                    continue;
                }
                throw new IllegalStateException("Found element type when expecting an array. Class " + outputElementClass);
            }
        } else {
            throw new IllegalStateException("Found element type when expecting an array. Class " + output.getClass());
        }
        return position;
    }

    public static long elementCount(long[] shape) {
        long count = 1L;
        for (int i = 0; i < shape.length; ++i) {
            if (shape[i] >= 0L) {
                count *= shape[i];
                continue;
            }
            throw new IllegalArgumentException("Received negative value in shape " + Arrays.toString(shape) + " .");
        }
        return count;
    }

    public static boolean validateShape(long[] shape) {
        boolean valid = true;
        for (int i = 0; i < shape.length; ++i) {
            valid &= shape[i] > 0L;
            valid &= (long)((int)shape[i]) == shape[i];
        }
        return valid && shape.length <= 8;
    }

    public static String[] flattenString(Object o) {
        if (o instanceof String[]) {
            return (String[])o;
        }
        ArrayList<String> output = new ArrayList<String>();
        OrtUtil.flattenString((Object[])o, output);
        return output.toArray(new String[0]);
    }

    private static void flattenString(Object[] input, ArrayList<String> output) {
        for (Object i : input) {
            Class<?> iClazz = i.getClass();
            if (iClazz.isArray()) {
                if (iClazz.getComponentType().isArray()) {
                    OrtUtil.flattenString((Object[])i, output);
                    continue;
                }
                if (iClazz.getComponentType().equals(String.class)) {
                    output.addAll(Arrays.asList((String[])i));
                    continue;
                }
                throw new IllegalStateException("Found a non-String, non-array element type, " + iClazz);
            }
            throw new IllegalStateException("Found an element type where there should have been an array. Class = " + iClazz);
        }
    }

    static Object convertBoxedPrimitiveToArray(OnnxJavaType javaType, Object data) {
        switch (javaType) {
            case FLOAT: {
                float[] floatArr = new float[]{((Float)data).floatValue()};
                return floatArr;
            }
            case DOUBLE: {
                double[] doubleArr = new double[]{(Double)data};
                return doubleArr;
            }
            case UINT8: 
            case INT8: {
                byte[] byteArr = new byte[]{(Byte)data};
                return byteArr;
            }
            case INT16: {
                short[] shortArr = new short[]{(Short)data};
                return shortArr;
            }
            case INT32: {
                int[] intArr = new int[]{(Integer)data};
                return intArr;
            }
            case INT64: {
                long[] longArr = new long[]{(Long)data};
                return longArr;
            }
            case BOOL: {
                boolean[] booleanArr = new boolean[]{(Boolean)data};
                return booleanArr;
            }
        }
        return null;
    }

    static Buffer convertBoxedPrimitiveToBuffer(OnnxJavaType javaType, Object data) {
        switch (javaType) {
            case FLOAT: {
                FloatBuffer buf = ByteBuffer.allocateDirect(javaType.size).order(ByteOrder.nativeOrder()).asFloatBuffer();
                buf.put(0, ((Float)data).floatValue());
                return buf;
            }
            case DOUBLE: {
                DoubleBuffer buf = ByteBuffer.allocateDirect(javaType.size).order(ByteOrder.nativeOrder()).asDoubleBuffer();
                buf.put(0, (Double)data);
                return buf;
            }
            case BOOL: {
                ByteBuffer buf = ByteBuffer.allocateDirect(javaType.size).order(ByteOrder.nativeOrder());
                buf.put(0, (Boolean)data != false ? (byte)1 : 0);
                return buf;
            }
            case UINT8: 
            case INT8: {
                ByteBuffer buf = ByteBuffer.allocateDirect(javaType.size).order(ByteOrder.nativeOrder());
                buf.put(0, (Byte)data);
                return buf;
            }
            case INT16: 
            case FLOAT16: 
            case BFLOAT16: {
                ShortBuffer buf = ByteBuffer.allocateDirect(javaType.size).order(ByteOrder.nativeOrder()).asShortBuffer();
                buf.put(0, (Short)data);
                return buf;
            }
            case INT32: {
                IntBuffer buf = ByteBuffer.allocateDirect(javaType.size).order(ByteOrder.nativeOrder()).asIntBuffer();
                buf.put(0, (Integer)data);
                return buf;
            }
            case INT64: {
                LongBuffer buf = ByteBuffer.allocateDirect(javaType.size).order(ByteOrder.nativeOrder()).asLongBuffer();
                buf.put(0, (Long)data);
                return buf;
            }
        }
        return null;
    }

    static Buffer convertArrayToBuffer(TensorInfo info, Object array) {
        ByteBuffer byteBuffer = ByteBuffer.allocateDirect((int)info.numElements * info.type.size).order(ByteOrder.nativeOrder());
        Buffer buffer = switch (info.type) {
            case OnnxJavaType.FLOAT -> byteBuffer.asFloatBuffer();
            case OnnxJavaType.DOUBLE -> byteBuffer.asDoubleBuffer();
            case OnnxJavaType.UINT8, OnnxJavaType.INT8, OnnxJavaType.BOOL -> byteBuffer;
            case OnnxJavaType.INT16, OnnxJavaType.FLOAT16, OnnxJavaType.BFLOAT16 -> byteBuffer.asShortBuffer();
            case OnnxJavaType.INT32 -> byteBuffer.asIntBuffer();
            case OnnxJavaType.INT64 -> byteBuffer.asLongBuffer();
            default -> throw new IllegalArgumentException("Unexpected type, expected Java primitive found " + (Object)((Object)info.type));
        };
        OrtUtil.fillBufferFromArray(info, array, 0, buffer);
        if (buffer.remaining() != 0) {
            throw new IllegalArgumentException("Failed to copy all elements into the buffer, expected to copy " + info.numElements + " into a buffer of capacity " + buffer.capacity() + " but had " + buffer.remaining() + " values left over.");
        }
        buffer.rewind();
        return buffer;
    }

    private static void fillBufferFromArray(TensorInfo info, Object array, int curDim, Buffer buffer) {
        block14: {
            block13: {
                if (curDim != info.shape.length - 1) break block13;
                switch (info.type) {
                    case FLOAT: {
                        float[] fArr = (float[])array;
                        FloatBuffer fBuf = (FloatBuffer)buffer;
                        fBuf.put(fArr);
                        break;
                    }
                    case DOUBLE: {
                        double[] dArr = (double[])array;
                        DoubleBuffer dBuf = (DoubleBuffer)buffer;
                        dBuf.put(dArr);
                        break;
                    }
                    case UINT8: 
                    case INT8: {
                        byte[] bArr = (byte[])array;
                        ByteBuffer bBuf = (ByteBuffer)buffer;
                        bBuf.put(bArr);
                        break;
                    }
                    case INT16: 
                    case FLOAT16: 
                    case BFLOAT16: {
                        short[] sArr = (short[])array;
                        ShortBuffer sBuf = (ShortBuffer)buffer;
                        sBuf.put(sArr);
                        break;
                    }
                    case INT32: {
                        int[] iArr = (int[])array;
                        IntBuffer iBuf = (IntBuffer)buffer;
                        iBuf.put(iArr);
                        break;
                    }
                    case INT64: {
                        long[] lArr = (long[])array;
                        LongBuffer lBuf = (LongBuffer)buffer;
                        lBuf.put(lArr);
                        break;
                    }
                    case BOOL: {
                        boolean[] boolArr = (boolean[])array;
                        ByteBuffer boolBuf = (ByteBuffer)buffer;
                        for (int i = 0; i < boolArr.length; ++i) {
                            boolBuf.put(boolArr[i] ? (byte)1 : 0);
                        }
                        break block14;
                    }
                    case STRING: 
                    case UNKNOWN: {
                        throw new IllegalArgumentException("Unexpected type, expected Java primitive found " + (Object)((Object)info.type));
                    }
                }
                break block14;
            }
            long expectedSize = info.shape[curDim];
            long actualSize = Array.getLength(array);
            if (expectedSize != actualSize) {
                throw new IllegalArgumentException("Mismatch in array sizes, expected " + expectedSize + " at dim " + curDim + " from shape " + Arrays.toString(info.shape) + ", found " + actualSize);
            }
            int i = 0;
            while ((long)i < actualSize) {
                OrtUtil.fillBufferFromArray(info, Array.get(array, i), curDim + 1, buffer);
                ++i;
            }
        }
    }

    static void fillArrayFromBuffer(TensorInfo info, Buffer buffer, int curDim, Object array) {
        block14: {
            block13: {
                if (curDim != info.shape.length - 1) break block13;
                switch (info.type) {
                    case FLOAT: 
                    case FLOAT16: 
                    case BFLOAT16: {
                        float[] fArr = (float[])array;
                        FloatBuffer fBuf = (FloatBuffer)buffer;
                        fBuf.get(fArr);
                        break;
                    }
                    case DOUBLE: {
                        double[] dArr = (double[])array;
                        DoubleBuffer dBuf = (DoubleBuffer)buffer;
                        dBuf.get(dArr);
                        break;
                    }
                    case UINT8: 
                    case INT8: {
                        byte[] bArr = (byte[])array;
                        ByteBuffer bBuf = (ByteBuffer)buffer;
                        bBuf.get(bArr);
                        break;
                    }
                    case INT16: {
                        short[] sArr = (short[])array;
                        ShortBuffer sBuf = (ShortBuffer)buffer;
                        sBuf.get(sArr);
                        break;
                    }
                    case INT32: {
                        int[] iArr = (int[])array;
                        IntBuffer iBuf = (IntBuffer)buffer;
                        iBuf.get(iArr);
                        break;
                    }
                    case INT64: {
                        long[] lArr = (long[])array;
                        LongBuffer lBuf = (LongBuffer)buffer;
                        lBuf.get(lArr);
                        break;
                    }
                    case BOOL: {
                        boolean[] boolArr = (boolean[])array;
                        ByteBuffer boolBuf = (ByteBuffer)buffer;
                        for (int i = 0; i < boolArr.length; ++i) {
                            boolArr[i] = boolBuf.get() != 0;
                        }
                        break block14;
                    }
                    case STRING: 
                    case UNKNOWN: {
                        throw new IllegalArgumentException("Unexpected type, expected Java primitive found " + (Object)((Object)info.type));
                    }
                }
                break block14;
            }
            long expectedSize = info.shape[curDim];
            long actualSize = Array.getLength(array);
            if (expectedSize != actualSize) {
                throw new IllegalArgumentException("Mismatch in array sizes, expected " + expectedSize + " at dim " + curDim + " from shape " + Arrays.toString(info.shape) + ", found " + actualSize);
            }
            int i = 0;
            while ((long)i < actualSize) {
                OrtUtil.fillArrayFromBuffer(info, buffer, curDim + 1, Array.get(array, i));
                ++i;
            }
        }
    }

    static int capacityFromSize(int size) {
        return (int)((double)size / 0.75 + 1.0);
    }

    static BufferTuple prepareBuffer(Buffer data, OnnxJavaType type) {
        int bufferPos;
        Buffer tmp;
        if (type == OnnxJavaType.STRING || type == OnnxJavaType.UNKNOWN) {
            throw new IllegalStateException("Cannot create a " + (Object)((Object)type) + " tensor from a buffer");
        }
        int elementSize = data instanceof ByteBuffer ? 1 : type.size;
        long bufferSizeLong = (long)data.remaining() * (long)elementSize;
        if (bufferSizeLong > Integer.MAX_VALUE - 8L * (long)elementSize) {
            throw new IllegalStateException("Cannot allocate a direct buffer of the requested size and type, size " + data.remaining() + ", type = " + (Object)((Object)type));
        }
        int bufferSize = data.remaining() * elementSize;
        if (data.isDirect()) {
            tmp = data;
            bufferPos = data.position() * elementSize;
        } else {
            int origPosition = data.position();
            ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
            switch (type) {
                case FLOAT: {
                    tmp = buffer.asFloatBuffer().put((FloatBuffer)data);
                    break;
                }
                case DOUBLE: {
                    tmp = buffer.asDoubleBuffer().put((DoubleBuffer)data);
                    break;
                }
                case UINT8: 
                case INT8: 
                case BOOL: {
                    tmp = buffer.put((ByteBuffer)data);
                    break;
                }
                case INT16: 
                case FLOAT16: 
                case BFLOAT16: {
                    tmp = buffer.asShortBuffer().put((ShortBuffer)data);
                    break;
                }
                case INT32: {
                    tmp = buffer.asIntBuffer().put((IntBuffer)data);
                    break;
                }
                case INT64: {
                    tmp = buffer.asLongBuffer().put((LongBuffer)data);
                    break;
                }
                default: {
                    throw new IllegalStateException("Impossible to reach here, managed to cast a buffer as an incorrect type, found " + (Object)((Object)type));
                }
            }
            data.position(origPosition);
            tmp.rewind();
            bufferPos = 0;
        }
        return new BufferTuple(tmp, bufferPos, bufferSize, data.remaining(), tmp != data);
    }

    static final class BufferTuple {
        final Buffer data;
        final int pos;
        final long byteSize;
        final long size;
        final boolean isCopy;

        BufferTuple(Buffer data, int pos, long byteSize, long size, boolean isCopy) {
            this.data = data;
            this.pos = pos;
            this.byteSize = byteSize;
            this.size = size;
            this.isCopy = isCopy;
        }
    }
}

