package org.mule.weave.v2.module.dwb.reader.indexed

import org.mule.weave.v2.core.io.SeekableStream
import org.mule.weave.v2.model.EvaluationContext
import org.mule.weave.v2.model.structure.Namespace
import org.mule.weave.v2.model.values.Value
import org.mule.weave.v2.module.core.common.LocationCacheBuilder
import org.mule.weave.v2.module.core.xml.reader.indexed.LocationCaches
import org.mule.weave.v2.module.core.xml.reader.indexed.TokenArray
import org.mule.weave.v2.module.dwb.DwTokenHelper
import org.mule.weave.v2.module.dwb.DwTokenHelper._
import org.mule.weave.v2.module.dwb.DwTokenType
import org.mule.weave.v2.module.dwb.DwTokenType.DwTokenType
import org.mule.weave.v2.module.dwb.WeaveBinaryUtils
import org.mule.weave.v2.module.dwb.WeaveKeyToken
import org.mule.weave.v2.module.dwb.WeaveValueToken
import org.mule.weave.v2.module.dwb.reader.WeaveBinaryParser
import org.mule.weave.v2.module.dwb.reader.exceptions.DWBRuntimeExecutionException
import org.mule.weave.v2.module.dwb.writer.WeaveBinaryWriter
import org.mule.weave.v2.module.reader.DefaultLongArray
import org.mule.weave.v2.module.reader.ILongArray
import org.mule.weave.v2.module.reader.SeekableStreamSourceReader
import org.mule.weave.v2.module.reader.SourceReader

import java.io.DataInputStream
import java.lang.{ Double => JDouble }
import java.lang.{ Long => JLong }
import java.lang.{ Long => JShort }
import java.nio.charset.Charset
import java.time.ZoneId
import java.time.ZoneOffset
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

class IndexedWeaveBinaryParser(val name: String, ss: SeekableStream)(implicit ctx: EvaluationContext) extends WeaveBinaryParser {
  private val seekableStream = ss.spinOff()
  private val charset: Charset = Charset.forName("UTF-8")

  private val dataInputStream = new DataInputStream(seekableStream)
  private val sourceReader = SeekableStreamSourceReader(seekableStream, charset, ctx.serviceManager.memoryService)

  //I assume we are reading in order
  private val namespaces = new ArrayBuffer[Namespace](0)
  private val names = new ArrayBuffer[String]()

  private val zoneIds = new ZoneIdMap(names)
  private val zoneOffsets = new ZoneOffsetMap(names)

  private val lcBuilder = new LocationCacheBuilder(DwTokenHelper)
  var tokenArray: TokenArray = _
  var locationCaches: LocationCaches = _

  private var depth = 0
  private var hasIndex: Boolean = _

  def close(): Unit = {
    dataInputStream.close()
    sourceReader.close()
  }

  def getInput(): BinaryParserInput = {
    new BinaryParserInput(ss, seekableStream, dataInputStream, sourceReader, locationCaches, tokenArray, namespaces, names, zoneIds, zoneOffsets)
  }

  def addToken(token: Array[Long], putTokenInLC: Boolean): Unit = {
    tokenArray += token
    if (putTokenInLC) {
      lcBuilder.addToken(token, tokenIndex = tokenArray.length - 1)
    }
  }

  def parse(): Value[_] = {
    readHeader()
    if (hasIndex) {
      readCaches()
    } else {
      loadCaches()
    }
    WeaveBinaryValue.apply(0, None, getInput())
  }

  def readHeader(): Unit = {
    verifyMagicWord(dataInputStream)
    verifyVersion(dataInputStream)
    hasIndex = verifyIndexPresence(dataInputStream)

  }

  /**
    * Loads the global token index and the location caches index.
    */
  def loadCaches(): Unit = {
    if (tokenArray != null) {
      return //only do it once
    }
    tokenArray = ctx.registerCloseable(new TokenArray())

    val nextValueType = readTokenType()
    if (nextValueType == -1) {
      throw new IllegalArgumentException("Input should not be empty")
    }
    processValue(nextValueType, putTokenInLC = true)
    locationCaches = ctx.registerCloseable(lcBuilder.build())
    seekableStream.seek(WeaveBinaryWriter.HEADER_BYTES)
  }

