/*
 * Decompiled with CFR 0.152.
 */
package net.snowflake.client.jdbc.internal.apache.arrow.vector.ipc.message;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import net.snowflake.client.jdbc.internal.apache.arrow.flatbuf.Buffer;
import net.snowflake.client.jdbc.internal.apache.arrow.flatbuf.DictionaryBatch;
import net.snowflake.client.jdbc.internal.apache.arrow.flatbuf.FieldNode;
import net.snowflake.client.jdbc.internal.apache.arrow.flatbuf.Message;
import net.snowflake.client.jdbc.internal.apache.arrow.flatbuf.RecordBatch;
import net.snowflake.client.jdbc.internal.apache.arrow.memory.BufferAllocator;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.ipc.ReadChannel;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.ipc.WriteChannel;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.ipc.message.ArrowBlock;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.ipc.message.ArrowBuffer;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.ipc.message.ArrowDictionaryBatch;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.ipc.message.ArrowFieldNode;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.ipc.message.ArrowMessage;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.ipc.message.MessageChannelReader;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.ipc.message.MessageMetadataResult;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.ipc.message.MessageResult;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.types.pojo.Schema;
import net.snowflake.client.jdbc.internal.google.flatbuffers.FlatBufferBuilder;
import net.snowflake.client.jdbc.internal.io.netty.buffer.ArrowBuf;

public class MessageSerializer {
    public static int bytesToInt(byte[] bytes) {
        return ((bytes[3] & 0xFF) << 24) + ((bytes[2] & 0xFF) << 16) + ((bytes[1] & 0xFF) << 8) + ((bytes[0] & 0xFF) << 0);
    }

    public static void intToBytes(int value, byte[] bytes) {
        bytes[3] = (byte)(value >>> 24);
        bytes[2] = (byte)(value >>> 16);
        bytes[1] = (byte)(value >>> 8);
        bytes[0] = (byte)(value >>> 0);
    }

    public static int writeMessageBuffer(WriteChannel out, int messageLength, ByteBuffer messageBuffer) throws IOException {
        if ((messageLength + 4) % 8 != 0) {
            messageLength += 8 - (messageLength + 4) % 8;
        }
        out.writeIntLittleEndian(messageLength);
        out.write(messageBuffer);
        out.align();
        return messageLength + 4;
    }

    public static long serialize(WriteChannel out, Schema schema) throws IOException {
        long start = out.getCurrentPosition();
        assert (start % 8L == 0L);
        FlatBufferBuilder builder = new FlatBufferBuilder();
        int schemaOffset = schema.getSchema(builder);
        ByteBuffer serializedMessage = MessageSerializer.serializeMessage(builder, (byte)1, schemaOffset, 0);
        int messageLength = serializedMessage.remaining();
        int bytesWritten = MessageSerializer.writeMessageBuffer(out, messageLength, serializedMessage);
        assert (bytesWritten % 8 == 0);
        return bytesWritten;
    }

    public static Schema deserializeSchema(Message schemaMessage) {
        return Schema.convertSchema((net.snowflake.client.jdbc.internal.apache.arrow.flatbuf.Schema)schemaMessage.header(new net.snowflake.client.jdbc.internal.apache.arrow.flatbuf.Schema()));
    }

    public static Schema deserializeSchema(ReadChannel in) throws IOException {
        MessageMetadataResult result = MessageSerializer.readMessage(in);
        if (result == null) {
            throw new IOException("Unexpected end of input when reading Schema");
        }
        if (result.getMessage().headerType() != 1) {
            throw new IOException("Expected schema but header was " + result.getMessage().headerType());
        }
        return MessageSerializer.deserializeSchema(result.getMessage());
    }

