/*
 * Decompiled with CFR 0.152.
 */
package net.snowflake.client.jdbc;

import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import net.snowflake.client.core.DataConversionContext;
import net.snowflake.client.core.arrow.ArrowVectorConverter;
import net.snowflake.client.core.arrow.BigIntToFixedConverter;
import net.snowflake.client.core.arrow.BigIntToScaledFixedConverter;
import net.snowflake.client.core.arrow.BigIntToTimeConverter;
import net.snowflake.client.core.arrow.BigIntToTimestampLTZConverter;
import net.snowflake.client.core.arrow.BigIntToTimestampNTZConverter;
import net.snowflake.client.core.arrow.DecimalToScaledFixedConverter;
import net.snowflake.client.core.arrow.DoubleToRealConverter;
import net.snowflake.client.core.arrow.IntToDateConverter;
import net.snowflake.client.core.arrow.IntToFixedConverter;
import net.snowflake.client.core.arrow.IntToScaledFixedConverter;
import net.snowflake.client.core.arrow.SmallIntToFixedConverter;
import net.snowflake.client.core.arrow.SmallIntToScaledFixedConverter;
import net.snowflake.client.core.arrow.ThreeFieldStructToTimestampTZConverter;
import net.snowflake.client.core.arrow.TinyIntToBooleanConverter;
import net.snowflake.client.core.arrow.TinyIntToFixedConverter;
import net.snowflake.client.core.arrow.TinyIntToScaledFixedConverter;
import net.snowflake.client.core.arrow.TwoFieldStructToTimestampLTZConverter;
import net.snowflake.client.core.arrow.TwoFieldStructToTimestampNTZConverter;
import net.snowflake.client.core.arrow.TwoFieldStructToTimestampTZConverter;
import net.snowflake.client.core.arrow.VarBinaryToBinaryConverter;
import net.snowflake.client.core.arrow.VarCharToTextConverter;
import net.snowflake.client.jdbc.ErrorCode;
import net.snowflake.client.jdbc.SnowflakeResultChunk;
import net.snowflake.client.jdbc.SnowflakeSQLException;
import net.snowflake.client.jdbc.SnowflakeType;
import net.snowflake.client.jdbc.internal.apache.arrow.memory.BufferAllocator;
import net.snowflake.client.jdbc.internal.apache.arrow.memory.RootAllocator;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.FieldVector;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.ValueVector;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.VectorSchemaRoot;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.ipc.ArrowStreamReader;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.types.Types;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.util.TransferPair;

