package com.atlassian.braid.transformation;

import com.atlassian.braid.BatchLoaderEnvironment;
import com.atlassian.braid.Link;
import com.atlassian.braid.LinkArgument;
import com.atlassian.braid.SchemaNamespace;
import com.atlassian.braid.SchemaSource;
import com.atlassian.braid.TypeUtils;
import graphql.execution.DataFetcherResult;
import graphql.language.FieldDefinition;
import graphql.language.InputValueDefinition;
import graphql.language.ListType;
import graphql.language.NonNullType;
import graphql.language.ObjectTypeDefinition;
import graphql.language.Type;
import graphql.language.TypeDefinition;
import graphql.language.TypeName;
import graphql.parser.Parser;
import graphql.schema.DataFetcher;
import graphql.schema.DataFetchingEnvironment;
import graphql.schema.idl.TypeDefinitionRegistry;
import org.dataloader.BatchLoader;
import org.dataloader.DataLoader;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static com.atlassian.braid.LinkArgument.ArgumentSource.FIELD_ARGUMENT;
import static com.atlassian.braid.LinkArgument.ArgumentSource.OBJECT_FIELD;
import static com.atlassian.braid.TypeUtils.findMutationType;
import static com.atlassian.braid.TypeUtils.findQueryType;
import static com.atlassian.braid.transformation.DataFetcherUtils.getLinkDataLoaderKey;
import static java.lang.String.format;
import static java.util.function.Function.identity;
import static java.util.stream.Collectors.toMap;

/**
 * A {@link SchemaTransformation} for processing links, which add fields to source object types. The field to add is
 * specified by the {@link Link}. The link field values
 * are fetched from top-level fields of the target schema sources.
 */
public class LinkSchemaTransformation implements SchemaTransformation {
    @Override
    public Map<String, BatchLoader<DataFetchingEnvironment, DataFetcherResult<Object>>> transform(BraidingContext braidingContext) {
        final Map<SchemaNamespace, BraidSchemaSource> sources = braidingContext.getDataSources();
        final ObjectTypeDefinition queryObjectTypeDefinition = braidingContext.getQueryObjectTypeDefinition();
        final ObjectTypeDefinition mutationObjectTypeDefinition = braidingContext.getMutationObjectTypeDefinition();
        final BatchLoaderEnvironment batchLoaderEnvironment = braidingContext.getBatchLoaderEnvironment();
        final TypeDefinitionRegistry braidTypeRegistry = braidingContext.getRegistry();

        Map<String, BatchLoader<DataFetchingEnvironment, DataFetcherResult<Object>>> batchLoaders = new HashMap<>();
        for (BraidSchemaSource source : sources.values()) {
            TypeDefinitionRegistry sourceTypeRegistry = source.getTypeRegistry();
            SchemaSource sourceSchemaSource = source.getSchemaSource();
            final TypeDefinitionRegistry privateTypes = sourceSchemaSource.getPrivateSchema();

            for (Link link : sourceSchemaSource.getLinks()) {
                Map<String, TypeDefinition> dsTypes = new HashMap<>(braidTypeRegistry.types());
                ObjectTypeDefinition braidObjectTypeDefinition = getObjectTypeDefinition(
                        queryObjectTypeDefinition,
                        mutationObjectTypeDefinition,
                        braidTypeRegistry,
                        dsTypes,
                        source.getLinkBraidSourceType(link)
                );

                // TopLevelFields are not yet in braidTypeRegistry
                // they will be copied in the TopLevelFieldTransformation that runs next
                if (braidObjectTypeDefinition.equals(TypeUtils.findQueryType(braidTypeRegistry).orElse(null))) {
                    braidObjectTypeDefinition = TypeUtils.findQueryType(sourceTypeRegistry).get();
                }
                if (braidObjectTypeDefinition.equals(TypeUtils.findMutationType(braidTypeRegistry).orElse(null))) {
                    braidObjectTypeDefinition = TypeUtils.findMutationType(sourceTypeRegistry).get();
                }

                validateSourceFromFieldExists(source, link, privateTypes);

                BraidSchemaSource targetSource = sources.get(link.getTargetNamespace());
                if (targetSource == null) {
                    throw new IllegalArgumentException("Can't find target schema source: " + link.getTargetNamespace());
                }
                if (!targetSource.hasType(TypeUtils.unwrap(Parser.parseType(link.getTargetType())))) {
                    throw new IllegalArgumentException("Can't find target type: " + link.getTargetType());
                }

                FieldDefinition topLevelField = topLevelFieldForLink(link, targetSource);
                BatchMapping batchMapping = BatchUtils.getBatchMapping(topLevelField);
                if (!link.isNoSchemaChangeNeeded()) {
                    List<FieldDefinition> fieldDefinitions = modifySchema(link, braidObjectTypeDefinition, topLevelField);
                    if (batchMapping != null) {
                        // Add batched query field definition id batching is enabled
                        FieldDefinition batchFieldDef = TypeUtils.findQueryType(targetSource.getSchemaSource().getPrivateSchema())
                                .flatMap(queryTypeDef -> queryTypeDef.getFieldDefinitions()
                                        .stream()
                                        .filter(fieldDefinition -> fieldDefinition.getName().equals(batchMapping.batchField))
                                        .findFirst())
                                .orElseThrow(() -> new IllegalStateException("Could not find query field: "+ batchMapping.batchField));
                        fieldDefinitions.add(batchFieldDef);
                    }

                    ObjectTypeDefinition newBraidObjectTypeDefinition = braidObjectTypeDefinition.transform(builder -> {
                        builder.fieldDefinitions(fieldDefinitions);
                    });

                    if (braidObjectTypeDefinition.equals(TypeUtils.findQueryType(sourceTypeRegistry).orElse(null))
                            || braidObjectTypeDefinition.equals(TypeUtils.findMutationType(sourceTypeRegistry).orElse(null))) {
                        sourceTypeRegistry.remove(braidObjectTypeDefinition);
                        sourceTypeRegistry.add(newBraidObjectTypeDefinition);
                    } else {
                        braidTypeRegistry.remove(braidObjectTypeDefinition);
                        braidTypeRegistry.add(newBraidObjectTypeDefinition);
                    }
                }

                String type = source.getLinkBraidSourceType(link);
                String field = link.getNewFieldName();
                String linkDataLoaderKey = getLinkDataLoaderKey(type, field);

                // Create the coordinating DataFetcher that uses the BatchLoader registered below to load data.
                DataFetcher<CompletableFuture<DataFetcherResult<Object>>> dataFetcher = env -> {
                    DataLoader<DataFetchingEnvironment, DataFetcherResult<Object>> dataLoader = env.getDataLoader(linkDataLoaderKey);
                    return dataLoader.load(env);
                };
                braidingContext.registerDataFetcher(type, field, dataFetcher);

                // Create the BatchLoader using the target SchemaSource. This BatchLoader is used by the DataFetcher
                // created above during execution.
                SchemaSource targetSchemaSource = targetSource.getSchemaSource();
                BatchLoader<DataFetchingEnvironment, DataFetcherResult<Object>> batchLoader =
                        targetSchemaSource.newBatchLoader(
                                targetSchemaSource,
                                new LinkTransformation(link, batchMapping),
                                batchLoaderEnvironment
                        );

                batchLoaders.put(linkDataLoaderKey, batchLoader);
            }
        }
        return batchLoaders;
    }

