package org.mule.weave.v2.module.dwb.reader.memory

import org.mule.weave.v2.core.io.SeekableStream
import org.mule.weave.v2.dwb.api.IWeaveValue
import org.mule.weave.v2.model.EvaluationContext
import org.mule.weave.v2.model.capabilities.UnknownLocationCapable
import org.mule.weave.v2.model.structure.ArraySeq
import org.mule.weave.v2.model.structure.KeyValuePair
import org.mule.weave.v2.model.structure.NameSeq
import org.mule.weave.v2.model.structure.NameValuePair
import org.mule.weave.v2.model.structure.Namespace
import org.mule.weave.v2.model.structure.ObjectSeq
import org.mule.weave.v2.model.structure.QualifiedName
import org.mule.weave.v2.model.structure.schema.Schema
import org.mule.weave.v2.model.structure.schema.SchemaProperty
import org.mule.weave.v2.model.types.StringType
import org.mule.weave.v2.model.values.ArrayValue
import org.mule.weave.v2.model.values.AttributesValue
import org.mule.weave.v2.model.values.BinaryValue
import org.mule.weave.v2.model.values.BooleanValue
import org.mule.weave.v2.model.values.DateTimeValue
import org.mule.weave.v2.model.values.KeyValue
import org.mule.weave.v2.model.values.LocalDateTimeValue
import org.mule.weave.v2.model.values.LocalDateValue
import org.mule.weave.v2.model.values.LocalTimeValue
import org.mule.weave.v2.model.values.MaterializedObjectValue
import org.mule.weave.v2.model.values.NullValue
import org.mule.weave.v2.model.values.NumberValue
import org.mule.weave.v2.model.values.ObjectValue
import org.mule.weave.v2.model.values.PeriodValue
import org.mule.weave.v2.model.values.RangeValue
import org.mule.weave.v2.model.values.RegexValue
import org.mule.weave.v2.model.values.StringValue
import org.mule.weave.v2.model.values.TimeValue
import org.mule.weave.v2.model.values.TimeZoneValue
import org.mule.weave.v2.model.values.Value
import org.mule.weave.v2.module.core.json.reader.memory.InMemoryJsonArray
import org.mule.weave.v2.module.dwb.DwTokenHelper
import org.mule.weave.v2.module.dwb.DwTokenType
import org.mule.weave.v2.module.dwb.reader.RedefinedValueRetriever
import org.mule.weave.v2.module.dwb.reader.WeaveBinaryParser
import org.mule.weave.v2.module.dwb.reader.WeaveValue
import org.mule.weave.v2.module.dwb.reader.exceptions.DWBRuntimeExecutionException
import org.mule.weave.v2.module.dwb.reader.indexed.BinaryValueRetriever
import org.mule.weave.v2.parser.location.LocationCapable
import org.mule.weave.v2.parser.location.UnknownLocation

import java.io.BufferedInputStream
import java.io.DataInputStream
import java.time.Instant
import java.time.LocalDate
import java.time.LocalDateTime
import java.time.LocalTime
import java.time.OffsetTime
import java.time.ZoneId
import java.time.ZoneOffset
import java.time.ZonedDateTime
import java.util.{ Map => JMap }
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

class InMemoryWeaveBinaryParser(val name: String, val inputStream: SeekableStream)(implicit ctx: EvaluationContext) extends WeaveBinaryParser {

  private val input = new DataInputStream(new BufferedInputStream(inputStream))

  //I assume we are reading in order
  private val namespaces = new ArrayBuffer[Namespace](0)
  private val names = new ArrayBuffer[String]()
  private val zoneIds = new mutable.HashMap[Integer, ZoneId]()
  private val zoneOffsets = new mutable.HashMap[Integer, ZoneOffset]()

  def readHeader(): Unit = {
    verifyMagicWord(input)
    verifyVersion(input)
    input.read() //FIXME: currently skipping index presence flag
  }

  def parse(): Value[_] = {
    readHeader()
    val nextValueType = input.read()
    if (nextValueType == -1) {
      throw new IllegalArgumentException("Input should not be empty")
    }
    val value = readValue(nextValueType)
    value
  }