public class ArrowResultChunk
extends SnowflakeResultChunk {
    private List<List<ValueVector>> batchOfVectors = new ArrayList<List<ValueVector>>();
    private static RootAllocator rootAllocator = new RootAllocator(Integer.MAX_VALUE);

    public ArrowResultChunk(String url, int rowCount, int colCount, int uncompressedSize) {
        super(url, rowCount, colCount, uncompressedSize);
    }

    private void addBatchData(List<ValueVector> batch) {
        this.batchOfVectors.add(batch);
    }

    public static void readArrowStream(InputStream is, ArrowResultChunk resultChunk) throws IOException {
        try (ArrowStreamReader reader = new ArrowStreamReader(is, (BufferAllocator)rootAllocator);){
            while (reader.loadNextBatch()) {
                ArrayList<ValueVector> valueVectors = new ArrayList<ValueVector>();
                VectorSchemaRoot root = reader.getVectorSchemaRoot();
                for (FieldVector f : root.getFieldVectors()) {
                    TransferPair t = f.getTransferPair(rootAllocator);
                    t.transfer();
                    valueVectors.add(t.getTo());
                }
                resultChunk.addBatchData(valueVectors);
            }
        }
    }

    @Override
    public long computeNeededChunkMemory() {
        return this.getUncompressedSize();
    }

    @Override
    public void freeData() {
        this.batchOfVectors.forEach(list -> list.forEach(ValueVector::clear));
    }

    private static List<ArrowVectorConverter> initConverters(List<ValueVector> vectors, DataConversionContext context) throws SnowflakeSQLException {
        ArrayList<ArrowVectorConverter> converters = new ArrayList<ArrowVectorConverter>();
        for (int i = 0; i < vectors.size(); ++i) {
            ValueVector vector = vectors.get(i);
            Types.MinorType type = Types.getMinorTypeForArrowType(vector.getField().getType());
            Map<String, String> customMeta = vector.getField().getMetadata();
            if (type == Types.MinorType.DECIMAL) {
                converters.add(new DecimalToScaledFixedConverter(vector, i, context));
                continue;
            }
            if (!customMeta.isEmpty()) {
                SnowflakeType st = SnowflakeType.valueOf(customMeta.get("logicalType"));
                block0 : switch (st) {
                    case ANY: 
                    case ARRAY: 
                    case CHAR: 
                    case TEXT: 
                    case OBJECT: 
                    case VARIANT: {
                        converters.add(new VarCharToTextConverter(vector, i, context));
                        break;
                    }
                    case BINARY: {
                        converters.add(new VarBinaryToBinaryConverter(vector, i, context));
                        break;
                    }
                    case BOOLEAN: {
                        converters.add(new TinyIntToBooleanConverter(vector, i, context));
                        break;
                    }
                    case DATE: {
                        converters.add(new IntToDateConverter(vector, i, context));
                        break;
                    }
                    case FIXED: {
                        String scaleStr = vector.getField().getMetadata().get("scale");
                        int sfScale = Integer.parseInt(scaleStr);
                        switch (type) {
                            case TINYINT: {
                                if (sfScale == 0) {
                                    converters.add(new TinyIntToFixedConverter(vector, i, context));
                                    break block0;
                                }
                                converters.add(new TinyIntToScaledFixedConverter(vector, i, context, sfScale));
                                break block0;
                            }
                            case SMALLINT: {
                                if (sfScale == 0) {
                                    converters.add(new SmallIntToFixedConverter(vector, i, context));
                                    break block0;
                                }
                                converters.add(new SmallIntToScaledFixedConverter(vector, i, context, sfScale));
                                break block0;
                            }
                            case INT: {
                                if (sfScale == 0) {
                                    converters.add(new IntToFixedConverter(vector, i, context));
                                    break block0;
                                }
                                converters.add(new IntToScaledFixedConverter(vector, i, context, sfScale));
                                break block0;
                            }
                            case BIGINT: {
                                if (sfScale == 0) {
                                    converters.add(new BigIntToFixedConverter(vector, i, context));
                                    break block0;
                                }
                                converters.add(new BigIntToScaledFixedConverter(vector, i, context, sfScale));
                            }
                        }
                        break;
                    }
                    case REAL: {
                        converters.add(new DoubleToRealConverter(vector, i, context));
                        break;
                    }
                    case TIME: {
                        converters.add(new BigIntToTimeConverter(vector, i, context));
                        break;
                    }
                    case TIMESTAMP_LTZ: {
                        if (vector.getField().getChildren().isEmpty()) {
                            converters.add(new BigIntToTimestampLTZConverter(vector, i, context));
                            break;
                        }
                        if (vector.getField().getChildren().size() == 2) {
                            converters.add(new TwoFieldStructToTimestampLTZConverter(vector, i, context));
                            break;
                        }
                        throw new SnowflakeSQLException("XX000", ErrorCode.INTERNAL_ERROR.getMessageCode(), "Unexpected Arrow Field for ", st.name());
                    }
                    case TIMESTAMP_NTZ: {
                        if (vector.getField().getChildren().isEmpty()) {
                            converters.add(new BigIntToTimestampNTZConverter(vector, i, context));
                            break;
                        }
                        if (vector.getField().getChildren().size() == 2) {
                            converters.add(new TwoFieldStructToTimestampNTZConverter(vector, i, context));
                            break;
                        }
                        throw new SnowflakeSQLException("XX000", ErrorCode.INTERNAL_ERROR.getMessageCode(), "Unexpected Arrow Field for ", st.name());
                    }
                    case TIMESTAMP_TZ: {
                        if (vector.getField().getChildren().size() == 2) {
                            converters.add(new TwoFieldStructToTimestampTZConverter(vector, i, context));
                            break;
                        }
                        if (vector.getField().getChildren().size() == 3) {
                            converters.add(new ThreeFieldStructToTimestampTZConverter(vector, i, context));
                            break;
                        }
                        throw new SnowflakeSQLException("XX000", ErrorCode.INTERNAL_ERROR.getMessageCode(), "Unexpected SnowflakeType ", st.name());
                    }
                    default: {
                        throw new SnowflakeSQLException("XX000", ErrorCode.INTERNAL_ERROR.getMessageCode(), "Unexpected Arrow Field for ", st.name());
                    }
                }
                continue;
            }
            throw new SnowflakeSQLException("XX000", ErrorCode.INTERNAL_ERROR.getMessageCode(), "Unexpected Arrow Field for ", type.toString());
        }
        return converters;
    }

    public ArrowChunkIterator getIterator(DataConversionContext dataConversionContext) {
        return new ArrowChunkIterator(this, dataConversionContext);
    }

    public static ArrowChunkIterator getEmptyChunkIterator() {
        return new ArrowChunkIterator(new EmptyArrowResultChunk());
    }

    private static class EmptyArrowResultChunk
    extends ArrowResultChunk {
        EmptyArrowResultChunk() {
            super("", 0, 0, 0);
        }

        @Override
        public final long computeNeededChunkMemory() {
            return 0L;
        }

        @Override
        public final void freeData() {
        }
    }

    public static class ArrowChunkIterator {
        private ArrowResultChunk resultChunk;
        private int currentRecordBatchIndex;
        private int totalRecordBatch;
        private int currentRowInRecordBatch;
        private int rowCountInCurrentRecordBatch;
        private List<ArrowVectorConverter> currentConverters;
        private DataConversionContext dataConversionContext;

        ArrowChunkIterator(ArrowResultChunk resultChunk, DataConversionContext dataConversionContext) {
            this.resultChunk = resultChunk;
            this.currentRecordBatchIndex = -1;
            this.totalRecordBatch = resultChunk.batchOfVectors.size();
            this.currentRowInRecordBatch = -1;
            this.rowCountInCurrentRecordBatch = 0;
            this.dataConversionContext = dataConversionContext;
        }

        ArrowChunkIterator(EmptyArrowResultChunk emptyArrowResultChunk) {
            this.resultChunk = emptyArrowResultChunk;
            this.currentRecordBatchIndex = 0;
            this.totalRecordBatch = 0;
            this.currentRowInRecordBatch = -1;
            this.rowCountInCurrentRecordBatch = 0;
            this.currentConverters = Collections.emptyList();
        }

        public boolean next() throws SnowflakeSQLException {
            ++this.currentRowInRecordBatch;
            if (this.currentRowInRecordBatch < this.rowCountInCurrentRecordBatch) {
                return true;
            }
            ++this.currentRecordBatchIndex;
            if (this.currentRecordBatchIndex < this.totalRecordBatch) {
                this.currentRowInRecordBatch = 0;
                this.rowCountInCurrentRecordBatch = ((ValueVector)((List)this.resultChunk.batchOfVectors.get(this.currentRecordBatchIndex)).get(0)).getValueCount();
                this.currentConverters = ArrowResultChunk.initConverters((List)this.resultChunk.batchOfVectors.get(this.currentRecordBatchIndex), this.dataConversionContext);
                return true;
            }
            return false;
        }

        public boolean isLast() {
            return this.currentRecordBatchIndex + 1 == this.totalRecordBatch && this.currentRowInRecordBatch + 1 == this.rowCountInCurrentRecordBatch;
        }

        public boolean isAfterLast() {
            return this.currentRecordBatchIndex >= this.totalRecordBatch && this.currentRowInRecordBatch >= this.rowCountInCurrentRecordBatch;
        }

        public ArrowResultChunk getChunk() {
            return this.resultChunk;
        }

        public ArrowVectorConverter getCurrentConverter(int columnIndex) {
            return this.currentConverters.get(columnIndex);
        }

        public int getCurrentRowInRecordBatch() {
            return this.currentRowInRecordBatch;
        }
    }
}

