package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.sql.planner.ExpressionSymbolInliner;
import com.facebook.presto.sql.planner.PartitioningScheme;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.Assignments;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Stream;

/* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.class */
public class PushProjectionThroughExchange implements Rule<ProjectNode> {
    private static final Capture<ExchangeNode> CHILD = Capture.newCapture();
    private static final Pattern<ProjectNode> PATTERN = Patterns.project().matching(projectNode -> {
        return !isSymbolToSymbolProjection(projectNode);
    }).with(Patterns.source().matching(Patterns.exchange().capturedAs(CHILD)));

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public Pattern<ProjectNode> getPattern() {
        return PATTERN;
    }

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public Rule.Result apply(ProjectNode projectNode, Captures captures, Rule.Context context) {
        ExchangeNode exchangeNode = (ExchangeNode) captures.get(CHILD);
        ImmutableList.Builder builder = ImmutableList.builder();
        ImmutableList.Builder builder2 = ImmutableList.builder();
        for (int i = 0; i < exchangeNode.getSources().size(); i++) {
            Map<Symbol, SymbolReference> extractExchangeOutputToInput = extractExchangeOutputToInput(exchangeNode, i);
            Assignments.Builder builder3 = Assignments.builder();
            ImmutableList.Builder builder4 = ImmutableList.builder();
            Stream<Symbol> stream = exchangeNode.getPartitioningScheme().getPartitioning().getColumns().stream();
            extractExchangeOutputToInput.getClass();
            stream.map((v1) -> {
                return r1.get(v1);
            }).forEach(symbolReference -> {
                Symbol from = Symbol.from(symbolReference);
                builder3.put(from, symbolReference);
                builder4.add((ImmutableList.Builder) from);
            });
            if (exchangeNode.getPartitioningScheme().getHashColumn().isPresent()) {
                builder3.put(exchangeNode.getPartitioningScheme().getHashColumn().get(), exchangeNode.getPartitioningScheme().getHashColumn().get().toSymbolReference());
                builder4.add((ImmutableList.Builder) exchangeNode.getPartitioningScheme().getHashColumn().get());
            }
            for (Map.Entry<Symbol, Expression> entry : projectNode.getAssignments().entrySet()) {
                Expression translateExpression = translateExpression(entry.getValue(), extractExchangeOutputToInput);
                Symbol newSymbol = context.getSymbolAllocator().newSymbol(translateExpression, context.getSymbolAllocator().getTypes().get(entry.getKey()));
                builder3.put(newSymbol, translateExpression);
                builder4.add((ImmutableList.Builder) newSymbol);
            }
            builder.add((ImmutableList.Builder) new ProjectNode(context.getIdAllocator().getNextId(), exchangeNode.getSources().get(i), builder3.build()));
            builder2.add((ImmutableList.Builder) builder4.build());
        }
        ImmutableList.Builder builder5 = ImmutableList.builder();
        Set<Symbol> columns = exchangeNode.getPartitioningScheme().getPartitioning().getColumns();
        builder5.getClass();
        columns.forEach((v1) -> {
            r1.add(v1);
        });
        if (exchangeNode.getPartitioningScheme().getHashColumn().isPresent()) {
            builder5.add((ImmutableList.Builder) exchangeNode.getPartitioningScheme().getHashColumn().get());
        }
        Iterator<Map.Entry<Symbol, Expression>> it2 = projectNode.getAssignments().entrySet().iterator();
        while (it2.hasNext()) {
            builder5.add((ImmutableList.Builder) it2.next().getKey());
        }
        ExchangeNode exchangeNode2 = new ExchangeNode(exchangeNode.getId(), exchangeNode.getType(), exchangeNode.getScope(), new PartitioningScheme(exchangeNode.getPartitioningScheme().getPartitioning(), builder5.build(), exchangeNode.getPartitioningScheme().getHashColumn(), exchangeNode.getPartitioningScheme().isReplicateNullsAndAny(), exchangeNode.getPartitioningScheme().getBucketToPartition()), builder.build(), builder2.build());
        return Rule.Result.ofPlanNode(Util.restrictOutputs(context.getIdAllocator(), exchangeNode2, ImmutableSet.copyOf((Collection) projectNode.getOutputSymbols())).orElse(exchangeNode2));
    }

    private static boolean isSymbolToSymbolProjection(ProjectNode projectNode) {
        return projectNode.getAssignments().getExpressions().stream().allMatch(expression -> {
            return expression instanceof SymbolReference;
        });
    }

    private static Map<Symbol, SymbolReference> extractExchangeOutputToInput(ExchangeNode exchangeNode, int i) {
        HashMap hashMap = new HashMap();
        for (int i2 = 0; i2 < exchangeNode.getOutputSymbols().size(); i2++) {
            hashMap.put(exchangeNode.getOutputSymbols().get(i2), exchangeNode.getInputs().get(i).get(i2).toSymbolReference());
        }
        return hashMap;
    }

    private static Expression translateExpression(Expression expression, Map<Symbol, SymbolReference> map) {
        map.getClass();
        return new ExpressionSymbolInliner((Function<Symbol, Expression>) (v1) -> {
            return r2.get(v1);
        }).rewrite(expression);
    }
}