  def readSchema(): Schema = {
    val props = new ArrayBuffer[SchemaProperty]
    val propCount = input.readUnsignedShort()
    var i = 0
    while (i < propCount) {
      val nextKeyValueType = input.read()
      val key = readValue(nextKeyValueType).asInstanceOf[KeyValue]
      //FIXME: Maybe use another type for these schema props, I'm writing them as DwTokenType.Key,
      // so it's parsed as KeyValue, but schema properties expect StringValue
      val keyStrValue = StringValue(key.evaluate.name)
      val value = readValue(input.read())
      props += SchemaProperty(keyStrValue, value)
      i += 1
    }
    Schema(props)
  }

  /**
    * Reads the stream and returns the next byte or -1 if the end of the stream is reached.
    */
  def readTokenType(): Int = {
    input.read()
  }

  def readValue(tokenTypeWithFlags: Int): Value[_] = {
    val tokenType = DwTokenHelper.getTokenType(tokenTypeWithFlags)
    val hasSchema = DwTokenHelper.hasSchemaProps(tokenTypeWithFlags)
    val value: Value[_] = tokenType match {

      case DwTokenType.Key =>
        val name: String = readLocalName()
        val qName = QualifiedName(name, None)
        KeyValue(qName, None)

      case DwTokenType.KeyWithNS =>
        val name: String = readLocalName()
        val namespace = readNamespace()
        val qName = QualifiedName(name, Some(namespace))
        KeyValue(qName, None)

      case DwTokenType.KeyWithAttr =>
        val name: String = readLocalName()
        val qName = QualifiedName(name, None)
        val attrs = readAttributes()
        KeyValue(qName, Some(attrs))

      case DwTokenType.KeyWithNSAttr =>
        val name: String = readLocalName()
        val namespace = readNamespace()
        val qName = QualifiedName(name, Some(namespace))
        val attrs = readAttributes()
        KeyValue(qName, Some(attrs))

      case DwTokenType.String8 =>
        // Max String size is 255 chars
        val length = input.read()
        val str = BinaryValueRetriever.readString(input, length, getReadBuffer(length))
        if (hasSchema) {
          val schema = readSchema()
          StringValue(str, UnknownLocationCapable, Some(schema))
        } else {
          StringValue(str)
        }

      case DwTokenType.String32 =>
        // Max String size is 4 GB
        val length = input.readInt()
        val str = BinaryValueRetriever.readString(input, length, getReadBuffer(length))
        if (hasSchema) {
          val schema = readSchema()
          StringValue(str, UnknownLocationCapable, Some(schema))
        } else {
          StringValue(str)
        }

      case DwTokenType.Null =>
        NullValue

      case DwTokenType.True =>
        BooleanValue.TRUE_BOOL

      case DwTokenType.False =>
        BooleanValue.FALSE_BOOL

      case DwTokenType.Int =>
        val number = input.readInt()
        NumberValue(number)

      case DwTokenType.Long =>
        val number = input.readLong()
        NumberValue(number)

      case DwTokenType.BigInt =>
        val length = input.readUnsignedShort()
        val buffer = new Array[Byte](length)
        input.readFully(buffer, 0, length)
        NumberValue(BigInt.apply(buffer))

      case DwTokenType.Double =>
        val number = input.readDouble()
        NumberValue(number)

      case DwTokenType.BigDecimal =>
        val str = BinaryValueRetriever.readShortString(input, getReadBuffer())
        NumberValue(BigDecimal.apply(str))

      case DwTokenType.ObjectStart =>
        readObject()

      case DwTokenType.ArrayStart =>
        readArray()

      case DwTokenType.Binary =>
        val length = input.readInt()
        val buffer = new Array[Byte](length) //TODO: Use a seekableStream.spinOff()
        input.readFully(buffer, 0, length)
        BinaryValue(buffer)

      case DwTokenType.DeclareNS =>
        readNSDeclaration()
        readValue(readTokenType())

      case DwTokenType.DeclareName =>
        readNameDeclaration()
        readValue(readTokenType())

      case DwTokenType.DateTime =>
        val unixTimestamp = input.readLong()
        val nanos = input.readInt()
        val instant = Instant.ofEpochSecond(unixTimestamp, nanos)
        val zoneId = readZoneId()
        val zonedDateTime = ZonedDateTime.ofInstant(instant, zoneId)
        DateTimeValue(zonedDateTime)

      case DwTokenType.LocalDateTime =>
        val unixTimestamp = input.readLong()
        val nanos = input.readInt()
        LocalDateTimeValue(LocalDateTime.ofEpochSecond(unixTimestamp, nanos, ZoneOffset.UTC))

      case DwTokenType.LocalDate =>
        val epochDay = input.readLong()
        LocalDateValue(LocalDate.ofEpochDay(epochDay))

      case DwTokenType.Time =>
        val nanoOfDay = input.readLong()
        val localTime = LocalTime.ofNanoOfDay(nanoOfDay)
        val zoneOffset = readZoneOffset()
        val offsetTime = OffsetTime.of(localTime, zoneOffset)
        TimeValue(offsetTime)

      case DwTokenType.LocalTime =>
        val nanoOfDay = input.readLong()
        val localTime = LocalTime.ofNanoOfDay(nanoOfDay)
        LocalTimeValue(localTime)

      case DwTokenType.TimeZone =>
        val zoneId = readZoneId()
        TimeZoneValue(zoneId)

      case DwTokenType.Period =>
        val str = input.readUTF()
        PeriodValue(str, UnknownLocationCapable)

      case DwTokenType.Regex =>
        val str = input.readUTF()
        RegexValue(str)

      case DwTokenType.Range =>
        val start = input.readInt()
        val end = input.readInt()
        RangeValue(start, end, UnknownLocationCapable)

      case x =>
        throw new DWBRuntimeExecutionException("Unexpected value type '" + DwTokenType.getNameFor(x) + "'")
    }
    value

  }

