package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.function.QualifiedFunctionName;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.execution.warnings.WarningCollector;
import com.facebook.presto.metadata.BuiltInFunctionNamespaceManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.function.StandardFunctionResolution;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.MarkDistinctNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
import com.facebook.presto.sql.planner.PlanVariableAllocator;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.plan.GroupIdNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.relational.Expressions;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.class */
public class OptimizeMixedDistinctAggregations implements PlanOptimizer {
    private final Metadata metadata;
    private final StandardFunctionResolution functionResolution;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations$AggregateInfo.class */
    public static class AggregateInfo {
        private final List<VariableReferenceExpression> groupByVariables;
        private final VariableReferenceExpression mask;
        private final Map<VariableReferenceExpression, AggregationNode.Aggregation> aggregations;
        private Map<VariableReferenceExpression, VariableReferenceExpression> newNonDistinctAggregateVariables;
        private VariableReferenceExpression newDistinctAggregateVariable;
        private boolean foundMarkDistinct;

        public AggregateInfo(List<VariableReferenceExpression> list, VariableReferenceExpression variableReferenceExpression, Map<VariableReferenceExpression, AggregationNode.Aggregation> map) {
            this.groupByVariables = ImmutableList.copyOf(list);
            this.mask = variableReferenceExpression;
            this.aggregations = ImmutableMap.copyOf(map);
        }

        public List<VariableReferenceExpression> getOriginalNonDistinctAggregateArgs() {
            Stream distinct = this.aggregations.values().stream().filter(aggregation -> {
                return !aggregation.getMask().isPresent();
            }).flatMap(aggregation2 -> {
                return aggregation2.getArguments().stream();
            }).distinct();
            Class<VariableReferenceExpression> cls = VariableReferenceExpression.class;
            VariableReferenceExpression.class.getClass();
            return (List) distinct.map((v1) -> {
                return r1.cast(v1);
            }).collect(Collectors.toList());
        }

        public List<VariableReferenceExpression> getOriginalDistinctAggregateArgs() {
            Stream distinct = this.aggregations.values().stream().filter(aggregation -> {
                return aggregation.getMask().isPresent();
            }).flatMap(aggregation2 -> {
                return aggregation2.getArguments().stream();
            }).distinct();
            Class<VariableReferenceExpression> cls = VariableReferenceExpression.class;
            VariableReferenceExpression.class.getClass();
            return (List) distinct.map((v1) -> {
                return r1.cast(v1);
            }).collect(Collectors.toList());
        }

        public VariableReferenceExpression getNewDistinctAggregateVariable() {
            return this.newDistinctAggregateVariable;
        }

        public void setNewDistinctAggregateSymbol(VariableReferenceExpression variableReferenceExpression) {
            this.newDistinctAggregateVariable = variableReferenceExpression;
        }

        public Map<VariableReferenceExpression, VariableReferenceExpression> getNewNonDistinctAggregateVariables() {
            return this.newNonDistinctAggregateVariables;
        }

        public void setNewNonDistinctAggregateSymbols(Map<VariableReferenceExpression, VariableReferenceExpression> map) {
            this.newNonDistinctAggregateVariables = map;
        }

        public VariableReferenceExpression getMask() {
            return this.mask;
        }

        public List<VariableReferenceExpression> getGroupByVariables() {
            return this.groupByVariables;
        }

        public Map<VariableReferenceExpression, AggregationNode.Aggregation> getAggregations() {
            return this.aggregations;
        }

        public void foundMarkDistinct() {
            this.foundMarkDistinct = true;
        }

        public boolean isFoundMarkDistinct() {
            return this.foundMarkDistinct;
        }
    }

    /* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations$Optimizer.class */
    private static class Optimizer extends SimplePlanRewriter<Optional<AggregateInfo>> {
        private final PlanNodeIdAllocator idAllocator;
        private final PlanVariableAllocator variableAllocator;
        private final Metadata metadata;
        private final StandardFunctionResolution functionResolution;

