/*
 * Copyright 2016 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.zonky.test.db.postgres;

import com.google.common.collect.ImmutableMap;
import io.zonky.test.db.AutoConfigureEmbeddedDatabase;
import io.zonky.test.db.AutoConfigureEmbeddedDatabase.Replace;
import io.zonky.test.db.AutoConfigureEmbeddedDatabases;
import io.zonky.test.db.flyway.DefaultFlywayDataSourceContext;
import io.zonky.test.db.flyway.FlywayClassUtils;
import io.zonky.test.db.flyway.FlywayDataSourceContext;
import io.zonky.test.db.provider.DatabaseDescriptor;
import io.zonky.test.db.provider.DatabaseType;
import io.zonky.test.db.provider.ProviderType;
import io.zonky.test.db.provider.impl.DockerPostgresDatabaseProvider;
import io.zonky.test.db.provider.impl.OpenTablePostgresDatabaseProvider;
import io.zonky.test.db.provider.impl.PrefetchingDatabaseProvider;
import io.zonky.test.db.provider.impl.YandexPostgresDatabaseProvider;
import io.zonky.test.db.provider.impl.ZonkyPostgresDatabaseProvider;
import org.apache.commons.lang3.StringUtils;
import org.flywaydb.core.Flyway;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanDefinitionHolder;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.boot.autoconfigure.flyway.FlywayProperties;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.EnvironmentAware;
import org.springframework.context.support.AbstractApplicationContext;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.core.env.ConfigurableEnvironment;
import org.springframework.core.env.Environment;
import org.springframework.core.env.MapPropertySource;
import org.springframework.test.context.ContextConfigurationAttributes;
import org.springframework.test.context.ContextCustomizer;
import org.springframework.test.context.ContextCustomizerFactory;
import org.springframework.test.context.MergedContextConfiguration;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.ObjectUtils;

import javax.sql.DataSource;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import static io.zonky.test.db.AutoConfigureEmbeddedDatabase.DatabaseProvider.DEFAULT;
import static io.zonky.test.db.AutoConfigureEmbeddedDatabase.EmbeddedDatabaseType;

/**
 * Implementation of the {@link org.springframework.test.context.ContextCustomizerFactory} interface,
 * which is responsible for initialization of the embedded postgres database and its registration to the application context.
 * The applied initialization strategy is driven by the {@link AutoConfigureEmbeddedDatabase} annotation.
 *
 * @see AutoConfigureEmbeddedDatabase
 */
public class EmbeddedPostgresContextCustomizerFactory implements ContextCustomizerFactory {

    private static final Logger logger = LoggerFactory.getLogger(EmbeddedPostgresContextCustomizerFactory.class);

    private static final boolean flywayNameAttributePresent = FlywayClassUtils.isFlywayNameAttributePresent();
    private static final boolean repeatableAnnotationPresent = FlywayClassUtils.isRepeatableFlywayTestAnnotationPresent();

    @Override
    public ContextCustomizer createContextCustomizer(Class<?> testClass, List<ContextConfigurationAttributes> configAttributes) {
        Set<AutoConfigureEmbeddedDatabase> databaseAnnotations = AnnotatedElementUtils.getMergedRepeatableAnnotations(
                testClass, AutoConfigureEmbeddedDatabase.class, AutoConfigureEmbeddedDatabases.class);

        databaseAnnotations = databaseAnnotations.stream()
                .filter(distinctByKey(AutoConfigureEmbeddedDatabase::beanName))
                .filter(databaseAnnotation -> databaseAnnotation.type() == EmbeddedDatabaseType.POSTGRES)
                .filter(databaseAnnotation -> databaseAnnotation.replace() != Replace.NONE)
                .collect(Collectors.toCollection(LinkedHashSet::new));

        if (!databaseAnnotations.isEmpty()) {
            return new PreloadableEmbeddedPostgresContextCustomizer(databaseAnnotations);
        }

        return null;
    }