  private def readShortString() = {
    BinaryValueRetriever.readShortString(input, getReadBuffer())
  }

  def readNameDeclaration(): String = {
    val name = readShortString()
    names += name
    name
  }

  def readNSDeclaration(): Namespace = {
    val str = readShortString()
    val (prefix, uri) = str.splitAt(str.indexOf(":"))
    val ns = Namespace(prefix, uri.substring(1))
    namespaces += ns
    ns
  }

  private def readArray(): ArrayValue = {
    val values = new ArrayBuffer[Value[_]]
    var nextValueType = readTokenType()
    while (nextValueType != DwTokenType.StructureEnd) {
      val value = readValue(nextValueType)
      values += value
      nextValueType = readTokenType()
    }
    new InMemoryJsonArray(ArraySeq(values, materialized = true), UnknownLocation)
  }

  private def readObject(): ObjectValue = {
    val kvps = new ArrayBuffer[KeyValuePair]
    var tokenTypeWithFlags = readTokenType()
    while (tokenType(tokenTypeWithFlags) != DwTokenType.StructureEnd) {
      //qname declarations are read in the next readValue()

      // don't do a readKey() directly because you can have declare names
      val key: KeyValue = readValue(tokenTypeWithFlags).asInstanceOf[KeyValue]
      val value: Value[_] = readValue(readTokenType())
      kvps += KeyValuePair(key, value)
      tokenTypeWithFlags = readTokenType()
    }
    val objectSeq = ObjectSeq(kvps, materialized = true)
    val objectHasSchema = DwTokenHelper.hasSchemaProps(tokenTypeWithFlags)
    if (objectHasSchema) {
      val schema = readSchema()
      getProcessor(schema) match {
        case Some(processor) => {
          val schemaMap: JMap[String, IWeaveValue[_]] = WeaveValue.toWeaveValueMap(schema)
          //TODO: put real locations
          val redefinedObjectSeq = new RedefinedObjectSeq(objectSeq, processor, schemaMap, UnknownLocationCapable)
          new MaterializedObjectValue(redefinedObjectSeq, UnknownLocationCapable, Some(schema))
        }
        case None =>
          new MaterializedObjectValue(objectSeq, UnknownLocationCapable, Some(schema))
      }
    } else {
      new MaterializedObjectValue(objectSeq, UnknownLocationCapable)
    }
  }

  private def getProcessor(schema: Schema): Option[String] = {
    schema
      .valueOf("processor")
      .map((processor) => {
        StringType.coerce(processor).evaluate.toString
      })
  }

