package com.atlassian.braid.transformation;

import com.atlassian.braid.BatchLoaderEnvironment;
import com.atlassian.braid.Link;
import com.atlassian.braid.SchemaNamespace;
import com.atlassian.braid.TypeUtils;
import graphql.execution.DataFetcherResult;
import graphql.language.FieldDefinition;
import graphql.language.ListType;
import graphql.language.NonNullType;
import graphql.language.ObjectTypeDefinition;
import graphql.language.Type;
import graphql.language.TypeDefinition;
import graphql.language.TypeName;
import graphql.schema.DataFetcher;
import graphql.schema.DataFetchingEnvironment;
import graphql.schema.idl.RuntimeWiring;
import graphql.schema.idl.TypeDefinitionRegistry;
import org.dataloader.BatchLoader;

import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

import static com.atlassian.braid.TypeUtils.findMutationType;
import static com.atlassian.braid.TypeUtils.findQueryType;
import static com.atlassian.braid.transformation.DataFetcherUtils.getLinkDataLoaderKey;
import static java.lang.String.format;

public class LinkSchemaTransformation implements SchemaTransformation {
    @Override
    public Map<String, BatchLoader> transform(BraidingContext braidingContext) {
        return linkTypes(braidingContext.getDataSources(), braidingContext.getQueryObjectTypeDefinition(),
                braidingContext.getMutationObjectTypeDefinition(),
                braidingContext.getRuntimeWiringBuilder(),
                braidingContext.getBatchLoaderEnvironment(),
                braidingContext.getRegistry());
    }

    private static Map<String, BatchLoader> linkTypes(Map<SchemaNamespace, BraidSchemaSource> sources,
                                                      ObjectTypeDefinition queryObjectTypeDefinition,
                                                      ObjectTypeDefinition mutationObjectTypeDefinition,
                                                      RuntimeWiring.Builder runtimeWiringBuilder,
                                                      BatchLoaderEnvironment batchLoaderEnvironment,
                                                      TypeDefinitionRegistry braidTypeRegistry) {
        Map<String, BatchLoader> batchLoaders = new HashMap<>();
        for (BraidSchemaSource source : sources.values()) {
            TypeDefinitionRegistry typeRegistry = source.getTypeRegistry();
            final TypeDefinitionRegistry privateTypes = source.getSchemaSource().getPrivateSchema();

            Map<String, TypeDefinition> dsTypes = new HashMap<>(braidTypeRegistry.types());

            for (Link link : source.getSchemaSource().getLinks()) {

                ObjectTypeDefinition braidObjectTypeDefinition = getObjectTypeDefinition(queryObjectTypeDefinition,
                        mutationObjectTypeDefinition, braidTypeRegistry, dsTypes, source.getLinkBraidSourceType(link));

                // TopLevelFields are not yet in braidTypeRegistry
                // they will be copied in the TopLevelFieldTransformation that runs next
                if( braidObjectTypeDefinition.equals(TypeUtils.findQueryType(braidTypeRegistry).orElse(null))) {
                    braidObjectTypeDefinition = TypeUtils.findQueryType(typeRegistry).get();
                }
                if( braidObjectTypeDefinition.equals(TypeUtils.findMutationType(braidTypeRegistry).orElse(null))) {
                    braidObjectTypeDefinition = TypeUtils.findMutationType(typeRegistry).get();
                }

                validateSourceFromFieldExists(source, link, privateTypes);

                Optional<FieldDefinition> newField = braidObjectTypeDefinition.getFieldDefinitions().stream()
                        .filter(d -> d.getName().equals(link.getNewFieldName()))
                        .findFirst();

                Optional<FieldDefinition> sourceInputField = braidObjectTypeDefinition.getFieldDefinitions()
                        .stream()
                        .filter(Objects::nonNull)
                        .filter(s -> s.getName().equals(link.getSourceInputFieldName()))
                        .findAny();

                BraidSchemaSource targetSource = sources.get(link.getTargetNamespace());
                if (targetSource == null) {
                    throw new IllegalArgumentException("Can't find target schema source: " + link.getTargetNamespace());
                }
                if (!targetSource.hasType(link.getTargetType())) {
                    throw new IllegalArgumentException("Can't find target type: " + link.getTargetType());

                }

                if (!link.isNoSchemaChangeNeeded()) {
                    modifySchema(link, braidObjectTypeDefinition, newField, sourceInputField);
                }

                String type = source.getLinkBraidSourceType(link);
                String field = link.getNewFieldName();
                String linkDataLoaderKey = getLinkDataLoaderKey(type, field);

                DataFetcher dataFetcher = env -> env.getDataLoader(linkDataLoaderKey).load(env);
                runtimeWiringBuilder.type(type, wiring -> wiring.dataFetcher(field, dataFetcher));

                BatchLoader<DataFetchingEnvironment, DataFetcherResult<Object>> batchLoader = targetSource.getSchemaSource().newBatchLoader(targetSource.getSchemaSource(),
                        new LinkTransformation(link), batchLoaderEnvironment);

                batchLoaders.put(linkDataLoaderKey, batchLoader);
            }
        }
        return batchLoaders;
    }

