package com.atlassian.braid.source;

import com.atlassian.braid.Link;
import com.atlassian.braid.LinkArgument;
import com.atlassian.braid.SchemaSource;
import com.atlassian.braid.transformation.BraidSchemaSource;
import graphql.execution.ConditionalNodes;
import graphql.language.Field;
import graphql.language.FragmentDefinition;
import graphql.language.FragmentSpread;
import graphql.language.InlineFragment;
import graphql.language.Node;
import graphql.language.NodeTraverser;
import graphql.language.NodeVisitorStub;
import graphql.language.Selection;
import graphql.language.SelectionSet;
import graphql.language.SelectionSetContainer;
import graphql.schema.DataFetchingEnvironment;
import graphql.schema.GraphQLFieldsContainer;
import graphql.schema.GraphQLOutputType;
import graphql.schema.GraphQLSchema;
import graphql.util.TraversalControl;
import graphql.util.TraverserContext;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

import static com.atlassian.braid.LinkArgument.ArgumentSource.OBJECT_FIELD;
import static graphql.schema.GraphQLTypeUtil.unwrapAll;
import static java.lang.String.format;

/**
 * When a linked field is queried, the query to the first schema source should include the argument field but not
 * the linked field. The methods in this class is used to handle the update to the field selection.
 *
 * For example, assume the following link definition.
 * <p>
 *     - from:
 *         type: Foo
 *         field: bar
 *         fromField: barId
 *       to:
 *         namespace: bar
 *         type: Bar
 *         field: bar
 *         argument: id
 * </p>
 *
 * For the following query,
 * <p>
 *     query {
 *         foo {
 *             bar {
 *                 value
 *             }
 *         }
 *     }
 * </p>
 *
 * The actual sent to schema source "foo" is the following. Note the field selection "bar" is removed and "barId" is
 * added.
 * <p>
 *     query {
 *         foo {
 *             barId
 *         }
 *     }
 * </p>
 */
public class TrimFieldsSelection {

    /**
     * Adds argument field selection and removes the linked field selection.
     *
     * @param schemaSource      The schema source containing the link.
     * @param rootField         The field selection containing the link field.
     * @param skipTopLevelField Whether to skip the top-level field. This is true when called from LinkTransformation
     *                          and ExtensionTransformation because link field cannot be at the top-level in those cases.
     */
    public static FieldAndReferencedFragments trimFieldSelection(SchemaSource schemaSource,
                                                                 DataFetchingEnvironment environment,
                                                                 Field rootField,
                                                                 boolean skipTopLevelField) {
        BraidSchemaSource braidSchemaSource = new BraidSchemaSource(schemaSource);
        GraphQLSchema schema = environment.getGraphQLSchema();
        Map<String, Object> variables = environment.getVariables();

        // Traverse the root field to replace the link field with argument source field.
        LinkProcessor linkProcessor = new LinkProcessor(
                braidSchemaSource, schema, unwrapAll(environment.getParentType()).getName(), variables, skipTopLevelField);
        Field newRootField = linkProcessor.field(rootField);

        // Each fragment definition needs to be transformed individually.
        Map<String, FragmentDefinition> fragmentsByName = environment.getFragmentsByName();
        Map<String, FragmentDefinition> newFragmentsByName = fragmentsByName.entrySet().stream()
                .collect(Collectors.toMap(Entry::getKey, entry -> {
                    FragmentDefinition fragment = entry.getValue();
                    LinkProcessor fragmentLinkProcessor = new LinkProcessor(
                            braidSchemaSource, schema, fragment.getTypeCondition().getName(), variables, skipTopLevelField);
                    return fragmentLinkProcessor.fragmentDefinition(fragment);
                }));

        Set<FragmentDefinition> referencedFragments = new LinkedHashSet<>();
        getReferencedFragments(newRootField, newFragmentsByName, referencedFragments);
        return new FieldAndReferencedFragments(newRootField, new ArrayList<>(referencedFragments));
    }

    private static boolean selectionSetContainsField(List<Selection> selections, Field fieldToCheck) {
        return selections.stream()
                .filter(selection -> selection instanceof Field)
                .map(field -> (Field) field)
                .anyMatch(field -> field.getName().equals(fieldToCheck.getName())
                        && Objects.equals(field.getAlias(), fieldToCheck.getAlias()));
    }

    /**
     * Recursively searches for fragments starting from the given root node
     *
     * @param root                  - The node to look for references in
     * @param fragmentDefinitionMap - the map of defined fragments in the query keyed by name
     * @param referencedFragments   - The set of already known referenced fragments
     */
    private static void getReferencedFragments(SelectionSetContainer<?> root,
                                               Map<String, FragmentDefinition> fragmentDefinitionMap,
                                               Set<FragmentDefinition> referencedFragments) {
        Set<FragmentDefinition> childFragments = new LinkedHashSet<>();
        NodeVisitorStub nodeVisitorStub = new NodeVisitorStub() {
            @Override
            public TraversalControl visitFragmentSpread(FragmentSpread fragmentSpread, TraverserContext<Node> context) {
                childFragments.add(fragmentDefinitionMap.get(fragmentSpread.getName()));
                return TraversalControl.CONTINUE;
            }
        };
        new NodeTraverser().preOrder(nodeVisitorStub, root);
        childFragments.stream()
                .filter(referencedFragments::add)
                .forEach(frag -> getReferencedFragments(frag, fragmentDefinitionMap, referencedFragments));
    }