  private def tokenType(tokenTypeWithFlags: Int): Int = {
    tokenTypeWithFlags & DwTokenHelper.TOKEN_TYPE_MASK
  }

  private def readAttributes(): Value[NameSeq] = {
    var attrPairs = new ArrayBuffer[NameValuePair]()
    var i = 0
    val attrCount = input.readUnsignedShort()
    while (i < attrCount) {
      val nextValueType = readTokenType()
      val qName = readValue(nextValueType).asInstanceOf[KeyValue]
      val value = readValue(readTokenType())
      attrPairs += NameValuePair(qName, value)
      i += 1
    }
    AttributesValue(attrPairs)
  }

  /**
    * Reads an index from the stream and returns the name it references.
    * Note: this assumes the name was previously declared.
    */
  def readLocalName(): String = {
    val nameIndex = input.readUnsignedShort()
    names(nameIndex)
  }

  /**
    * Reads an index from the stream and returns the Namespace it references.
    * Note: this assumes the Namespace was previously declared.
    */
  private def readNamespace(): Namespace = {
    val nsIndex = input.readUnsignedShort()
    val namespace = namespaces(nsIndex)
    namespace
  }

  private def readZoneId(): ZoneId = {
    val nameIndex = input.readUnsignedShort()
    val zoneIdMaybe = zoneIds.get(nameIndex)
    zoneIdMaybe match {
      case Some(zoneId) =>
        zoneId
      case None =>
        val nameStr = names(nameIndex)
        val zoneId = ZoneId.of(nameStr)
        zoneIds.put(nameIndex, zoneId)
        zoneId
    }
  }

  private def readZoneOffset(): ZoneOffset = {
    val nameIndex = input.readUnsignedShort()
    val zoneOffsetMaybe = zoneOffsets.get(nameIndex)
    zoneOffsetMaybe match {
      case Some(zoneId) =>
        zoneId
      case None =>
        val nameStr = names(nameIndex)
        val zoneOffset = ZoneOffset.of(nameStr)
        zoneOffsets.put(nameIndex, zoneOffset)
        zoneOffset
    }
  }

}

class RedefinedObjectSeq(seq: ObjectSeq, processorClass: String, schema: JMap[String, IWeaveValue[_]], locationCapable: LocationCapable)(implicit ctx: EvaluationContext) extends ObjectSeq {
  val retriever = new RedefinedValueRetriever(seq, processorClass, schema, locationCapable)

  override def selectKeyValue(key: Value[QualifiedName])(implicit ctx: EvaluationContext): KeyValuePair = {
    val selectedValue = seq.selectKeyValue(key)
    if (selectedValue != null) {
      return selectedValue
    }
    val name = key.evaluate.name
    val redefined = retriever.getRedefinedValue(name)
    if (redefined == null) {
      null
    } else {
      KeyValuePair(KeyValue(name), redefined)
    }
  }

  override def allKeyValuesOf(key: Value[QualifiedName])(implicit ctx: EvaluationContext): Option[ObjectSeq] = {
    seq.allKeyValuesOf(key) match {
      case None =>
        val name: String = key.evaluate.name
        val redefined = retriever.getRedefinedValue(name)
        if (redefined == null) {
          None
        } else {
          Some(ObjectSeq(Map(name -> redefined)))
        }

      case some =>
        some
    }
  }

  override def toIterator()(implicit ctx: EvaluationContext): Iterator[KeyValuePair] = seq.toIterator()

  override def size()(implicit ctx: EvaluationContext): Long = seq.size()

  override def isEmpty()(implicit ctx: EvaluationContext): Boolean = seq.isEmpty()

  override def apply(index: Long)(implicit ctx: EvaluationContext): KeyValuePair = seq(index)

  override def toSeq()(implicit ctx: EvaluationContext): Seq[KeyValuePair] = seq.toSeq()
  override def toArray()(implicit ctx: EvaluationContext): Array[KeyValuePair] = seq.toArray()
  override def removeKey(keyNameToRemove: QualifiedName)(implicit ctx: EvaluationContext): ObjectSeq = seq.removeKey(keyNameToRemove)
}