  def processValue(tokenTypeWithFlags: Int, putTokenInLC: Boolean): Unit = {
    val tokenType = DwTokenHelper.getTokenType(tokenTypeWithFlags)
    var hasSchema = DwTokenHelper.hasSchemaProps(tokenTypeWithFlags)
    tokenType match {

      case DwTokenType.Key =>
        val nameIndex = dataInputStream.readUnsignedShort()
        val nameStr = names(nameIndex)
        val token = createKeyToken(nameStr, nameIndex, NO_NAMESPACE, DwTokenType.Key)
        addToken(token, putTokenInLC)

      case DwTokenType.KeyWithNS =>
        val nameIndex = dataInputStream.readUnsignedShort()
        val nameStr = names(nameIndex)
        val nsIndex = dataInputStream.readUnsignedShort()
        val token = createKeyToken(nameStr, nameIndex, nsIndex, DwTokenType.KeyWithNS)
        addToken(token, putTokenInLC)

      case DwTokenType.KeyWithAttr =>
        val nameIndex = dataInputStream.readUnsignedShort()
        val nameStr = names(nameIndex)
        val token = createKeyToken(nameStr, nameIndex, NO_NAMESPACE, DwTokenType.KeyWithAttr)
        addToken(token, putTokenInLC)
        readAttributes()

      case DwTokenType.KeyWithNSAttr =>
        val nameIndex = dataInputStream.readUnsignedShort()
        val nameStr = names(nameIndex)
        val nsIndex = dataInputStream.readUnsignedShort()
        val token = createKeyToken(nameStr, nameIndex, nsIndex, DwTokenType.KeyWithNSAttr)
        addToken(token, putTokenInLC)
        readAttributes()

      case DwTokenType.String8 =>
        // Max String size is 255 chars
        val length = dataInputStream.read()
        val token = createValueToken(length, tokenType, hasSchema)
        addToken(token, putTokenInLC)
        val readBuffer = getReadBuffer(length)
        BinaryValueRetriever.readString(dataInputStream, length, readBuffer) //TODO: skip instead of reading

      case DwTokenType.String32 =>
        // Max String size is 4 GB
        val length = dataInputStream.readInt()
        val token = createValueToken(length, tokenType, hasSchema)
        addToken(token, putTokenInLC)
        BinaryValueRetriever.readString(dataInputStream, length, getReadBuffer(length))

      case DwTokenType.Null =>
        val token = createValueToken(length = 0, tokenType, hasSchema)
        addToken(token, putTokenInLC)

      case DwTokenType.True =>
        val token = createValueToken(length = 0, tokenType, hasSchema)
        addToken(token, putTokenInLC)

      case DwTokenType.False =>
        val token = createValueToken(length = 0, tokenType, hasSchema)
        addToken(token, putTokenInLC)

      case DwTokenType.Int =>
        val token = createValueToken(length = Integer.BYTES, tokenType, hasSchema)
        addToken(token, putTokenInLC)
        skipBytes(Integer.BYTES)

      case DwTokenType.Long =>
        val token = createValueToken(length = JLong.BYTES, tokenType, hasSchema)
        addToken(token, putTokenInLC)
        skipBytes(java.lang.Long.BYTES)

      case DwTokenType.BigInt =>
        val length = dataInputStream.readUnsignedShort()
        val token = createValueToken(length, tokenType, hasSchema)
        addToken(token, putTokenInLC)
        skipBytes(length)

      case DwTokenType.Double =>
        val token = createValueToken(JDouble.BYTES, tokenType, hasSchema)
        addToken(token, putTokenInLC)
        skipBytes(JDouble.BYTES)

      case DwTokenType.BigDecimal =>
        val length = dataInputStream.readUnsignedShort()
        val token = createValueToken(length, tokenType, hasSchema)
        addToken(token, putTokenInLC)
        skipBytes(length)

      case DwTokenType.ObjectStart =>
        val startToken = createValueToken(length = 0, tokenType, hasSchema = false)
        val startTokenIndex = tokenArray.length
        addToken(startToken, putTokenInLC)
        val structureEndToken = readObject()
        hasSchema = DwTokenHelper.hasSchemaProps(structureEndToken)
        if (hasSchema) {
          //this updates the length of the start object, which is used to know where the schema starts
          //and the new token says the object has a schema
          val startOffset = DwTokenHelper.getOffset(startToken)
          val length = seekableStream.position() - startOffset
          val newStartToken = WeaveValueToken(startOffset, DwTokenType.ObjectStart, depth, length, hasSchema)
          tokenArray.update(startTokenIndex, newStartToken)
        }

      case DwTokenType.ArrayStart =>
        val startToken = createValueToken(length = 0, tokenType, hasSchema = false)
        val startTokenIndex = tokenArray.length
        addToken(startToken, putTokenInLC)
        val structureEndToken = readArray()
        hasSchema = DwTokenHelper.hasSchemaProps(structureEndToken)
        if (hasSchema) {
          //this updates the length of the start array, which is used to know where the schema starts
          //and the new token says the array has a schema
          val startOffset = DwTokenHelper.getOffset(startToken)
          val length = seekableStream.position() - startOffset
          val newStartToken = WeaveValueToken(startOffset, DwTokenType.ArrayStart, depth, length, hasSchema)
          tokenArray.update(startTokenIndex, newStartToken)
        }

      case DwTokenType.Binary =>
        val length = dataInputStream.readInt()
        val token = createValueToken(length, tokenType, hasSchema)
        addToken(token, putTokenInLC)
        skipBytes(length)

      case DwTokenType.DeclareNS =>
        readNSDeclaration()
        processValue(readTokenType(), putTokenInLC)

      case DwTokenType.DeclareName =>
        readNameDeclaration()
        processValue(readTokenType(), putTokenInLC)

      case DwTokenType.DateTime =>
        //unixTimestamp + nanos + zoneId
        val length = JLong.BYTES + Integer.BYTES + JShort.BYTES
        val token = createValueToken(length, tokenType, hasSchema)
        addToken(token, putTokenInLC)
        skipBytes(length)

      case DwTokenType.LocalDateTime =>
        //unixTimestamp + nanos
        val length = JLong.BYTES + Integer.BYTES
        val token = createValueToken(length, tokenType, hasSchema)
        addToken(token, putTokenInLC)
        skipBytes(length)

      case DwTokenType.LocalDate =>
        //unixTimestamp
        val length = JLong.BYTES
        val token = createValueToken(length, tokenType, hasSchema)
        addToken(token, putTokenInLC)
        skipBytes(length)

      case DwTokenType.Time =>
        //nanoOfDay + zoneOffset
        val length = JLong.BYTES + JShort.BYTES
        val token = createValueToken(length, tokenType, hasSchema)
        addToken(token, putTokenInLC)
        skipBytes(length)

      case DwTokenType.LocalTime =>
        //nanoOfDay
        val length = JLong.BYTES
        val token = createValueToken(length, tokenType, hasSchema)
        addToken(token, putTokenInLC)
        skipBytes(length)

      case DwTokenType.TimeZone =>
        //zoneId
        val length = JShort.BYTES
        val token = createValueToken(length, tokenType, hasSchema)
        addToken(token, putTokenInLC)
        skipBytes(length)

      case DwTokenType.Period =>
        val length = dataInputStream.readUnsignedShort()
        val token = createValueToken(length, tokenType, hasSchema)
        addToken(token, putTokenInLC)
        skipBytes(length)

      case DwTokenType.Range =>
        val length = Integer.BYTES + Integer.BYTES
        val token = createValueToken(length, tokenType, hasSchema)
        addToken(token, putTokenInLC)
        skipBytes(length)

      case DwTokenType.Regex =>
        val length = dataInputStream.readUnsignedShort()
        val token = createValueToken(length, tokenType, hasSchema)
        addToken(token, putTokenInLC)
        skipBytes(length)

      case x =>
        throw new DWBRuntimeExecutionException("Unexpected value type '" + DwTokenType.getNameFor(x) + "'")
    }
    if (hasSchema) {
      readSchema()
    }
  }

