/*
 * Decompiled with CFR 0.152.
 */
package org.teavm.backend.lowlevel.transform;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.teavm.common.Graph;
import org.teavm.common.GraphSplittingBackend;
import org.teavm.common.GraphUtils;
import org.teavm.hppc.IntHashSet;
import org.teavm.hppc.IntIntHashMap;
import org.teavm.hppc.IntIntMap;
import org.teavm.model.BasicBlock;
import org.teavm.model.ClassReader;
import org.teavm.model.ClassReaderSource;
import org.teavm.model.Incoming;
import org.teavm.model.Instruction;
import org.teavm.model.MethodDescriptor;
import org.teavm.model.MethodReader;
import org.teavm.model.MethodReference;
import org.teavm.model.Program;
import org.teavm.model.TextLocation;
import org.teavm.model.ValueType;
import org.teavm.model.Variable;
import org.teavm.model.instructions.BranchingCondition;
import org.teavm.model.instructions.BranchingInstruction;
import org.teavm.model.instructions.DoubleConstantInstruction;
import org.teavm.model.instructions.ExitInstruction;
import org.teavm.model.instructions.FloatConstantInstruction;
import org.teavm.model.instructions.InitClassInstruction;
import org.teavm.model.instructions.IntegerConstantInstruction;
import org.teavm.model.instructions.InvocationType;
import org.teavm.model.instructions.InvokeInstruction;
import org.teavm.model.instructions.JumpInstruction;
import org.teavm.model.instructions.LongConstantInstruction;
import org.teavm.model.instructions.MonitorEnterInstruction;
import org.teavm.model.instructions.NullConstantInstruction;
import org.teavm.model.instructions.SwitchInstruction;
import org.teavm.model.instructions.SwitchTableEntry;
import org.teavm.model.util.BasicBlockMapper;
import org.teavm.model.util.BasicBlockSplitter;
import org.teavm.model.util.DefinitionExtractor;
import org.teavm.model.util.LivenessAnalyzer;
import org.teavm.model.util.PhiUpdater;
import org.teavm.model.util.ProgramUtils;
import org.teavm.model.util.TypeInferer;
import org.teavm.model.util.UsageExtractor;
import org.teavm.model.util.VariableType;
import org.teavm.runtime.Fiber;

public class CoroutineTransformation {
    private static final MethodReference FIBER_SUSPEND = new MethodReference(Fiber.class, "suspend", Fiber.AsyncCall.class, Object.class);
    private static final String ASYNC_CALL = Fiber.class.getName() + "$AsyncCall";
    private ClassReaderSource classSource;
    private LivenessAnalyzer livenessAnalysis = new LivenessAnalyzer();
    private TypeInferer variableTypes = new TypeInferer();
    private Set<MethodReference> asyncMethods;
    private Program program;
    private Variable fiberVar;
    private BasicBlockSplitter splitter;
    private SwitchInstruction resumeSwitch;
    private int parameterCount;
    private ValueType returnType;
    private boolean hasThreads;

    public CoroutineTransformation(ClassReaderSource classSource, Set<MethodReference> asyncMethods, boolean hasThreads) {
        this.classSource = classSource;
        this.asyncMethods = asyncMethods;
        this.hasThreads = hasThreads;
    }

    public void apply(Program program, MethodReference methodReference) {
        if (methodReference.getClassName().equals(Fiber.class.getName())) {
            return;
        }
        ClassReader cls = this.classSource.get(methodReference.getClassName());
        if (cls != null && cls.getInterfaces().contains(ASYNC_CALL)) {
            return;
        }
        boolean hasJob = false;
        for (BasicBlock block : program.getBasicBlocks()) {
            if (!this.hasSplitInstructions(block)) continue;
            hasJob = true;
        }
        if (!hasJob) {
            return;
        }
        this.program = program;
        this.parameterCount = methodReference.parameterCount();
        this.returnType = methodReference.getReturnType();
        this.variableTypes.inferTypes(program, methodReference);
        this.livenessAnalysis.analyze(program, methodReference.getDescriptor());
        this.splitter = new BasicBlockSplitter(program);
        int basicBlockCount = program.basicBlockCount();
        this.createSplitPrologue();
        for (int i = 1; i <= basicBlockCount; ++i) {
            this.processBlock(program.basicBlockAt(i));
        }
        this.splitter.fixProgram();
        this.processIrreducibleCfg();
        new PhiUpdater().updatePhis(program, methodReference.parameterCount() + 1);
    }

