package org.mule.weave.v2.mapping

import org.mule.weave.v2.codegen.CodeGenerator
import org.mule.weave.v2.codegen.StringCodeWriter
import org.mule.weave.v2.grammar.AsOpId
import org.mule.weave.v2.parser.annotation.EnclosedMarkAnnotation
import org.mule.weave.v2.parser.annotation.InfixNotationFunctionCallAnnotation
import org.mule.weave.v2.parser.ast.AstNode
import org.mule.weave.v2.parser.ast.functions.FunctionCallNode
import org.mule.weave.v2.parser.ast.functions.FunctionCallParametersNode
import org.mule.weave.v2.parser.ast.functions.FunctionNode
import org.mule.weave.v2.parser.ast.functions.FunctionParameter
import org.mule.weave.v2.parser.ast.functions.FunctionParameters
import org.mule.weave.v2.parser.ast.header.HeaderNode
import org.mule.weave.v2.parser.ast.header.directives.ContentType
import org.mule.weave.v2.parser.ast.header.directives.OutputDirective
import org.mule.weave.v2.parser.ast.header.directives.VersionDirective
import org.mule.weave.v2.parser.ast.operators.BinaryOpNode
import org.mule.weave.v2.parser.ast.structure.ArrayNode
import org.mule.weave.v2.parser.ast.structure.DocumentNode
import org.mule.weave.v2.parser.ast.structure.KeyNode
import org.mule.weave.v2.parser.ast.structure.KeyValuePairNode
import org.mule.weave.v2.parser.ast.structure.ObjectNode
import org.mule.weave.v2.parser.ast.structure.StringNode
import org.mule.weave.v2.parser.ast.types.TypeReferenceNode
import org.mule.weave.v2.parser.ast.types.WeaveTypeNode
import org.mule.weave.v2.parser.ast.variables.NameIdentifier
import org.mule.weave.v2.parser.ast.variables.VariableReferenceNode
import org.mule.weave.v2.parser.location.UnknownLocation
import org.mule.weave.v2.ts.WeaveType

import scala.annotation.tailrec

/**
  * The intention of this class is to generate a data-weave script from a Mapping.
  */
class DataMappingCodeGenerator {

  def generateCode(mapping: DataMapping, outputType: String = "application/json"): String = {
    val astNode = generateDocumentAst(mapping, outputType)
    val writer: StringCodeWriter = new StringCodeWriter()
    val codeGenerator: CodeGenerator = new CodeGenerator(writer)
    codeGenerator.generate(astNode)
    writer.codeContent()
  }

  def generateDocumentAst(rootMapping: DataMapping, outputType: String = "application/json"): AstNode = {
    val header = createHeader(outputType)
    DocumentNode(header, generateAst(rootMapping))
  }

  private def generateAst(mapping: DataMapping): AstNode = {
    val childMappings: Array[DataMapping] = mapping.childMappings()
    val arrowAssignments = mapping.fieldAssignments()
    val exprAssignments = mapping.expressionAssignments()

    if (arrowAssignments.isEmpty && exprAssignments.isEmpty && childMappings.length == 1) { //single innerMapping
      val innerMapping = childMappings.head
      val targetPath = innerMapping.relativeTargetPath()
      processInnerMapping(innerMapping, targetPath, currentNodeMaybe = None, parentMaybe = None, result = None)
      //TODO: check if all mappings are to a same target, then append them
    } else {
      val rootObject = ObjectNode(Seq()) //this 'rootObject' will be filled with the nodes generated from the mappings
      for (inner <- childMappings) {
        val targetPath = inner.relativeTargetPath()
        processInnerMapping(inner, targetPath, Some(rootObject), None, Some(rootObject))
      }

      for (FieldAssignment(source, target, sourceType, targetType) <- arrowAssignments) {
        val targetPath = target.relativePathFrom(mapping.target)
        processAssignment(targetPath, source.toSelector(mapping.getMappingIndex), rootObject, sourceType, targetType)
      }

      for (ExpressionAssignment(target, expressionNode, sourceType, targetType) <- exprAssignments) {
        val targetPath = target.relativePathFrom(mapping.target)
        processAssignment(targetPath, expressionNode, rootObject, sourceType, targetType)
      }

      if (mapping.isRootMapping() && arrowAssignments.nonEmpty && arrowAssignments.forall(x => x.target.path().head.isArray() && x.target.cardinality > x.source.cardinality)) {
        //this hack is here for the single-var-to-array-item-in-root case (/flowVars/department -> /[]/employee/customerNumber)
        ArrayNode(Seq(rootObject))
      } else {
        rootObject
      }
    }
  }

