package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.sql.planner.VariablesExtractor;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.relational.OriginalExpressionUtils;
import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableSet;

/* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/PruneJoinChildrenColumns.class */
public class PruneJoinChildrenColumns implements Rule<JoinNode> {
    private static final Pattern<JoinNode> PATTERN = Patterns.join().matching(Predicates.not((v0) -> {
        return v0.isCrossJoin();
    }));

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public Pattern<JoinNode> getPattern() {
        return PATTERN;
    }

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public Rule.Result apply(JoinNode joinNode, Captures captures, Rule.Context context) {
        ImmutableSet build = ImmutableSet.builder().addAll((Iterable) joinNode.getOutputVariables()).addAll((Iterable) joinNode.getFilter().map(OriginalExpressionUtils::castToExpression).map(expression -> {
            return VariablesExtractor.extractUnique(expression, context.getVariableAllocator().getTypes());
        }).orElse(ImmutableSet.of())).build();
        return (Rule.Result) Util.restrictChildOutputs(context.getIdAllocator(), joinNode, ImmutableSet.builder().addAll((Iterable) build).addAll(joinNode.getCriteria().stream().map((v0) -> {
            return v0.getLeft();
        }).iterator()).addAll((Iterable) joinNode.getLeftHashVariable().map((v0) -> {
            return ImmutableSet.of(v0);
        }).orElse(ImmutableSet.of())).build(), ImmutableSet.builder().addAll((Iterable) build).addAll(joinNode.getCriteria().stream().map((v0) -> {
            return v0.getRight();
        }).iterator()).addAll((Iterable) joinNode.getRightHashVariable().map((v0) -> {
            return ImmutableSet.of(v0);
        }).orElse(ImmutableSet.of())).build()).map(Rule.Result::ofPlanNode).orElse(Rule.Result.empty());
    }
}
