/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.ndarray;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

final class NDSerializer {
    private static final int VERSION = 3;
    private static final int BUFFER_SIZE = 0x100000;
    private static final String MAGIC_NUMBER = "NDAR";
    private static final byte[] NUMPY_MAGIC = new byte[]{-109, 78, 85, 77, 80, 89};
    private static final int ARRAY_ALIGN = 64;
    private static final Pattern PATTERN = Pattern.compile("\\{'descr': '(.+)', 'fortran_order': False, 'shape': \\((.*)\\),");

    private NDSerializer() {
    }

    static byte[] encode(NDArray array) {
        byte[] byArray;
        ByteArrayOutputStream baos = new ByteArrayOutputStream(0x100000);
        try {
            NDSerializer.encode(array, baos);
            byArray = baos.toByteArray();
        }
        catch (Throwable throwable) {
            try {
                try {
                    baos.close();
                }
                catch (Throwable throwable2) {
                    throwable.addSuppressed(throwable2);
                }
                throw throwable;
            }
            catch (IOException e) {
                throw new AssertionError("This should never happen", e);
            }
        }
        baos.close();
        return byArray;
    }

    static void encode(NDArray array, OutputStream os) throws IOException {
        DataOutputStream dos = new DataOutputStream(os);
        dos.writeUTF(MAGIC_NUMBER);
        dos.writeInt(3);
        String name = array.getName();
        if (name == null) {
            dos.write(0);
        } else {
            dos.write(1);
            dos.writeUTF(name);
        }
        dos.writeUTF(array.getSparseFormat().name());
        dos.writeUTF(array.getDataType().name());
        Shape shape = array.getShape();
        dos.write(shape.getEncoded());
        ByteBuffer bb = array.toByteBuffer();
        dos.write(bb.order() == ByteOrder.BIG_ENDIAN ? 62 : 60);
        int length = bb.remaining();
        dos.writeInt(length);
        if (length > 0) {
            byte[] buf;
            if (length > 0x100000) {
                buf = new byte[0x100000];
                while (length > 0x100000) {
                    bb.get(buf);
                    dos.write(buf);
                    length = bb.remaining();
                }
            }
            buf = new byte[length];
            bb.get(buf);
            dos.write(buf);
        }
        dos.flush();
    }

    static void encodeAsNumpy(NDArray array, OutputStream os) throws IOException {
        StringBuilder sb = new StringBuilder(80);
        sb.append("{'descr': '").append(array.getDataType().asNumpy()).append("', 'fortran_order': False, 'shape': ");
        long[] shape = array.getShape().getShape();
        if (shape.length == 1) {
            sb.append('(').append(shape[0]).append(",)");
        } else {
            sb.append(array.getShape());
        }
        sb.append(", }");
        int len = sb.length() + 1;
        int padding = 64 - (NUMPY_MAGIC.length + len + 4) % 64;
        ByteBuffer bb = ByteBuffer.allocate(2);
        bb.order(ByteOrder.LITTLE_ENDIAN);
        bb.putShort((short)(padding + len));
        os.write(NUMPY_MAGIC);
        os.write(1);
        os.write(0);
        os.write(bb.array());
        os.write(sb.toString().getBytes(StandardCharsets.US_ASCII));
        for (int i = 0; i < padding; ++i) {
            os.write(32);
        }
        os.write(10);
        os.write(array.toByteArray());
    }

    static NDArray decode(NDManager manager, InputStream is) throws IOException {
        byte flag;
        DataInputStream dis = is instanceof DataInputStream ? (DataInputStream)is : new DataInputStream(is);
        if (!MAGIC_NUMBER.equals(dis.readUTF())) {
            throw new IllegalArgumentException("Malformed NDArray data");
        }
        int version = dis.readInt();
        if (version < 1 || version > 3) {
            throw new IllegalArgumentException("Unexpected NDArray encode version " + version);
        }
        String name = null;
        if (version > 1 && (flag = dis.readByte()) == 1) {
            name = dis.readUTF();
        }
        dis.readUTF();
        DataType dataType = DataType.valueOf(dis.readUTF());
        Shape shape = Shape.decode(dis);
        ByteOrder order = version > 2 ? (dis.readByte() == 62 ? ByteOrder.BIG_ENDIAN : ByteOrder.LITTLE_ENDIAN) : ByteOrder.nativeOrder();
        int length = dis.readInt();
        ByteBuffer data = manager.allocateDirect(length);
        data.order(order);
        NDSerializer.readData(dis, data, length);
        NDArray array = manager.create(dataType.asDataType(data), shape, dataType);
        array.setName(name);
        return array;
    }

    static NDArray decodeNumpy(NDManager manager, InputStream is) throws IOException {
        long[] longs;
        DataInputStream dis = is instanceof DataInputStream ? (DataInputStream)is : new DataInputStream(is);
        byte[] buf = new byte[NUMPY_MAGIC.length];
        dis.readFully(buf);
        if (!Arrays.equals(buf, NUMPY_MAGIC)) {
            throw new IllegalArgumentException("Malformed numpy data");
        }
        byte major = dis.readByte();
        byte minor = dis.readByte();
        if (major < 1 || major > 3 || minor != 0) {
            throw new IllegalArgumentException("Unknown numpy version: " + major + '.' + minor);
        }
        int len = major == 1 ? 2 : 4;
        dis.readFully(buf, 0, len);
        ByteBuffer bb = ByteBuffer.wrap(buf, 0, len);
        bb.order(ByteOrder.LITTLE_ENDIAN);
        len = major == 1 ? (int)bb.getShort() : bb.getInt();
        buf = new byte[len];
        dis.readFully(buf);
        String header = new String(buf, StandardCharsets.UTF_8).trim();
        Matcher m = PATTERN.matcher(header);
        if (!m.find()) {
            throw new IllegalArgumentException("Invalid numpy header: " + header);
        }
        String typeStr = m.group(1);
        DataType dataType = DataType.fromNumpy(typeStr);
        String shapeStr = m.group(2);
        if (shapeStr.isEmpty()) {
            longs = new long[]{};
        } else {
            String[] tokens = shapeStr.split(", ?");
            longs = Arrays.stream(tokens).mapToLong(Long::parseLong).toArray();
        }
        Shape shape = new Shape(longs);
        len = Math.toIntExact(shape.size() * (long)dataType.getNumOfBytes());
        ByteBuffer data = manager.allocateDirect(len);
        char order = typeStr.charAt(0);
        if (order == '>') {
            data.order(ByteOrder.BIG_ENDIAN);
        } else if (order == '<') {
            data.order(ByteOrder.LITTLE_ENDIAN);
        }
        NDSerializer.readData(dis, data, len);
        return manager.create(dataType.asDataType(data), shape, dataType);
    }

    private static void readData(DataInputStream dis, ByteBuffer data, int len) throws IOException {
        if (len > 0) {
            byte[] buf = new byte[0x100000];
            while (len > 0x100000) {
                dis.readFully(buf);
                data.put(buf);
                len -= 0x100000;
            }
            dis.readFully(buf, 0, len);
            data.put(buf, 0, len);
            data.rewind();
        }
    }
}

