/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.test.context.aot;

import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Predicate;
import java.util.stream.Stream;
import org.springframework.aot.AotDetector;
import org.springframework.aot.generate.ClassNameGenerator;
import org.springframework.aot.generate.DefaultGenerationContext;
import org.springframework.aot.generate.GeneratedClasses;
import org.springframework.aot.generate.GeneratedFiles;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.ReflectionHints;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.RuntimeHintsRegistrar;
import org.springframework.aot.hint.TypeReference;
import org.springframework.aot.hint.annotation.ReflectiveRuntimeHintsRegistrar;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.aot.AotServices;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.ImportRuntimeHints;
import org.springframework.context.aot.ApplicationContextAotGenerator;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.core.SpringProperties;
import org.springframework.core.annotation.MergedAnnotation;
import org.springframework.core.annotation.MergedAnnotations;
import org.springframework.javapoet.ClassName;
import org.springframework.test.context.BootstrapUtils;
import org.springframework.test.context.ContextLoadException;
import org.springframework.test.context.ContextLoader;
import org.springframework.test.context.MergedContextConfiguration;
import org.springframework.test.context.TestContextAnnotationUtils;
import org.springframework.test.context.TestContextBootstrapper;
import org.springframework.test.context.aot.AotContextLoader;
import org.springframework.test.context.aot.AotTestAttributes;
import org.springframework.test.context.aot.AotTestAttributesCodeGenerator;
import org.springframework.test.context.aot.AotTestAttributesFactory;
import org.springframework.test.context.aot.AotTestContextInitializers;
import org.springframework.test.context.aot.AotTestContextInitializersCodeGenerator;
import org.springframework.test.context.aot.AotTestContextInitializersFactory;
import org.springframework.test.context.aot.AotTestExecutionListener;
import org.springframework.test.context.aot.DisabledInAotMode;
import org.springframework.test.context.aot.MergedContextConfigurationRuntimeHints;
import org.springframework.test.context.aot.TestContextAotException;
import org.springframework.test.context.aot.TestContextGenerationContext;
import org.springframework.test.context.aot.TestRuntimeHintsRegistrar;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import wiremock.org.apache.commons.logging.Log;
import wiremock.org.apache.commons.logging.LogFactory;

public class TestContextAotGenerator {
    public static final String FAIL_ON_ERROR_PROPERTY_NAME = "spring.test.aot.processing.failOnError";
    private static final Log logger = LogFactory.getLog(TestContextAotGenerator.class);
    private static final Predicate<? super Class<?>> isDisabledInAotMode = testClass -> TestContextAnnotationUtils.hasAnnotation(testClass, DisabledInAotMode.class);
    private final ApplicationContextAotGenerator aotGenerator = new ApplicationContextAotGenerator();
    private final AotServices<TestRuntimeHintsRegistrar> testRuntimeHintsRegistrars;
    private final MergedContextConfigurationRuntimeHints mergedConfigRuntimeHints = new MergedContextConfigurationRuntimeHints();
    private final AtomicInteger sequence = new AtomicInteger();
    private final GeneratedFiles generatedFiles;
    private final RuntimeHints runtimeHints;
    final boolean failOnError;

    public TestContextAotGenerator(GeneratedFiles generatedFiles) {
        this(generatedFiles, new RuntimeHints());
    }

    public TestContextAotGenerator(GeneratedFiles generatedFiles, RuntimeHints runtimeHints) {
        this(generatedFiles, runtimeHints, TestContextAotGenerator.getFailOnErrorFlag());
    }

    public TestContextAotGenerator(GeneratedFiles generatedFiles, RuntimeHints runtimeHints, boolean failOnError) {
        this.testRuntimeHintsRegistrars = AotServices.factories().load(TestRuntimeHintsRegistrar.class);
        this.generatedFiles = generatedFiles;
        this.runtimeHints = runtimeHints;
        this.failOnError = failOnError;
    }