    public static class FieldAndReferencedFragments {
        public final Field field;
        public final List<FragmentDefinition> referencedFragments;

        private FieldAndReferencedFragments(Field field, List<FragmentDefinition> referencedFragments) {
            this.field = field;
            this.referencedFragments = referencedFragments;
        }
    }

    private static class LinkProcessor {
        private final ConditionalNodes conditionalNodes = new ConditionalNodes();
        private final BraidSchemaSource braidSchemaSource;
        private final GraphQLSchema schema;
        private final Map<String, Object> variables;
        private final boolean skipTopLevelField;
        private final Deque<String> typeStack;

        private LinkProcessor(BraidSchemaSource braidSchemaSource,
                              GraphQLSchema schema,
                              String parentType,
                              Map<String, Object> variables,
                              boolean skipTopLevelField) {
            this.braidSchemaSource = braidSchemaSource;
            this.schema = schema;
            this.variables = variables;
            this.skipTopLevelField = skipTopLevelField;

            this.typeStack = new ArrayDeque<>();
            typeStack.push(parentType);
        }

        public Field field(Field node) {
            if (node.getSelectionSet() == null) {
                return node;
            }

            boolean isTopLevelField = typeStack.size() == 1;
            if (isTopLevelField && !skipTopLevelField) {
                Link link = getLinkForField(node);
                if (link != null) {
                    LinkArgument linkArgument = link.getLinkArguments().stream()
                            .filter(argument -> argument.getArgumentSource() == OBJECT_FIELD)
                            .findFirst()
                            .orElseThrow(() -> new IllegalStateException(format("Link for top level field '%s' requires exactly one object field argument",
                                    link.getNewFieldName())));

                    return node.transform(builder -> builder
                            .name(linkArgument.getSourceName())
                            .selectionSet(null)
                            .build()
                    );
                }
            }

            GraphQLFieldsContainer parentType = (GraphQLFieldsContainer) schema.getType(typeStack.peek());
            GraphQLOutputType fieldType = parentType.getFieldDefinition(node.getName()).getType();
            typeStack.push((unwrapAll(fieldType).getName()));
            Field newNode = node.transform(builder -> builder
                    .selectionSet(newSelectionSet(node.getSelectionSet()))
            );
            typeStack.pop();
            return newNode;
        }

        public InlineFragment inlineFragment(InlineFragment node) {
            typeStack.push(node.getTypeCondition().getName());

            InlineFragment newNode = node.transform(builder -> builder
                    .selectionSet(newSelectionSet(node.getSelectionSet()))
            );

            typeStack.pop();
            return newNode;
        }

        public FragmentDefinition fragmentDefinition(FragmentDefinition node) {
            return node.transform(builder -> builder
                    .selectionSet(newSelectionSet(node.getSelectionSet()))
            );
        }

        private SelectionSet newSelectionSet(SelectionSet selectionSet) {
            List<Selection> newSelections = new ArrayList<>();
            selectionSet.getSelections().forEach(selection -> {
                if (selection instanceof Field) {
                    Field field = (Field) selection;
                    if (!conditionalNodes.shouldInclude(variables, field.getDirectives())) {
                        return;
                    }

                    Link linkForField = getLinkForField(field);
                    if (linkForField != null) {
                        linkForField.getLinkArguments().stream()
                                // We are interested in arguments that need to be taken from source object
                                .filter(argument -> argument.getArgumentSource() == OBJECT_FIELD)
                                .forEach(argument -> {
                                    Field newField = Field.newField()
                                            .name(argument.getSourceName())
                                            .build();

                                    if (!selectionSetContainsField(newSelections, newField)) {
                                        newSelections.add(newField);
                                    }
                                });
                    } else {
                        newSelections.add(field(field));
                    }
                } else if (selection instanceof InlineFragment) {
                    InlineFragment inlineFragment = (InlineFragment) selection;
                    if (!conditionalNodes.shouldInclude(variables, inlineFragment.getDirectives())) {
                        return;
                    }

                    newSelections.add(inlineFragment(inlineFragment));
                } else if (selection instanceof FragmentSpread) {
                    FragmentSpread fragmentSpread = (FragmentSpread) selection;
                    if (!conditionalNodes.shouldInclude(variables, fragmentSpread.getDirectives())) {
                        return;
                    }

                    newSelections.add(fragmentSpread);
                }
            });
            return SelectionSet.newSelectionSet()
                    .selections(newSelections)
                    .build();
        }

        private Link getLinkForField(Field field) {
            String typeName = typeStack.peek();
            String fieldName = field.getName();

            return braidSchemaSource.getSchemaSource().getLinks().stream()
                    .filter(link -> braidSchemaSource.getLinkBraidSourceType(link).equals(typeName)
                            && link.getNewFieldName().equals(fieldName))
                    .findFirst().orElse(null);
        }
    }
}
