/*
 * Decompiled with CFR 0.152.
 */
package jadx.core.utils;

import jadx.api.IDecompileScheduler;
import jadx.api.JavaClass;
import jadx.core.utils.Utils;
import jadx.core.utils.exceptions.JadxRuntimeException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.stream.Collectors;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DecompilerScheduler
implements IDecompileScheduler {
    private static final Logger LOG = LoggerFactory.getLogger(DecompilerScheduler.class);
    private static final int MERGED_BATCH_SIZE = 16;
    private static final boolean DEBUG_BATCHES = false;

    @Override
    public List<List<JavaClass>> buildBatches(List<JavaClass> classes) {
        try {
            long start = System.currentTimeMillis();
            List<List<JavaClass>> result = this.internalBatches(classes);
            if (LOG.isDebugEnabled()) {
                LOG.debug("Build decompilation batches in {}ms for {} classes", (Object)(System.currentTimeMillis() - start), (Object)classes.size());
            }
            return result;
        }
        catch (BootstrapMethodError | StackOverflowError e) {
            LOG.warn("Stack overflow while building decompile batches, continue with fallback");
        }
        catch (Exception e) {
            LOG.warn("Build batches failed (continue with fallback)", (Throwable)e);
        }
        return DecompilerScheduler.buildFallback(classes);
    }

    public List<List<JavaClass>> internalBatches(List<JavaClass> classes) {
        List<DepInfo> deps = DecompilerScheduler.sumDependencies(classes);
        HashSet<JavaClass> added = new HashSet<JavaClass>(classes.size());
        Comparator<JavaClass> cmpDepSize = Comparator.comparingInt(JavaClass::getTotalDepsCount);
        ArrayList<List<JavaClass>> result = new ArrayList<List<JavaClass>>();
        ArrayList<JavaClass> mergedBatch = new ArrayList<JavaClass>(16);
        for (DepInfo depInfo : deps) {
            JavaClass cls = depInfo.getCls();
            if (!added.add(cls)) continue;
            int depsSize = cls.getTotalDepsCount();
            if (depsSize == 0) {
                mergedBatch.add(cls);
                if (mergedBatch.size() < 16) continue;
                result.add(mergedBatch);
                mergedBatch = new ArrayList(16);
                continue;
            }
            ArrayList<JavaClass> batch = new ArrayList<JavaClass>();
            for (JavaClass dep : cls.getDependencies()) {
                JavaClass topDep = dep.getTopParentClass();
                if (added.contains(topDep)) continue;
                batch.add(topDep);
                added.add(topDep);
            }
            batch.sort(cmpDepSize);
            batch.add(cls);
            result.add(Utils.lockList(batch));
        }
        if (!mergedBatch.isEmpty()) {
            result.add(mergedBatch);
        }
        return result;
    }

    private static List<DepInfo> sumDependencies(List<JavaClass> classes) {
        ArrayList<DepInfo> deps = new ArrayList<DepInfo>(classes.size());
        for (JavaClass cls : classes) {
            int count = 0;
            for (JavaClass dep : cls.getDependencies()) {
                count += 1 + dep.getTotalDepsCount();
            }
            deps.add(new DepInfo(cls, count));
        }
        Collections.sort(deps);
        return deps;
    }

    private static List<List<JavaClass>> buildFallback(List<JavaClass> classes) {
        return classes.stream().sorted(Comparator.comparingInt(c -> c.getClassNode().getTotalDepsCount())).map(Collections::singletonList).collect(Collectors.toList());
    }

    private void dumpBatchesStats(List<JavaClass> classes, List<List<JavaClass>> result, List<DepInfo> deps) {
        int clsInBatches = result.stream().mapToInt(List::size).sum();
        double avg = result.stream().mapToInt(List::size).average().orElse(-1.0);
        int maxSingleDeps = classes.stream().mapToInt(JavaClass::getTotalDepsCount).max().orElse(-1);
        int maxSubDeps = deps.stream().mapToInt(DepInfo::getDepsCount).max().orElse(-1);
        LOG.info("Batches stats:\n input classes: " + classes.size() + ",\n classes in batches: " + clsInBatches + ",\n batches: " + result.size() + ",\n average batch size: " + String.format("%.2f", avg) + ",\n max single deps count: " + maxSingleDeps + ",\n max sub deps count: " + maxSubDeps);
    }

    private static void check(List<List<JavaClass>> result, List<JavaClass> classes) {
        int classInBatches = result.stream().mapToInt(List::size).sum();
        if (classes.size() != classInBatches) {
            throw new JadxRuntimeException("Incorrect number of classes in result batch: " + classInBatches + ", expected: " + classes.size());
        }
    }

    private static final class DepInfo
    implements Comparable<DepInfo> {
        private final JavaClass cls;
        private final int depsCount;

        private DepInfo(JavaClass cls, int depsCount) {
            this.cls = cls;
            this.depsCount = depsCount;
        }

        public JavaClass getCls() {
            return this.cls;
        }

        public int getDepsCount() {
            return this.depsCount;
        }

        @Override
        public int compareTo(@NotNull DepInfo o) {
            int deps = Integer.compare(this.depsCount, o.depsCount);
            if (deps == 0) {
                return this.cls.getClassNode().compareTo(o.cls.getClassNode());
            }
            return deps;
        }

        public String toString() {
            return String.valueOf(this.cls) + ":" + this.depsCount;
        }
    }
}