        private Optimizer(PlanNodeIdAllocator planNodeIdAllocator, PlanVariableAllocator planVariableAllocator, Metadata metadata, StandardFunctionResolution standardFunctionResolution) {
            this.idAllocator = (PlanNodeIdAllocator) Objects.requireNonNull(planNodeIdAllocator, "idAllocator is null");
            this.variableAllocator = (PlanVariableAllocator) Objects.requireNonNull(planVariableAllocator, "variableAllocator is null");
            this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
            this.functionResolution = (StandardFunctionResolution) Objects.requireNonNull(standardFunctionResolution, "functionResolution is null");
        }

        public PlanNode visitAggregation(AggregationNode aggregationNode, SimplePlanRewriter.RewriteContext<Optional<AggregateInfo>> rewriteContext) {
            List list = (List) aggregationNode.getAggregations().values().stream().map((v0) -> {
                return v0.getMask();
            }).filter((v0) -> {
                return v0.isPresent();
            }).map((v0) -> {
                return v0.get();
            }).collect(ImmutableList.toImmutableList());
            ImmutableSet copyOf = ImmutableSet.copyOf(list);
            if (copyOf.size() != 1 || list.size() == aggregationNode.getAggregations().size()) {
                return rewriteContext.defaultRewrite(aggregationNode, Optional.empty());
            }
            if (!aggregationNode.getAggregations().values().stream().map((v0) -> {
                return v0.getFilter();
            }).anyMatch((v0) -> {
                return v0.isPresent();
            }) && !aggregationNode.hasOrderings()) {
                AggregateInfo aggregateInfo = new AggregateInfo(aggregationNode.getGroupingKeys(), (VariableReferenceExpression) Iterables.getOnlyElement(copyOf), aggregationNode.getAggregations());
                if (!checkAllEquatableTypes(aggregateInfo)) {
                    return rewriteContext.defaultRewrite(aggregationNode, Optional.empty());
                }
                PlanNode rewrite = rewriteContext.rewrite(aggregationNode.getSource(), Optional.of(aggregateInfo));
                if (!aggregateInfo.isFoundMarkDistinct()) {
                    return rewriteContext.defaultRewrite(aggregationNode, Optional.empty());
                }
                ImmutableMap.Builder builder = ImmutableMap.builder();
                ImmutableMap.Builder builder2 = ImmutableMap.builder();
                for (Map.Entry entry : aggregationNode.getAggregations().entrySet()) {
                    if (((AggregationNode.Aggregation) entry.getValue()).getMask().isPresent()) {
                        builder.put(entry.getKey(), new AggregationNode.Aggregation(new CallExpression(((AggregationNode.Aggregation) entry.getValue()).getCall().getDisplayName(), ((AggregationNode.Aggregation) entry.getValue()).getCall().getFunctionHandle(), ((AggregationNode.Aggregation) entry.getValue()).getCall().getType(), ImmutableList.of(aggregateInfo.getNewDistinctAggregateVariable())), Optional.empty(), Optional.empty(), false, Optional.empty()));
                    } else {
                        VariableReferenceExpression variableReferenceExpression = aggregateInfo.getNewNonDistinctAggregateVariables().get(entry.getKey());
                        AggregationNode.Aggregation aggregation = new AggregationNode.Aggregation(new CallExpression("arbitrary", this.metadata.getFunctionManager().lookupFunction("arbitrary", TypeSignatureProvider.fromTypes((List<? extends Type>) ImmutableList.of(variableReferenceExpression.getType()))), ((VariableReferenceExpression) entry.getKey()).getType(), ImmutableList.of(variableReferenceExpression)), Optional.empty(), Optional.empty(), false, Optional.empty());
                        QualifiedFunctionName name = this.metadata.getFunctionManager().getFunctionMetadata(((AggregationNode.Aggregation) entry.getValue()).getFunctionHandle()).getName();
                        if (name.equals(QualifiedFunctionName.of(BuiltInFunctionNamespaceManager.DEFAULT_NAMESPACE, "count")) || name.equals(QualifiedFunctionName.of(BuiltInFunctionNamespaceManager.DEFAULT_NAMESPACE, "count_if")) || name.equals(QualifiedFunctionName.of(BuiltInFunctionNamespaceManager.DEFAULT_NAMESPACE, "approx_distinct"))) {
                            VariableReferenceExpression newVariable = this.variableAllocator.newVariable("expr", ((VariableReferenceExpression) entry.getKey()).getType());
                            builder.put(newVariable, aggregation);
                            builder2.put(newVariable, entry.getKey());
                        } else {
                            builder.put(entry.getKey(), aggregation);
                        }
                    }
                }
                ImmutableMap build = builder2.build();
                AggregationNode aggregationNode2 = new AggregationNode(this.idAllocator.getNextId(), rewrite, builder.build(), aggregationNode.getGroupingSets(), ImmutableList.of(), aggregationNode.getStep(), Optional.empty(), aggregationNode.getGroupIdVariable());
                if (build.isEmpty()) {
                    return aggregationNode2;
                }
                Assignments.Builder builder3 = Assignments.builder();
                for (RowExpression rowExpression : aggregationNode2.getOutputVariables()) {
                    if (build.containsKey(rowExpression)) {
                        builder3.put((VariableReferenceExpression) build.get(rowExpression), new SpecialFormExpression(SpecialFormExpression.Form.COALESCE, BigintType.BIGINT, new RowExpression[]{rowExpression, Expressions.constant(0L, BigintType.BIGINT)}));
                    } else {
                        builder3.put(rowExpression, rowExpression);
                    }
                }
                return new ProjectNode(this.idAllocator.getNextId(), aggregationNode2, builder3.build());
            }
            return rewriteContext.defaultRewrite(aggregationNode, Optional.empty());
        }

