package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.expressions.RowExpressionRewriter;
import com.facebook.presto.expressions.RowExpressionTreeRewriter;
import com.facebook.presto.spi.block.SortOrder;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Ordering;
import com.facebook.presto.spi.plan.OrderingScheme;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.TopNNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.PartitioningScheme;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.plan.StatisticAggregations;
import com.facebook.presto.sql.planner.plan.StatisticAggregationsDescriptor;
import com.facebook.presto.sql.planner.plan.StatisticsWriterNode;
import com.facebook.presto.sql.planner.plan.TableFinishNode;
import com.facebook.presto.sql.planner.plan.TableWriterMergeNode;
import com.facebook.presto.sql.planner.plan.TableWriterNode;
import com.facebook.presto.sql.relational.OriginalExpressionUtils;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.ExpressionRewriter;
import com.facebook.presto.sql.tree.ExpressionTreeRewriter;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;

/* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/SymbolMapper.class */
public class SymbolMapper {
    private final Map<String, String> mapping;
    private final TypeProvider types;

    /* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/SymbolMapper$Builder.class */
    public static class Builder {
        private final ImmutableMap.Builder<VariableReferenceExpression, VariableReferenceExpression> mappingsBuilder = ImmutableMap.builder();

        public SymbolMapper build() {
            return new SymbolMapper(this.mappingsBuilder.build());
        }

        public void put(VariableReferenceExpression variableReferenceExpression, VariableReferenceExpression variableReferenceExpression2) {
            this.mappingsBuilder.put(variableReferenceExpression, variableReferenceExpression2);
        }
    }

    public SymbolMapper(Map<VariableReferenceExpression, VariableReferenceExpression> map) {
        Objects.requireNonNull(map, "mapping is null");
        this.mapping = (Map) map.entrySet().stream().collect(ImmutableMap.toImmutableMap(entry -> {
            return ((VariableReferenceExpression) entry.getKey()).getName();
        }, entry2 -> {
            return ((VariableReferenceExpression) entry2.getValue()).getName();
        }));
        ImmutableSet.Builder builder = ImmutableSet.builder();
        map.entrySet().forEach(entry3 -> {
            builder.add((ImmutableSet.Builder) entry3.getKey());
            builder.add((ImmutableSet.Builder) entry3.getValue());
        });
        this.types = TypeProvider.fromVariables(builder.build());
    }

    public SymbolMapper(Map<String, String> map, TypeProvider typeProvider) {
        Objects.requireNonNull(map, "mapping is null");
        this.mapping = ImmutableMap.copyOf((Map) map);
        this.types = (TypeProvider) Objects.requireNonNull(typeProvider, "types is null");
    }

    public Symbol map(Symbol symbol) {
        String str;
        String name = symbol.getName();
        while (true) {
            str = name;
            if (!this.mapping.containsKey(str) || this.mapping.get(str).equals(str)) {
                break;
            }
            name = this.mapping.get(str);
        }
        return new Symbol(str);
    }

    public VariableReferenceExpression map(VariableReferenceExpression variableReferenceExpression) {
        String str;
        String name = variableReferenceExpression.getName();
        while (true) {
            str = name;
            if (!this.mapping.containsKey(str) || this.mapping.get(str).equals(str)) {
                break;
            }
            name = this.mapping.get(str);
        }
        return str.equals(variableReferenceExpression.getName()) ? variableReferenceExpression : new VariableReferenceExpression(str, this.types.get(new SymbolReference(str)));
    }