    private static void modifySchema(Link link, ObjectTypeDefinition typeDefinition, Optional<FieldDefinition> newField, Optional<FieldDefinition> sourceInputField) {
        if (sourceInputField.isPresent() && link.isRemoveInputField()) {
            typeDefinition.getFieldDefinitions().remove(sourceInputField.get());
        }

        Type targetType = new TypeName(link.getTargetType());
        if (!newField.isPresent()) {
            // Add source field to schema if not already there
            if (sourceInputField.isPresent() && isListType(sourceInputField.get().getType())) {
                targetType = new ListType(targetType);
            }
            FieldDefinition field = new FieldDefinition(link.getSourceField(), targetType);
            typeDefinition.getFieldDefinitions().add(field);
        } else if (isListType(newField.get().getType())) {
            if (newField.get().getType() instanceof NonNullType) {
                newField.get().setType(new NonNullType(new ListType(targetType)));
            } else {
                newField.get().setType(new ListType(targetType));
            }
        } else {
            // Change source field type to the braided type
            newField.get().setType(targetType);
        }
    }

    private static ObjectTypeDefinition getObjectTypeDefinition(ObjectTypeDefinition queryObjectTypeDefinition,
                                                                ObjectTypeDefinition mutationObjectTypeDefinition,
                                                                TypeDefinitionRegistry typeRegistry,
                                                                Map<String, TypeDefinition> dsTypes,
                                                                String linkSourceType) {
        ObjectTypeDefinition typeDefinition = (ObjectTypeDefinition) dsTypes.get(linkSourceType);
        if (typeDefinition == null && linkSourceType.equals(queryObjectTypeDefinition.getName())) {
            typeDefinition = findQueryType(typeRegistry).orElse(null);
            if (typeDefinition == null && linkSourceType.equals(mutationObjectTypeDefinition.getName())) {
                typeDefinition = findMutationType(typeRegistry).orElse(null);
            }
        }

        if (typeDefinition == null) {
            throw new IllegalArgumentException("Can't find source type: " + linkSourceType);
        }
        return typeDefinition;
    }

    private static void validateSourceFromFieldExists(BraidSchemaSource source, Link link, TypeDefinitionRegistry privateTypeDefinitionRegistry) {
        final String sourceType = source.getSourceTypeName(link.getSourceType());
        ObjectTypeDefinition typeDefinition = privateTypeDefinitionRegistry
                .getType(sourceType, ObjectTypeDefinition.class)
                .orElseThrow(() -> new IllegalArgumentException(
                        format("Can't find source type '%s' in private schema for link %s",
                                sourceType, link.getSourceField())));
        //noinspection ResultOfMethodCallIgnored
        typeDefinition.getFieldDefinitions().stream()
                .filter(d -> d.getName().equals(link.getSourceFromField()))
                .findFirst()
                .orElseThrow(() ->
                        new IllegalArgumentException(
                                format("Can't find source from field: %s", link.getSourceFromField())));
    }

    private static boolean isListType(Type type) {
        return type instanceof ListType ||
                (type instanceof NonNullType && ((NonNullType) type).getType() instanceof ListType);
    }
}