        public PlanNode visitMarkDistinct(MarkDistinctNode markDistinctNode, SimplePlanRewriter.RewriteContext<Optional<AggregateInfo>> rewriteContext) {
            Optional<AggregateInfo> optional = rewriteContext.get();
            if (!optional.isPresent() || !optional.get().getMask().equals(markDistinctNode.getMarkerVariable())) {
                return rewriteContext.defaultRewrite(markDistinctNode, Optional.empty());
            }
            optional.get().foundMarkDistinct();
            PlanNode rewrite = rewriteContext.rewrite(markDistinctNode.getSource(), Optional.empty());
            HashSet hashSet = new HashSet();
            List<VariableReferenceExpression> groupByVariables = optional.get().getGroupByVariables();
            List<VariableReferenceExpression> originalNonDistinctAggregateArgs = optional.get().getOriginalNonDistinctAggregateArgs();
            VariableReferenceExpression variableReferenceExpression = (VariableReferenceExpression) Iterables.getOnlyElement(optional.get().getOriginalDistinctAggregateArgs());
            VariableReferenceExpression variableReferenceExpression2 = variableReferenceExpression;
            if (originalNonDistinctAggregateArgs.contains(variableReferenceExpression)) {
                VariableReferenceExpression newVariable = this.variableAllocator.newVariable(variableReferenceExpression);
                originalNonDistinctAggregateArgs.set(originalNonDistinctAggregateArgs.indexOf(variableReferenceExpression), newVariable);
                variableReferenceExpression2 = newVariable;
            }
            hashSet.addAll(groupByVariables);
            hashSet.addAll(originalNonDistinctAggregateArgs);
            hashSet.add(variableReferenceExpression);
            VariableReferenceExpression newVariable2 = this.variableAllocator.newVariable("group", (Type) BigintType.BIGINT);
            GroupIdNode createGroupIdNode = createGroupIdNode(groupByVariables, originalNonDistinctAggregateArgs, variableReferenceExpression, variableReferenceExpression2, newVariable2, hashSet, rewrite);
            HashSet hashSet2 = new HashSet(groupByVariables);
            hashSet2.add(variableReferenceExpression);
            hashSet2.add(newVariable2);
            ImmutableMap.Builder<VariableReferenceExpression, VariableReferenceExpression> builder = ImmutableMap.builder();
            return createProjectNode(createNonDistinctAggregation(optional.get(), variableReferenceExpression, variableReferenceExpression2, hashSet2, createGroupIdNode, markDistinctNode, builder), optional.get(), variableReferenceExpression, newVariable2, groupByVariables, builder.build());
        }