  /**
    * The place this function is invoked in matters because it depends on `seekableStream.position` and `depth` field values
    */
  private def createKeyToken(name: String, nameIndex: Int, nsIndex: Int, tokenType: DwTokenType): Array[Long] = {
    val length = WeaveBinaryUtils.getUTFByteLength(name)
    val nameHash = DwTokenHelper.hash(name)
    WeaveKeyToken(seekableStream.position(), tokenType, depth, length, nameHash, nameIndex, nsIndex)
  }

  /**
    * The place this function is invoked in matters because it depends on `seekableStream.position` and `depth` field values
    */
  private def createValueToken(length: Long, tokenType: DwTokenType, hasSchema: Boolean): Array[Long] = {
    WeaveValueToken(seekableStream.position(), tokenType, depth, length, hasSchema)
  }

  def readTokenType(): Int = {
    dataInputStream.read()
  }

  private def readAttributes(): Unit = {
    var i = 0
    val attrCount = dataInputStream.readUnsignedShort()
    while (i < attrCount) {
      val nextValueType = readTokenType()
      processValue(nextValueType, putTokenInLC = false)
      processValue(readTokenType(), putTokenInLC = false)
      i += 1
    }
  }

  def readSchema(): Unit = {
    val propCount = dataInputStream.readUnsignedShort()
    var i = 0
    while (i < propCount) {
      val nextKeyValueType = dataInputStream.read()
      processValue(nextKeyValueType, putTokenInLC = false)
      //FIXME: Maybe use another type for these schema props, I'm writing them as DwTokenType.Key,
      // so it's parsed as KeyValue, but schema properties expect StringValue
      processValue(dataInputStream.read(), putTokenInLC = false)
      i += 1
    }
  }