    public static ArrowBlock serialize(WriteChannel out, ArrowRecordBatch batch) throws IOException {
        long start = out.getCurrentPosition();
        int bodyLength = batch.computeBodyLength();
        assert (bodyLength % 8 == 0);
        FlatBufferBuilder builder = new FlatBufferBuilder();
        int batchOffset = batch.writeTo(builder);
        ByteBuffer serializedMessage = MessageSerializer.serializeMessage(builder, (byte)3, batchOffset, bodyLength);
        int metadataLength = serializedMessage.remaining();
        int padding = (int)((start + (long)metadataLength + 4L) % 8L);
        if (padding != 0) {
            metadataLength += 8 - padding;
        }
        out.writeIntLittleEndian(metadataLength);
        out.write(serializedMessage);
        out.align();
        long bufferLength = MessageSerializer.writeBatchBuffers(out, batch);
        assert (bufferLength % 8L == 0L);
        return new ArrowBlock(start, metadataLength + 4, bufferLength);
    }

    public static long writeBatchBuffers(WriteChannel out, ArrowRecordBatch batch) throws IOException {
        long bufferStart = out.getCurrentPosition();
        List<ArrowBuf> buffers = batch.getBuffers();
        List<ArrowBuffer> buffersLayout = batch.getBuffersLayout();
        for (int i = 0; i < buffers.size(); ++i) {
            ArrowBuf buffer = buffers.get(i);
            ArrowBuffer layout = buffersLayout.get(i);
            long startPosition = bufferStart + layout.getOffset();
            if (startPosition != out.getCurrentPosition()) {
                out.writeZeros((int)(startPosition - out.getCurrentPosition()));
            }
            out.write(buffer);
            if (out.getCurrentPosition() == startPosition + layout.getSize()) continue;
            throw new IllegalStateException("wrong buffer size: " + out.getCurrentPosition() + " != " + startPosition + layout.getSize());
        }
        out.align();
        return out.getCurrentPosition() - bufferStart;
    }

    public static ArrowRecordBatch deserializeRecordBatch(Message recordBatchMessage, ArrowBuf bodyBuffer) throws IOException {
        RecordBatch recordBatchFB = (RecordBatch)recordBatchMessage.header(new RecordBatch());
        return MessageSerializer.deserializeRecordBatch(recordBatchFB, bodyBuffer);
    }

    public static ArrowRecordBatch deserializeRecordBatch(ReadChannel in, BufferAllocator allocator) throws IOException {
        MessageMetadataResult result = MessageSerializer.readMessage(in);
        if (result == null) {
            throw new IOException("Unexpected end of input when reading a RecordBatch");
        }
        if (result.getMessage().headerType() != 3) {
            throw new IOException("Expected RecordBatch but header was " + result.getMessage().headerType());
        }
        int bodyLength = (int)result.getMessageBodyLength();
        ArrowBuf bodyBuffer = MessageSerializer.readMessageBody(in, bodyLength, allocator);
        return MessageSerializer.deserializeRecordBatch(result.getMessage(), bodyBuffer);
    }

    public static ArrowRecordBatch deserializeRecordBatch(ReadChannel in, ArrowBlock block, BufferAllocator alloc) throws IOException {
        long totalLen = (long)block.getMetadataLength() + block.getBodyLength();
        if (totalLen > Integer.MAX_VALUE) {
            throw new IOException("Cannot currently deserialize record batches over 2GB");
        }
        ArrowBuf buffer = alloc.buffer((int)totalLen);
        if ((long)in.readFully(buffer, (int)totalLen) != totalLen) {
            throw new IOException("Unexpected end of input trying to read batch.");
        }
        ArrowBuf metadataBuffer = buffer.slice(4, block.getMetadataLength() - 4);
        Message messageFB = Message.getRootAsMessage(metadataBuffer.nioBuffer().asReadOnlyBuffer());
        RecordBatch recordBatchFB = (RecordBatch)messageFB.header(new RecordBatch());
        ArrowBuf body = buffer.slice(block.getMetadataLength(), (int)totalLen - block.getMetadataLength());
        return MessageSerializer.deserializeRecordBatch(recordBatchFB, body);
    }

