/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.runtime.functions.aggregate;

import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction;
import org.apache.flink.table.runtime.typeutils.InternalSerializers;
import org.apache.flink.table.runtime.typeutils.LinkedListSerializer;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.LogicalTypeRoot;
import org.apache.flink.table.types.utils.DataTypeUtils;

public class LagAggFunction<T>
extends BuiltInAggregateFunction<T, LagAcc<T>> {
    private final transient DataType[] valueDataTypes;

    public LagAggFunction(LogicalType[] valueTypes) {
        this.valueDataTypes = (DataType[])Arrays.stream(valueTypes).map(DataTypeUtils::toInternalDataType).toArray(DataType[]::new);
        if (!this.valueDataTypes[0].getLogicalType().isNullable() && (this.valueDataTypes.length < 3 || this.valueDataTypes.length == 3 && this.valueDataTypes[2].getLogicalType().isNullable())) {
            this.valueDataTypes[0] = (DataType)this.valueDataTypes[0].nullable();
        }
        if (this.valueDataTypes.length == 3 && this.valueDataTypes[2].getLogicalType().getTypeRoot() != LogicalTypeRoot.NULL && this.valueDataTypes[0].getConversionClass() != this.valueDataTypes[2].getConversionClass()) {
            throw new TableException(String.format("Please explicitly cast default value %s to %s.", this.valueDataTypes[2], this.valueDataTypes[1]));
        }
    }

    @Override
    public List<DataType> getArgumentDataTypes() {
        return Arrays.asList(this.valueDataTypes);
    }

    @Override
    public DataType getAccumulatorDataType() {
        return DataTypes.STRUCTURED(LagAcc.class, (DataTypes.Field[])new DataTypes.Field[]{DataTypes.FIELD((String)"offset", (DataType)DataTypes.INT()), DataTypes.FIELD((String)"defaultValue", (DataType)((DataType)this.valueDataTypes[0].nullable())), DataTypes.FIELD((String)"buffer", (DataType)this.getLinkedListType())});
    }

    private DataType getLinkedListType() {
        TypeSerializer serializer = InternalSerializers.create(this.getOutputDataType().getLogicalType());
        return DataTypes.RAW(LinkedList.class, new LinkedListSerializer(serializer));
    }

    @Override
    public DataType getOutputDataType() {
        return this.valueDataTypes[0];
    }

    public void accumulate(LagAcc<T> acc, T value) throws Exception {
        acc.buffer.add(value);
        while (acc.buffer.size() > acc.offset + 1) {
            acc.buffer.removeFirst();
        }
    }

    public void accumulate(LagAcc<T> acc, T value, int offset) throws Exception {
        if (offset < 0) {
            throw new TableException(String.format("Offset(%d) should be positive.", offset));
        }
        acc.offset = offset;
        this.accumulate(acc, value);
    }

    public void accumulate(LagAcc<T> acc, T value, int offset, T defaultValue) throws Exception {
        acc.defaultValue = defaultValue;
        this.accumulate(acc, value, offset);
    }

    public void resetAccumulator(LagAcc<T> acc) throws Exception {
        acc.offset = 1;
        acc.defaultValue = null;
        acc.buffer.clear();
    }

    public T getValue(LagAcc<T> acc) {
        if (acc.buffer.size() < acc.offset + 1) {
            return acc.defaultValue;
        }
        if (acc.buffer.size() == acc.offset + 1) {
            return acc.buffer.getFirst();
        }
        throw new TableException("Too more elements: " + acc);
    }

    public LagAcc<T> createAccumulator() {
        return new LagAcc();
    }

    public static class LagAcc<T> {
        public int offset = 1;
        public T defaultValue = null;
        public LinkedList<T> buffer = new LinkedList();

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            LagAcc lagAcc = (LagAcc)o;
            return this.offset == lagAcc.offset && Objects.equals(this.defaultValue, lagAcc.defaultValue) && Objects.equals(this.buffer, lagAcc.buffer);
        }

        public int hashCode() {
            return Objects.hash(this.offset, this.defaultValue, this.buffer);
        }
    }
}