    private void createSplitPrologue() {
        this.fiberVar = this.program.createVariable();
        this.fiberVar.setLabel("fiber");
        BasicBlock firstBlock = this.program.basicBlockAt(0);
        BasicBlock continueBlock = this.splitter.split(firstBlock, null);
        BasicBlock switchStateBlock = this.program.createBasicBlock();
        TextLocation location = continueBlock.getFirstInstruction().getLocation();
        InvokeInstruction getFiber = new InvokeInstruction();
        getFiber.setType(InvocationType.SPECIAL);
        getFiber.setMethod(new MethodReference(Fiber.class, "current", Fiber.class));
        getFiber.setReceiver(this.fiberVar);
        getFiber.setLocation(location);
        firstBlock.add(getFiber);
        InvokeInstruction isResuming = new InvokeInstruction();
        isResuming.setType(InvocationType.SPECIAL);
        isResuming.setMethod(new MethodReference(Fiber.class, "isResuming", Boolean.TYPE));
        isResuming.setInstance(this.fiberVar);
        isResuming.setReceiver(this.program.createVariable());
        isResuming.setLocation(location);
        firstBlock.add(isResuming);
        BranchingInstruction jumpIfResuming = new BranchingInstruction(BranchingCondition.NOT_EQUAL);
        jumpIfResuming.setOperand(isResuming.getReceiver());
        jumpIfResuming.setConsequent(switchStateBlock);
        jumpIfResuming.setAlternative(continueBlock);
        firstBlock.add(jumpIfResuming);
        InvokeInstruction popInt = new InvokeInstruction();
        popInt.setType(InvocationType.SPECIAL);
        popInt.setMethod(new MethodReference(Fiber.class, "popInt", Integer.TYPE));
        popInt.setInstance(this.fiberVar);
        popInt.setReceiver(this.program.createVariable());
        popInt.setLocation(location);
        switchStateBlock.add(popInt);
        this.resumeSwitch = new SwitchInstruction();
        this.resumeSwitch.setDefaultTarget(continueBlock);
        this.resumeSwitch.setCondition(popInt.getReceiver());
        this.resumeSwitch.setLocation(location);
        switchStateBlock.add(this.resumeSwitch);
    }

    private void processBlock(BasicBlock block) {
        Map<Instruction, BitSet> splitInstructions = this.collectSplitInstructions(block);
        ArrayList<Instruction> instructionList = new ArrayList<Instruction>(splitInstructions.keySet());
        Collections.reverse(instructionList);
        for (Instruction instruction : instructionList) {
            BasicBlock intermediate = this.splitter.split(block, instruction.getPrevious());
            BasicBlock next = this.splitter.split(intermediate, instruction);
            this.createSplitPoint(block, intermediate, next, splitInstructions.get(instruction));
            block = next;
        }
    }

    private Map<Instruction, BitSet> collectSplitInstructions(BasicBlock block) {
        if (!this.hasSplitInstructions(block)) {
            return Collections.emptyMap();
        }
        BitSet live = this.livenessAnalysis.liveOut(block.getIndex());
        LinkedHashMap<Instruction, BitSet> result = new LinkedHashMap<Instruction, BitSet>();
        UsageExtractor use = new UsageExtractor();
        DefinitionExtractor def = new DefinitionExtractor();
        for (Instruction instruction = block.getLastInstruction(); instruction != null; instruction = instruction.getPrevious()) {
            instruction.acceptVisitor(def);
            if (def.getDefinedVariables() != null) {
                for (Variable var : def.getDefinedVariables()) {
                    live.clear(var.getIndex());
                }
            }
            instruction.acceptVisitor(use);
            if (use.getUsedVariables() != null) {
                for (Variable var : use.getUsedVariables()) {
                    live.set(var.getIndex());
                }
            }
            if (!this.isSplitInstruction(instruction)) continue;
            result.put(instruction, (BitSet)live.clone());
        }
        return result;
    }