    public static ArrowRecordBatch deserializeRecordBatch(RecordBatch recordBatchFB, ArrowBuf body) throws IOException {
        int nodesLength = recordBatchFB.nodesLength();
        ArrayList<ArrowFieldNode> nodes = new ArrayList<ArrowFieldNode>();
        for (int i = 0; i < nodesLength; ++i) {
            FieldNode node = recordBatchFB.nodes(i);
            if ((long)((int)node.length()) != node.length() || (long)((int)node.nullCount()) != node.nullCount()) {
                throw new IOException("Cannot currently deserialize record batches with node length larger than Int.MAX_VALUE");
            }
            nodes.add(new ArrowFieldNode((int)node.length(), (int)node.nullCount()));
        }
        ArrayList<ArrowBuf> buffers = new ArrayList<ArrowBuf>();
        for (int i = 0; i < recordBatchFB.buffersLength(); ++i) {
            Buffer bufferFB = recordBatchFB.buffers(i);
            ArrowBuf vectorBuffer = body.slice((int)bufferFB.offset(), (int)bufferFB.length());
            buffers.add(vectorBuffer);
        }
        if ((long)((int)recordBatchFB.length()) != recordBatchFB.length()) {
            throw new IOException("Cannot currently deserialize record batches over 2GB");
        }
        ArrowRecordBatch arrowRecordBatch = new ArrowRecordBatch((int)recordBatchFB.length(), nodes, buffers);
        body.release();
        return arrowRecordBatch;
    }

    public static ArrowBlock serialize(WriteChannel out, ArrowDictionaryBatch batch) throws IOException {
        long start = out.getCurrentPosition();
        int bodyLength = batch.computeBodyLength();
        assert (bodyLength % 8 == 0);
        FlatBufferBuilder builder = new FlatBufferBuilder();
        int batchOffset = batch.writeTo(builder);
        ByteBuffer serializedMessage = MessageSerializer.serializeMessage(builder, (byte)2, batchOffset, bodyLength);
        int metadataLength = serializedMessage.remaining();
        int padding = (int)((start + (long)metadataLength + 4L) % 8L);
        if (padding != 0) {
            metadataLength += 8 - padding;
        }
        out.writeIntLittleEndian(metadataLength);
        out.write(serializedMessage);
        out.align();
        long bufferLength = MessageSerializer.writeBatchBuffers(out, batch.getDictionary());
        assert (bufferLength % 8L == 0L);
        return new ArrowBlock(start, metadataLength + 4, bufferLength);
    }

    public static ArrowDictionaryBatch deserializeDictionaryBatch(Message message, ArrowBuf bodyBuffer) throws IOException {
        DictionaryBatch dictionaryBatchFB = (DictionaryBatch)message.header(new DictionaryBatch());
        ArrowRecordBatch recordBatch = MessageSerializer.deserializeRecordBatch(dictionaryBatchFB.data(), bodyBuffer);
        return new ArrowDictionaryBatch(dictionaryBatchFB.id(), recordBatch);
    }

    public static ArrowDictionaryBatch deserializeDictionaryBatch(ReadChannel in, BufferAllocator allocator) throws IOException {
        MessageMetadataResult result = MessageSerializer.readMessage(in);
        if (result == null) {
            throw new IOException("Unexpected end of input when reading a DictionaryBatch");
        }
        if (result.getMessage().headerType() != 2) {
            throw new IOException("Expected DictionaryBatch but header was " + result.getMessage().headerType());
        }
        int bodyLength = (int)result.getMessageBodyLength();
        ArrowBuf bodyBuffer = MessageSerializer.readMessageBody(in, bodyLength, allocator);
        return MessageSerializer.deserializeDictionaryBatch(result.getMessage(), bodyBuffer);
    }