    public Expression map(Expression expression) {
        return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Void>() { // from class: com.facebook.presto.sql.planner.optimizations.SymbolMapper.1
            @Override // com.facebook.presto.sql.tree.ExpressionRewriter
            public Expression rewriteSymbolReference(SymbolReference symbolReference, Void r5, ExpressionTreeRewriter<Void> expressionTreeRewriter) {
                return SymbolMapper.this.map(Symbol.from(symbolReference)).toSymbolReference();
            }
        }, expression);
    }

    public RowExpression map(RowExpression rowExpression) {
        return OriginalExpressionUtils.isExpression(rowExpression) ? OriginalExpressionUtils.castToRowExpression(map(OriginalExpressionUtils.castToExpression(rowExpression))) : RowExpressionTreeRewriter.rewriteWith(new RowExpressionRewriter<Void>() { // from class: com.facebook.presto.sql.planner.optimizations.SymbolMapper.2
            @Override // com.facebook.presto.expressions.RowExpressionRewriter
            public RowExpression rewriteVariableReference(VariableReferenceExpression variableReferenceExpression, Void r5, RowExpressionTreeRewriter<Void> rowExpressionTreeRewriter) {
                return SymbolMapper.this.map(variableReferenceExpression);
            }
        }, rowExpression);
    }

    public OrderingScheme map(OrderingScheme orderingScheme) {
        ImmutableList.Builder builder = ImmutableList.builder();
        ImmutableMap.Builder builder2 = ImmutableMap.builder();
        for (VariableReferenceExpression variableReferenceExpression : orderingScheme.getOrderByVariables()) {
            VariableReferenceExpression map = map(variableReferenceExpression);
            builder.add((ImmutableList.Builder) map);
            builder2.put(map, orderingScheme.getOrdering(variableReferenceExpression));
        }
        ImmutableMap build = builder2.build();
        return new OrderingScheme((List) builder.build().stream().map(variableReferenceExpression2 -> {
            return new Ordering(variableReferenceExpression2, (SortOrder) build.get(variableReferenceExpression2));
        }).collect(ImmutableList.toImmutableList()));
    }

    public AggregationNode map(AggregationNode aggregationNode, PlanNode planNode) {
        return map(aggregationNode, planNode, aggregationNode.getId());
    }

    public AggregationNode map(AggregationNode aggregationNode, PlanNode planNode, PlanNodeIdAllocator planNodeIdAllocator) {
        return map(aggregationNode, planNode, planNodeIdAllocator.getNextId());
    }

    private AggregationNode map(AggregationNode aggregationNode, PlanNode planNode, PlanNodeId planNodeId) {
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (Map.Entry<VariableReferenceExpression, AggregationNode.Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
            builder.put(map(entry.getKey()), map(entry.getValue()));
        }
        return new AggregationNode(planNodeId, planNode, builder.build(), AggregationNode.groupingSets(mapAndDistinctVariable(aggregationNode.getGroupingKeys()), aggregationNode.getGroupingSetCount(), aggregationNode.getGlobalGroupingSets()), ImmutableList.of(), aggregationNode.getStep(), aggregationNode.getHashVariable().map(this::map), aggregationNode.getGroupIdVariable().map(this::map));
    }

    private AggregationNode.Aggregation map(AggregationNode.Aggregation aggregation) {
        return new AggregationNode.Aggregation(new CallExpression(aggregation.getCall().getDisplayName(), aggregation.getCall().getFunctionHandle(), aggregation.getCall().getType(), (List) aggregation.getArguments().stream().map(this::map).collect(ImmutableList.toImmutableList())), aggregation.getFilter().map(this::map), aggregation.getOrderBy().map(this::map), aggregation.isDistinct(), aggregation.getMask().map(this::map));
    }

    public TopNNode map(TopNNode topNNode, PlanNode planNode, PlanNodeId planNodeId) {
        ImmutableList.Builder builder = ImmutableList.builder();
        ImmutableMap.Builder builder2 = ImmutableMap.builder();
        HashSet hashSet = new HashSet(topNNode.getOrderingScheme().getOrderByVariables().size());
        for (VariableReferenceExpression variableReferenceExpression : topNNode.getOrderingScheme().getOrderByVariables()) {
            VariableReferenceExpression map = map(variableReferenceExpression);
            if (hashSet.add(map)) {
                hashSet.add(map);
                builder.add((ImmutableList.Builder) map);
                builder2.put(map, topNNode.getOrderingScheme().getOrdering(variableReferenceExpression));
            }
        }
        ImmutableMap build = builder2.build();
        return new TopNNode(planNodeId, planNode, topNNode.getCount(), new OrderingScheme((List) builder.build().stream().map(variableReferenceExpression2 -> {
            return new Ordering(variableReferenceExpression2, (SortOrder) build.get(variableReferenceExpression2));
        }).collect(ImmutableList.toImmutableList())), topNNode.getStep());
    }

    public TableWriterNode map(TableWriterNode tableWriterNode, PlanNode planNode) {
        return map(tableWriterNode, planNode, tableWriterNode.getId());
    }

    public TableWriterNode map(TableWriterNode tableWriterNode, PlanNode planNode, PlanNodeId planNodeId) {
        return new TableWriterNode(planNodeId, planNode, tableWriterNode.getTarget(), map(tableWriterNode.getRowCountVariable()), map(tableWriterNode.getFragmentVariable()), map(tableWriterNode.getTableCommitContextVariable()), (ImmutableList) tableWriterNode.getColumns().stream().map(this::map).collect(ImmutableList.toImmutableList()), tableWriterNode.getColumnNames(), tableWriterNode.getPartitioningScheme().map(partitioningScheme -> {
            return canonicalize(partitioningScheme, planNode);
        }), tableWriterNode.getStatisticsAggregation().map(this::map));
    }

    public StatisticsWriterNode map(StatisticsWriterNode statisticsWriterNode, PlanNode planNode) {
        return new StatisticsWriterNode(statisticsWriterNode.getId(), planNode, statisticsWriterNode.getTarget(), statisticsWriterNode.getRowCountVariable(), statisticsWriterNode.isRowCountEnabled(), statisticsWriterNode.getDescriptor().map(this::map));
    }

    public TableFinishNode map(TableFinishNode tableFinishNode, PlanNode planNode) {
        return new TableFinishNode(tableFinishNode.getId(), planNode, tableFinishNode.getTarget(), map(tableFinishNode.getRowCountVariable()), tableFinishNode.getStatisticsAggregation().map(this::map), tableFinishNode.getStatisticsAggregationDescriptor().map(statisticAggregationsDescriptor -> {
            return statisticAggregationsDescriptor.map(this::map);
        }));
    }

    public TableWriterMergeNode map(TableWriterMergeNode tableWriterMergeNode, PlanNode planNode) {
        return new TableWriterMergeNode(tableWriterMergeNode.getId(), planNode, map(tableWriterMergeNode.getRowCountVariable()), map(tableWriterMergeNode.getFragmentVariable()), map(tableWriterMergeNode.getTableCommitContextVariable()), tableWriterMergeNode.getStatisticsAggregation().map(this::map));
    }

    private PartitioningScheme canonicalize(PartitioningScheme partitioningScheme, PlanNode planNode) {
        return new PartitioningScheme(partitioningScheme.getPartitioning().translateVariable(this::map), mapAndDistinctVariable(planNode.getOutputVariables()), partitioningScheme.getHashColumn().map(this::map), partitioningScheme.isReplicateNullsAndAny(), partitioningScheme.getBucketToPartition());
    }

    private StatisticAggregations map(StatisticAggregations statisticAggregations) {
        return new StatisticAggregations((Map) statisticAggregations.getAggregations().entrySet().stream().collect(ImmutableMap.toImmutableMap(entry -> {
            return map((VariableReferenceExpression) entry.getKey());
        }, entry2 -> {
            return map((AggregationNode.Aggregation) entry2.getValue());
        })), mapAndDistinctVariable(statisticAggregations.getGroupingVariables()));
    }

    private StatisticAggregationsDescriptor<VariableReferenceExpression> map(StatisticAggregationsDescriptor<VariableReferenceExpression> statisticAggregationsDescriptor) {
        return statisticAggregationsDescriptor.map(this::map);
    }

    private List<Symbol> mapAndDistinctSymbol(List<Symbol> list) {
        HashSet hashSet = new HashSet();
        ImmutableList.Builder builder = ImmutableList.builder();
        Iterator<Symbol> it2 = list.iterator();
        while (it2.hasNext()) {
            Symbol map = map(it2.next());
            if (hashSet.add(map)) {
                builder.add((ImmutableList.Builder) map);
            }
        }
        return builder.build();
    }

    private List<VariableReferenceExpression> mapAndDistinctVariable(List<VariableReferenceExpression> list) {
        HashSet hashSet = new HashSet();
        ImmutableList.Builder builder = ImmutableList.builder();
        Iterator<VariableReferenceExpression> it2 = list.iterator();
        while (it2.hasNext()) {
            VariableReferenceExpression map = map(it2.next());
            if (hashSet.add(map)) {
                builder.add((ImmutableList.Builder) map);
            }
        }
        return builder.build();
    }

    public static Builder builder() {
        return new Builder();
    }
}