    private static FieldDefinition topLevelFieldForLink(Link link, BraidSchemaSource targetSource) {
        return TypeUtils.findQueryType(targetSource.getSchemaSource().getPrivateSchema())
                .flatMap(queryType -> queryType.getFieldDefinitions().stream()
                        .filter(fieldDefinitiuon -> link.getTopLevelQueryField().equals(fieldDefinitiuon.getName()))
                        .findFirst())
                .orElseThrow(() -> new IllegalStateException(format("Cannot find top level query field '%s' in source '%s' for link on field '%s' defined in '%s'",
                        link.getTopLevelQueryField(), link.getTargetNamespace(), link.getNewFieldName(), link.getSourceNamespace())));
    }

    private static List<FieldDefinition> modifySchema(Link link,
                                                      ObjectTypeDefinition typeDefinition,
                                                      FieldDefinition topLevelField) {
        List<FieldDefinition> fieldDefinitions = new ArrayList<>(typeDefinition.getFieldDefinitions());
        FieldDefinition newField = fieldDefinitions.stream()
                .filter(d -> d.getName().equals(link.getNewFieldName()))
                .findFirst()
                .orElse(null);

        Map<String, FieldDefinition> objectFields = fieldDefinitions
                .stream()
                .filter(Objects::nonNull)
                .collect(toMap(FieldDefinition::getName, identity()));

        link.getLinkArguments().stream()
                .filter(linkArgument -> linkArgument.getArgumentSource() == OBJECT_FIELD
                        && linkArgument.isRemoveInputField())
                .map(LinkArgument::getSourceName)
                .forEach(fieldToRemove -> {
                    Optional.ofNullable(objectFields.get(fieldToRemove))
                            .ifPresent(fieldDefinitions::remove);
                });

        Type targetType = Parser.parseType(link.getTargetType());
        targetType = link.targetNonNullable() ? NonNullType.newNonNullType(targetType).build() : targetType;

        if (newField == null) {
            targetType = adjustTypeForSimpleLink(link, objectFields, targetType);

            List<InputValueDefinition> inputValueDefs = link.getLinkArguments().stream()
                    .filter(linkArgument -> linkArgument.getArgumentSource() == FIELD_ARGUMENT)
                    .flatMap(linkArgument -> buildInputValueDefinitionForLink(topLevelField, linkArgument))
                    .collect(Collectors.toList());

            newField = FieldDefinition.newFieldDefinition()
                    .name(link.getNewFieldName())
                    .type(targetType)
                    .inputValueDefinitions(inputValueDefs)
                    .build();
            fieldDefinitions.add(newField);
        } else {
            if (isListType(newField.getType())) {
                if (newField.getType() instanceof NonNullType) {
                    targetType = new NonNullType(new ListType(targetType));
                } else {
                    targetType = new ListType(targetType);
                }
            }
            fieldDefinitions.remove(newField);
            Type finalTargetType = targetType;
            fieldDefinitions.add(newField.transform(builder -> builder.type(finalTargetType)));
        }

        return fieldDefinitions;
    }

