package com.facebook.presto.sql.relational;

import com.facebook.presto.common.function.SqlFunctionProperties;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.TypeSignature;
import com.facebook.presto.expressions.RowExpressionRewriter;
import com.facebook.presto.expressions.RowExpressionTreeRewriter;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.function.FunctionImplementationType;
import com.facebook.presto.spi.function.FunctionMetadata;
import com.facebook.presto.spi.function.SqlInvokedScalarFunctionImplementation;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.analyzer.ExpressionAnalysis;
import com.facebook.presto.sql.analyzer.ExpressionAnalyzer;
import com.facebook.presto.sql.parser.ParsingOptions;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.PlanVariableAllocator;
import com.facebook.presto.sql.planner.iterative.rule.LambdaCaptureDesugaringRewriter;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.ExpressionRewriter;
import com.facebook.presto.sql.tree.ExpressionTreeRewriter;
import com.facebook.presto.sql.tree.Identifier;
import com.facebook.presto.sql.tree.LambdaArgumentDeclaration;
import com.facebook.presto.sql.tree.LambdaExpression;
import com.facebook.presto.sql.tree.NodeRef;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Stream;

/* loaded from: input_file:com/facebook/presto/sql/relational/SqlFunctionUtils.class */
public final class SqlFunctionUtils {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/facebook/presto/sql/relational/SqlFunctionUtils$SqlFunctionArgumentBinder.class */
    public static final class SqlFunctionArgumentBinder {

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:com/facebook/presto/sql/relational/SqlFunctionUtils$SqlFunctionArgumentBinder$ExpressionFunctionVisitor.class */
        public static class ExpressionFunctionVisitor extends ExpressionRewriter<Void> {
            private final Map<String, Expression> argumentBindings;

            public ExpressionFunctionVisitor(Map<String, Expression> map) {
                this.argumentBindings = (Map) Objects.requireNonNull(map, "argumentBindings is null");
            }

            @Override // com.facebook.presto.sql.tree.ExpressionRewriter
            public Expression rewriteIdentifier(Identifier identifier, Void r5, ExpressionTreeRewriter<Void> expressionTreeRewriter) {
                return this.argumentBindings.containsKey(identifier.getValue()) ? this.argumentBindings.get(identifier.getValue()) : identifier;
            }
        }

        private SqlFunctionArgumentBinder() {
        }

        public static Expression bindFunctionArguments(Expression expression, List<String> list, List<Expression> list2) {
            Preconditions.checkArgument(list.size() == list2.size(), String.format("Expect same size for argumentNames (%d) and argumentValues (%d)", Integer.valueOf(list.size()), Integer.valueOf(list2.size())));
            ImmutableMap.Builder builder = ImmutableMap.builder();
            for (int i = 0; i < list.size(); i++) {
                builder.put(list.get(i), list2.get(i));
            }
            return ExpressionTreeRewriter.rewriteWith(new ExpressionFunctionVisitor(builder.build()), expression);
        }

        public static RowExpression bindFunctionArguments(RowExpression rowExpression, List<Optional<String>> list, List<RowExpression> list2) {
            Preconditions.checkArgument(list.size() == list2.size(), String.format("Expect same size for argumentNames (%d) and argumentValues (%d)", Integer.valueOf(list.size()), Integer.valueOf(list2.size())));
            ImmutableMap.Builder builder = ImmutableMap.builder();
            for (int i = 0; i < list.size(); i++) {
                if (list.get(i).isPresent()) {
                    builder.put(list.get(i).get(), list2.get(i));
                }
            }
            return RowExpressionTreeRewriter.rewriteWith(new RowExpressionRewriter<Map<String, RowExpression>>() { // from class: com.facebook.presto.sql.relational.SqlFunctionUtils.SqlFunctionArgumentBinder.1
                @Override // com.facebook.presto.expressions.RowExpressionRewriter
                public RowExpression rewriteVariableReference(VariableReferenceExpression variableReferenceExpression, Map<String, RowExpression> map, RowExpressionTreeRewriter<Map<String, RowExpression>> rowExpressionTreeRewriter) {
                    return map.containsKey(variableReferenceExpression.getName()) ? map.get(variableReferenceExpression.getName()) : variableReferenceExpression;
                }
            }, rowExpression, builder.build());
        }
    }