    protected static class PreloadableEmbeddedPostgresContextCustomizer implements ContextCustomizer {

        private final Set<AutoConfigureEmbeddedDatabase> databaseAnnotations;

        public PreloadableEmbeddedPostgresContextCustomizer(Set<AutoConfigureEmbeddedDatabase> databaseAnnotations) {
            this.databaseAnnotations = databaseAnnotations;
        }

        @Override
        public void customizeContext(ConfigurableApplicationContext context, MergedContextConfiguration mergedConfig) {
            context.addBeanFactoryPostProcessor(new EnvironmentPostProcessor(context.getEnvironment()));

            BeanDefinitionRegistry registry = getBeanDefinitionRegistry(context);
            RootBeanDefinition registrarDefinition = new RootBeanDefinition();

            registrarDefinition.setBeanClass(PreloadableEmbeddedPostgresRegistrar.class);
            registrarDefinition.getConstructorArgumentValues()
                    .addIndexedArgumentValue(0, databaseAnnotations);

            registry.registerBeanDefinition("preloadableEmbeddedPostgresRegistrar", registrarDefinition);
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (o == null || getClass() != o.getClass()) return false;

            PreloadableEmbeddedPostgresContextCustomizer that =
                    (PreloadableEmbeddedPostgresContextCustomizer) o;

            return databaseAnnotations.equals(that.databaseAnnotations);
        }

        @Override
        public int hashCode() {
            return databaseAnnotations.hashCode();
        }
    }

    protected static class EnvironmentPostProcessor implements BeanDefinitionRegistryPostProcessor {

        private final ConfigurableEnvironment environment;

        public EnvironmentPostProcessor(ConfigurableEnvironment environment) {
            this.environment = environment;
        }

        @Override
        public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException {
            environment.getPropertySources().addFirst(new MapPropertySource(
                    PreloadableEmbeddedPostgresContextCustomizer.class.getSimpleName(),
                    ImmutableMap.of("spring.test.database.replace", "NONE")));
        }