        private boolean checkAllEquatableTypes(AggregateInfo aggregateInfo) {
            Iterator<VariableReferenceExpression> it = aggregateInfo.getOriginalNonDistinctAggregateArgs().iterator();
            while (it.hasNext()) {
                if (!it.next().getType().isComparable()) {
                    return false;
                }
            }
            return aggregateInfo.getMask().getType().isComparable();
        }

        private ProjectNode createProjectNode(AggregationNode aggregationNode, AggregateInfo aggregateInfo, VariableReferenceExpression variableReferenceExpression, VariableReferenceExpression variableReferenceExpression2, List<VariableReferenceExpression> list, Map<VariableReferenceExpression, VariableReferenceExpression> map) {
            Assignments.Builder builder = Assignments.builder();
            ImmutableMap.Builder builder2 = ImmutableMap.builder();
            for (VariableReferenceExpression variableReferenceExpression3 : aggregationNode.getOutputVariables()) {
                if (variableReferenceExpression.equals(variableReferenceExpression3)) {
                    VariableReferenceExpression newVariable = this.variableAllocator.newVariable("expr", variableReferenceExpression3.getType());
                    aggregateInfo.setNewDistinctAggregateSymbol(newVariable);
                    builder.put(newVariable, new SpecialFormExpression(SpecialFormExpression.Form.IF, variableReferenceExpression3.getType(), ImmutableList.of(Expressions.call(OperatorType.EQUAL.name(), this.functionResolution.comparisonFunction(OperatorType.EQUAL, BigintType.BIGINT, BigintType.BIGINT), (Type) BooleanType.BOOLEAN, (List<RowExpression>) ImmutableList.of(variableReferenceExpression2, Expressions.constant(1L, BigintType.BIGINT))), variableReferenceExpression3, Expressions.constantNull(variableReferenceExpression3.getType()))));
                } else if (map.containsKey(variableReferenceExpression3)) {
                    VariableReferenceExpression newVariable2 = this.variableAllocator.newVariable("expr", variableReferenceExpression3.getType());
                    builder2.put(map.get(variableReferenceExpression3), newVariable2);
                    builder.put(newVariable2, new SpecialFormExpression(SpecialFormExpression.Form.IF, variableReferenceExpression3.getType(), ImmutableList.of(Expressions.call(OperatorType.EQUAL.name(), this.functionResolution.comparisonFunction(OperatorType.EQUAL, BigintType.BIGINT, BigintType.BIGINT), (Type) BooleanType.BOOLEAN, (List<RowExpression>) ImmutableList.of(variableReferenceExpression2, Expressions.constant(0L, BigintType.BIGINT))), variableReferenceExpression3, Expressions.constantNull(variableReferenceExpression3.getType()))));
                }
                if (list.contains(variableReferenceExpression3)) {
                    builder.put(variableReferenceExpression3, variableReferenceExpression3);
                }
            }
            builder.put(aggregateInfo.getMask(), Expressions.constantNull(aggregateInfo.getMask().getType()));
            aggregateInfo.setNewNonDistinctAggregateSymbols(builder2.build());
            return new ProjectNode(this.idAllocator.getNextId(), aggregationNode, builder.build());
        }

        private GroupIdNode createGroupIdNode(List<VariableReferenceExpression> list, List<VariableReferenceExpression> list2, VariableReferenceExpression variableReferenceExpression, VariableReferenceExpression variableReferenceExpression2, VariableReferenceExpression variableReferenceExpression3, Set<VariableReferenceExpression> set, PlanNode planNode) {
            ArrayList arrayList = new ArrayList();
            HashSet hashSet = new HashSet();
            hashSet.addAll(list);
            hashSet.addAll(list2);
            arrayList.add(ImmutableList.copyOf(hashSet));
            HashSet hashSet2 = new HashSet(list);
            hashSet2.add(variableReferenceExpression);
            arrayList.add(ImmutableList.copyOf(hashSet2));
            return new GroupIdNode(this.idAllocator.getNextId(), planNode, arrayList, (Map) set.stream().collect(Collectors.toMap(Function.identity(), variableReferenceExpression4 -> {
                return variableReferenceExpression4.equals(variableReferenceExpression2) ? variableReferenceExpression : variableReferenceExpression4;
            })), ImmutableList.of(), variableReferenceExpression3);
        }