    private static Stream<InputValueDefinition> buildInputValueDefinitionForLink(FieldDefinition topLevelField, LinkArgument linkArgument) {
        return topLevelField.getInputValueDefinitions().stream()
                .filter(input -> linkArgument.getQueryArgumentName().equals(input.getName()))
                .findFirst()
                .map(input -> Stream.of(InputValueDefinition.newInputValueDefinition()
                        .name(linkArgument.getSourceName())
                        .type(input.getType())
                        .build()))
                .orElse(Stream.empty());
    }

    private static Type adjustTypeForSimpleLink(Link link, Map<String, FieldDefinition> objectFields, Type targetType) {
        if (link.isSimpleLink()) {
            Optional<FieldDefinition> sourceInputField = Optional.ofNullable(objectFields.get(link.getSourceInputFieldName()));
            // Add source field to schema if not already there
            if (sourceInputField.isPresent() && isListType(sourceInputField.get().getType())) {
                targetType = new ListType(targetType);
            }
        }
        return targetType;
    }

    private static ObjectTypeDefinition getObjectTypeDefinition(ObjectTypeDefinition queryObjectTypeDefinition,
                                                                ObjectTypeDefinition mutationObjectTypeDefinition,
                                                                TypeDefinitionRegistry braidTypeRegistry,
                                                                Map<String, TypeDefinition> dsTypes,
                                                                String linkSourceType) {
        ObjectTypeDefinition typeDefinition = (ObjectTypeDefinition) dsTypes.get(linkSourceType);
        if (typeDefinition == null && linkSourceType.equals(queryObjectTypeDefinition.getName())) {
            typeDefinition = findQueryType(braidTypeRegistry).orElse(null);
            if (typeDefinition == null && linkSourceType.equals(mutationObjectTypeDefinition.getName())) {
                typeDefinition = findMutationType(braidTypeRegistry).orElse(null);
            }
        }

        if (typeDefinition == null) {
            throw new IllegalArgumentException("Can't find source type: " + linkSourceType);
        }
        return typeDefinition;
    }

    private static void validateSourceFromFieldExists(BraidSchemaSource source, Link link,
                                                      TypeDefinitionRegistry privateTypeDefinitionRegistry) {
        final String sourceType = source.getSourceTypeName(link.getSourceType());
        ObjectTypeDefinition typeDefinition = privateTypeDefinitionRegistry
                .getType(sourceType, ObjectTypeDefinition.class)
                .orElseThrow(() -> new IllegalArgumentException(
                        format("Can't find source type '%s' in private schema for link %s",
                                sourceType, link.getNewFieldName())));
        Map<String, FieldDefinition> fieldsByName = typeDefinition.getFieldDefinitions().stream()
                .collect(toMap(FieldDefinition::getName, identity()));

        List<String> missingSourceObjectFields = link.getLinkArguments().stream()
                .filter(linkArgument -> linkArgument.getArgumentSource() == LinkArgument.ArgumentSource.OBJECT_FIELD)
                .filter(linkArgument -> !fieldsByName.containsKey(linkArgument.getSourceName()))
                .map(LinkArgument::getSourceName)
                .collect(Collectors.toList());


        if (!missingSourceObjectFields.isEmpty()) {
            String missingFieldsStr = missingSourceObjectFields.stream().collect(Collectors.joining(", "));
            throw new IllegalArgumentException("Can't find source from field: " + missingFieldsStr);
        }
    }

    private static boolean isListType(Type type) {
        return type instanceof ListType ||
                (type instanceof NonNullType && ((NonNullType) type).getType() instanceof ListType);
    }
}
