package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.execution.warnings.WarningCollector;
import com.facebook.presto.hive.$internal.jodd.util.StringPool;
import com.facebook.presto.hive.jdbc.$internal.org.apache.hadoop.fs.shell.Count;
import com.facebook.presto.metadata.FunctionManager;
import com.facebook.presto.spi.function.StandardFunctionResolution;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.ExceptNode;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.IntersectNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.SetOperationNode;
import com.facebook.presto.spi.plan.UnionNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.planner.PlanVariableAllocator;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.plan.AssignmentUtils;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.OriginalExpressionUtils;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.GenericLiteral;
import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Maps;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.class */
public class ImplementIntersectAndExceptAsUnion implements PlanOptimizer {
    private final FunctionManager functionManager;

    /* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion$Rewriter.class */
    private static class Rewriter extends SimplePlanRewriter<Void> {
        private static final String MARKER = "marker";
        private final Session session;
        private final StandardFunctionResolution functionResolution;
        private final PlanNodeIdAllocator idAllocator;
        private final PlanVariableAllocator variableAllocator;

        private Rewriter(Session session, FunctionManager functionManager, PlanNodeIdAllocator planNodeIdAllocator, PlanVariableAllocator planVariableAllocator) {
            Objects.requireNonNull(functionManager, "functionManager is null");
            this.session = (Session) Objects.requireNonNull(session, "session is null");
            this.functionResolution = new FunctionResolution(functionManager);
            this.idAllocator = (PlanNodeIdAllocator) Objects.requireNonNull(planNodeIdAllocator, "idAllocator is null");
            this.variableAllocator = (PlanVariableAllocator) Objects.requireNonNull(planVariableAllocator, "variableAllocator is null");
        }

        @Override // com.facebook.presto.spi.plan.PlanVisitor
        public PlanNode visitIntersect(IntersectNode intersectNode, SimplePlanRewriter.RewriteContext<Void> rewriteContext) {
            Stream<PlanNode> stream = intersectNode.getSources().stream();
            rewriteContext.getClass();
            List<PlanNode> list = (List) stream.map(rewriteContext::rewrite).collect(Collectors.toList());
            List<VariableReferenceExpression> allocateVariables = allocateVariables(list.size(), MARKER, BooleanType.BOOLEAN);
            List<PlanNode> appendMarkers = appendMarkers(allocateVariables, list, intersectNode);
            List<VariableReferenceExpression> outputVariables = intersectNode.getOutputVariables();
            return project(addFilterForIntersect(computeCounts(union(appendMarkers, ImmutableList.copyOf(Iterables.concat(outputVariables, allocateVariables))), intersectNode.getOutputVariables(), allocateVariables, allocateVariables(allocateVariables.size(), Count.NAME, BigintType.BIGINT))), outputVariables);
        }

        @Override // com.facebook.presto.spi.plan.PlanVisitor
        public PlanNode visitExcept(ExceptNode exceptNode, SimplePlanRewriter.RewriteContext<Void> rewriteContext) {
            Stream<PlanNode> stream = exceptNode.getSources().stream();
            rewriteContext.getClass();
            List<PlanNode> list = (List) stream.map(rewriteContext::rewrite).collect(Collectors.toList());
            List<VariableReferenceExpression> allocateVariables = allocateVariables(list.size(), MARKER, BooleanType.BOOLEAN);
            List<PlanNode> appendMarkers = appendMarkers(allocateVariables, list, exceptNode);
            List<VariableReferenceExpression> outputVariables = exceptNode.getOutputVariables();
            UnionNode union = union(appendMarkers, ImmutableList.copyOf(Iterables.concat(outputVariables, allocateVariables)));
            List<VariableReferenceExpression> allocateVariables2 = allocateVariables(allocateVariables.size(), Count.NAME, BigintType.BIGINT);
            return project(addFilterForExcept(computeCounts(union, exceptNode.getOutputVariables(), allocateVariables, allocateVariables2), allocateVariables2.get(0), allocateVariables2.subList(1, allocateVariables2.size())), outputVariables);
        }

        private List<VariableReferenceExpression> allocateVariables(int i, String str, Type type) {
            ImmutableList.Builder builder = ImmutableList.builder();
            for (int i2 = 0; i2 < i; i2++) {
                builder.add((ImmutableList.Builder) this.variableAllocator.newVariable(str, type));
            }
            return builder.build();
        }

        private List<PlanNode> appendMarkers(List<VariableReferenceExpression> list, List<PlanNode> list2, SetOperationNode setOperationNode) {
            ImmutableList.Builder builder = ImmutableList.builder();
            for (int i = 0; i < list2.size(); i++) {
                builder.add((ImmutableList.Builder) appendMarkers(list2.get(i), i, list, Maps.transformValues(setOperationNode.sourceVariableMap(i), variableReferenceExpression -> {
                    return new SymbolReference(variableReferenceExpression.getName());
                })));
            }
            return builder.build();
        }

