package com.facebook.presto.sql.planner.iterative;

import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.google.common.base.Preconditions;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.Multiset;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nullable;

/* loaded from: input_file:com/facebook/presto/sql/planner/iterative/Memo.class */
public class Memo {
    private static final int ROOT_GROUP_REF = 0;
    private final PlanNodeIdAllocator idAllocator;
    private final int rootGroup;
    private final Map<Integer, Group> groups = new HashMap();
    private int nextGroupId = 1;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/facebook/presto/sql/planner/iterative/Memo$Group.class */
    public static final class Group {
        private PlanNode membership;
        private Multiset<Integer> incomingReferences = HashMultiset.create();

        @Nullable
        private PlanNodeStatsEstimate stats;

        static Group withMember(PlanNode planNode) {
            return new Group(planNode);
        }

        private Group(PlanNode planNode) {
            this.membership = (PlanNode) Objects.requireNonNull(planNode, "member is null");
        }
    }

    public Memo(PlanNodeIdAllocator planNodeIdAllocator, PlanNode planNode) {
        this.idAllocator = planNodeIdAllocator;
        this.rootGroup = insertRecursive(planNode);
        this.groups.get(Integer.valueOf(this.rootGroup)).incomingReferences.add(0);
    }

    public int getRootGroup() {
        return this.rootGroup;
    }

    private Group getGroup(int i) {
        Preconditions.checkArgument(this.groups.containsKey(Integer.valueOf(i)), "Invalid group: %s", i);
        return this.groups.get(Integer.valueOf(i));
    }

    public PlanNode getNode(int i) {
        return getGroup(i).membership;
    }

    public PlanNode resolve(GroupReference groupReference) {
        return getNode(groupReference.getGroupId());
    }

    public PlanNode extract() {
        return extract(getNode(this.rootGroup));
    }

    private PlanNode extract(PlanNode planNode) {
        return Plans.resolveGroupReferences(planNode, Lookup.from(groupReference -> {
            return Stream.of(resolve(groupReference));
        }));
    }

    public PlanNode replace(int i, PlanNode planNode, String str) {
        PlanNode planNode2 = getGroup(i).membership;
        Preconditions.checkArgument(new HashSet(planNode2.getOutputSymbols()).equals(new HashSet(planNode.getOutputSymbols())), "%s: transformed expression doesn't produce same outputs: %s vs %s", str, planNode2.getOutputSymbols(), planNode.getOutputSymbols());
        PlanNode node = planNode instanceof GroupReference ? getNode(((GroupReference) planNode).getGroupId()) : insertChildrenAndRewrite(planNode);
        incrementReferenceCounts(node, i);
        getGroup(i).membership = node;
        decrementReferenceCounts(planNode2, i);
        evictStatistics(i);
        return node;
    }

    private void evictStatistics(int i) {
        getGroup(i).stats = null;
        Iterator it2 = getGroup(i).incomingReferences.elementSet().iterator();
        while (it2.hasNext()) {
            int intValue = ((Integer) it2.next()).intValue();
            if (intValue != 0) {
                evictStatistics(intValue);
            }
        }
    }

    public Optional<PlanNodeStatsEstimate> getStats(int i) {
        return Optional.ofNullable(getGroup(i).stats);
    }

    public void storeStats(int i, PlanNodeStatsEstimate planNodeStatsEstimate) {
        Group group = getGroup(i);
        if (group.stats != null) {
            evictStatistics(i);
        }
        group.stats = (PlanNodeStatsEstimate) Objects.requireNonNull(planNodeStatsEstimate, "stats is null");
    }

    private void incrementReferenceCounts(PlanNode planNode, int i) {
        Iterator<Integer> it2 = getAllReferences(planNode).iterator();
        while (it2.hasNext()) {
            this.groups.get(Integer.valueOf(it2.next().intValue())).incomingReferences.add(Integer.valueOf(i));
        }
    }

    private void decrementReferenceCounts(PlanNode planNode, int i) {
        Iterator<Integer> it2 = getAllReferences(planNode).iterator();
        while (it2.hasNext()) {
            int intValue = it2.next().intValue();
            Group group = this.groups.get(Integer.valueOf(intValue));
            Preconditions.checkState(group.incomingReferences.remove(Integer.valueOf(i)), "Reference to remove not found");
            if (group.incomingReferences.isEmpty()) {
                deleteGroup(intValue);
            }
        }
    }

    private Set<Integer> getAllReferences(PlanNode planNode) {
        Stream<PlanNode> stream = planNode.getSources().stream();
        Class<GroupReference> cls = GroupReference.class;
        GroupReference.class.getClass();
        return (Set) stream.map((v1) -> {
            return r1.cast(v1);
        }).map((v0) -> {
            return v0.getGroupId();
        }).collect(Collectors.toSet());
    }

    private void deleteGroup(int i) {
        Preconditions.checkArgument(getGroup(i).incomingReferences.isEmpty(), "Cannot delete group that has incoming references");
        decrementReferenceCounts(this.groups.remove(Integer.valueOf(i)).membership, i);
    }

    private PlanNode insertChildrenAndRewrite(PlanNode planNode) {
        return planNode.replaceChildren((List) planNode.getSources().stream().map(planNode2 -> {
            return new GroupReference(this.idAllocator.getNextId(), insertRecursive(planNode2), planNode2.getOutputSymbols());
        }).collect(Collectors.toList()));
    }

    private int insertRecursive(PlanNode planNode) {
        if (planNode instanceof GroupReference) {
            return ((GroupReference) planNode).getGroupId();
        }
        int nextGroupId = nextGroupId();
        PlanNode insertChildrenAndRewrite = insertChildrenAndRewrite(planNode);
        this.groups.put(Integer.valueOf(nextGroupId), Group.withMember(insertChildrenAndRewrite));
        incrementReferenceCounts(insertChildrenAndRewrite, nextGroupId);
        return nextGroupId;
    }

    private int nextGroupId() {
        int i = this.nextGroupId;
        this.nextGroupId = i + 1;
        return i;
    }

    public int getGroupCount() {
        return this.groups.size();
    }
}