    private boolean hasSplitInstructions(BasicBlock block) {
        for (Instruction instruction : block) {
            if (!this.isSplitInstruction(instruction)) continue;
            return true;
        }
        return false;
    }

    private boolean isSplitInstruction(Instruction instruction) {
        if (instruction instanceof InvokeInstruction) {
            InvokeInstruction invoke = (InvokeInstruction)instruction;
            MethodReference method = this.findRealMethod(invoke.getMethod());
            if (method.equals(FIBER_SUSPEND)) {
                return true;
            }
            if (method.getClassName().equals(Fiber.class.getName())) {
                return false;
            }
            return this.asyncMethods.contains(method);
        }
        if (instruction instanceof InitClassInstruction) {
            return this.isSplittingClassInitializer(((InitClassInstruction)instruction).getClassName());
        }
        return this.hasThreads && instruction instanceof MonitorEnterInstruction;
    }

    private void createSplitPoint(BasicBlock block, BasicBlock intermediate, BasicBlock next, BitSet liveVars) {
        int stateNumber = this.resumeSwitch.getEntries().size();
        Instruction splitInstruction = intermediate.getFirstInstruction();
        JumpInstruction jumpToIntermediate = new JumpInstruction();
        jumpToIntermediate.setTarget(intermediate);
        jumpToIntermediate.setLocation(splitInstruction.getLocation());
        block.add(jumpToIntermediate);
        BasicBlock restoreBlock = this.program.createBasicBlock();
        BasicBlock saveBlock = this.program.createBasicBlock();
        SwitchTableEntry switchTableEntry = new SwitchTableEntry();
        switchTableEntry.setCondition(stateNumber);
        switchTableEntry.setTarget(restoreBlock);
        this.resumeSwitch.getEntries().add(switchTableEntry);
        InvokeInstruction isSuspending = new InvokeInstruction();
        isSuspending.setType(InvocationType.SPECIAL);
        isSuspending.setMethod(new MethodReference(Fiber.class, "isSuspending", Boolean.TYPE));
        isSuspending.setInstance(this.fiberVar);
        isSuspending.setReceiver(this.program.createVariable());
        isSuspending.setLocation(splitInstruction.getLocation());
        intermediate.add(isSuspending);
        BranchingInstruction branchIfSuspending = new BranchingInstruction(BranchingCondition.NOT_EQUAL);
        branchIfSuspending.setOperand(isSuspending.getReceiver());
        branchIfSuspending.setConsequent(saveBlock);
        branchIfSuspending.setAlternative(next);
        branchIfSuspending.setLocation(splitInstruction.getLocation());
        intermediate.add(branchIfSuspending);
        restoreBlock.addAll(this.restoreState(liveVars));
        JumpInstruction doneRestoring = new JumpInstruction();
        doneRestoring.setTarget(intermediate);
        restoreBlock.add(doneRestoring);
        for (Instruction instruction : restoreBlock) {
            instruction.setLocation(splitInstruction.getLocation());
        }
        for (Instruction instruction : this.saveState(liveVars)) {
            instruction.setLocation(splitInstruction.getLocation());
            saveBlock.add(instruction);
        }
        for (Instruction instruction : this.saveStateNumber(stateNumber)) {
            instruction.setLocation(splitInstruction.getLocation());
            saveBlock.add(instruction);
        }
        this.createReturnInstructions(splitInstruction.getLocation(), saveBlock);
    }

    private List<Instruction> saveState(BitSet vars) {
        ArrayList<Instruction> instructions = new ArrayList<Instruction>();
        int var = vars.nextSetBit(0);
        while (var >= 0) {
            this.saveVariable(var, instructions);
            var = vars.nextSetBit(var + 1);
        }
        return instructions;
    }