  @tailrec
  private def processInnerMapping(mapping: DataMapping, targetPath: Array[NamePathElement], currentNodeMaybe: Option[ObjectNode], parentMaybe: Option[KeyValuePairNode], result: Option[AstNode]): AstNode = {
    targetPath match {
      case Array(lastTarget) =>
        //TODO: handle when the last target key already exists
        processInnerMappingLastSegment(mapping, currentNodeMaybe, parentMaybe, result)

      case htail =>
        val targetSegment = htail.head
        val targetTail = htail.tail
        val selectedValueMaybe = selectValue(currentNodeMaybe, targetSegment.name)
        selectedValueMaybe match {
          case Some(valueNode) =>
            processInnerMapping(mapping, targetTail, Some(valueNode), parentMaybe, result)
          case None =>
            val emptyObject = ObjectNode(Seq())
            val currentNode = currentNodeMaybe.getOrElse(ObjectNode(Seq()))
            val kvp: KeyValuePairNode = addKvp(currentNode, targetSegment.name, emptyObject)
            processInnerMapping(mapping, targetTail, Some(emptyObject), Some(kvp), result.orElse(Some(currentNode)))
        }
    }
  }

  private def processInnerMappingLastSegment(mapping: DataMapping, currentNodeMaybe: Option[ObjectNode], parentMaybe: Option[KeyValuePairNode], resultMaybe: Option[AstNode]): AstNode = {
    val source = mapping.source
    val target = mapping.target
    val indexOfInnerMap = mapping.getMappingIndex
    val sourceNode = source.toSelector(indexOfInnerMap - 1)

    if (source.isArrayOrRepeated() && target.isArray()) {
      val rhs = FunctionNode(FunctionParameters(Seq(FunctionParameter(NameIdentifier(s"value$indexOfInnerMap")), FunctionParameter(NameIdentifier(s"index$indexOfInnerMap")))), generateAst(mapping))
      val node = createMapOp(sourceNode, rhs)
      if (target.isRoot()) {
        node
      } else {
        val currentNode = currentNodeMaybe.getOrElse(ObjectNode(Seq()))
        addKvp(currentNode, target.name, node)
        resultMaybe.getOrElse(currentNode)
      }
    } else if (source.isArrayOrRepeated() && target.isRepeated()) {
      val body = ObjectNode(Seq(KeyValuePairNode(target.toKey, generateAst(mapping))))
      val rhs = FunctionNode(FunctionParameters(Seq(FunctionParameter(NameIdentifier(s"value$indexOfInnerMap")), FunctionParameter(NameIdentifier(s"index$indexOfInnerMap")))), body)
      val mapOp = createMapOp(sourceNode, rhs)
      mapOp.annotate(EnclosedMarkAnnotation(UnknownLocation))
      val node = ObjectNode(Seq(mapOp))

      if (target.isRoot()) {
        node
      } else {
        parentMaybe.foreach(parent => setValue(parent, node))
        resultMaybe.getOrElse(node)
      }
    } else {
      throw new RuntimeException("Invalid inner mapping combination")
    }
  }

  private def createMapObject(lhs: AstNode, rhs: AstNode) = {
    val node = FunctionCallNode(VariableReferenceNode("mapObject"), FunctionCallParametersNode(Seq(lhs, rhs)))
    node.annotate(InfixNotationFunctionCallAnnotation())
    node
  }