    private SqlFunctionUtils() {
    }

    public static Expression getSqlFunctionExpression(FunctionMetadata functionMetadata, SqlInvokedScalarFunctionImplementation sqlInvokedScalarFunctionImplementation, SqlFunctionProperties sqlFunctionProperties, List<Expression> list) {
        Preconditions.checkArgument(functionMetadata.getImplementationType().equals(FunctionImplementationType.SQL), String.format("Expect SQL function, get %s", functionMetadata.getImplementationType()));
        Preconditions.checkArgument(functionMetadata.getArgumentNames().isPresent(), "Argument name is missing");
        return SqlFunctionArgumentBinder.bindFunctionArguments(parseSqlFunctionExpression(sqlInvokedScalarFunctionImplementation, sqlFunctionProperties), functionMetadata.getArgumentNames().get(), list);
    }

    public static RowExpression getSqlFunctionRowExpression(FunctionMetadata functionMetadata, SqlInvokedScalarFunctionImplementation sqlInvokedScalarFunctionImplementation, Metadata metadata, SqlFunctionProperties sqlFunctionProperties, List<RowExpression> list) {
        Expression coerceIfNecessary = coerceIfNecessary(functionMetadata, parseSqlFunctionExpression(sqlInvokedScalarFunctionImplementation, sqlFunctionProperties), sqlFunctionProperties, metadata);
        PlanVariableAllocator planVariableAllocator = new PlanVariableAllocator();
        Map<Identifier, VariableReferenceExpression> buildIdentifierToVariableMap = buildIdentifierToVariableMap(functionMetadata, coerceIfNecessary, sqlFunctionProperties, metadata, planVariableAllocator);
        Expression rewrite = LambdaCaptureDesugaringRewriter.rewrite(rewriteSqlFunctionExpressionWithVariables(coerceIfNecessary, buildIdentifierToVariableMap), planVariableAllocator);
        RowExpression translate = SqlToRowExpressionTranslator.translate(rewrite, ExpressionAnalyzer.analyzeSqlFunctionExpression(metadata, sqlFunctionProperties, rewrite, planVariableAllocator.getTypes().allTypes()).getExpressionTypes(), ImmutableMap.of(), metadata.getFunctionManager(), metadata.getTypeManager(), Optional.empty(), Optional.empty(), sqlFunctionProperties);
        Stream<R> map = functionMetadata.getArgumentNames().get().stream().map(Identifier::new);
        buildIdentifierToVariableMap.getClass();
        return SqlFunctionArgumentBinder.bindFunctionArguments(translate, (List<Optional<String>>) map.map((v1) -> {
            return r2.get(v1);
        }).map((v0) -> {
            return Optional.ofNullable(v0);
        }).map(optional -> {
            return optional.map((v0) -> {
                return v0.getName();
            });
        }).collect(ImmutableList.toImmutableList()), list);
    }

    private static Expression parseSqlFunctionExpression(SqlInvokedScalarFunctionImplementation sqlInvokedScalarFunctionImplementation, SqlFunctionProperties sqlFunctionProperties) {
        return new SqlParser().createRoutineBody(sqlInvokedScalarFunctionImplementation.getImplementation(), ParsingOptions.builder().setDecimalLiteralTreatment(sqlFunctionProperties.isParseDecimalLiteralAsDouble() ? ParsingOptions.DecimalLiteralTreatment.AS_DOUBLE : ParsingOptions.DecimalLiteralTreatment.AS_DECIMAL).build()).getExpression();
    }

    private static Map<String, Type> getFunctionArgumentTypes(FunctionMetadata functionMetadata, Metadata metadata) {
        List<String> list = functionMetadata.getArgumentNames().get();
        Stream<TypeSignature> stream = functionMetadata.getArgumentTypes().stream();
        metadata.getClass();
        List list2 = (List) stream.map(metadata::getType).collect(ImmutableList.toImmutableList());
        Preconditions.checkState(list.size() == list2.size(), String.format("Expect argumentNames (size %d) and argumentTypes (size %d) to be of the same size", Integer.valueOf(list.size()), Integer.valueOf(list2.size())));
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (int i = 0; i < list.size(); i++) {
            builder.put(list.get(i), list2.get(i));
        }
        return builder.build();
    }

