/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.extensions.sql.impl.transform.agg;

import java.math.BigDecimal;
import java.math.MathContext;
import java.math.RoundingMode;
import java.util.stream.StreamSupport;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.runtime.SqlFunctions;
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.coders.SerializableCoder;
import org.apache.beam.sdk.extensions.sql.impl.transform.agg.CovarianceAccumulator;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.values.KV;

@Internal
public class CovarianceFn<T extends Number>
extends Combine.CombineFn<KV<T, T>, CovarianceAccumulator, T> {
    static final MathContext MATH_CTX = new MathContext(10, RoundingMode.HALF_UP);
    private static final boolean SAMPLE = true;
    private static final boolean POP = false;
    private boolean isSample;
    private SerializableFunction<BigDecimal, T> decimalConverter;

    public static <V extends Number> CovarianceFn newPopulation(SerializableFunction<BigDecimal, V> decimalConverter) {
        return new CovarianceFn<V>(false, decimalConverter);
    }

    public static <V extends Number> CovarianceFn newSample(SerializableFunction<BigDecimal, V> decimalConverter) {
        return new CovarianceFn<V>(true, decimalConverter);
    }

    private CovarianceFn(boolean isSample, SerializableFunction<BigDecimal, T> decimalConverter) {
        this.isSample = isSample;
        this.decimalConverter = decimalConverter;
    }

    public CovarianceAccumulator createAccumulator() {
        return CovarianceAccumulator.ofZeroElements();
    }

    public CovarianceAccumulator addInput(CovarianceAccumulator currentVariance, KV<T, T> rawInput) {
        if (rawInput == null) {
            return currentVariance;
        }
        return currentVariance.combineWith(CovarianceAccumulator.ofSingleElement(SqlFunctions.toBigDecimal((Number)rawInput.getKey()), SqlFunctions.toBigDecimal((Number)rawInput.getValue())));
    }

    public CovarianceAccumulator mergeAccumulators(Iterable<CovarianceAccumulator> covariances) {
        return StreamSupport.stream(covariances.spliterator(), false).reduce(CovarianceAccumulator.ofZeroElements(), CovarianceAccumulator::combineWith);
    }

    public Coder<CovarianceAccumulator> getAccumulatorCoder(CoderRegistry registry, Coder<KV<T, T>> inputCoder) {
        return SerializableCoder.of(CovarianceAccumulator.class);
    }

    public T extractOutput(CovarianceAccumulator accumulator) {
        return (T)((Number)this.decimalConverter.apply((Object)this.getCovariance(accumulator)));
    }

    private BigDecimal getCovariance(CovarianceAccumulator covariance) {
        BigDecimal adjustedCount = this.isSample ? covariance.count().subtract(BigDecimal.ONE) : covariance.count();
        return covariance.covariance().divide(adjustedCount, MATH_CTX);
    }
}