  private def createMapOp(lhs: AstNode, rhs: AstNode): FunctionCallNode = {
    val node = FunctionCallNode(VariableReferenceNode("map"), FunctionCallParametersNode(Seq(lhs, rhs)))
    node.annotate(InfixNotationFunctionCallAnnotation())
    node
  }

  private def createAsOp(nodeToCast: AstNode, typeNode: WeaveTypeNode): AstNode = {
    BinaryOpNode(AsOpId, nodeToCast, typeNode)
  }

  private def shouldCoerce(maybeSourceType: Option[WeaveType], maybeTargetType: Option[WeaveType]) = {
    maybeSourceType.isDefined && maybeTargetType.isDefined &&
      (maybeSourceType.get != maybeTargetType.get) &&
      typeToNode(maybeTargetType.get).isDefined
  }

  private def typeToNode(weaveType: WeaveType): Option[WeaveTypeNode] = {
    val weaveTypeStr = weaveType.toString
    val isSimpleType = WeaveType.getSimpleType(weaveTypeStr).isDefined
    if (isSimpleType) {
      Some(TypeReferenceNode(NameIdentifier(weaveTypeStr)))
    } else {
      None
    }
  }

  private def processAssignment(targetPath: Array[NamePathElement], expressionNode: AstNode, currentNode: ObjectNode, maybeSourceType: Option[WeaveType], maybeTargetType: Option[WeaveType]): Unit = {
    //traverse through the nodes following the target path
    //recursively create the target, mutating the nodes
    targetPath match {
      case Array(lastTarget) =>
        //TODO: handle when the last target key already exists
        val valueNode =
          if (shouldCoerce(maybeSourceType, maybeTargetType)) {
            createAsOp(expressionNode, typeToNode(maybeTargetType.get).get)
          } else {
            expressionNode
          }
        addKvp(currentNode, lastTarget.name, valueNode)

      case htail =>
        val targetSegment = htail.head
        val targetTail = htail.tail

        val targetSegmentName = targetSegment.name
        val selectedValueMaybe = selectValue(currentNode, targetSegmentName)
        selectedValueMaybe match {
          case Some(valueNode) =>
            processAssignment(targetTail, expressionNode, valueNode, maybeSourceType, maybeTargetType)
          case None =>
            if (targetSegment.isArray()) {
              val emptyObject = ObjectNode(Seq())
              val array = ArrayNode(Seq(emptyObject))
              addKvp(currentNode, targetSegmentName, array)
              processAssignment(targetTail, expressionNode, emptyObject, maybeSourceType, maybeTargetType)
            } else {
              val emptyObject = ObjectNode(Seq())
              addKvp(currentNode, targetSegmentName, emptyObject)
              processAssignment(targetTail, expressionNode, emptyObject, maybeSourceType, maybeTargetType)
            }
        }
    }
  }

  private def selectValue(currentNodeMaybe: Option[ObjectNode], keyName: String): Option[ObjectNode] = {
    currentNodeMaybe match {
      case Some(currentNode) => selectValue(currentNode, keyName)
      case None              => None
    }
  }

  private def selectValue(currentNode: ObjectNode, keyName: String): Option[ObjectNode] = {
    currentNode.elements.collectFirst {
      case KeyValuePairNode(KeyNode(StringNode(_keyName, _), _, _, _), value: ObjectNode, _) if _keyName == keyName =>
        value
    }
  }

  private def addKvp(obj: ObjectNode, keyName: String, value: AstNode): KeyValuePairNode = {
    val kvp = KeyValuePairNode(KeyNode(keyName), value)
    obj.elements = obj.elements :+ kvp
    kvp
  }

  private def setValue(kvp: KeyValuePairNode, value: AstNode): Unit = {
    kvp.value = value
  }

  private def createHeader(outputType: String) = {
    new HeaderNode(List(
      new VersionDirective(),
      OutputDirective(ContentType(outputType), None, None)))
  }
}
