package org.mule.weave.v2.interpreted.node.executors

import org.mule.weave.v2.core.functions.BaseBinaryFunctionValue
import org.mule.weave.v2.core.exception.ExecutionException
import org.mule.weave.v2.core.exception.UnexpectedFunctionCallTypesException

import org.mule.weave.v2.interpreted.Frame
import org.mule.weave.v2.interpreted.node.FunctionDispatchingHelper
import org.mule.weave.v2.interpreted.node.FunctionDispatchingHelper.allTargets
import org.mule.weave.v2.interpreted.node.FunctionDispatchingHelper.findMatchingFunctionWithCoercion
import org.mule.weave.v2.interpreted.node.FunctionDispatchingHelper.indexOfFunction
import org.mule.weave.v2.interpreted.node.ValueNode
import org.mule.weave.v2.interpreted.node.structure.ConstantValueNode
import org.mule.weave.v2.model.types.Type
import org.mule.weave.v2.model.values.FunctionValue
import org.mule.weave.v2.model.values.Value
import org.mule.weave.v2.model.values.ValuesHelper
import org.mule.weave.v2.interpreted.ExecutionContext
import org.mule.weave.v2.parser.location.WeaveLocation
import org.mule.weave.v2.runtime.core.functions.OverloadedFunctionValue

import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.atomic.AtomicReference

abstract class BinaryOverloadedStaticExecutor(val isOperator: Boolean) extends BinaryExecutor {

  private val cachedDispatchIndex: AtomicInteger = new AtomicInteger(-1)
  private val cachedCoercedDispatchIndex: AtomicReference[CachedCoercedFunctionDispatch] = new AtomicReference()

  override def execute(arguments: Array[Value[_]])(implicit ctx: ExecutionContext): Value[Any] = {
    //This should always be like this as they are operators
    executeBinary(arguments(0), arguments(1))
  }

