/*
 * Decompiled with CFR 0.152.
 */
package jadx.core.utils;

import jadx.api.IDecompileScheduler;
import jadx.api.JadxDecompiler;
import jadx.api.JavaClass;
import jadx.core.dex.nodes.ClassNode;
import jadx.core.utils.Utils;
import jadx.core.utils.exceptions.JadxRuntimeException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.stream.Collectors;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DecompilerScheduler
implements IDecompileScheduler {
    private static final Logger LOG = LoggerFactory.getLogger(DecompilerScheduler.class);
    private static final int MERGED_BATCH_SIZE = 16;
    private static final boolean DEBUG_BATCHES = false;
    private final JadxDecompiler decompiler;

    public DecompilerScheduler(JadxDecompiler decompiler) {
        this.decompiler = decompiler;
    }

    @Override
    public List<List<JavaClass>> buildBatches(List<JavaClass> classes) {
        try {
            long start = System.currentTimeMillis();
            List<List<ClassNode>> batches = this.internalBatches(Utils.collectionMap(classes, JavaClass::getClassNode));
            List<List<JavaClass>> result = Utils.collectionMap(batches, l -> Utils.collectionMapNoNull(l, this.decompiler::getJavaClassByNode));
            if (LOG.isDebugEnabled()) {
                LOG.debug("Build decompilation batches in {}ms", (Object)(System.currentTimeMillis() - start));
            }
            return result;
        }
        catch (Throwable e) {
            LOG.warn("Build batches failed (continue with fallback)", e);
            return DecompilerScheduler.buildFallback(classes);
        }
    }

    public List<List<ClassNode>> internalBatches(List<ClassNode> classes) {
        List<DepInfo> deps = DecompilerScheduler.sumDependencies(classes);
        HashSet<ClassNode> added = new HashSet<ClassNode>(classes.size());
        Comparator<ClassNode> cmpDepSize = Comparator.comparingInt(ClassNode::getTotalDepsCount);
        ArrayList<List<ClassNode>> result = new ArrayList<List<ClassNode>>();
        ArrayList<ClassNode> mergedBatch = new ArrayList<ClassNode>(16);
        for (DepInfo depInfo : deps) {
            ClassNode cls = depInfo.getCls();
            if (!added.add(cls)) continue;
            int depsSize = cls.getTotalDepsCount();
            if (depsSize == 0) {
                mergedBatch.add(cls);
                if (mergedBatch.size() < 16) continue;
                result.add(mergedBatch);
                mergedBatch = new ArrayList(16);
                continue;
            }
            ArrayList<ClassNode> batch = new ArrayList<ClassNode>(depsSize + 1);
            for (ClassNode dep : cls.getDependencies()) {
                ClassNode topDep = dep.getTopParentClass();
                if (added.contains(topDep)) continue;
                batch.add(topDep);
                added.add(topDep);
            }
            batch.sort(cmpDepSize);
            batch.add(cls);
            result.add(batch);
        }
        if (mergedBatch.size() > 0) {
            result.add(mergedBatch);
        }
        return result;
    }

    private static List<DepInfo> sumDependencies(List<ClassNode> classes) {
        ArrayList<DepInfo> deps = new ArrayList<DepInfo>(classes.size());
        for (ClassNode cls : classes) {
            int count = 0;
            for (ClassNode dep : cls.getDependencies()) {
                count += 1 + dep.getTotalDepsCount();
            }
            deps.add(new DepInfo(cls, count));
        }
        Collections.sort(deps);
        return deps;
    }

    private static List<List<JavaClass>> buildFallback(List<JavaClass> classes) {
        return classes.stream().sorted(Comparator.comparingInt(c -> c.getClassNode().getTotalDepsCount())).map(Collections::singletonList).collect(Collectors.toList());
    }

    private void dumpBatchesStats(List<ClassNode> classes, List<List<ClassNode>> result, List<DepInfo> deps) {
        double avg = result.stream().mapToInt(List::size).average().orElse(-1.0);
        int maxSingleDeps = classes.stream().mapToInt(ClassNode::getTotalDepsCount).max().orElse(-1);
        int maxSubDeps = deps.stream().mapToInt(DepInfo::getDepsCount).max().orElse(-1);
        LOG.info("Batches stats:\n input classes: " + classes.size() + ",\n batches: " + result.size() + ",\n average batch size: " + String.format("%.2f", avg) + ",\n max single deps count: " + maxSingleDeps + ",\n max sub deps count: " + maxSubDeps);
    }

    private static void check(List<List<JavaClass>> result, List<JavaClass> classes) {
        int classInBatches = result.stream().mapToInt(List::size).sum();
        if (classes.size() != classInBatches) {
            throw new JadxRuntimeException("Incorrect number of classes in result batch: " + classInBatches + ", expected: " + classes.size());
        }
    }

    private static final class DepInfo
    implements Comparable<DepInfo> {
        private final ClassNode cls;
        private final int depsCount;

        private DepInfo(ClassNode cls, int depsCount) {
            this.cls = cls;
            this.depsCount = depsCount;
        }

        public ClassNode getCls() {
            return this.cls;
        }

        public int getDepsCount() {
            return this.depsCount;
        }

        @Override
        public int compareTo(@NotNull DepInfo o) {
            int deps = Integer.compare(this.depsCount, o.depsCount);
            if (deps == 0) {
                return this.cls.compareTo(o.cls);
            }
            return deps;
        }

        public String toString() {
            return this.cls + ":" + this.depsCount;
        }
    }
}

