/*
 * Decompiled with CFR 0.152.
 */
package io.trino.plugin.redshift;

import com.google.common.base.Verify;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.aggregation.AggregateFunctionPatterns;
import io.trino.plugin.base.aggregation.AggregateFunctionRule;
import io.trino.plugin.base.expression.ConnectorExpressionPatterns;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.expression.ParameterizedExpression;
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.Variable;
import io.trino.spi.type.DecimalType;
import java.util.Optional;

public class ImplementRedshiftAvgDecimal
implements AggregateFunctionRule<JdbcExpression, ParameterizedExpression> {
    private static final Capture<Variable> INPUT = Capture.newCapture();

    public Pattern<AggregateFunction> getPattern() {
        return AggregateFunctionPatterns.basicAggregation().with(AggregateFunctionPatterns.functionName().equalTo((Object)"avg")).with(AggregateFunctionPatterns.singleArgument().matching(ConnectorExpressionPatterns.variable().with(ConnectorExpressionPatterns.type().matching(DecimalType.class::isInstance)).capturedAs(INPUT)));
    }

    public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, AggregateFunctionRule.RewriteContext<ParameterizedExpression> context) {
        Variable input = (Variable)captures.get(INPUT);
        JdbcColumnHandle columnHandle = (JdbcColumnHandle)context.getAssignment(input.getName());
        DecimalType type = (DecimalType)columnHandle.getColumnType();
        Verify.verify((boolean)aggregateFunction.getOutputType().equals(type));
        ParameterizedExpression rewrittenArgument = (ParameterizedExpression)context.rewriteExpression((ConnectorExpression)input).orElseThrow();
        if (type.getPrecision() == 38) {
            return Optional.of(new JdbcExpression(String.format("avg(CAST(%s AS decimal(%s, %s)))", rewrittenArgument.expression(), type.getPrecision(), type.getScale()), rewrittenArgument.parameters(), columnHandle.getJdbcTypeHandle()));
        }
        return Optional.of(new JdbcExpression(String.format("round(avg(CAST(%s AS decimal(%s, %s))), %s)", rewrittenArgument.expression(), type.getPrecision() + 1, type.getScale() + 1, type.getScale()), rewrittenArgument.parameters(), columnHandle.getJdbcTypeHandle()));
    }
}