  def executeBinary(lv: Value[_], rv: Value[_])(implicit ctx: ExecutionContext): Value[Any] = {
    val targetFunction: FunctionValue = target()
    val resolvedTargets: Array[_ <: FunctionValue] = allTargets(targetFunction)
    val dispatchIndex = cachedDispatchIndex.get()
    if (dispatchIndex != -1) {
      val binaryFunctionValue: BaseBinaryFunctionValue = resolvedTargets(dispatchIndex).asInstanceOf[BaseBinaryFunctionValue]
      //If values are literal we do not need to validate every just the first time when we load the binaryFunctionValue
      //Then is ok every time as its type will NEVER change
      val firstValue: Value[_] = resolveCachedDispatchFirstValue(lv, binaryFunctionValue)
      val secondValue: Value[_] = resolveCachedDispatchSecondValue(rv, binaryFunctionValue)

      if ((firstArgConstantType || binaryFunctionValue.L.accepts(firstValue)) && (secondArgConstantType || binaryFunctionValue.R.accepts(secondValue))) {
        return doCall(firstValue, secondValue, binaryFunctionValue)
      }
    }

    val firstValue: Value[_] = if (targetFunction.paramsTypesRequiresMaterialize) {
      FunctionDispatchingHelper.materializeOverloadedFunctionArgs(resolvedTargets, 0, lv)
    } else {
      lv
    }
    val secondValue: Value[_] = if (targetFunction.paramsTypesRequiresMaterialize) {
      FunctionDispatchingHelper.materializeOverloadedFunctionArgs(resolvedTargets, 1, rv)
    } else {
      rv
    }

    val coercedOperation = cachedCoercedDispatchIndex.get()
    if (coercedOperation != null) {
      //If values are literal we do not need to validate every just the first time when we load the coerced operation
      //Then is ok every time as its type will NEVER change
      if ((firstArgConstantType || coercedOperation.leftValueType.accepts(lv)) &&
        (secondArgConstantType || coercedOperation.rightValueType.accepts(rv))) {
        val functionToDispatch = resolvedTargets(coercedOperation.functionIndex).asInstanceOf[BaseBinaryFunctionValue]

        val maybeFirstValue = if (!coercedOperation.leftArgNeedsCoercion) Some(firstValue) else functionToDispatch.L.coerceMaybe(firstValue)
        val maybeSecondValue = if (!coercedOperation.rightArgNeedsCoercion) Some(secondValue) else functionToDispatch.R.coerceMaybe(secondValue)

        if (maybeFirstValue.isDefined && maybeSecondValue.isDefined) {
          return doCall(maybeFirstValue.get, maybeSecondValue.get, functionToDispatch)
        }
      }
    }

    val matchingOp = findMatchingFunction(resolvedTargets, firstValue, secondValue)
    if (matchingOp > -1) {
      //Update it here should we use any strategy
      if (targetFunction.dispatchCanBeCached) {
        cachedDispatchIndex.set(matchingOp)
      }
      doCall(firstValue, secondValue, resolvedTargets(matchingOp).asInstanceOf[BaseBinaryFunctionValue])
    } else {
      //VERY SLOW PATH
      val materializedValues: Array[Value[Any]] = ValuesHelper.array(firstValue.materialize, secondValue.materialize)
      val argTypes: Array[Type] = materializedValues.map(_.valueType)
      val sortedOperators: Array[FunctionValue] = FunctionDispatchingHelper.sortByParameterTypeWeight(resolvedTargets, argTypes)
      val functionToCallWithCoercion: Option[(Int, Array[Value[_]], Seq[Int])] = findMatchingFunctionWithCoercion(materializedValues, sortedOperators, this)
      functionToCallWithCoercion match {
        case Some((functionToDispatch, argumentsWithCoercion, paramsToCoerce)) =>
          //Cache the coercion use the base type to avoid Memory Leaks as Types may have references to Streams or Objects
          val binaryFunctionValue = sortedOperators(functionToDispatch).asInstanceOf[BaseBinaryFunctionValue]
          val leftCoercedValue = argumentsWithCoercion(0)
          val rightCoercedValue = argumentsWithCoercion(1)
          val leftValueType = firstValue.valueType.baseType
          val rightValueType = secondValue.valueType.baseType
          if (targetFunction.dispatchCanBeCached && binaryFunctionValue.allowUseCachedOnCoerce(leftValueType, rightValueType)) {
            val idx = indexOfFunction(resolvedTargets, binaryFunctionValue)
            val cached = CachedCoercedFunctionDispatch(idx, leftValueType, rightValueType, paramsToCoerce.contains(0), paramsToCoerce.contains(1))
            cachedCoercedDispatchIndex.set(cached)
          }
          doCall(leftCoercedValue, rightCoercedValue, binaryFunctionValue)
        case None =>
          throw new UnexpectedFunctionCallTypesException(location(), name, materializedValues, sortedOperators.map(_.parameters.map(_.wtype)))
      }
    }
  }

  final def resolveCachedDispatchFirstValue(lv: Value[_], binaryFunctionValue: BaseBinaryFunctionValue)(implicit ctx: ExecutionContext): Value[_] = {
    if (isOperator) {
      lv
    } else {
      if (!firstArgConstantType && binaryFunctionValue.leftParam.typeRequiresMaterialization) {
        lv.materialize
      } else {
        lv
      }
    }
  }

  final def resolveCachedDispatchSecondValue(rv: Value[_], binaryFunctionValue: BaseBinaryFunctionValue)(implicit ctx: ExecutionContext): Value[_] = {
    if (isOperator) {
      rv
    } else {
      if (!secondArgConstantType && binaryFunctionValue.rightParam.typeRequiresMaterialization) {
        rv.materialize
      } else {
        rv
      }
    }
  }

