package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.type.DoubleType;
import com.facebook.presto.spi.type.TypeSignature;
import com.facebook.presto.sql.planner.ExpressionNodeInliner;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolsExtractor;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.Assignments;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.PlanNodeId;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.SymbolReference;
import com.facebook.presto.util.SpatialJoinUtils;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;

/* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/TransformSpatialPredicates.class */
public class TransformSpatialPredicates {
    private static final TypeSignature GEOMETRY_TYPE_SIGNATURE = TypeSignature.parseTypeSignature("Geometry");
    private final Metadata metadata;

    /* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/TransformSpatialPredicates$TransformSpatialPredicateToJoin.class */
    public static final class TransformSpatialPredicateToJoin implements Rule<FilterNode> {
        private static final Capture<JoinNode> JOIN = Capture.newCapture();
        private static final Pattern<FilterNode> PATTERN = Patterns.filter().with(Patterns.source().matching(Patterns.join().capturedAs(JOIN).matching(joinNode -> {
            return joinNode.isCrossJoin();
        })));
        private final Metadata metadata;

        public TransformSpatialPredicateToJoin(Metadata metadata) {
            this.metadata = metadata;
        }

        @Override // com.facebook.presto.sql.planner.iterative.Rule
        public boolean isEnabled(Session session) {
            return SystemSessionProperties.isSpatialJoinEnabled(session);
        }

        @Override // com.facebook.presto.sql.planner.iterative.Rule
        public Pattern<FilterNode> getPattern() {
            return PATTERN;
        }

        @Override // com.facebook.presto.sql.planner.iterative.Rule
        public Rule.Result apply(FilterNode filterNode, Captures captures, Rule.Context context) {
            JoinNode joinNode = (JoinNode) captures.get(JOIN);
            Expression predicate = filterNode.getPredicate();
            Iterator<FunctionCall> it2 = SpatialJoinUtils.extractSupportedSpatialFunctions(predicate).iterator();
            while (it2.hasNext()) {
                Rule.Result tryCreateSpatialJoin = TransformSpatialPredicates.tryCreateSpatialJoin(context, joinNode, predicate, filterNode.getId(), filterNode.getOutputSymbols(), it2.next(), this.metadata);
                if (!tryCreateSpatialJoin.isEmpty()) {
                    return tryCreateSpatialJoin;
                }
            }
            Iterator<ComparisonExpression> it3 = SpatialJoinUtils.extractSupportedSpatialComparisons(predicate).iterator();
            while (it3.hasNext()) {
                Rule.Result tryCreateSpatialJoin2 = TransformSpatialPredicates.tryCreateSpatialJoin(context, joinNode, predicate, filterNode.getId(), filterNode.getOutputSymbols(), it3.next(), this.metadata);
                if (!tryCreateSpatialJoin2.isEmpty()) {
                    return tryCreateSpatialJoin2;
                }
            }
            return Rule.Result.empty();
        }
    }

    /* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/TransformSpatialPredicates$TransformSpatialPredicateToLeftJoin.class */
    public static final class TransformSpatialPredicateToLeftJoin implements Rule<JoinNode> {
        private static final Pattern<JoinNode> PATTERN = Patterns.join().matching(joinNode -> {
            return joinNode.getCriteria().isEmpty() && joinNode.getFilter().isPresent() && joinNode.getType() == JoinNode.Type.LEFT && !joinNode.isSpatialJoin();
        });
        private final Metadata metadata;

        public TransformSpatialPredicateToLeftJoin(Metadata metadata) {
            this.metadata = metadata;
        }

        @Override // com.facebook.presto.sql.planner.iterative.Rule
        public boolean isEnabled(Session session) {
            return SystemSessionProperties.isSpatialJoinEnabled(session);
        }

        @Override // com.facebook.presto.sql.planner.iterative.Rule
        public Pattern<JoinNode> getPattern() {
            return PATTERN;
        }