        private AggregationNode createNonDistinctAggregation(AggregateInfo aggregateInfo, VariableReferenceExpression variableReferenceExpression, VariableReferenceExpression variableReferenceExpression2, Set<VariableReferenceExpression> set, GroupIdNode groupIdNode, MarkDistinctNode markDistinctNode, ImmutableMap.Builder<VariableReferenceExpression, VariableReferenceExpression> builder) {
            ImmutableList arguments;
            ImmutableMap.Builder builder2 = ImmutableMap.builder();
            for (Map.Entry<VariableReferenceExpression, AggregationNode.Aggregation> entry : aggregateInfo.getAggregations().entrySet()) {
                if (!entry.getValue().getMask().isPresent()) {
                    VariableReferenceExpression newVariable = this.variableAllocator.newVariable(entry.getKey());
                    AggregationNode.Aggregation value = entry.getValue();
                    builder.put(newVariable, entry.getKey());
                    if (variableReferenceExpression2.equals(variableReferenceExpression) || !extractVariables(entry.getValue().getArguments(), this.variableAllocator.getTypes()).contains(variableReferenceExpression)) {
                        arguments = value.getArguments();
                    } else {
                        ImmutableList.Builder builder3 = ImmutableList.builder();
                        for (RowExpression rowExpression : value.getArguments()) {
                            if ((rowExpression instanceof VariableReferenceExpression) && rowExpression.equals(variableReferenceExpression)) {
                                builder3.add(variableReferenceExpression2);
                            } else {
                                builder3.add(rowExpression);
                            }
                        }
                        arguments = builder3.build();
                    }
                    builder2.put(newVariable, new AggregationNode.Aggregation(new CallExpression(value.getCall().getDisplayName(), value.getCall().getFunctionHandle(), value.getCall().getType(), arguments), Optional.empty(), Optional.empty(), false, Optional.empty()));
                }
            }
            return new AggregationNode(this.idAllocator.getNextId(), groupIdNode, builder2.build(), AggregationNode.singleGroupingSet(ImmutableList.copyOf(set)), ImmutableList.of(), AggregationNode.Step.SINGLE, markDistinctNode.getHashVariable(), Optional.empty());
        }

        private static Set<VariableReferenceExpression> extractVariables(List<RowExpression> list, TypeProvider typeProvider) {
            ImmutableSet.Builder builder = ImmutableSet.builder();
            Iterator<RowExpression> it = list.iterator();
            while (it.hasNext()) {
                VariableReferenceExpression variableReferenceExpression = (RowExpression) it.next();
                if (variableReferenceExpression instanceof VariableReferenceExpression) {
                    builder.add(variableReferenceExpression);
                }
            }
            return builder.build();
        }
    }

    public OptimizeMixedDistinctAggregations(Metadata metadata) {
        this.metadata = metadata;
        this.functionResolution = new FunctionResolution(metadata.getFunctionManager());
    }

    @Override // com.facebook.presto.sql.planner.optimizations.PlanOptimizer
    public PlanNode optimize(PlanNode planNode, Session session, TypeProvider typeProvider, PlanVariableAllocator planVariableAllocator, PlanNodeIdAllocator planNodeIdAllocator, WarningCollector warningCollector) {
        return SystemSessionProperties.isOptimizeDistinctAggregationEnabled(session) ? SimplePlanRewriter.rewriteWith(new Optimizer(planNodeIdAllocator, planVariableAllocator, this.metadata, this.functionResolution), planNode, Optional.empty()) : planNode;
    }
}