  final def doCall(leftValue: Value[_], rightValue: Value[_], operation: BaseBinaryFunctionValue)(implicit ctx: ExecutionContext): Value[_] = {
    if (isOperator) {
      operation.call(leftValue, rightValue)
    } else {
      try {
        operation.call(leftValue, rightValue)
      } catch {
        case ex: ExecutionException =>
          if (showInStacktrace()) {
            ex.addCallToStacktrace(location(), name())
          }
          throw ex
      }
    }
  }

  def findMatchingFunction(targets: Array[_ <: FunctionValue], firstArg: Value[_], secondArg: Value[_])(implicit ctx: ExecutionContext): Int = {
    var i = 0
    while (i < targets.length) {
      val functionValue: BaseBinaryFunctionValue = targets(i).asInstanceOf[BaseBinaryFunctionValue]
      if (functionValue.L.accepts(firstArg) && functionValue.R.accepts(secondArg)) {
        return i
      }
      i = i + 1
    }
    -1
  }

  def firstArgConstantType: Boolean

  def secondArgConstantType: Boolean

  def target()(implicit ctx: ExecutionContext): FunctionValue

}

case class CachedCoercedFunctionDispatch(functionIndex: Int, leftValueType: Type, rightValueType: Type,
  leftArgNeedsCoercion: Boolean, rightArgNeedsCoercion: Boolean)

/**
  * This class has the logic to execute binary operators. It knows how to dispatch how to do coercions and such
  *
  * @param targetOperators       All the targets
  * @param name                  The name of the function
  * @param location              The location where this call is being done
  * @param firstArgConstantType  If the left type is constant. This mean we can validate at compile time that the type of this param will never change
  * @param secondArgConstantType If the right type is constant. This mean we can validate at compile time that the type of this param will never change
  */
class BinaryOpExecutor(
  val targetOperators: Array[_ <: BaseBinaryFunctionValue],
  val name: String,
  override val location: WeaveLocation,
  val firstArgConstantType: Boolean,
  val secondArgConstantType: Boolean) extends BinaryOverloadedStaticExecutor(isOperator = true) with Product4[Array[_ <: BaseBinaryFunctionValue], String, Boolean, Boolean] {

  private val overloadedFunctionValue = OverloadedFunctionValue.createValue(targetOperators, Array.empty, Some(this.name), location, cacheabe = true)

  override def target()(implicit ctx: ExecutionContext): FunctionValue = {
    overloadedFunctionValue
  }

  override def name()(implicit ctx: ExecutionContext): String = this.name

  override def node(): ValueNode[_] = {
    ConstantValueNode(overloadedFunctionValue)
  }

  override def showInStacktrace(): Boolean = false

  override def _1: Array[_ <: BaseBinaryFunctionValue] = targetOperators

  override def _2: String = name

  override def _3: Boolean = firstArgConstantType

  override def _4: Boolean = secondArgConstantType
}

class BinaryStaticOverloadedFunctionExecutor(
  override val node: ValueNode[_],
  val name: String,
  override val firstArgConstantType: Boolean,
  override val secondArgConstantType: Boolean,
  override val showInStacktrace: Boolean,
  override val location: WeaveLocation) extends BinaryOverloadedStaticExecutor(isOperator = false) with Product4[ValueNode[_], String, Boolean, Boolean] {

  override def target()(implicit ctx: ExecutionContext): FunctionValue = {
    node.execute.asInstanceOf[FunctionValue]
  }

  override def executeBinary(leftValue: Value[_], rightValue: Value[_])(implicit ctx: ExecutionContext): Value[Any] = {
    val activeFrame: Frame = ctx.executionStack().activeFrame()
    try {
      activeFrame.updateCallSite(node)
      super.executeBinary(leftValue, rightValue)
    } finally {
      activeFrame.cleanCallSite()
    }
  }

  override def name()(implicit ctx: ExecutionContext): String = this.name

  override def _1: ValueNode[_] = node

  override def _2: String = name

  override def _3: Boolean = firstArgConstantType

  override def _4: Boolean = secondArgConstantType
}