        @Override // com.facebook.presto.sql.planner.iterative.Rule
        public Rule.Result apply(JoinNode joinNode, Captures captures, Rule.Context context) {
            Expression expression = joinNode.getFilter().get();
            Iterator<FunctionCall> it2 = SpatialJoinUtils.extractSupportedSpatialFunctions(expression).iterator();
            while (it2.hasNext()) {
                Rule.Result tryCreateSpatialJoin = TransformSpatialPredicates.tryCreateSpatialJoin(context, joinNode, expression, joinNode.getId(), joinNode.getOutputSymbols(), it2.next(), this.metadata);
                if (!tryCreateSpatialJoin.isEmpty()) {
                    return tryCreateSpatialJoin;
                }
            }
            Iterator<ComparisonExpression> it3 = SpatialJoinUtils.extractSupportedSpatialComparisons(expression).iterator();
            while (it3.hasNext()) {
                Rule.Result tryCreateSpatialJoin2 = TransformSpatialPredicates.tryCreateSpatialJoin(context, joinNode, expression, joinNode.getId(), joinNode.getOutputSymbols(), it3.next(), this.metadata);
                if (!tryCreateSpatialJoin2.isEmpty()) {
                    return tryCreateSpatialJoin2;
                }
            }
            return Rule.Result.empty();
        }
    }

    public TransformSpatialPredicates(Metadata metadata) {
        this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
    }

