package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.ValuesNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.ApplyNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.planner.plan.SpatialJoinNode;
import com.facebook.presto.sql.planner.plan.StatisticAggregations;
import com.facebook.presto.sql.planner.plan.TableFinishNode;
import com.facebook.presto.sql.planner.plan.TableWriterNode;
import com.facebook.presto.sql.planner.plan.WindowNode;
import com.facebook.presto.sql.relational.Expressions;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

/* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet.class */
public class RowExpressionRewriteRuleSet {
    protected final PlanRowExpressionRewriter rewriter;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet$AggregationRowExpressionRewrite.class */
    public final class AggregationRowExpressionRewrite implements Rule<AggregationNode> {
        private AggregationRowExpressionRewrite() {
        }

        @Override // com.facebook.presto.sql.planner.iterative.Rule
        public Pattern<AggregationNode> getPattern() {
            return Patterns.aggregation();
        }

        @Override // com.facebook.presto.sql.planner.iterative.Rule
        public Rule.Result apply(AggregationNode aggregationNode, Captures captures, Rule.Context context) {
            Preconditions.checkState(aggregationNode.getSource() != null);
            boolean z = false;
            ImmutableMap.Builder builder = ImmutableMap.builder();
            for (Map.Entry entry : aggregationNode.getAggregations().entrySet()) {
                AggregationNode.Aggregation rewriteAggregation = RowExpressionRewriteRuleSet.this.rewriteAggregation((AggregationNode.Aggregation) entry.getValue(), context);
                builder.put(entry.getKey(), rewriteAggregation);
                if (!rewriteAggregation.equals(entry.getValue())) {
                    z = true;
                }
            }
            return z ? Rule.Result.ofPlanNode(new AggregationNode(aggregationNode.getId(), aggregationNode.getSource(), builder.build(), aggregationNode.getGroupingSets(), aggregationNode.getPreGroupedVariables(), aggregationNode.getStep(), aggregationNode.getHashVariable(), aggregationNode.getGroupIdVariable())) : Rule.Result.empty();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet$ApplyRowExpressionRewrite.class */
    public final class ApplyRowExpressionRewrite implements Rule<ApplyNode> {
        private ApplyRowExpressionRewrite() {
        }

        @Override // com.facebook.presto.sql.planner.iterative.Rule
        public Pattern<ApplyNode> getPattern() {
            return Patterns.applyNode();
        }

        @Override // com.facebook.presto.sql.planner.iterative.Rule
        public Rule.Result apply(ApplyNode applyNode, Captures captures, Rule.Context context) {
            Optional translateAssignments = RowExpressionRewriteRuleSet.this.translateAssignments(applyNode.getSubqueryAssignments(), context);
            return !translateAssignments.isPresent() ? Rule.Result.empty() : Rule.Result.ofPlanNode(new ApplyNode(applyNode.getId(), applyNode.getInput(), applyNode.getSubquery(), (Assignments) translateAssignments.get(), applyNode.getCorrelation(), applyNode.getOriginSubqueryError()));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet$FilterRowExpressionRewrite.class */
    public final class FilterRowExpressionRewrite implements Rule<FilterNode> {
        private FilterRowExpressionRewrite() {
        }

        @Override // com.facebook.presto.sql.planner.iterative.Rule
        public Pattern<FilterNode> getPattern() {
            return Patterns.filter();
        }

        @Override // com.facebook.presto.sql.planner.iterative.Rule
        public Rule.Result apply(FilterNode filterNode, Captures captures, Rule.Context context) {
            Preconditions.checkState(filterNode.getSource() != null);
            RowExpression rewrite = RowExpressionRewriteRuleSet.this.rewriter.rewrite(filterNode.getPredicate(), context);
            return filterNode.getPredicate().equals(rewrite) ? Rule.Result.empty() : Rule.Result.ofPlanNode(new FilterNode(filterNode.getId(), filterNode.getSource(), rewrite));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet$JoinRowExpressionRewrite.class */
    public final class JoinRowExpressionRewrite implements Rule<JoinNode> {
        private JoinRowExpressionRewrite() {
        }

        @Override // com.facebook.presto.sql.planner.iterative.Rule
        public Pattern<JoinNode> getPattern() {
            return Patterns.join();
        }

        @Override // com.facebook.presto.sql.planner.iterative.Rule
        public Rule.Result apply(JoinNode joinNode, Captures captures, Rule.Context context) {
            if (!joinNode.getFilter().isPresent()) {
                return Rule.Result.empty();
            }
            RowExpression rowExpression = joinNode.getFilter().get();
            RowExpression rewrite = RowExpressionRewriteRuleSet.this.rewriter.rewrite(rowExpression, context);
            return rowExpression.equals(rewrite) ? Rule.Result.empty() : Rule.Result.ofPlanNode(new JoinNode(joinNode.getId(), joinNode.getType(), joinNode.getLeft(), joinNode.getRight(), joinNode.getCriteria(), joinNode.getOutputVariables(), Optional.of(rewrite), joinNode.getLeftHashVariable(), joinNode.getRightHashVariable(), joinNode.getDistributionType()));
        }
    }

    /* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet$PlanRowExpressionRewriter.class */
    public interface PlanRowExpressionRewriter {
        RowExpression rewrite(RowExpression rowExpression, Rule.Context context);
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet$ProjectRowExpressionRewrite.class */
    public final class ProjectRowExpressionRewrite implements Rule<ProjectNode> {
        private ProjectRowExpressionRewrite() {
        }

        @Override // com.facebook.presto.sql.planner.iterative.Rule
        public Pattern<ProjectNode> getPattern() {
            return Patterns.project();
        }

        @Override // com.facebook.presto.sql.planner.iterative.Rule
        public Rule.Result apply(ProjectNode projectNode, Captures captures, Rule.Context context) {
            Assignments.Builder builder = Assignments.builder();
            boolean z = false;
            for (Map.Entry entry : projectNode.getAssignments().getMap().entrySet()) {
                RowExpression rewrite = RowExpressionRewriteRuleSet.this.rewriter.rewrite((RowExpression) entry.getValue(), context);
                if (!rewrite.equals(entry.getValue())) {
                    z = true;
                }
                builder.put((VariableReferenceExpression) entry.getKey(), rewrite);
            }
            return z ? Rule.Result.ofPlanNode(new ProjectNode(projectNode.getId(), projectNode.getSource(), builder.build())) : Rule.Result.empty();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet$SpatialJoinRowExpressionRewrite.class */
    public final class SpatialJoinRowExpressionRewrite implements Rule<SpatialJoinNode> {
        private SpatialJoinRowExpressionRewrite() {
        }

        @Override // com.facebook.presto.sql.planner.iterative.Rule
        public Pattern<SpatialJoinNode> getPattern() {
            return Patterns.spatialJoin();
        }

        @Override // com.facebook.presto.sql.planner.iterative.Rule
        public Rule.Result apply(SpatialJoinNode spatialJoinNode, Captures captures, Rule.Context context) {
            RowExpression filter = spatialJoinNode.getFilter();
            RowExpression rewrite = RowExpressionRewriteRuleSet.this.rewriter.rewrite(filter, context);
            return filter.equals(rewrite) ? Rule.Result.empty() : Rule.Result.ofPlanNode(new SpatialJoinNode(spatialJoinNode.getId(), spatialJoinNode.getType(), spatialJoinNode.getLeft(), spatialJoinNode.getRight(), spatialJoinNode.getOutputVariables(), rewrite, spatialJoinNode.getLeftPartitionVariable(), spatialJoinNode.getRightPartitionVariable(), spatialJoinNode.getKdbTree()));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet$TableFinishRowExpressionRewrite.class */
    public final class TableFinishRowExpressionRewrite implements Rule<TableFinishNode> {
        private TableFinishRowExpressionRewrite() {
        }

        @Override // com.facebook.presto.sql.planner.iterative.Rule
        public Pattern<TableFinishNode> getPattern() {
            return Patterns.tableFinish();
        }

        @Override // com.facebook.presto.sql.planner.iterative.Rule
        public Rule.Result apply(TableFinishNode tableFinishNode, Captures captures, Rule.Context context) {
            Preconditions.checkState(tableFinishNode.getSource() != null);
            if (!tableFinishNode.getStatisticsAggregation().isPresent()) {
                return Rule.Result.empty();
            }
            Optional translateStatisticAggregation = RowExpressionRewriteRuleSet.this.translateStatisticAggregation(tableFinishNode.getStatisticsAggregation().get(), context);
            return translateStatisticAggregation.isPresent() ? Rule.Result.ofPlanNode(new TableFinishNode(tableFinishNode.getId(), tableFinishNode.getSource(), tableFinishNode.getTarget(), tableFinishNode.getRowCountVariable(), translateStatisticAggregation, tableFinishNode.getStatisticsAggregationDescriptor())) : Rule.Result.empty();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet$TableWriterRowExpressionRewrite.class */
    public final class TableWriterRowExpressionRewrite implements Rule<TableWriterNode> {
        private TableWriterRowExpressionRewrite() {
        }

        @Override // com.facebook.presto.sql.planner.iterative.Rule
        public Pattern<TableWriterNode> getPattern() {
            return Patterns.tableWriterNode();
        }

        @Override // com.facebook.presto.sql.planner.iterative.Rule
        public Rule.Result apply(TableWriterNode tableWriterNode, Captures captures, Rule.Context context) {
            Preconditions.checkState(tableWriterNode.getSource() != null);
            if (!tableWriterNode.getStatisticsAggregation().isPresent()) {
                return Rule.Result.empty();
            }
            Optional translateStatisticAggregation = RowExpressionRewriteRuleSet.this.translateStatisticAggregation(tableWriterNode.getStatisticsAggregation().get(), context);
            return translateStatisticAggregation.isPresent() ? Rule.Result.ofPlanNode(new TableWriterNode(tableWriterNode.getId(), tableWriterNode.getSource(), tableWriterNode.getTarget(), tableWriterNode.getRowCountVariable(), tableWriterNode.getFragmentVariable(), tableWriterNode.getTableCommitContextVariable(), tableWriterNode.getColumns(), tableWriterNode.getColumnNames(), tableWriterNode.getTablePartitioningScheme(), tableWriterNode.getPreferredShufflePartitioningScheme(), translateStatisticAggregation)) : Rule.Result.empty();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet$ValuesRowExpressionRewrite.class */
    public final class ValuesRowExpressionRewrite implements Rule<ValuesNode> {
        private ValuesRowExpressionRewrite() {
        }

        @Override // com.facebook.presto.sql.planner.iterative.Rule
        public Pattern<ValuesNode> getPattern() {
            return Patterns.values();
        }

        @Override // com.facebook.presto.sql.planner.iterative.Rule
        public Rule.Result apply(ValuesNode valuesNode, Captures captures, Rule.Context context) {
            boolean z = false;
            ImmutableList.Builder builder = ImmutableList.builder();
            for (List<RowExpression> list : valuesNode.getRows()) {
                ImmutableList.Builder builder2 = ImmutableList.builder();
                for (RowExpression rowExpression : list) {
                    RowExpression rewrite = RowExpressionRewriteRuleSet.this.rewriter.rewrite(rowExpression, context);
                    if (!rewrite.equals(rowExpression)) {
                        z = true;
                    }
                    builder2.add(rewrite);
                }
                builder.add(builder2.build());
            }
            return z ? Rule.Result.ofPlanNode(new ValuesNode(valuesNode.getId(), valuesNode.getOutputVariables(), builder.build())) : Rule.Result.empty();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet$WindowRowExpressionRewrite.class */
    public final class WindowRowExpressionRewrite implements Rule<WindowNode> {
        private WindowRowExpressionRewrite() {
        }

        @Override // com.facebook.presto.sql.planner.iterative.Rule
        public Pattern<WindowNode> getPattern() {
            return Patterns.window();
        }

        @Override // com.facebook.presto.sql.planner.iterative.Rule
        public Rule.Result apply(WindowNode windowNode, Captures captures, Rule.Context context) {
            Preconditions.checkState(windowNode.getSource() != null);
            boolean z = false;
            ImmutableMap.Builder builder = ImmutableMap.builder();
            for (Map.Entry<VariableReferenceExpression, WindowNode.Function> entry : windowNode.getWindowFunctions().entrySet()) {
                ImmutableList.Builder builder2 = ImmutableList.builder();
                CallExpression functionCall = entry.getValue().getFunctionCall();
                for (RowExpression rowExpression : functionCall.getArguments()) {
                    RowExpression rewrite = RowExpressionRewriteRuleSet.this.rewriter.rewrite(rowExpression, context);
                    if (rewrite != rowExpression) {
                        z = true;
                    }
                    builder2.add(rewrite);
                }
                builder.put(entry.getKey(), new WindowNode.Function(Expressions.call(functionCall.getDisplayName(), functionCall.getFunctionHandle(), functionCall.getType(), (List<RowExpression>) builder2.build()), entry.getValue().getFrame(), entry.getValue().isIgnoreNulls()));
            }
            return z ? Rule.Result.ofPlanNode(new WindowNode(windowNode.getId(), windowNode.getSource(), windowNode.getSpecification(), builder.build(), windowNode.getHashVariable(), windowNode.getPrePartitionedInputs(), windowNode.getPreSortedOrderPrefix())) : Rule.Result.empty();
        }
    }

    public RowExpressionRewriteRuleSet(PlanRowExpressionRewriter planRowExpressionRewriter) {
        this.rewriter = (PlanRowExpressionRewriter) Objects.requireNonNull(planRowExpressionRewriter, "rewriter is null");
    }

    public Set<Rule<?>> rules() {
        return ImmutableSet.of(valueRowExpressionRewriteRule(), filterRowExpressionRewriteRule(), projectRowExpressionRewriteRule(), applyNodeRowExpressionRewriteRule(), windowRowExpressionRewriteRule(), joinRowExpressionRewriteRule(), new Rule[]{spatialJoinRowExpressionRewriteRule(), aggregationRowExpressionRewriteRule(), tableFinishRowExpressionRewriteRule(), tableWriterRowExpressionRewriteRule()});
    }

    public Rule<ValuesNode> valueRowExpressionRewriteRule() {
        return new ValuesRowExpressionRewrite();
    }

    public Rule<FilterNode> filterRowExpressionRewriteRule() {
        return new FilterRowExpressionRewrite();
    }

    public Rule<ProjectNode> projectRowExpressionRewriteRule() {
        return new ProjectRowExpressionRewrite();
    }

    public Rule<ApplyNode> applyNodeRowExpressionRewriteRule() {
        return new ApplyRowExpressionRewrite();
    }

    public Rule<WindowNode> windowRowExpressionRewriteRule() {
        return new WindowRowExpressionRewrite();
    }

    public Rule<JoinNode> joinRowExpressionRewriteRule() {
        return new JoinRowExpressionRewrite();
    }

    public Rule<SpatialJoinNode> spatialJoinRowExpressionRewriteRule() {
        return new SpatialJoinRowExpressionRewrite();
    }

    public Rule<TableFinishNode> tableFinishRowExpressionRewriteRule() {
        return new TableFinishRowExpressionRewrite();
    }

    public Rule<TableWriterNode> tableWriterRowExpressionRewriteRule() {
        return new TableWriterRowExpressionRewrite();
    }

    public Rule<AggregationNode> aggregationRowExpressionRewriteRule() {
        return new AggregationRowExpressionRewrite();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Optional<Assignments> translateAssignments(Assignments assignments, Rule.Context context) {
        Assignments.Builder builder = Assignments.builder();
        assignments.getMap().entrySet().stream().forEach(entry -> {
            builder.put((VariableReferenceExpression) entry.getKey(), this.rewriter.rewrite((RowExpression) entry.getValue(), context));
        });
        Assignments build = builder.build();
        return build.equals(assignments) ? Optional.empty() : Optional.of(build);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Optional<StatisticAggregations> translateStatisticAggregation(StatisticAggregations statisticAggregations, Rule.Context context) {
        ImmutableMap.Builder builder = ImmutableMap.builder();
        boolean z = false;
        for (Map.Entry<VariableReferenceExpression, AggregationNode.Aggregation> entry : statisticAggregations.getAggregations().entrySet()) {
            AggregationNode.Aggregation rewriteAggregation = rewriteAggregation(entry.getValue(), context);
            builder.put(entry.getKey(), rewriteAggregation);
            if (!rewriteAggregation.equals(entry.getValue())) {
                z = true;
            }
        }
        return z ? Optional.of(new StatisticAggregations(builder.build(), statisticAggregations.getGroupingVariables())) : Optional.empty();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public AggregationNode.Aggregation rewriteAggregation(AggregationNode.Aggregation aggregation, Rule.Context context) {
        CallExpression rewrite = this.rewriter.rewrite(aggregation.getCall(), context);
        Preconditions.checkArgument(rewrite instanceof CallExpression, "Aggregation CallExpression must be rewritten to CallExpression");
        return new AggregationNode.Aggregation(rewrite, aggregation.getFilter().map(rowExpression -> {
            return this.rewriter.rewrite(rowExpression, context);
        }), aggregation.getOrderBy(), aggregation.isDistinct(), aggregation.getMask());
    }
}
