package com.facebook.presto.cost;

import com.facebook.presto.Session;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.planner.plan.SpatialJoinNode;
import java.util.Objects;
import java.util.Optional;

/* loaded from: input_file:com/facebook/presto/cost/SpatialJoinStatsRule.class */
public class SpatialJoinStatsRule extends SimpleStatsRule<SpatialJoinNode> {
    private static final Pattern<SpatialJoinNode> PATTERN = Patterns.spatialJoin();
    private final FilterStatsCalculator statsCalculator;

    public SpatialJoinStatsRule(FilterStatsCalculator filterStatsCalculator, StatsNormalizer statsNormalizer) {
        super(statsNormalizer);
        this.statsCalculator = (FilterStatsCalculator) Objects.requireNonNull(filterStatsCalculator, "statsCalculator is null");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.facebook.presto.cost.SimpleStatsRule
    public Optional<PlanNodeStatsEstimate> doCalculate(SpatialJoinNode spatialJoinNode, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider typeProvider) {
        PlanNodeStatsEstimate crossJoinStats = crossJoinStats(spatialJoinNode, statsProvider.getStats(spatialJoinNode.getLeft()), statsProvider.getStats(spatialJoinNode.getRight()));
        switch (spatialJoinNode.getType()) {
            case INNER:
                return Optional.of(this.statsCalculator.filterStats(crossJoinStats, spatialJoinNode.getFilter(), session, typeProvider));
            case LEFT:
                return Optional.of(PlanNodeStatsEstimate.unknown());
            default:
                throw new IllegalArgumentException("Unknown spatial join type: " + spatialJoinNode.getType());
        }
    }

    private PlanNodeStatsEstimate crossJoinStats(SpatialJoinNode spatialJoinNode, PlanNodeStatsEstimate planNodeStatsEstimate, PlanNodeStatsEstimate planNodeStatsEstimate2) {
        PlanNodeStatsEstimate.Builder outputRowCount = PlanNodeStatsEstimate.builder().setOutputRowCount(planNodeStatsEstimate.getOutputRowCount() * planNodeStatsEstimate2.getOutputRowCount());
        spatialJoinNode.getLeft().getOutputSymbols().forEach(symbol -> {
            outputRowCount.addSymbolStatistics(symbol, planNodeStatsEstimate.getSymbolStatistics(symbol));
        });
        spatialJoinNode.getRight().getOutputSymbols().forEach(symbol2 -> {
            outputRowCount.addSymbolStatistics(symbol2, planNodeStatsEstimate2.getSymbolStatistics(symbol2));
        });
        return outputRowCount.build();
    }

    @Override // com.facebook.presto.cost.ComposableStatsCalculator.Rule
    public Pattern<SpatialJoinNode> getPattern() {
        return PATTERN;
    }
}
