/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.extensions.sql.impl.transform;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.stream.IntStream;
import org.apache.beam.sdk.coders.BigDecimalCoder;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.coders.CustomCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.RowCoder;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.extensions.sql.SqlTypeCoder;
import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlInputRefExpression;
import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.UdafImpl;
import org.apache.beam.sdk.extensions.sql.impl.transform.BeamBuiltinAggregations;
import org.apache.beam.sdk.extensions.sql.impl.transform.agg.CovarianceFn;
import org.apache.beam.sdk.extensions.sql.impl.transform.agg.VarianceFn;
import org.apache.beam.sdk.extensions.sql.impl.utils.BigDecimalConverter;
import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.RowType;
import org.apache.beam.sdks.java.extensions.sql.repackaged.com.google.common.collect.ImmutableList;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.rel.core.AggregateCall;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.sql.validate.SqlUserDefinedAggFunction;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.util.ImmutableBitSet;
import org.joda.time.Instant;

public class BeamAggregationTransforms
implements Serializable {

    public static class AggregationAccumulatorCoder
    extends CustomCoder<AggregationAccumulator> {
        private VarIntCoder sizeCoder = VarIntCoder.of();
        private List<Coder> elementCoders;

        public AggregationAccumulatorCoder(List<Coder> elementCoders) {
            this.elementCoders = elementCoders;
        }

        public void encode(AggregationAccumulator value, OutputStream outStream) throws CoderException, IOException {
            this.sizeCoder.encode(Integer.valueOf(value.accumulatorElements.size()), outStream);
            for (int idx = 0; idx < value.accumulatorElements.size(); ++idx) {
                this.elementCoders.get(idx).encode(value.accumulatorElements.get(idx), outStream);
            }
        }

        public AggregationAccumulator decode(InputStream inStream) throws CoderException, IOException {
            AggregationAccumulator accu = new AggregationAccumulator();
            int size = this.sizeCoder.decode(inStream);
            for (int idx = 0; idx < size; ++idx) {
                accu.accumulatorElements.add(this.elementCoders.get(idx).decode(inStream));
            }
            return accu;
        }
    }

    public static class AggregationAccumulator {
        private List accumulatorElements = new ArrayList();
    }

    public static class AggregationAdaptor
    extends Combine.CombineFn<Row, AggregationAccumulator, Row> {
        private List<Combine.CombineFn> aggregators = new ArrayList<Combine.CombineFn>();
        private List<Object> sourceFieldExps = new ArrayList<Object>();
        private RowType finalRowType;

        public AggregationAdaptor(List<AggregateCall> aggregationCalls, RowType sourceRowType) {
            ImmutableList.Builder fields = ImmutableList.builder();
            block24: for (AggregateCall call : aggregationCalls) {
                if (call.getArgList().size() == 2) {
                    int refIndexKey = call.getArgList().get(0);
                    int refIndexValue = call.getArgList().get(1);
                    BeamSqlInputRefExpression sourceExpKey = new BeamSqlInputRefExpression(CalciteUtils.getFieldCalciteType(sourceRowType, refIndexKey), refIndexKey);
                    BeamSqlInputRefExpression sourceExpValue = new BeamSqlInputRefExpression(CalciteUtils.getFieldCalciteType(sourceRowType, refIndexValue), refIndexValue);
                    this.sourceFieldExps.add(KV.of((Object)sourceExpKey, (Object)sourceExpValue));
                } else {
                    int refIndex = call.getArgList().size() > 0 ? call.getArgList().get(0) : 0;
                    BeamSqlInputRefExpression sourceExp = new BeamSqlInputRefExpression(CalciteUtils.getFieldCalciteType(sourceRowType, refIndex), refIndex);
                    this.sourceFieldExps.add(sourceExp);
                }
                SqlTypeCoder outFieldType = CalciteUtils.toCoder(call.type.getSqlTypeName());
                fields.add(RowType.newField((String)call.name, (Coder)outFieldType));
                switch (call.getAggregation().getName()) {
                    case "COUNT": {
                        this.aggregators.add(Count.combineFn());
                        continue block24;
                    }
                    case "MAX": {
                        this.aggregators.add(BeamBuiltinAggregations.createMax(call.type.getSqlTypeName()));
                        continue block24;
                    }
                    case "MIN": {
                        this.aggregators.add(BeamBuiltinAggregations.createMin(call.type.getSqlTypeName()));
                        continue block24;
                    }
                    case "SUM": {
                        this.aggregators.add(BeamBuiltinAggregations.createSum(call.type.getSqlTypeName()));
                        continue block24;
                    }
                    case "AVG": {
                        this.aggregators.add(BeamBuiltinAggregations.createAvg(call.type.getSqlTypeName()));
                        continue block24;
                    }
                    case "VAR_POP": {
                        this.aggregators.add(VarianceFn.newPopulation(BigDecimalConverter.forSqlType(outFieldType)));
                        continue block24;
                    }
                    case "VAR_SAMP": {
                        this.aggregators.add(VarianceFn.newSample(BigDecimalConverter.forSqlType(outFieldType)));
                        continue block24;
                    }
                    case "COVAR_POP": {
                        this.aggregators.add(CovarianceFn.newPopulation(BigDecimalConverter.forSqlType(outFieldType)));
                        continue block24;
                    }
                    case "COVAR_SAMP": {
                        this.aggregators.add(CovarianceFn.newSample(BigDecimalConverter.forSqlType(outFieldType)));
                        continue block24;
                    }
                }
                if (call.getAggregation() instanceof SqlUserDefinedAggFunction) {
                    SqlUserDefinedAggFunction udaf = (SqlUserDefinedAggFunction)call.getAggregation();
                    UdafImpl fn = (UdafImpl)udaf.function;
                    try {
                        this.aggregators.add(fn.getCombineFn());
                        continue;
                    }
                    catch (Exception e) {
                        throw new IllegalStateException(e);
                    }
                }
                throw new UnsupportedOperationException(String.format("Aggregator [%s] is not supported", call.getAggregation().getName()));
            }
            this.finalRowType = (RowType)fields.build().stream().collect(RowType.toRowType());
        }

        public AggregationAccumulator createAccumulator() {
            AggregationAccumulator initialAccu = new AggregationAccumulator();
            for (Combine.CombineFn agg : this.aggregators) {
                initialAccu.accumulatorElements.add(agg.createAccumulator());
            }
            return initialAccu;
        }

        public AggregationAccumulator addInput(AggregationAccumulator accumulator, Row input) {
            AggregationAccumulator deltaAcc = new AggregationAccumulator();
            for (int idx = 0; idx < this.aggregators.size(); ++idx) {
                BeamSqlInputRefExpression exp;
                if (this.sourceFieldExps.get(idx) instanceof BeamSqlInputRefExpression) {
                    exp = (BeamSqlInputRefExpression)this.sourceFieldExps.get(idx);
                    deltaAcc.accumulatorElements.add(this.aggregators.get(idx).addInput(accumulator.accumulatorElements.get(idx), exp.evaluate(input, null).getValue()));
                    continue;
                }
                if (!(this.sourceFieldExps.get(idx) instanceof KV)) continue;
                exp = (KV)this.sourceFieldExps.get(idx);
                deltaAcc.accumulatorElements.add(this.aggregators.get(idx).addInput(accumulator.accumulatorElements.get(idx), (Object)KV.of(((BeamSqlInputRefExpression)exp.getKey()).evaluate(input, null).getValue(), ((BeamSqlInputRefExpression)exp.getValue()).evaluate(input, null).getValue())));
            }
            return deltaAcc;
        }

        public AggregationAccumulator mergeAccumulators(Iterable<AggregationAccumulator> accumulators) {
            AggregationAccumulator deltaAcc = new AggregationAccumulator();
            for (int idx = 0; idx < this.aggregators.size(); ++idx) {
                ArrayList accs = new ArrayList();
                for (AggregationAccumulator accumulator : accumulators) {
                    accs.add(accumulator.accumulatorElements.get(idx));
                }
                deltaAcc.accumulatorElements.add(this.aggregators.get(idx).mergeAccumulators(accs));
            }
            return deltaAcc;
        }

        public Row extractOutput(AggregationAccumulator accumulator) {
            return (Row)IntStream.range(0, this.aggregators.size()).mapToObj(idx -> this.getAggregatorOutput(accumulator, idx)).collect(Row.toRow((RowType)this.finalRowType));
        }

        private Object getAggregatorOutput(AggregationAccumulator accumulator, int idx) {
            return this.aggregators.get(idx).extractOutput(accumulator.accumulatorElements.get(idx));
        }

        public Coder<AggregationAccumulator> getAccumulatorCoder(CoderRegistry registry, Coder<Row> inputCoder) throws CannotProvideCoderException {
            RowCoder rowCoder = (RowCoder)inputCoder;
            registry.registerCoderForClass(BigDecimal.class, (Coder)BigDecimalCoder.of());
            ArrayList<Coder> aggAccuCoderList = new ArrayList<Coder>();
            for (int idx = 0; idx < this.aggregators.size(); ++idx) {
                BeamSqlInputRefExpression exp;
                if (this.sourceFieldExps.get(idx) instanceof BeamSqlInputRefExpression) {
                    exp = (BeamSqlInputRefExpression)this.sourceFieldExps.get(idx);
                    int srcFieldIndex = exp.getInputRef();
                    Coder srcFieldCoder = (Coder)rowCoder.getCoders().get(srcFieldIndex);
                    aggAccuCoderList.add(this.aggregators.get(idx).getAccumulatorCoder(registry, srcFieldCoder));
                    continue;
                }
                if (!(this.sourceFieldExps.get(idx) instanceof KV)) continue;
                exp = (KV)this.sourceFieldExps.get(idx);
                int srcFieldIndexKey = ((BeamSqlInputRefExpression)exp.getKey()).getInputRef();
                int srcFieldIndexValue = ((BeamSqlInputRefExpression)exp.getValue()).getInputRef();
                Coder srcFieldCoderKey = (Coder)rowCoder.getCoders().get(srcFieldIndexKey);
                Coder srcFieldCoderValue = (Coder)rowCoder.getCoders().get(srcFieldIndexValue);
                aggAccuCoderList.add(this.aggregators.get(idx).getAccumulatorCoder(registry, (Coder)KvCoder.of((Coder)srcFieldCoderKey, (Coder)srcFieldCoderValue)));
            }
            return new AggregationAccumulatorCoder(aggAccuCoderList);
        }
    }

    public static class WindowTimestampFn
    implements SerializableFunction<Row, Instant> {
        private int windowFieldIdx = -1;

        public WindowTimestampFn(int windowFieldIdx) {
            this.windowFieldIdx = windowFieldIdx;
        }

        public Instant apply(Row input) {
            return new Instant(input.getDate(this.windowFieldIdx).getTime());
        }
    }

    public static class AggregationGroupByKeyFn
    implements SerializableFunction<Row, Row> {
        private List<Integer> groupByKeys = new ArrayList<Integer>();

        public AggregationGroupByKeyFn(int windowFieldIdx, ImmutableBitSet groupSet) {
            for (int i : groupSet.asList()) {
                if (i == windowFieldIdx) continue;
                this.groupByKeys.add(i);
            }
        }

        public Row apply(Row input) {
            RowType typeOfKey = this.exTypeOfKeyRow(input.getRowType());
            return (Row)this.groupByKeys.stream().map(arg_0 -> ((Row)input).getValue(arg_0)).collect(Row.toRow((RowType)typeOfKey));
        }

        private RowType exTypeOfKeyRow(RowType dataType) {
            return (RowType)this.groupByKeys.stream().map(i -> RowType.newField((String)dataType.getFieldName(i.intValue()), (Coder)dataType.getFieldCoder(i.intValue()))).collect(RowType.toRowType());
        }
    }

    public static class MergeAggregationRecord
    extends DoFn<KV<Row, Row>, Row> {
        private RowType outRowType;
        private List<String> aggFieldNames;
        private int windowStartFieldIdx;

        public MergeAggregationRecord(RowType outRowType, List<AggregateCall> aggList, int windowStartFieldIdx) {
            this.outRowType = outRowType;
            this.aggFieldNames = new ArrayList<String>();
            for (AggregateCall ac : aggList) {
                this.aggFieldNames.add(ac.getName());
            }
            this.windowStartFieldIdx = windowStartFieldIdx;
        }

        @DoFn.ProcessElement
        public void processElement(DoFn.ProcessContext c, BoundedWindow window) {
            KV kvRow = (KV)c.element();
            ArrayList<Date> fieldValues = new ArrayList<Date>();
            fieldValues.addAll(((Row)kvRow.getKey()).getValues());
            fieldValues.addAll(((Row)kvRow.getValue()).getValues());
            if (this.windowStartFieldIdx != -1) {
                fieldValues.add(this.windowStartFieldIdx, ((IntervalWindow)window).start().toDate());
            }
            c.output((Object)Row.withRowType((RowType)this.outRowType).addValues(fieldValues).build());
        }
    }
}