  private def tokenType(tokenTypeWithFlags: Int): Int = {
    tokenTypeWithFlags & DwTokenHelper.TOKEN_TYPE_MASK
  }

  /**
    * Returns the structure end token
    */
  private def readArray(): Int = {
    depth += 1
    var tokenTypeWithFlags = readTokenType()
    while (tokenType(tokenTypeWithFlags) != DwTokenType.StructureEnd) {
      processValue(tokenTypeWithFlags, putTokenInLC = true)
      tokenTypeWithFlags = readTokenType()
    }
    depth -= 1
    tokenTypeWithFlags
  }

  /**
    * Returns the structure end token
    */
  private def readObject(): Int = {
    depth += 1
    var tokenTypeWithFlags = readTokenType()
    while (tokenType(tokenTypeWithFlags) != DwTokenType.StructureEnd) {
      //qname declarations are read in the next readValue()

      // don't do a readKey() directly because you can have declare names
      processValue(tokenTypeWithFlags, putTokenInLC = true)
      processValue(readTokenType(), putTokenInLC = false)
      tokenTypeWithFlags = readTokenType()
    }
    depth -= 1
    tokenTypeWithFlags
  }

  private def readShortString() = {
    BinaryValueRetriever.readShortString(dataInputStream, getReadBuffer())
  }

  def readNameDeclaration(): String = {
    val name = readShortString()
    names += name
    name
  }

  def readNSDeclaration(): Namespace = {
    val str = readShortString()
    val (prefix, uri) = str.splitAt(str.indexOf(":"))
    val ns = Namespace(prefix, uri.substring(1))
    namespaces += ns
    ns
  }

  private def skipBytes(n: Int): Unit = {
    var i = 0
    while (i < n) {
      dataInputStream.read()
      i += 1
    }

  }

