package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.hive.$internal.jodd.util.StringPool;
import com.facebook.presto.hive.jdbc.$internal.org.apache.hadoop.fs.shell.Count;
import com.facebook.presto.metadata.FunctionKind;
import com.facebook.presto.metadata.Signature;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.BooleanType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.TypeSignature;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.Assignments;
import com.facebook.presto.sql.planner.plan.ExceptNode;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.IntersectNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SetOperationNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.planner.plan.UnionNode;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.GenericLiteral;
import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.class */
public class ImplementIntersectAndExceptAsUnion implements PlanOptimizer {

    /* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion$Rewriter.class */
    private static class Rewriter extends SimplePlanRewriter<Void> {
        private static final String MARKER = "marker";
        private static final Signature COUNT_AGGREGATION = new Signature(Count.NAME, FunctionKind.AGGREGATE, TypeSignature.parseTypeSignature("bigint"), TypeSignature.parseTypeSignature("boolean"));
        private final PlanNodeIdAllocator idAllocator;
        private final SymbolAllocator symbolAllocator;

        private Rewriter(PlanNodeIdAllocator planNodeIdAllocator, SymbolAllocator symbolAllocator) {
            this.idAllocator = (PlanNodeIdAllocator) Objects.requireNonNull(planNodeIdAllocator, "idAllocator is null");
            this.symbolAllocator = (SymbolAllocator) Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNode visitIntersect(IntersectNode intersectNode, SimplePlanRewriter.RewriteContext<Void> rewriteContext) {
            Stream<PlanNode> stream = intersectNode.getSources().stream();
            rewriteContext.getClass();
            List<PlanNode> list = (List) stream.map(rewriteContext::rewrite).collect(Collectors.toList());
            List<Symbol> allocateSymbols = allocateSymbols(list.size(), MARKER, BooleanType.BOOLEAN);
            List<PlanNode> appendMarkers = appendMarkers(allocateSymbols, list, intersectNode);
            List<Symbol> outputSymbols = intersectNode.getOutputSymbols();
            return project(addFilterForIntersect(computeCounts(union(appendMarkers, ImmutableList.copyOf(Iterables.concat(outputSymbols, allocateSymbols))), outputSymbols, allocateSymbols, allocateSymbols(allocateSymbols.size(), Count.NAME, BigintType.BIGINT))), outputSymbols);
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNode visitExcept(ExceptNode exceptNode, SimplePlanRewriter.RewriteContext<Void> rewriteContext) {
            Stream<PlanNode> stream = exceptNode.getSources().stream();
            rewriteContext.getClass();
            List<PlanNode> list = (List) stream.map(rewriteContext::rewrite).collect(Collectors.toList());
            List<Symbol> allocateSymbols = allocateSymbols(list.size(), MARKER, BooleanType.BOOLEAN);
            List<PlanNode> appendMarkers = appendMarkers(allocateSymbols, list, exceptNode);
            List<Symbol> outputSymbols = exceptNode.getOutputSymbols();
            UnionNode union = union(appendMarkers, ImmutableList.copyOf(Iterables.concat(outputSymbols, allocateSymbols)));
            List<Symbol> allocateSymbols2 = allocateSymbols(allocateSymbols.size(), Count.NAME, BigintType.BIGINT);
            return project(addFilterForExcept(computeCounts(union, outputSymbols, allocateSymbols, allocateSymbols2), allocateSymbols2.get(0), allocateSymbols2.subList(1, allocateSymbols2.size())), outputSymbols);
        }

        private List<Symbol> allocateSymbols(int i, String str, Type type) {
            ImmutableList.Builder builder = ImmutableList.builder();
            for (int i2 = 0; i2 < i; i2++) {
                builder.add((ImmutableList.Builder) this.symbolAllocator.newSymbol(str, type));
            }
            return builder.build();
        }

        private List<PlanNode> appendMarkers(List<Symbol> list, List<PlanNode> list2, SetOperationNode setOperationNode) {
            ImmutableList.Builder builder = ImmutableList.builder();
            for (int i = 0; i < list2.size(); i++) {
                builder.add((ImmutableList.Builder) appendMarkers(list2.get(i), i, list, setOperationNode.sourceSymbolMap(i)));
            }
            return builder.build();
        }

        private PlanNode appendMarkers(PlanNode planNode, int i, List<Symbol> list, Map<Symbol, SymbolReference> map) {
            Assignments.Builder builder = Assignments.builder();
            for (Map.Entry<Symbol, SymbolReference> entry : map.entrySet()) {
                builder.put(this.symbolAllocator.newSymbol(entry.getKey().getName(), this.symbolAllocator.getTypes().get(entry.getKey())), entry.getValue());
            }
            int i2 = 0;
            while (i2 < list.size()) {
                builder.put(this.symbolAllocator.newSymbol(list.get(i2).getName(), BooleanType.BOOLEAN), i2 == i ? BooleanLiteral.TRUE_LITERAL : new Cast(new NullLiteral(), "boolean"));
                i2++;
            }
            return new ProjectNode(this.idAllocator.getNextId(), planNode, builder.build());
        }

        private UnionNode union(List<PlanNode> list, List<Symbol> list2) {
            ImmutableListMultimap.Builder builder = ImmutableListMultimap.builder();
            for (PlanNode planNode : list) {
                for (int i = 0; i < planNode.getOutputSymbols().size(); i++) {
                    builder.put((ImmutableListMultimap.Builder) list2.get(i), planNode.getOutputSymbols().get(i));
                }
            }
            return new UnionNode(this.idAllocator.getNextId(), list, builder.build(), list2);
        }

        private AggregationNode computeCounts(UnionNode unionNode, List<Symbol> list, List<Symbol> list2, List<Symbol> list3) {
            ImmutableMap.Builder builder = ImmutableMap.builder();
            for (int i = 0; i < list2.size(); i++) {
                builder.put(list3.get(i), new AggregationNode.Aggregation(new FunctionCall(QualifiedName.of(Count.NAME), ImmutableList.of(list2.get(i).toSymbolReference())), COUNT_AGGREGATION, Optional.empty()));
            }
            return new AggregationNode(this.idAllocator.getNextId(), unionNode, builder.build(), ImmutableList.of(list), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty());
        }

        private FilterNode addFilterForIntersect(AggregationNode aggregationNode) {
            return new FilterNode(this.idAllocator.getNextId(), aggregationNode, ExpressionUtils.and((ImmutableList) aggregationNode.getAggregations().keySet().stream().map(symbol -> {
                return new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL, symbol.toSymbolReference(), new GenericLiteral("BIGINT", StringPool.ONE));
            }).collect(ImmutableList.toImmutableList())));
        }

        private FilterNode addFilterForExcept(AggregationNode aggregationNode, Symbol symbol, List<Symbol> list) {
            ImmutableList.Builder builder = ImmutableList.builder();
            builder.add((ImmutableList.Builder) new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL, symbol.toSymbolReference(), new GenericLiteral("BIGINT", StringPool.ONE)));
            Iterator<Symbol> it2 = list.iterator();
            while (it2.hasNext()) {
                builder.add((ImmutableList.Builder) new ComparisonExpression(ComparisonExpression.Operator.EQUAL, it2.next().toSymbolReference(), new GenericLiteral("BIGINT", "0")));
            }
            return new FilterNode(this.idAllocator.getNextId(), aggregationNode, ExpressionUtils.and(builder.build()));
        }

        private ProjectNode project(PlanNode planNode, List<Symbol> list) {
            return new ProjectNode(this.idAllocator.getNextId(), planNode, Assignments.identity(list));
        }
    }

    @Override // com.facebook.presto.sql.planner.optimizations.PlanOptimizer
    public PlanNode optimize(PlanNode planNode, Session session, TypeProvider typeProvider, SymbolAllocator symbolAllocator, PlanNodeIdAllocator planNodeIdAllocator) {
        Objects.requireNonNull(planNode, "plan is null");
        Objects.requireNonNull(session, "session is null");
        Objects.requireNonNull(typeProvider, "types is null");
        Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
        Objects.requireNonNull(planNodeIdAllocator, "idAllocator is null");
        return SimplePlanRewriter.rewriteWith(new Rewriter(planNodeIdAllocator, symbolAllocator), planNode);
    }
}