    private List<Instruction> saveStateNumber(int number) {
        IntegerConstantInstruction constant = new IntegerConstantInstruction();
        constant.setReceiver(this.program.createVariable());
        constant.setConstant(number);
        InvokeInstruction invoke = new InvokeInstruction();
        invoke.setType(InvocationType.SPECIAL);
        invoke.setMethod(new MethodReference(Fiber.class, "push", Integer.TYPE, Void.TYPE));
        invoke.setInstance(this.fiberVar);
        invoke.setArguments(constant.getReceiver());
        return Arrays.asList(constant, invoke);
    }

    private List<Instruction> restoreState(BitSet vars) {
        ArrayList<Instruction> instructions = new ArrayList<Instruction>();
        int[] varArray = new int[vars.cardinality()];
        int j = 0;
        int i = vars.nextSetBit(0);
        while (i >= 0) {
            varArray[j++] = i;
            i = vars.nextSetBit(i + 1);
        }
        for (i = varArray.length - 1; i >= 0; --i) {
            this.restoreVariable(varArray[i], instructions);
        }
        return instructions;
    }

    private void saveVariable(int var, List<Instruction> instructions) {
        VariableType type = this.variableTypes.typeOf(var);
        InvokeInstruction invoke = new InvokeInstruction();
        invoke.setType(InvocationType.SPECIAL);
        invoke.setInstance(this.fiberVar);
        invoke.setArguments(this.program.variableAt(var));
        switch (type) {
            case INT: {
                invoke.setMethod(new MethodReference(Fiber.class, "push", Integer.TYPE, Void.TYPE));
                break;
            }
            case LONG: {
                invoke.setMethod(new MethodReference(Fiber.class, "push", Long.TYPE, Void.TYPE));
                break;
            }
            case FLOAT: {
                invoke.setMethod(new MethodReference(Fiber.class, "push", Float.TYPE, Void.TYPE));
                break;
            }
            case DOUBLE: {
                invoke.setMethod(new MethodReference(Fiber.class, "push", Double.TYPE, Void.TYPE));
                break;
            }
            default: {
                invoke.setMethod(new MethodReference(Fiber.class, "push", Object.class, Void.TYPE));
            }
        }
        instructions.add(invoke);
    }

    private void restoreVariable(int var, List<Instruction> instructions) {
        VariableType type = this.variableTypes.typeOf(var);
        InvokeInstruction invoke = new InvokeInstruction();
        invoke.setType(InvocationType.SPECIAL);
        invoke.setInstance(this.fiberVar);
        invoke.setReceiver(this.program.variableAt(var));
        switch (type) {
            case INT: {
                invoke.setMethod(new MethodReference(Fiber.class, "popInt", Integer.TYPE));
                break;
            }
            case LONG: {
                invoke.setMethod(new MethodReference(Fiber.class, "popLong", Long.TYPE));
                break;
            }
            case FLOAT: {
                invoke.setMethod(new MethodReference(Fiber.class, "popFloat", Float.TYPE));
                break;
            }
            case DOUBLE: {
                invoke.setMethod(new MethodReference(Fiber.class, "popDouble", Double.TYPE));
                break;
            }
            default: {
                invoke.setMethod(new MethodReference(Fiber.class, "popObject", Object.class));
            }
        }
        instructions.add(invoke);
    }

    private boolean isSplittingClassInitializer(String className) {
        ClassReader cls = this.classSource.get(className);
        if (cls == null) {
            return false;
        }
        MethodReader method = cls.getMethod(new MethodDescriptor("<clinit>", ValueType.VOID));
        return method != null && this.asyncMethods.contains(method.getReference());
    }

    private MethodReference findRealMethod(MethodReference method) {
        ClassReader cls;
        String clsName = method.getClassName();
        while (clsName != null && (cls = this.classSource.get(clsName)) != null) {
            MethodReader methodReader = cls.getMethod(method.getDescriptor());
            if (methodReader != null) {
                return new MethodReference(clsName, method.getDescriptor());
            }
            clsName = cls.getParent();
            if (clsName == null || !clsName.equals(cls.getName())) continue;
            break;
        }
        return method;
    }