        private PlanNode appendMarkers(PlanNode planNode, int i, List<VariableReferenceExpression> list, Map<VariableReferenceExpression, SymbolReference> map) {
            Assignments.Builder builder = Assignments.builder();
            for (Map.Entry<VariableReferenceExpression, SymbolReference> entry : map.entrySet()) {
                builder.put(this.variableAllocator.newVariable(entry.getKey().getName(), entry.getKey().getType()), OriginalExpressionUtils.castToRowExpression(entry.getValue()));
            }
            int i2 = 0;
            while (i2 < list.size()) {
                builder.put(this.variableAllocator.newVariable(list.get(i2).getName(), BooleanType.BOOLEAN), OriginalExpressionUtils.castToRowExpression(i2 == i ? BooleanLiteral.TRUE_LITERAL : new Cast(new NullLiteral(), "boolean")));
                i2++;
            }
            return new ProjectNode(this.idAllocator.getNextId(), planNode, builder.build());
        }

        private UnionNode union(List<PlanNode> list, List<VariableReferenceExpression> list2) {
            ImmutableListMultimap.Builder builder = ImmutableListMultimap.builder();
            for (PlanNode planNode : list) {
                for (int i = 0; i < planNode.getOutputVariables().size(); i++) {
                    builder.put((ImmutableListMultimap.Builder) list2.get(i), planNode.getOutputVariables().get(i));
                }
            }
            ImmutableListMultimap build = builder.build();
            return new UnionNode(this.idAllocator.getNextId(), list, ImmutableList.copyOf(build.keySet()), SetOperationNodeUtils.fromListMultimap(build));
        }

        private AggregationNode computeCounts(UnionNode unionNode, List<VariableReferenceExpression> list, List<VariableReferenceExpression> list2, List<VariableReferenceExpression> list3) {
            ImmutableMap.Builder builder = ImmutableMap.builder();
            for (int i = 0; i < list2.size(); i++) {
                builder.put(list3.get(i), new AggregationNode.Aggregation(new CallExpression(Count.NAME, this.functionResolution.countFunction(list2.get(i).getType()), BigintType.BIGINT, ImmutableList.of(OriginalExpressionUtils.castToRowExpression(OriginalExpressionUtils.asSymbolReference(list2.get(i))))), Optional.empty(), Optional.empty(), false, Optional.empty()));
            }
            return new AggregationNode(this.idAllocator.getNextId(), unionNode, builder.build(), AggregationNode.singleGroupingSet(list), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty());
        }

        private FilterNode addFilterForIntersect(AggregationNode aggregationNode) {
            return new FilterNode(this.idAllocator.getNextId(), aggregationNode, OriginalExpressionUtils.castToRowExpression(ExpressionUtils.and((ImmutableList) aggregationNode.getAggregations().keySet().stream().map(variableReferenceExpression -> {
                return new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL, new SymbolReference(variableReferenceExpression.getName()), new GenericLiteral("BIGINT", StringPool.ONE));
            }).collect(ImmutableList.toImmutableList()))));
        }

        private FilterNode addFilterForExcept(AggregationNode aggregationNode, VariableReferenceExpression variableReferenceExpression, List<VariableReferenceExpression> list) {
            ImmutableList.Builder builder = ImmutableList.builder();
            builder.add((ImmutableList.Builder) new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL, new SymbolReference(variableReferenceExpression.getName()), new GenericLiteral("BIGINT", StringPool.ONE)));
            Iterator<VariableReferenceExpression> it2 = list.iterator();
            while (it2.hasNext()) {
                builder.add((ImmutableList.Builder) new ComparisonExpression(ComparisonExpression.Operator.EQUAL, new SymbolReference(it2.next().getName()), new GenericLiteral("BIGINT", "0")));
            }
            return new FilterNode(this.idAllocator.getNextId(), aggregationNode, OriginalExpressionUtils.castToRowExpression(ExpressionUtils.and(builder.build())));
        }

        private ProjectNode project(PlanNode planNode, List<VariableReferenceExpression> list) {
            return new ProjectNode(this.idAllocator.getNextId(), planNode, AssignmentUtils.identityAssignmentsAsSymbolReferences(list));
        }
    }

    public ImplementIntersectAndExceptAsUnion(FunctionManager functionManager) {
        this.functionManager = (FunctionManager) Objects.requireNonNull(functionManager, "functionManager is null");
    }

    @Override // com.facebook.presto.sql.planner.optimizations.PlanOptimizer
    public PlanNode optimize(PlanNode planNode, Session session, TypeProvider typeProvider, PlanVariableAllocator planVariableAllocator, PlanNodeIdAllocator planNodeIdAllocator, WarningCollector warningCollector) {
        Objects.requireNonNull(planNode, "plan is null");
        Objects.requireNonNull(session, "session is null");
        Objects.requireNonNull(typeProvider, "types is null");
        Objects.requireNonNull(planVariableAllocator, "variableAllocator is null");
        Objects.requireNonNull(planNodeIdAllocator, "idAllocator is null");
        return SimplePlanRewriter.rewriteWith(new Rewriter(session, this.functionManager, planNodeIdAllocator, planVariableAllocator), planNode);
    }
}