  /**
    * Reads the global token index and the location caches index.
    */
  def readCaches(): Unit = {
    if (tokenArray != null) {
      return //only do it once
    }

    //position cursor at the end to read lengths
    val size = seekableStream.size()
    seekableStream.seek(size - 32)
    val namesIndexBytes = dataInputStream.readInt()
    val namesIndexCount = dataInputStream.readInt()
    val nsIndexBytes = dataInputStream.readInt()
    val nsIndexCount = dataInputStream.readInt()
    val tokensLongsCount = dataInputStream.readLong()
    val lcIndexLength = dataInputStream.readLong()

    val offset = size - 32 - namesIndexBytes - nsIndexBytes - (tokensLongsCount * 8) - (lcIndexLength)
    //position cursor to read global token array
    seekableStream.seek(offset)

    //read names
    for (_ <- 1 to namesIndexCount) {
      val length = dataInputStream.readUnsignedShort()
      val name = BinaryValueRetriever.readString(dataInputStream, length, getReadBuffer(length))
      names += name
    }

    //read namespaces
    for (_ <- 1 to nsIndexCount) {
      val length = dataInputStream.readUnsignedShort()
      val prefixUriStr = BinaryValueRetriever.readString(dataInputStream, length, getReadBuffer(length))
      val index = prefixUriStr.indexOf(":")
      val (prefix, uri) = prefixUriStr.splitAt(index)
      namespaces += Namespace(prefix, uri.substring(1))
    }

    tokenArray = ctx.registerCloseable(readGlobalTokenIndex(tokensLongsCount))
    locationCaches = ctx.registerCloseable(readLocationCaches())
    seekableStream.seek(WeaveBinaryWriter.HEADER_BYTES)
  }

  def readGlobalTokenIndex(tokensLongsCount: Long): TokenArray = {
    val tokenArray = new TokenArray()
    var i = 0
    val token = new Array[Long](2)
    while (i < tokensLongsCount / 2) {
      val firstLong = dataInputStream.readLong()
      val secondLong = dataInputStream.readLong()
      token.update(0, firstLong)
      token.update(1, secondLong)
      tokenArray += token
      i += 1
    }
    tokenArray
  }

  def readLocationCaches(): LocationCaches = {
    val lcsLength = dataInputStream.readUnsignedShort()
    val lcs = new Array[ILongArray](lcsLength)
    for (depth <- 0 until lcsLength) {
      val lc = new DefaultLongArray
      val lcLength = dataInputStream.readLong()
      var entryIndex = 0L
      while (entryIndex < lcLength) {
        lc += dataInputStream.readLong()
        entryIndex += 1
      }
      lcs(depth) = lc
    }
    new LocationCaches(lcs)
  }

}

class BinaryParserInput(
  parentSeekableStream: SeekableStream,
  val seekableStream: SeekableStream,
  val dataInputStream: DataInputStream,
  val sourceReader: SourceReader,
  val locationCaches: LocationCaches,
  val tokenArray: TokenArray,
  val namespaces: ArrayBuffer[Namespace],
  val names: ArrayBuffer[String],
  val zoneIds: ZoneIdMap,
  val zoneOffsets: ZoneOffsetMap)(implicit ctx: EvaluationContext) {

  def streamSpinOff(): SeekableStream = parentSeekableStream.spinOff()
}

class ZoneIdMap(names: ArrayBuffer[String]) {
  private val zoneIds = new mutable.HashMap[Integer, ZoneId]()

  def get(index: Int): ZoneId = {
    zoneIds.getOrElseUpdate(index, {
      val zoneIdStr = names(index)
      ZoneId.of(zoneIdStr)
    })
  }

}

class ZoneOffsetMap(names: ArrayBuffer[String]) {
  private val zoneOffsets = new mutable.HashMap[Integer, ZoneOffset]()

  def get(index: Int): ZoneOffset = {
    zoneOffsets.getOrElseUpdate(index, {
      val offsetID = names(index)
      ZoneOffset.of(offsetID)
    })
  }
}