    public final RuntimeHints getRuntimeHints() {
        return this.runtimeHints;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void processAheadOfTime(Stream<Class<?>> testClasses) throws TestContextAotException {
        Assert.state(!AotDetector.useGeneratedArtifacts(), "Cannot perform AOT processing during AOT run-time execution");
        try {
            this.resetAotFactories();
            LinkedHashSet coreRuntimeHintsRegistrarClasses = new LinkedHashSet();
            ReflectiveRuntimeHintsRegistrar reflectiveRuntimeHintsRegistrar = new ReflectiveRuntimeHintsRegistrar();
            LinkedMultiValueMap mergedConfigMappings = new LinkedMultiValueMap();
            ClassLoader classLoader = this.getClass().getClassLoader();
            testClasses.forEach(testClass -> {
                MergedContextConfiguration mergedConfig = this.buildMergedContextConfiguration((Class<?>)testClass);
                mergedConfigMappings.add(mergedConfig, (Class<?>)testClass);
                this.collectRuntimeHintsRegistrarClasses((Class<?>)testClass, coreRuntimeHintsRegistrarClasses);
                reflectiveRuntimeHintsRegistrar.registerRuntimeHints(this.runtimeHints, (Class<?>)testClass);
                this.testRuntimeHintsRegistrars.forEach(registrar -> {
                    if (logger.isTraceEnabled()) {
                        logger.trace("Processing RuntimeHints contribution from class [%s]".formatted(registrar.getClass().getCanonicalName()));
                    }
                    registrar.registerHints(this.runtimeHints, (Class<?>)testClass, classLoader);
                });
            });
            coreRuntimeHintsRegistrarClasses.stream().map(BeanUtils::instantiateClass).forEach(registrar -> {
                if (logger.isTraceEnabled()) {
                    logger.trace("Processing RuntimeHints contribution from class [%s]".formatted(registrar.getClass().getCanonicalName()));
                }
                registrar.registerHints(this.runtimeHints, classLoader);
            });
            MultiValueMap<ClassName, Class<?>> initializerClassMappings = this.processAheadOfTime(mergedConfigMappings);
            this.generateAotTestContextInitializerMappings(initializerClassMappings);
            this.generateAotTestAttributeMappings();
            this.registerSkippedExceptionTypes();
        }
        finally {
            this.resetAotFactories();
        }
    }

    private void collectRuntimeHintsRegistrarClasses(Class<?> testClass, Set<Class<? extends RuntimeHintsRegistrar>> coreRuntimeHintsRegistrarClasses) {
        MergedAnnotations.from(testClass, MergedAnnotations.SearchStrategy.TYPE_HIERARCHY).stream(ImportRuntimeHints.class).filter(MergedAnnotation::isPresent).map(MergedAnnotation::synthesize).map(ImportRuntimeHints::value).flatMap(Arrays::stream).forEach(coreRuntimeHintsRegistrarClasses::add);
    }

    private void resetAotFactories() {
        AotTestAttributesFactory.reset();
        AotTestContextInitializersFactory.reset();
    }

    private MultiValueMap<ClassName, Class<?>> processAheadOfTime(MultiValueMap<MergedContextConfiguration, Class<?>> mergedConfigMappings) {
        ClassLoader classLoader = this.getClass().getClassLoader();
        LinkedMultiValueMap initializerClassMappings = new LinkedMultiValueMap();
        mergedConfigMappings.forEach((mergedConfig, testClasses) -> {
            block11: {
                long numDisabled = testClasses.stream().filter(isDisabledInAotMode).count();
                if (numDisabled > 0L) {
                    if (numDisabled != (long)testClasses.size()) {
                        if (this.failOnError) {
                            throw new TestContextAotException("All test classes that share an ApplicationContext must be annotated with @DisabledInAotMode if one of them is: " + TestContextAotGenerator.classNames(testClasses));
                        }
                        if (logger.isWarnEnabled()) {
                            logger.warn("All test classes that share an ApplicationContext must be annotated with @DisabledInAotMode if one of them is: " + TestContextAotGenerator.classNames(testClasses));
                        }
                    }
                    if (logger.isInfoEnabled()) {
                        logger.info("Skipping AOT processing due to the presence of @DisabledInAotMode for test classes " + TestContextAotGenerator.classNames(testClasses));
                    }
                } else {
                    if (logger.isDebugEnabled()) {
                        logger.debug("Generating AOT artifacts for test classes " + TestContextAotGenerator.classNames(testClasses));
                    }
                    this.mergedConfigRuntimeHints.registerHints(this.runtimeHints, (MergedContextConfiguration)mergedConfig, classLoader);
                    try {
                        Class testClass = (Class)testClasses.get(0);
                        DefaultGenerationContext generationContext = this.createGenerationContext(testClass);
                        ClassName initializer = this.processAheadOfTime((MergedContextConfiguration)mergedConfig, generationContext);
                        Assert.state(!initializerClassMappings.containsKey(initializer), () -> "ClassName [%s] already encountered".formatted(initializer.reflectionName()));
                        initializerClassMappings.addAll(initializer, (List<Class<?>>)testClasses);
                        generationContext.writeGeneratedContent();
                    }
                    catch (Exception ex) {
                        if (this.failOnError) {
                            throw new TestContextAotException("Failed to generate AOT artifacts for test classes " + TestContextAotGenerator.classNames(testClasses), ex);
                        }
                        if (logger.isDebugEnabled()) {
                            logger.debug("Failed to generate AOT artifacts for test classes " + TestContextAotGenerator.classNames(testClasses), ex);
                        }
                        if (!logger.isWarnEnabled()) break block11;
                        logger.warn("Failed to generate AOT artifacts for test classes %s. Enable DEBUG logging to view the stack trace. %s".formatted(TestContextAotGenerator.classNames(testClasses), ex));
                    }
                }
            }
        });
        return initializerClassMappings;
    }

    ClassName processAheadOfTime(MergedContextConfiguration mergedConfig, GenerationContext generationContext) throws TestContextAotException {
        GenericApplicationContext gac = this.loadContextForAotProcessing(mergedConfig);
        try {
            return this.aotGenerator.processAheadOfTime(gac, generationContext);
        }
        catch (Throwable ex) {
            throw new TestContextAotException("Failed to process test class [%s] for AOT".formatted(mergedConfig.getTestClass().getName()), ex);
        }
    }

    private GenericApplicationContext loadContextForAotProcessing(MergedContextConfiguration mergedConfig) throws TestContextAotException {
        Class<?> testClass = mergedConfig.getTestClass();
        ContextLoader contextLoader = mergedConfig.getContextLoader();
        Assert.notNull((Object)contextLoader, () -> "Cannot load an ApplicationContext with a NULL 'contextLoader'. Consider annotating test class [%s] with @ContextConfiguration or @ContextHierarchy.".formatted(testClass.getName()));
        if (contextLoader instanceof AotContextLoader) {
            AotContextLoader aotContextLoader = (AotContextLoader)contextLoader;
            try {
                ApplicationContext context = aotContextLoader.loadContextForAotProcessing(mergedConfig);
                if (context instanceof GenericApplicationContext) {
                    GenericApplicationContext gac = (GenericApplicationContext)context;
                    return gac;
                }
            }
            catch (Exception ex) {
                Throwable throwable;
                if (ex instanceof ContextLoadException) {
                    ContextLoadException cle = (ContextLoadException)ex;
                    throwable = cle.getCause();
                } else {
                    throwable = ex;
                }
                Exception cause = throwable;
                throw new TestContextAotException("Failed to load ApplicationContext for AOT processing for test class [%s]".formatted(testClass.getName()), cause);
            }
        }
        throw new TestContextAotException("Cannot generate AOT artifacts for test class [%s]. The configured ContextLoader [%s] must be an AotContextLoader and must create a GenericApplicationContext.".formatted(testClass.getName(), contextLoader.getClass().getName()));
    }

    private MergedContextConfiguration buildMergedContextConfiguration(Class<?> testClass) {
        TestContextBootstrapper testContextBootstrapper = BootstrapUtils.resolveTestContextBootstrapper(testClass);
        this.registerDeclaredConstructors(testContextBootstrapper.getClass());
        testContextBootstrapper.getTestExecutionListeners().forEach(listener -> {
            this.registerDeclaredConstructors(listener.getClass());
            if (listener instanceof AotTestExecutionListener) {
                AotTestExecutionListener aotListener = (AotTestExecutionListener)listener;
                aotListener.processAheadOfTime(this.runtimeHints, testClass, this.getClass().getClassLoader());
            }
        });
        return testContextBootstrapper.buildMergedContextConfiguration();
    }

    DefaultGenerationContext createGenerationContext(Class<?> testClass) {
        ClassNameGenerator classNameGenerator = new ClassNameGenerator(ClassName.get(testClass));
        TestContextGenerationContext generationContext = new TestContextGenerationContext(classNameGenerator, this.generatedFiles, this.runtimeHints);
        return generationContext.withName(this.nextTestContextId());
    }

    private String nextTestContextId() {
        return "TestContext%03d_".formatted(this.sequence.incrementAndGet());
    }

    private void generateAotTestContextInitializerMappings(MultiValueMap<ClassName, Class<?>> initializerClassMappings) {
        ClassNameGenerator classNameGenerator = new ClassNameGenerator(ClassName.get(AotTestContextInitializers.class));
        DefaultGenerationContext generationContext = new DefaultGenerationContext(classNameGenerator, this.generatedFiles, this.runtimeHints);
        GeneratedClasses generatedClasses = generationContext.getGeneratedClasses();
        AotTestContextInitializersCodeGenerator codeGenerator = new AotTestContextInitializersCodeGenerator(initializerClassMappings, generatedClasses);
        generationContext.writeGeneratedContent();
        String className = codeGenerator.getGeneratedClass().getName().reflectionName();
        this.registerPublicMethods(className);
    }

    private void generateAotTestAttributeMappings() {
        ClassNameGenerator classNameGenerator = new ClassNameGenerator(ClassName.get(AotTestAttributes.class));
        DefaultGenerationContext generationContext = new DefaultGenerationContext(classNameGenerator, this.generatedFiles, this.runtimeHints);
        GeneratedClasses generatedClasses = generationContext.getGeneratedClasses();
        Map<String, String> attributes = AotTestAttributesFactory.getAttributes();
        AotTestAttributesCodeGenerator codeGenerator = new AotTestAttributesCodeGenerator(attributes, generatedClasses);
        generationContext.writeGeneratedContent();
        String className = codeGenerator.getGeneratedClass().getName().reflectionName();
        this.registerPublicMethods(className);
    }

    private void registerPublicMethods(String className) {
        this.runtimeHints.reflection().registerType(TypeReference.of(className), MemberCategory.INVOKE_PUBLIC_METHODS);
    }

    private void registerDeclaredConstructors(Class<?> type) {
        this.runtimeHints.reflection().registerType(type, MemberCategory.INVOKE_DECLARED_CONSTRUCTORS);
    }

    private void registerSkippedExceptionTypes() {
        Stream.of("org.opentest4j.TestAbortedException", "org.junit.AssumptionViolatedException", "org.testng.SkipException").map(TypeReference::of).forEach(arg_0 -> TestContextAotGenerator.lambda$registerSkippedExceptionTypes$8(this.runtimeHints.reflection(), arg_0));
    }

    private static boolean getFailOnErrorFlag() {
        String failOnError = SpringProperties.getProperty(FAIL_ON_ERROR_PROPERTY_NAME);
        if (StringUtils.hasText(failOnError)) {
            return Boolean.parseBoolean(failOnError.trim());
        }
        return true;
    }

    private static List<String> classNames(List<Class<?>> classes) {
        return classes.stream().map(Class::getName).toList();
    }

    private static /* synthetic */ void lambda$registerSkippedExceptionTypes$8(ReflectionHints rec$, TypeReference x$0) {
        rec$.registerType(x$0, new MemberCategory[0]);
    }
}