    private void createReturnInstructions(TextLocation location, BasicBlock block) {
        ExitInstruction exit = new ExitInstruction();
        exit.setLocation(location);
        if (this.returnType == ValueType.VOID) {
            block.add(exit);
            return;
        }
        exit.setValueToReturn(this.program.createVariable());
        Instruction returnValue = this.createReturnValueInstruction(exit.getValueToReturn());
        returnValue.setLocation(location);
        block.add(returnValue);
        block.add(exit);
    }

    private Instruction createReturnValueInstruction(Variable target) {
        if (this.returnType instanceof ValueType.Primitive) {
            switch (((ValueType.Primitive)this.returnType).getKind()) {
                case BOOLEAN: 
                case BYTE: 
                case CHARACTER: 
                case SHORT: 
                case INTEGER: {
                    IntegerConstantInstruction instruction = new IntegerConstantInstruction();
                    instruction.setReceiver(target);
                    return instruction;
                }
                case LONG: {
                    LongConstantInstruction instruction = new LongConstantInstruction();
                    instruction.setReceiver(target);
                    return instruction;
                }
                case FLOAT: {
                    FloatConstantInstruction instruction = new FloatConstantInstruction();
                    instruction.setReceiver(target);
                    return instruction;
                }
                case DOUBLE: {
                    DoubleConstantInstruction instruction = new DoubleConstantInstruction();
                    instruction.setReceiver(target);
                    return instruction;
                }
            }
        }
        NullConstantInstruction instruction = new NullConstantInstruction();
        instruction.setReceiver(target);
        return instruction;
    }

    private void processIrreducibleCfg() {
        Graph graph = ProgramUtils.buildControlFlowGraph(this.program);
        if (!GraphUtils.isIrreducible(graph)) {
            return;
        }
        SplittingBackend splittingBackend = new SplittingBackend();
        int[] weights = new int[graph.size()];
        for (int i = 0; i < this.program.basicBlockCount(); ++i) {
            weights[i] = this.program.basicBlockAt(i).instructionCount();
        }
        GraphUtils.splitIrreducibleGraph(graph, weights, splittingBackend);
        new PhiUpdater().updatePhis(this.program, this.parameterCount + 1);
    }

    class SplittingBackend
    implements GraphSplittingBackend {
        SplittingBackend() {
        }

        @Override
        public int[] split(int[] domain, int[] nodes) {
            int[] copies = new int[nodes.length];
            IntIntHashMap map = new IntIntHashMap();
            IntHashSet nodeSet = IntHashSet.from((int[])nodes);
            List<List<Incoming>> outputs = ProgramUtils.getPhiOutputs(CoroutineTransformation.this.program);
            for (int i = 0; i < nodes.length; ++i) {
                int node = nodes[i];
                BasicBlock block = CoroutineTransformation.this.program.basicBlockAt(node);
                BasicBlock blockCopy = CoroutineTransformation.this.program.createBasicBlock();
                ProgramUtils.copyBasicBlock(block, blockCopy);
                copies[i] = blockCopy.getIndex();
                map.put(node, copies[i] + 1);
            }
            BasicBlockMapper copyBlockMapper = new BasicBlockMapper(arg_0 -> SplittingBackend.lambda$split$0((IntIntMap)map, arg_0));
            for (int copy : copies) {
                copyBlockMapper.transform(CoroutineTransformation.this.program.basicBlockAt(copy));
            }
            for (int domainNode : domain) {
                copyBlockMapper.transformWithoutPhis(CoroutineTransformation.this.program.basicBlockAt(domainNode));
            }
            for (int i = 0; i < nodes.length; ++i) {
                int node = nodes[i];
                BasicBlock blockCopy = CoroutineTransformation.this.program.basicBlockAt(copies[i]);
                for (Incoming output : outputs.get(node)) {
                    if (nodeSet.contains(output.getPhi().getBasicBlock().getIndex())) continue;
                    Incoming outputCopy = new Incoming();
                    outputCopy.setSource(blockCopy);
                    outputCopy.setValue(output.getValue());
                    output.getPhi().getIncomings().add(outputCopy);
                }
            }
            return copies;
        }

        private static /* synthetic */ int lambda$split$0(IntIntMap map, int block) {
            int mappedIndex = map.get(block);
            return mappedIndex == 0 ? block : mappedIndex - 1;
        }
    }
}