    public static ArrowDictionaryBatch deserializeDictionaryBatch(ReadChannel in, ArrowBlock block, BufferAllocator alloc) throws IOException {
        long totalLen = (long)block.getMetadataLength() + block.getBodyLength();
        if (totalLen > Integer.MAX_VALUE) {
            throw new IOException("Cannot currently deserialize record batches over 2GB");
        }
        ArrowBuf buffer = alloc.buffer((int)totalLen);
        if ((long)in.readFully(buffer, (int)totalLen) != totalLen) {
            throw new IOException("Unexpected end of input trying to read batch.");
        }
        ArrowBuf metadataBuffer = buffer.slice(4, block.getMetadataLength() - 4);
        Message messageFB = Message.getRootAsMessage(metadataBuffer.nioBuffer().asReadOnlyBuffer());
        DictionaryBatch dictionaryBatchFB = (DictionaryBatch)messageFB.header(new DictionaryBatch());
        ArrowBuf body = buffer.slice(block.getMetadataLength(), (int)totalLen - block.getMetadataLength());
        ArrowRecordBatch recordBatch = MessageSerializer.deserializeRecordBatch(dictionaryBatchFB.data(), body);
        return new ArrowDictionaryBatch(dictionaryBatchFB.id(), recordBatch);
    }

    public static ArrowMessage deserializeMessageBatch(MessageChannelReader reader) throws IOException {
        MessageResult result = reader.readNext();
        if (result == null) {
            return null;
        }
        if (result.getMessage().bodyLength() > Integer.MAX_VALUE) {
            throw new IOException("Cannot currently deserialize record batches over 2GB");
        }
        if (result.getMessage().version() != 3) {
            throw new IOException("Received metadata with an incompatible version number");
        }
        switch (result.getMessage().headerType()) {
            case 3: {
                return MessageSerializer.deserializeRecordBatch(result.getMessage(), result.getBodyBuffer());
            }
            case 2: {
                return MessageSerializer.deserializeDictionaryBatch(result.getMessage(), result.getBodyBuffer());
            }
        }
        throw new IOException("Unexpected message header type " + result.getMessage().headerType());
    }

    public static ArrowMessage deserializeMessageBatch(ReadChannel in, BufferAllocator alloc) throws IOException {
        return MessageSerializer.deserializeMessageBatch(new MessageChannelReader(in, alloc));
    }

    public static ByteBuffer serializeMessage(FlatBufferBuilder builder, byte headerType, int headerOffset, int bodyLength) {
        Message.startMessage(builder);
        Message.addHeaderType(builder, headerType);
        Message.addHeader(builder, headerOffset);
        Message.addVersion(builder, (short)3);
        Message.addBodyLength(builder, bodyLength);
        builder.finish(Message.endMessage(builder));
        return builder.dataBuffer();
    }

    public static MessageMetadataResult readMessage(ReadChannel in) throws IOException {
        int messageLength;
        ByteBuffer buffer = ByteBuffer.allocate(4);
        if (in.readFully(buffer) == 4 && (messageLength = MessageSerializer.bytesToInt(buffer.array())) != 0) {
            ByteBuffer messageBuffer = ByteBuffer.allocate(messageLength);
            if (in.readFully(messageBuffer) != messageLength) {
                throw new IOException("Unexpected end of stream trying to read message.");
            }
            messageBuffer.rewind();
            Message message = Message.getRootAsMessage(messageBuffer);
            return new MessageMetadataResult(messageLength, messageBuffer, message);
        }
        return null;
    }

    public static ArrowBuf readMessageBody(ReadChannel in, int bodyLength, BufferAllocator allocator) throws IOException {
        ArrowBuf bodyBuffer = allocator.buffer(bodyLength);
        if (in.readFully(bodyBuffer, bodyLength) != bodyLength) {
            throw new IOException("Unexpected end of input trying to read batch.");
        }
        return bodyBuffer;
    }
}