        @Override
        public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
            // nothing to do
        }
    }

    protected static class FlywayPropertiesPostProcessor implements BeanPostProcessor {

        @Override
        public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
            return bean;
        }

        @Override
        public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
            if (bean instanceof FlywayProperties) {
                FlywayProperties properties = (FlywayProperties) bean;
                properties.setUrl(null);
                properties.setUser(null);
                properties.setPassword(null);
            }
            return bean;
        }
    }

    protected static class PreloadableEmbeddedPostgresRegistrar implements BeanDefinitionRegistryPostProcessor, EnvironmentAware {

        private final Set<AutoConfigureEmbeddedDatabase> databaseAnnotations;

        private Environment environment;

        public PreloadableEmbeddedPostgresRegistrar(Set<AutoConfigureEmbeddedDatabase> databaseAnnotations) {
            this.databaseAnnotations = databaseAnnotations;
        }

        @Override
        public void setEnvironment(Environment environment) {
            this.environment = environment;
        }

        @Override
        public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException {
            Assert.isInstanceOf(ConfigurableListableBeanFactory.class, registry,
                    "Embedded Database Auto-configuration can only be used with a ConfigurableListableBeanFactory");
            ConfigurableListableBeanFactory beanFactory = (ConfigurableListableBeanFactory) registry;

            registerBeanIfMissing(registry, "defaultDatabaseProvider", PrefetchingDatabaseProvider.class);

            if (ClassUtils.isPresent("org.testcontainers.containers.PostgreSQLContainer", null)) {
                registerBeanIfMissing(registry, "dockerPostgresProvider", DockerPostgresDatabaseProvider.class);
            }
            if (ClassUtils.isPresent("io.zonky.test.db.postgres.embedded.EmbeddedPostgres", null)) {
                registerBeanIfMissing(registry, "zonkyPostgresProvider", ZonkyPostgresDatabaseProvider.class);
            }
            if (ClassUtils.isPresent("com.opentable.db.postgres.embedded.EmbeddedPostgres", null)) {
                registerBeanIfMissing(registry, "openTablePostgresProvider", OpenTablePostgresDatabaseProvider.class);
            }
            if (ClassUtils.isPresent("ru.yandex.qatools.embed.postgresql.EmbeddedPostgres", null)) {
                registerBeanIfMissing(registry, "yandexPostgresProvider", YandexPostgresDatabaseProvider.class);
            }
            if (ClassUtils.isPresent("org.springframework.boot.autoconfigure.flyway.FlywayProperties", null)) {
                registerBeanIfMissing(registry, "flywayPropertiesPostProcessor", FlywayPropertiesPostProcessor.class);
            }

            for (AutoConfigureEmbeddedDatabase databaseAnnotation : databaseAnnotations) {
                DatabaseDescriptor databaseDescriptor = resolveDatabaseDescriptor(environment, databaseAnnotation);

                BeanDefinitionHolder dataSourceInfo = getDataSourceBeanDefinition(beanFactory, databaseAnnotation);
                BeanDefinitionHolder flywayInfo = getFlywayBeanDefinition(beanFactory);

                RootBeanDefinition dataSourceDefinition = new RootBeanDefinition();
                dataSourceDefinition.setPrimary(dataSourceInfo.getBeanDefinition().isPrimary());

                if (flywayInfo == null || databaseAnnotations.size() > 1) {
                    dataSourceDefinition.setBeanClass(EmptyEmbeddedPostgresDataSourceFactoryBean.class);
                    dataSourceDefinition.getConstructorArgumentValues().addIndexedArgumentValue(0, databaseDescriptor);
                } else {
                    BeanDefinitionHolder contextInfo = getDataSourceContextBeanDefinition(beanFactory, flywayInfo.getBeanName());

                    if (contextInfo == null) {
                        RootBeanDefinition dataSourceContextDefinition = new RootBeanDefinition();
                        dataSourceContextDefinition.setBeanClass(DefaultFlywayDataSourceContext.class);
                        registry.registerBeanDefinition("defaultDataSourceContext", dataSourceContextDefinition);
                        contextInfo = new BeanDefinitionHolder(dataSourceContextDefinition, "defaultDataSourceContext");
                    }

                    contextInfo.getBeanDefinition().getPropertyValues().addPropertyValue("descriptor", databaseDescriptor);

                    dataSourceDefinition.setBeanClass(FlywayEmbeddedPostgresDataSourceFactoryBean.class);
                    dataSourceDefinition.getConstructorArgumentValues()
                            .addIndexedArgumentValue(0, flywayInfo.getBeanName());
                    dataSourceDefinition.getConstructorArgumentValues()
                            .addIndexedArgumentValue(1, contextInfo.getBeanName());
                }

                String dataSourceBeanName = dataSourceInfo.getBeanName();
                if (registry.containsBeanDefinition(dataSourceBeanName)) {
                    logger.info("Replacing '{}' DataSource bean with embedded version", dataSourceBeanName);
                    registry.removeBeanDefinition(dataSourceBeanName);
                }
                registry.registerBeanDefinition(dataSourceBeanName, dataSourceDefinition);
            }
        }

        @Override
        public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
            // nothing to do
        }

        protected DatabaseDescriptor resolveDatabaseDescriptor(Environment environment, AutoConfigureEmbeddedDatabase databaseAnnotation) {
            String providerName = databaseAnnotation.provider() != DEFAULT ? databaseAnnotation.provider().name() :
                    environment.getProperty("zonky.test.database.provider", ProviderType.ZONKY.toString());
            return new DatabaseDescriptor(DatabaseType.POSTGRES, ProviderType.valueOf(providerName));
        }
    }

    protected static void registerBeanIfMissing(BeanDefinitionRegistry registry, String beanName, Class<?> beanClass) {
        if (!registry.containsBeanDefinition(beanName)) {
            RootBeanDefinition providerDefinition = new RootBeanDefinition(beanClass);
            registry.registerBeanDefinition(beanName, providerDefinition);
        }
    }

    protected static BeanDefinitionRegistry getBeanDefinitionRegistry(ApplicationContext context) {
        if (context instanceof BeanDefinitionRegistry) {
            return (BeanDefinitionRegistry) context;
        }
        if (context instanceof AbstractApplicationContext) {
            return (BeanDefinitionRegistry) ((AbstractApplicationContext) context).getBeanFactory();
        }
        throw new IllegalStateException("Could not locate BeanDefinitionRegistry");
    }

    protected static BeanDefinitionHolder getDataSourceBeanDefinition(ConfigurableListableBeanFactory beanFactory, AutoConfigureEmbeddedDatabase annotation) {
        if (StringUtils.isNotBlank(annotation.beanName())) {
            if (beanFactory.containsBean(annotation.beanName())) {
                BeanDefinition beanDefinition = beanFactory.getBeanDefinition(annotation.beanName());
                return new BeanDefinitionHolder(beanDefinition, annotation.beanName());
            } else {
                return new BeanDefinitionHolder(new RootBeanDefinition(), annotation.beanName());
            }
        }

        String[] beanNames = beanFactory.getBeanNamesForType(DataSource.class);

        if (ObjectUtils.isEmpty(beanNames)) {
            throw new IllegalStateException("No DataSource beans found, embedded version will not be used, " +
                    "you must specify data source name - use @AutoConfigureEmbeddedDatabase(beanName = \"dataSource\") annotation");
        }

        if (beanNames.length == 1) {
            String beanName = beanNames[0];
            BeanDefinition beanDefinition = beanFactory.getBeanDefinition(beanName);
            return new BeanDefinitionHolder(beanDefinition, beanName);
        }

        for (String beanName : beanNames) {
            BeanDefinition beanDefinition = beanFactory.getBeanDefinition(beanName);
            if (beanDefinition.isPrimary()) {
                return new BeanDefinitionHolder(beanDefinition, beanName);
            }
        }

        throw new IllegalStateException("No primary DataSource found, embedded version will not be used");
    }

    protected static BeanDefinitionHolder getDataSourceContextBeanDefinition(ConfigurableListableBeanFactory beanFactory, String flywayName) {
        String[] beanNames = beanFactory.getBeanNamesForType(FlywayDataSourceContext.class, true, false);

        if (ObjectUtils.isEmpty(beanNames)) {
            return null;
        }

        if (beanNames.length == 1) {
            String beanName = beanNames[0];
            BeanDefinition beanDefinition = beanFactory.getBeanDefinition(beanName);
            return new BeanDefinitionHolder(beanDefinition, beanName);
        }

        if (beanFactory.containsBean(flywayName + "DataSourceContext")) {
            BeanDefinition beanDefinition = beanFactory.getBeanDefinition(flywayName + "DataSourceContext");
            return new BeanDefinitionHolder(beanDefinition, flywayName + "DataSourceContext");
        }

        return null;
    }

    protected static BeanDefinitionHolder getFlywayBeanDefinition(ConfigurableListableBeanFactory beanFactory) {
        String[] beanNames = beanFactory.getBeanNamesForType(Flyway.class, true, false);

        if (beanNames.length == 1) {
            String beanName = beanNames[0];
            BeanDefinition beanDefinition = beanFactory.getBeanDefinition(beanName);
            return new BeanDefinitionHolder(beanDefinition, beanName);
        }

        return null;
    }

    protected static <T> Predicate<T> distinctByKey(Function<? super T, ?> keyExtractor) {
        Set<Object> seen = ConcurrentHashMap.newKeySet();
        return t -> seen.add(keyExtractor.apply(t));
    }
}
