/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.spark;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Stack;
import java.util.TreeSet;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.ql.exec.MapJoinOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.spark.SparkTask;
import org.apache.hadoop.hive.ql.lib.Dispatcher;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.TaskGraphWalker;
import org.apache.hadoop.hive.ql.optimizer.OperatorComparatorFactory;
import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalContext;
import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalPlanResolver;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.BaseWork;
import org.apache.hadoop.hive.ql.plan.MapJoinDesc;
import org.apache.hadoop.hive.ql.plan.MapWork;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.PartitionDesc;
import org.apache.hadoop.hive.ql.plan.SparkEdgeProperty;
import org.apache.hadoop.hive.ql.plan.SparkWork;
import org.apache.hive.com.google.common.collect.Maps;
import org.apache.hive.com.google.common.collect.Sets;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CombineEquivalentWorkResolver
implements PhysicalPlanResolver {
    protected static transient Logger LOG = LoggerFactory.getLogger(CombineEquivalentWorkResolver.class);

    @Override
    public PhysicalContext resolve(PhysicalContext pctx) throws SemanticException {
        ArrayList<Node> topNodes = new ArrayList<Node>();
        topNodes.addAll(pctx.getRootTasks());
        TaskGraphWalker taskWalker = new TaskGraphWalker(new EquivalentWorkMatcher());
        HashMap<Node, Object> nodeOutput = Maps.newHashMap();
        taskWalker.startWalking(topNodes, nodeOutput);
        return pctx;
    }

    class EquivalentWorkMatcher
    implements Dispatcher {
        private Comparator<BaseWork> baseWorkComparator = new Comparator<BaseWork>(){

            @Override
            public int compare(BaseWork o1, BaseWork o2) {
                return o1.getName().compareTo(o2.getName());
            }
        };

        EquivalentWorkMatcher() {
        }

        @Override
        public Object dispatch(Node nd, Stack<Node> stack, Object ... nodeOutputs) throws SemanticException {
            if (nd instanceof SparkTask) {
                SparkTask sparkTask = (SparkTask)nd;
                SparkWork sparkWork = (SparkWork)sparkTask.getWork();
                Set<BaseWork> roots = sparkWork.getRoots();
                this.compareWorksRecursively(roots, sparkWork);
            }
            return null;
        }

        private void compareWorksRecursively(Set<BaseWork> works, SparkWork sparkWork) {
            Set<Set<BaseWork>> equivalentWorks = this.compareChildWorks(works, sparkWork);
            Set<BaseWork> removedWorks = this.combineEquivalentWorks(equivalentWorks, sparkWork);
            for (BaseWork work : works) {
                if (removedWorks.contains(work)) continue;
                HashSet<BaseWork> children = Sets.newHashSet();
                children.addAll(sparkWork.getChildren(work));
                if (children.size() <= 0) continue;
                this.compareWorksRecursively(children, sparkWork);
            }
        }

        private Set<Set<BaseWork>> compareChildWorks(Set<BaseWork> children, SparkWork sparkWork) {
            HashSet<Set<BaseWork>> equivalentChildren = Sets.newHashSet();
            if (children.size() > 1) {
                for (BaseWork work : children) {
                    boolean assigned = false;
                    for (Set set : equivalentChildren) {
                        if (!this.belongToSet(set, work, sparkWork)) continue;
                        set.add(work);
                        assigned = true;
                        break;
                    }
                    if (assigned) continue;
                    TreeSet<BaseWork> newSet = Sets.newTreeSet(this.baseWorkComparator);
                    newSet.add(work);
                    equivalentChildren.add(newSet);
                }
            }
            return equivalentChildren;
        }

        private boolean belongToSet(Set<BaseWork> set, BaseWork work, SparkWork sparkWork) {
            if (set.isEmpty()) {
                return true;
            }
            return this.compareWork(set.iterator().next(), work, sparkWork);
        }

        private Set<BaseWork> combineEquivalentWorks(Set<Set<BaseWork>> equivalentWorks, SparkWork sparkWork) {
            HashSet<BaseWork> removedWorks = Sets.newHashSet();
            for (Set<BaseWork> workSet : equivalentWorks) {
                if (workSet.size() <= 1) continue;
                Iterator<BaseWork> iterator = workSet.iterator();
                BaseWork first = iterator.next();
                while (iterator.hasNext()) {
                    BaseWork next = iterator.next();
                    this.replaceWork(next, first, sparkWork);
                    removedWorks.add(next);
                }
            }
            return removedWorks;
        }

        private void replaceWork(BaseWork previous, BaseWork current, SparkWork sparkWork) {
            this.updateReference(previous, current, sparkWork);
            List<BaseWork> parents = sparkWork.getParents(previous);
            List<BaseWork> children = sparkWork.getChildren(previous);
            if (parents != null) {
                for (BaseWork parent : parents) {
                    sparkWork.disconnect(parent, previous);
                }
            }
            if (children != null) {
                for (BaseWork child : children) {
                    SparkEdgeProperty edgeProperty = sparkWork.getEdgeProperty(previous, child);
                    sparkWork.disconnect(previous, child);
                    sparkWork.connect(current, child, edgeProperty);
                }
            }
            sparkWork.remove(previous);
        }

        private void updateReference(BaseWork previous, BaseWork current, SparkWork sparkWork) {
            String previousName = previous.getName();
            String currentName = current.getName();
            List<BaseWork> children = sparkWork.getAllWork();
            for (BaseWork child : children) {
                Set<Operator<?>> allOperators = child.getAllOperators();
                for (Operator<?> operator : allOperators) {
                    if (!(operator instanceof MapJoinOperator)) continue;
                    MapJoinDesc mapJoinDesc = (MapJoinDesc)((MapJoinOperator)operator).getConf();
                    Map<Integer, String> parentToInput = mapJoinDesc.getParentToInput();
                    for (Integer id : parentToInput.keySet()) {
                        String parent = parentToInput.get(id);
                        if (!parent.equals(previousName)) continue;
                        parentToInput.put(id, currentName);
                    }
                }
            }
        }

        private boolean compareWork(BaseWork first, BaseWork second, SparkWork sparkWork) {
            if (!first.getClass().getName().equals(second.getClass().getName())) {
                return false;
            }
            if (!this.hasSameParent(first, second, sparkWork)) {
                return false;
            }
            if (sparkWork.getLeaves().contains(first) && sparkWork.getLeaves().contains(second)) {
                return false;
            }
            if (first instanceof MapWork && !this.compareMapWork((MapWork)first, (MapWork)second)) {
                return false;
            }
            Set<Operator<? extends OperatorDesc>> firstRootOperators = first.getAllRootOperators();
            Set<Operator<? extends OperatorDesc>> secondRootOperators = second.getAllRootOperators();
            if (firstRootOperators.size() != secondRootOperators.size()) {
                return false;
            }
            Iterator<Operator<? extends OperatorDesc>> firstIterator = firstRootOperators.iterator();
            Iterator<Operator<? extends OperatorDesc>> secondIterator = secondRootOperators.iterator();
            while (firstIterator.hasNext()) {
                boolean result = this.compareOperatorChain(firstIterator.next(), secondIterator.next());
                if (result) continue;
                return result;
            }
            return true;
        }

        private boolean compareMapWork(MapWork first, MapWork second) {
            LinkedHashMap<Path, PartitionDesc> pathToPartition1 = first.getPathToPartitionInfo();
            LinkedHashMap<Path, PartitionDesc> pathToPartition2 = second.getPathToPartitionInfo();
            if (pathToPartition1.size() == pathToPartition2.size()) {
                for (Map.Entry entry : pathToPartition1.entrySet()) {
                    PartitionDesc partitionDesc2;
                    Path path1 = (Path)entry.getKey();
                    PartitionDesc partitionDesc1 = (PartitionDesc)entry.getValue();
                    if (partitionDesc1.equals(partitionDesc2 = (PartitionDesc)pathToPartition2.get(path1))) continue;
                    return false;
                }
                return true;
            }
            return false;
        }

        private boolean hasSameParent(BaseWork first, BaseWork second, SparkWork sparkWork) {
            boolean result = true;
            List<BaseWork> firstParents = sparkWork.getParents(first);
            List<BaseWork> secondParents = sparkWork.getParents(second);
            if (firstParents.size() != secondParents.size()) {
                result = false;
            }
            for (BaseWork parent : firstParents) {
                if (secondParents.contains(parent)) continue;
                result = false;
                break;
            }
            return result;
        }

        private boolean compareOperatorChain(Operator<?> firstOperator, Operator<?> secondOperator) {
            boolean result = this.compareCurrentOperator(firstOperator, secondOperator);
            if (!result) {
                return result;
            }
            List<Operator<OperatorDesc>> firstOperatorChildOperators = firstOperator.getChildOperators();
            List<Operator<OperatorDesc>> secondOperatorChildOperators = secondOperator.getChildOperators();
            if (firstOperatorChildOperators == null && secondOperatorChildOperators != null) {
                return false;
            }
            if (firstOperatorChildOperators != null && secondOperatorChildOperators == null) {
                return false;
            }
            if (firstOperatorChildOperators != null && secondOperatorChildOperators != null) {
                if (firstOperatorChildOperators.size() != secondOperatorChildOperators.size()) {
                    return false;
                }
                int size = firstOperatorChildOperators.size();
                for (int i = 0; i < size; ++i) {
                    result = this.compareOperatorChain(firstOperatorChildOperators.get(i), secondOperatorChildOperators.get(i));
                    if (result) continue;
                    return false;
                }
            }
            return true;
        }

        private boolean compareCurrentOperator(Operator<?> firstOperator, Operator<?> secondOperator) {
            if (!firstOperator.getClass().getName().equals(secondOperator.getClass().getName())) {
                return false;
            }
            OperatorComparatorFactory.OperatorComparator operatorComparator = OperatorComparatorFactory.getOperatorComparator(firstOperator.getClass());
            return operatorComparator.equals(firstOperator, secondOperator);
        }
    }
}