    private static Map<Identifier, VariableReferenceExpression> buildIdentifierToVariableMap(FunctionMetadata functionMetadata, Expression expression, SqlFunctionProperties sqlFunctionProperties, Metadata metadata, PlanVariableAllocator planVariableAllocator) {
        Map<String, Type> functionArgumentTypes = getFunctionArgumentTypes(functionMetadata, metadata);
        Map<NodeRef<Expression>, Type> expressionTypes = ExpressionAnalyzer.analyzeSqlFunctionExpression(metadata, sqlFunctionProperties, expression, functionArgumentTypes).getExpressionTypes();
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (Map.Entry<NodeRef<Expression>, Type> entry : expressionTypes.entrySet()) {
            Expression node = entry.getKey().getNode();
            if (node instanceof LambdaArgumentDeclaration) {
                LambdaArgumentDeclaration lambdaArgumentDeclaration = (LambdaArgumentDeclaration) node;
                if (!linkedHashMap.containsKey(lambdaArgumentDeclaration.getName())) {
                    linkedHashMap.put(lambdaArgumentDeclaration.getName(), planVariableAllocator.newVariable(lambdaArgumentDeclaration.getName(), entry.getValue()));
                }
            } else if ((node instanceof Identifier) && functionArgumentTypes.containsKey(((Identifier) node).getValue()) && !linkedHashMap.containsKey(node)) {
                linkedHashMap.put((Identifier) node, planVariableAllocator.newVariable(node, entry.getValue()));
            }
        }
        return linkedHashMap;
    }

    private static Expression rewriteSqlFunctionExpressionWithVariables(Expression expression, Map<Identifier, VariableReferenceExpression> map) {
        return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Map<Identifier, VariableReferenceExpression>>() { // from class: com.facebook.presto.sql.relational.SqlFunctionUtils.1
            @Override // com.facebook.presto.sql.tree.ExpressionRewriter
            public Expression rewriteLambdaExpression(LambdaExpression lambdaExpression, Map<Identifier, VariableReferenceExpression> map2, ExpressionTreeRewriter<Map<Identifier, VariableReferenceExpression>> expressionTreeRewriter) {
                ImmutableList.Builder builder = ImmutableList.builder();
                Iterator<LambdaArgumentDeclaration> it2 = lambdaExpression.getArguments().iterator();
                while (it2.hasNext()) {
                    builder.add((ImmutableList.Builder) new LambdaArgumentDeclaration(new Identifier(map2.get(it2.next().getName()).getName())));
                }
                return new LambdaExpression(builder.build(), expressionTreeRewriter.rewrite((ExpressionTreeRewriter<Map<Identifier, VariableReferenceExpression>>) lambdaExpression.getBody(), (Expression) map2));
            }

            @Override // com.facebook.presto.sql.tree.ExpressionRewriter
            public Expression rewriteIdentifier(Identifier identifier, Map<Identifier, VariableReferenceExpression> map2, ExpressionTreeRewriter<Map<Identifier, VariableReferenceExpression>> expressionTreeRewriter) {
                return new SymbolReference(map2.get(identifier).getName());
            }
        }, expression, map);
    }

    private static Expression coerceIfNecessary(FunctionMetadata functionMetadata, Expression expression, SqlFunctionProperties sqlFunctionProperties, Metadata metadata) {
        final ExpressionAnalysis analyzeSqlFunctionExpression = ExpressionAnalyzer.analyzeSqlFunctionExpression(metadata, sqlFunctionProperties, expression, getFunctionArgumentTypes(functionMetadata, metadata));
        return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<ExpressionAnalysis>() { // from class: com.facebook.presto.sql.relational.SqlFunctionUtils.2
            @Override // com.facebook.presto.sql.tree.ExpressionRewriter
            public Expression rewriteExpression(Expression expression2, ExpressionAnalysis expressionAnalysis, ExpressionTreeRewriter<ExpressionAnalysis> expressionTreeRewriter) {
                Expression defaultRewrite = expressionTreeRewriter.defaultRewrite(expression2, null);
                Type coercion = ExpressionAnalysis.this.getCoercion(expression2);
                return coercion != null ? new Cast(defaultRewrite, coercion.getTypeSignature().toString(), false, ExpressionAnalysis.this.isTypeOnlyCoercion(expression2)) : defaultRewrite;
            }
        }, expression, analyzeSqlFunctionExpression);
    }
}
