/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped;

import com.google.common.base.Preconditions;
import org.apache.iotdb.db.queryengine.execution.aggregation.VarianceAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.AggregationMask;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.DoubleBigArray;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.LongBigArray;
import org.apache.tsfile.block.column.Column;
import org.apache.tsfile.block.column.ColumnBuilder;
import org.apache.tsfile.enums.TSDataType;
import org.apache.tsfile.read.common.block.column.BinaryColumn;
import org.apache.tsfile.read.common.block.column.BinaryColumnBuilder;
import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn;
import org.apache.tsfile.utils.Binary;
import org.apache.tsfile.utils.BytesUtils;
import org.apache.tsfile.utils.RamUsageEstimator;
import org.apache.tsfile.write.UnSupportedDataTypeException;

public class GroupedVarianceAccumulator
implements GroupedAccumulator {
    private static final long INSTANCE_SIZE = RamUsageEstimator.shallowSizeOfInstance(GroupedVarianceAccumulator.class);
    private final TSDataType seriesDataType;
    private final VarianceAccumulator.VarianceType varianceType;
    private final LongBigArray counts = new LongBigArray();
    private final DoubleBigArray means = new DoubleBigArray();
    private final DoubleBigArray m2s = new DoubleBigArray();

    public GroupedVarianceAccumulator(TSDataType seriesDataType, VarianceAccumulator.VarianceType varianceType) {
        this.seriesDataType = seriesDataType;
        this.varianceType = varianceType;
    }

    @Override
    public long getEstimatedSize() {
        return INSTANCE_SIZE + this.counts.sizeOf() + this.means.sizeOf() + this.m2s.sizeOf();
    }

    @Override
    public void setGroupCount(long groupCount) {
        this.counts.ensureCapacity(groupCount);
        this.means.ensureCapacity(groupCount);
        this.m2s.ensureCapacity(groupCount);
    }

    @Override
    public void addInput(int[] groupIds, Column[] arguments, AggregationMask mask) {
        switch (this.seriesDataType) {
            case INT32: {
                this.addIntInput(groupIds, arguments[0], mask);
                return;
            }
            case INT64: {
                this.addLongInput(groupIds, arguments[0], mask);
                return;
            }
            case FLOAT: {
                this.addFloatInput(groupIds, arguments[0], mask);
                return;
            }
            case DOUBLE: {
                this.addDoubleInput(groupIds, arguments[0], mask);
                return;
            }
        }
        throw new UnSupportedDataTypeException(String.format("Unsupported data type in aggregation variance : %s", this.seriesDataType));
    }

    @Override
    public void addIntermediate(int[] groupIds, Column argument) {
        Preconditions.checkArgument((argument instanceof BinaryColumn || argument instanceof RunLengthEncodedColumn && ((RunLengthEncodedColumn)argument).getValue() instanceof BinaryColumn ? 1 : 0) != 0, (Object)"intermediate input and output should be BinaryColumn");
        for (int i = 0; i < argument.getPositionCount(); ++i) {
            if (argument.isNull(i)) continue;
            byte[] bytes = argument.getBinary(i).getValues();
            long intermediateCount = BytesUtils.bytesToLong((byte[])bytes, (int)8);
            double intermediateMean = BytesUtils.bytesToDouble((byte[])bytes, (int)8);
            double intermediateM2 = BytesUtils.bytesToDouble((byte[])bytes, (int)16);
            long newCount = this.counts.get(groupIds[i]) + intermediateCount;
            double newMean = ((double)intermediateCount * intermediateMean + (double)this.counts.get(groupIds[i]) * this.means.get(groupIds[i])) / (double)newCount;
            double delta = intermediateMean - this.means.get(groupIds[i]);
            this.m2s.add(groupIds[i], intermediateM2 + delta * delta * (double)intermediateCount * (double)this.counts.get(groupIds[i]) / (double)newCount);
            this.counts.set(groupIds[i], newCount);
            this.means.set(groupIds[i], newMean);
        }
    }

    @Override
    public void evaluateIntermediate(int groupId, ColumnBuilder columnBuilder) {
        Preconditions.checkArgument((boolean)(columnBuilder instanceof BinaryColumnBuilder), (Object)"intermediate input and output should be BinaryColumn");
        if (this.counts.get(groupId) == 0L) {
            columnBuilder.appendNull();
        } else {
            byte[] bytes = new byte[24];
            BytesUtils.longToBytes((long)this.counts.get(groupId), (byte[])bytes, (int)0);
            BytesUtils.doubleToBytes((double)this.means.get(groupId), (byte[])bytes, (int)8);
            BytesUtils.doubleToBytes((double)this.m2s.get(groupId), (byte[])bytes, (int)16);
            columnBuilder.writeBinary(new Binary(bytes));
        }
    }

    @Override
    public void evaluateFinal(int groupId, ColumnBuilder columnBuilder) {
        switch (this.varianceType) {
            case STDDEV_POP: {
                if (this.counts.get(groupId) == 0L) {
                    columnBuilder.appendNull();
                    break;
                }
                columnBuilder.writeDouble(Math.sqrt(this.m2s.get(groupId) / (double)this.counts.get(groupId)));
                break;
            }
            case STDDEV_SAMP: {
                if (this.counts.get(groupId) < 2L) {
                    columnBuilder.appendNull();
                    break;
                }
                columnBuilder.writeDouble(Math.sqrt(this.m2s.get(groupId) / (double)(this.counts.get(groupId) - 1L)));
                break;
            }
            case VAR_POP: {
                if (this.counts.get(groupId) == 0L) {
                    columnBuilder.appendNull();
                    break;
                }
                columnBuilder.writeDouble(this.m2s.get(groupId) / (double)this.counts.get(groupId));
                break;
            }
            case VAR_SAMP: {
                if (this.counts.get(groupId) < 2L) {
                    columnBuilder.appendNull();
                    break;
                }
                columnBuilder.writeDouble(this.m2s.get(groupId) / (double)(this.counts.get(groupId) - 1L));
                break;
            }
            default: {
                throw new EnumConstantNotPresentException(VarianceAccumulator.VarianceType.class, this.varianceType.name());
            }
        }
    }

    @Override
    public void prepareFinal() {
    }

    @Override
    public void reset() {
        this.counts.reset();
        this.means.reset();
        this.m2s.reset();
    }

    private void addIntInput(int[] groupIds, Column column, AggregationMask mask) {
        int positionCount = mask.getSelectedPositionCount();
        if (mask.isSelectAll()) {
            for (int i = 0; i < positionCount; ++i) {
                if (column.isNull(i)) continue;
                int value = column.getInt(i);
                this.counts.increment(groupIds[i]);
                double delta = (double)value - this.means.get(groupIds[i]);
                this.means.add(groupIds[i], delta / (double)this.counts.get(groupIds[i]));
                this.m2s.add(groupIds[i], delta * ((double)value - this.means.get(groupIds[i])));
            }
        } else {
            int[] selectedPositions = mask.getSelectedPositions();
            for (int i = 0; i < positionCount; ++i) {
                int position = selectedPositions[i];
                if (column.isNull(position)) continue;
                int value = column.getInt(position);
                this.counts.increment(groupIds[position]);
                double delta = (double)value - this.means.get(groupIds[position]);
                this.means.add(groupIds[position], delta / (double)this.counts.get(groupIds[position]));
                this.m2s.add(groupIds[position], delta * ((double)value - this.means.get(groupIds[position])));
            }
        }
    }

    private void addLongInput(int[] groupIds, Column column, AggregationMask mask) {
        int positionCount = mask.getSelectedPositionCount();
        if (mask.isSelectAll()) {
            for (int i = 0; i < positionCount; ++i) {
                if (column.isNull(i)) continue;
                long value = column.getLong(i);
                this.counts.increment(groupIds[i]);
                double delta = (double)value - this.means.get(groupIds[i]);
                this.means.add(groupIds[i], delta / (double)this.counts.get(groupIds[i]));
                this.m2s.add(groupIds[i], delta * ((double)value - this.means.get(groupIds[i])));
            }
        } else {
            int[] selectedPositions = mask.getSelectedPositions();
            for (int i = 0; i < positionCount; ++i) {
                int position = selectedPositions[i];
                if (column.isNull(position)) continue;
                long value = column.getLong(position);
                this.counts.increment(groupIds[position]);
                double delta = (double)value - this.means.get(groupIds[position]);
                this.means.add(groupIds[position], delta / (double)this.counts.get(groupIds[position]));
                this.m2s.add(groupIds[position], delta * ((double)value - this.means.get(groupIds[position])));
            }
        }
    }

    private void addFloatInput(int[] groupIds, Column column, AggregationMask mask) {
        int positionCount = mask.getSelectedPositionCount();
        if (mask.isSelectAll()) {
            for (int i = 0; i < positionCount; ++i) {
                if (column.isNull(i)) continue;
                float value = column.getFloat(i);
                this.counts.increment(groupIds[i]);
                double delta = (double)value - this.means.get(groupIds[i]);
                this.means.add(groupIds[i], delta / (double)this.counts.get(groupIds[i]));
                this.m2s.add(groupIds[i], delta * ((double)value - this.means.get(groupIds[i])));
            }
        } else {
            int[] selectedPositions = mask.getSelectedPositions();
            for (int i = 0; i < positionCount; ++i) {
                int position = selectedPositions[i];
                if (column.isNull(position)) continue;
                float value = column.getFloat(position);
                this.counts.increment(groupIds[position]);
                double delta = (double)value - this.means.get(groupIds[position]);
                this.means.add(groupIds[position], delta / (double)this.counts.get(groupIds[position]));
                this.m2s.add(groupIds[position], delta * ((double)value - this.means.get(groupIds[position])));
            }
        }
    }

    private void addDoubleInput(int[] groupIds, Column column, AggregationMask mask) {
        int positionCount = mask.getSelectedPositionCount();
        if (mask.isSelectAll()) {
            for (int i = 0; i < positionCount; ++i) {
                if (column.isNull(i)) continue;
                double value = column.getDouble(i);
                this.counts.increment(groupIds[i]);
                double delta = value - this.means.get(groupIds[i]);
                this.means.add(groupIds[i], delta / (double)this.counts.get(groupIds[i]));
                this.m2s.add(groupIds[i], delta * (value - this.means.get(groupIds[i])));
            }
        } else {
            int[] selectedPositions = mask.getSelectedPositions();
            for (int i = 0; i < positionCount; ++i) {
                int position = selectedPositions[i];
                if (column.isNull(position)) continue;
                double value = column.getDouble(position);
                this.counts.increment(groupIds[position]);
                double delta = value - this.means.get(groupIds[position]);
                this.means.add(groupIds[position], delta / (double)this.counts.get(groupIds[position]));
                this.m2s.add(groupIds[position], delta * (value - this.means.get(groupIds[position])));
            }
        }
    }
}

