package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.BuiltInFunctionNamespaceManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.FullyQualifiedName;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.spi.type.IntegerType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.TypeSignature;
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.AssignmentUtils;
import com.facebook.presto.sql.planner.plan.Assignments;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.relational.OriginalExpressionUtils;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.QualifiedName;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

/* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.class */
public class RewriteSpatialPartitioningAggregation implements Rule<AggregationNode> {
    private static final TypeSignature GEOMETRY_TYPE_SIGNATURE = TypeSignature.parseTypeSignature("Geometry");
    private static final FullyQualifiedName NAME = FullyQualifiedName.of(BuiltInFunctionNamespaceManager.DEFAULT_NAMESPACE, "spatial_partitioning");
    private final Pattern<AggregationNode> pattern = Patterns.aggregation().matching(this::hasSpatialPartitioningAggregation);
    private final Metadata metadata;

    public RewriteSpatialPartitioningAggregation(Metadata metadata) {
        this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
    }

    private boolean hasSpatialPartitioningAggregation(AggregationNode aggregationNode) {
        return aggregationNode.getAggregations().values().stream().anyMatch(aggregation -> {
            return this.metadata.getFunctionManager().getFunctionMetadata(aggregation.getFunctionHandle()).getName().equals(NAME) && aggregation.getArguments().size() == 1;
        });
    }

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public Pattern<AggregationNode> getPattern() {
        return this.pattern;
    }

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public Rule.Result apply(AggregationNode aggregationNode, Captures captures, Rule.Context context) {
        ImmutableMap.Builder builder = ImmutableMap.builder();
        VariableReferenceExpression newVariable = context.getVariableAllocator().newVariable("partition_count", IntegerType.INTEGER);
        ImmutableMap.Builder builder2 = ImmutableMap.builder();
        for (Map.Entry<VariableReferenceExpression, AggregationNode.Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
            AggregationNode.Aggregation value = entry.getValue();
            FullyQualifiedName name = this.metadata.getFunctionManager().getFunctionMetadata(value.getFunctionHandle()).getName();
            Type type = this.metadata.getType(GEOMETRY_TYPE_SIGNATURE);
            if (name.equals(NAME) && value.getArguments().size() == 1) {
                RowExpression rowExpression = (RowExpression) Iterables.getOnlyElement(value.getArguments());
                VariableReferenceExpression newVariable2 = context.getVariableAllocator().newVariable("envelope", type);
                if (isFunctionNameMatch(rowExpression, "ST_Envelope")) {
                    builder2.put(newVariable2, rowExpression);
                } else {
                    builder2.put(newVariable2, OriginalExpressionUtils.castToRowExpression(new FunctionCall(QualifiedName.of("ST_Envelope"), ImmutableList.of(OriginalExpressionUtils.castToExpression(rowExpression)))));
                }
                builder.put(entry.getKey(), new AggregationNode.Aggregation(new CallExpression(name.getSuffix(), this.metadata.getFunctionManager().lookupFunction(NAME.getSuffix(), TypeSignatureProvider.fromTypes(type, IntegerType.INTEGER)), entry.getKey().getType(), ImmutableList.of(OriginalExpressionUtils.castToRowExpression(OriginalExpressionUtils.asSymbolReference(newVariable2)), OriginalExpressionUtils.castToRowExpression(OriginalExpressionUtils.asSymbolReference(newVariable)))), Optional.empty(), Optional.empty(), false, value.getMask()));
            } else {
                builder.put(entry);
            }
        }
        return Rule.Result.ofPlanNode(new AggregationNode(aggregationNode.getId(), new ProjectNode(context.getIdAllocator().getNextId(), aggregationNode.getSource(), Assignments.builder().putAll(AssignmentUtils.identitiesAsSymbolReferences(aggregationNode.getSource().getOutputVariables())).put(newVariable, OriginalExpressionUtils.castToRowExpression(new LongLiteral(Integer.toString(SystemSessionProperties.getHashPartitionCount(context.getSession()))))).putAll(builder2.build()).build()), builder.build(), aggregationNode.getGroupingSets(), aggregationNode.getPreGroupedVariables(), aggregationNode.getStep(), aggregationNode.getHashVariable(), aggregationNode.getGroupIdVariable()));
    }

    private static boolean isFunctionNameMatch(RowExpression rowExpression, String str) {
        if (OriginalExpressionUtils.castToExpression(rowExpression) instanceof FunctionCall) {
            return ((FunctionCall) OriginalExpressionUtils.castToExpression(rowExpression)).getName().toString().equalsIgnoreCase(str);
        }
        return false;
    }
}
