/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.parse.spark;

import io.trino.hive.$internal.com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Stack;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.OperatorUtils;
import org.apache.hadoop.hive.ql.exec.SerializationUtilities;
import org.apache.hadoop.hive.ql.exec.TableScanOperator;
import org.apache.hadoop.hive.ql.exec.Utilities;
import org.apache.hadoop.hive.ql.exec.spark.SparkUtilities;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.parse.spark.GenSparkProcContext;
import org.apache.hadoop.hive.ql.parse.spark.SparkPartitionPruningSinkOperator;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.TableScanDesc;

public class SplitOpTreeForDPP
implements NodeProcessor {
    @Override
    public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
        SparkPartitionPruningSinkOperator pruningSinkOp = (SparkPartitionPruningSinkOperator)nd;
        GenSparkProcContext context = (GenSparkProcContext)procCtx;
        for (Operator<?> op : context.pruningSinkSet) {
            if (!pruningSinkOp.getOperatorId().equals(op.getOperatorId())) continue;
            return null;
        }
        if (pruningSinkOp.isWithMapjoin()) {
            context.pruningSinkSet.add(pruningSinkOp);
            return null;
        }
        LinkedList roots = new LinkedList();
        this.collectRoots(roots, pruningSinkOp);
        Operator<?> branchingOp = pruningSinkOp.getBranchingOp();
        String marker = "SPARK_DPP_BRANCH_POINT_" + branchingOp.getOperatorId();
        branchingOp.setMarker(marker);
        List<Operator<? extends OperatorDesc>> savedChildOps = branchingOp.getChildOperators();
        List<Operator<? extends OperatorDesc>> firstNodesOfPruningBranch = this.findFirstNodesOfPruningBranch(branchingOp);
        branchingOp.setChildOperators(null);
        List<Operator<?>> newRoots = SerializationUtilities.cloneOperatorTree(roots);
        for (int i = 0; i < roots.size(); ++i) {
            TableScanOperator newTs = (TableScanOperator)newRoots.get(i);
            TableScanOperator oldTs = (TableScanOperator)roots.get(i);
            ((TableScanDesc)newTs.getConf()).setTableMetadata(((TableScanDesc)oldTs.getConf()).getTableMetadata());
        }
        context.clonedPruningTableScanSet.addAll(newRoots);
        Operator<?> newBranchingOp = null;
        for (int i = 0; i < newRoots.size() && newBranchingOp == null; ++i) {
            newBranchingOp = OperatorUtils.findOperatorByMarker(newRoots.get(i), marker);
        }
        Preconditions.checkNotNull(newBranchingOp, "Cannot find the branching operator in cloned tree.");
        newBranchingOp.setChildOperators(firstNodesOfPruningBranch);
        branchingOp.setChildOperators(savedChildOps);
        for (Operator<? extends OperatorDesc> selOp : firstNodesOfPruningBranch) {
            branchingOp.removeChild(selOp);
        }
        LinkedHashSet sinkSet = new LinkedHashSet();
        for (Operator<? extends OperatorDesc> sel : firstNodesOfPruningBranch) {
            SparkUtilities.collectOp(sinkSet, sel, SparkPartitionPruningSinkOperator.class);
            sel.setParentOperators(Utilities.makeList(newBranchingOp));
        }
        context.pruningSinkSet.addAll(sinkSet);
        return null;
    }

    private List<Operator<?>> findFirstNodesOfPruningBranch(Operator<?> branchingOp) {
        ArrayList res = new ArrayList();
        for (Operator<OperatorDesc> child : branchingOp.getChildOperators()) {
            if (!SparkUtilities.isDirectDPPBranch(child)) continue;
            res.add(child);
        }
        return res;
    }

    private void collectRoots(List<Operator<?>> result, Operator<?> op) {
        if (op.getNumParent() == 0) {
            result.add(op);
        } else {
            for (Operator<OperatorDesc> parentOp : op.getParentOperators()) {
                this.collectRoots(result, parentOp);
            }
        }
    }
}