    public Set<Rule<?>> rules() {
        return ImmutableSet.of((TransformSpatialPredicateToLeftJoin) new TransformSpatialPredicateToJoin(this.metadata), new TransformSpatialPredicateToLeftJoin(this.metadata));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Rule.Result tryCreateSpatialJoin(Rule.Context context, JoinNode joinNode, Expression expression, PlanNodeId planNodeId, List<Symbol> list, ComparisonExpression comparisonExpression, Metadata metadata) {
        Expression right;
        Optional<Symbol> newRadiusSymbol;
        ComparisonExpression comparisonExpression2;
        PlanNode left = joinNode.getLeft();
        PlanNode right2 = joinNode.getRight();
        List<Symbol> outputSymbols = left.getOutputSymbols();
        List<Symbol> outputSymbols2 = right2.getOutputSymbols();
        if (comparisonExpression.getOperator() == ComparisonExpression.Operator.LESS_THAN || comparisonExpression.getOperator() == ComparisonExpression.Operator.LESS_THAN_OR_EQUAL) {
            right = comparisonExpression.getRight();
            Set<Symbol> extractUnique = SymbolsExtractor.extractUnique(right);
            if (!extractUnique.isEmpty() && (!outputSymbols2.containsAll(extractUnique) || !containsNone(outputSymbols, extractUnique))) {
                return Rule.Result.empty();
            }
            newRadiusSymbol = newRadiusSymbol(context, right);
            comparisonExpression2 = new ComparisonExpression(comparisonExpression.getOperator(), comparisonExpression.getLeft(), toExpression(newRadiusSymbol, right));
        } else {
            right = comparisonExpression.getLeft();
            Set<Symbol> extractUnique2 = SymbolsExtractor.extractUnique(right);
            if (!extractUnique2.isEmpty() && (!outputSymbols2.containsAll(extractUnique2) || !containsNone(outputSymbols, extractUnique2))) {
                return Rule.Result.empty();
            }
            newRadiusSymbol = newRadiusSymbol(context, right);
            comparisonExpression2 = new ComparisonExpression(comparisonExpression.getOperator().flip(), comparisonExpression.getRight(), toExpression(newRadiusSymbol, right));
        }
        Expression replaceExpression = ExpressionNodeInliner.replaceExpression(expression, ImmutableMap.of(comparisonExpression, comparisonExpression2));
        Expression expression2 = right;
        return tryCreateSpatialJoin(context, new JoinNode(joinNode.getId(), joinNode.getType(), left, (PlanNode) newRadiusSymbol.map(symbol -> {
            return addProjection(context, right2, symbol, expression2);
        }).orElse(right2), joinNode.getCriteria(), joinNode.getOutputSymbols(), Optional.of(replaceExpression), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType()), replaceExpression, planNodeId, list, (FunctionCall) comparisonExpression2.getLeft(), metadata);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Rule.Result tryCreateSpatialJoin(Rule.Context context, JoinNode joinNode, Expression expression, PlanNodeId planNodeId, List<Symbol> list, FunctionCall functionCall, Metadata metadata) {
        PlanNode planNode;
        PlanNode planNode2;
        List<Expression> arguments = functionCall.getArguments();
        Verify.verify(arguments.size() == 2);
        Expression expression2 = arguments.get(0);
        Expression expression3 = arguments.get(1);
        Set<Symbol> extractUnique = SymbolsExtractor.extractUnique(expression2);
        Set<Symbol> extractUnique2 = SymbolsExtractor.extractUnique(expression3);
        if (extractUnique.isEmpty() || extractUnique2.isEmpty()) {
            return Rule.Result.empty();
        }
        Optional<Symbol> newGeometrySymbol = newGeometrySymbol(context, expression2, metadata);
        Optional<Symbol> newGeometrySymbol2 = newGeometrySymbol(context, expression3, metadata);
        PlanNode left = joinNode.getLeft();
        PlanNode right = joinNode.getRight();
        int checkAlignment = checkAlignment(joinNode, extractUnique, extractUnique2);
        if (checkAlignment > 0) {
            planNode = (PlanNode) newGeometrySymbol.map(symbol -> {
                return addProjection(context, left, symbol, expression2);
            }).orElse(left);
            planNode2 = (PlanNode) newGeometrySymbol2.map(symbol2 -> {
                return addProjection(context, right, symbol2, expression3);
            }).orElse(right);
        } else {
            if (checkAlignment >= 0) {
                return Rule.Result.empty();
            }
            planNode = (PlanNode) newGeometrySymbol2.map(symbol3 -> {
                return addProjection(context, left, symbol3, expression3);
            }).orElse(left);
            planNode2 = (PlanNode) newGeometrySymbol.map(symbol4 -> {
                return addProjection(context, right, symbol4, expression2);
            }).orElse(right);
        }
        return Rule.Result.ofPlanNode(new JoinNode(planNodeId, joinNode.getType(), planNode, planNode2, joinNode.getCriteria(), list, Optional.of(ExpressionNodeInliner.replaceExpression(expression, ImmutableMap.of(functionCall, new FunctionCall(functionCall.getName(), ImmutableList.of(toExpression(newGeometrySymbol, expression2), toExpression(newGeometrySymbol2, expression3)))))), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType()));
    }

    private static int checkAlignment(JoinNode joinNode, Set<Symbol> set, Set<Symbol> set2) {
        List<Symbol> outputSymbols = joinNode.getLeft().getOutputSymbols();
        List<Symbol> outputSymbols2 = joinNode.getRight().getOutputSymbols();
        if (outputSymbols.containsAll(set) && containsNone(outputSymbols, set2) && outputSymbols2.containsAll(set2) && containsNone(outputSymbols2, set)) {
            return 1;
        }
        return (outputSymbols.containsAll(set2) && containsNone(outputSymbols, set) && outputSymbols2.containsAll(set) && containsNone(outputSymbols2, set2)) ? -1 : 0;
    }

    private static Expression toExpression(Optional<Symbol> optional, Expression expression) {
        return (Expression) optional.map(symbol -> {
            return symbol.toSymbolReference();
        }).orElse(expression);
    }

    private static Optional<Symbol> newGeometrySymbol(Rule.Context context, Expression expression, Metadata metadata) {
        return expression instanceof SymbolReference ? Optional.empty() : Optional.of(context.getSymbolAllocator().newSymbol(expression, metadata.getType(GEOMETRY_TYPE_SIGNATURE)));
    }

    private static Optional<Symbol> newRadiusSymbol(Rule.Context context, Expression expression) {
        return expression instanceof SymbolReference ? Optional.empty() : Optional.of(context.getSymbolAllocator().newSymbol(expression, DoubleType.DOUBLE));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static PlanNode addProjection(Rule.Context context, PlanNode planNode, Symbol symbol, Expression expression) {
        Assignments.Builder builder = Assignments.builder();
        Iterator<Symbol> it2 = planNode.getOutputSymbols().iterator();
        while (it2.hasNext()) {
            builder.putIdentity(it2.next());
        }
        builder.put(symbol, expression);
        return new ProjectNode(context.getIdAllocator().getNextId(), planNode, builder.build());
    }

    private static boolean containsNone(Collection<Symbol> collection, Collection<Symbol> collection2) {
        Stream<Symbol> stream = collection.stream();
        ImmutableSet copyOf = ImmutableSet.copyOf((Collection) collection2);
        copyOf.getClass();
        return stream.noneMatch((v1) -> {
            return r1.contains(v1);
        });
    }
}
