/*
 * Copyright 2017 University of Rostock
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package sessl.ml3

import java.util.function.Consumer

import org.jamesii.core.util.misc.Pair
import org.jamesii.ml3.experiment.Job
import org.jamesii.ml3.model.values.IValue
import org.jamesii.ml3.observation._
import org.jamesii.ml3.simulator.simulators.ISimulator
import sessl._
import sessl.util.SimpleObservation

import scala.collection.JavaConverters._

/**
  * This trait adds two ML3-specific observation methods.
  * First, observations triggered by events are possible.
  * As these events are triggered by a rule that fires for a specific agent, the value of
  * expressions that are evaluated on the agent can be observed.
  * Second, time-triggered observations of expressions that are evaluated for all agents of
  * a certain type are possible, which result in a distribution of observed values for each
  * observation time.
  *
  * @example
  * {{{
  * val exp = new Experiment with Observation {
  *
  *   // observe at certain time steps
  *   observeAt(range(0, 1, 500)
  *   observe("numPerson" ~ agentCount("Person")
  *   observe("numPersonWithFilter" ~ agentCount("Person", "ego.status = 'active'")
  *   observe("statusDistribution" ~ expressionDistribution("Person", "ego.status"))
  *
  *   // observe each time a Person dies
  *   observeAt(Death("Person")) {
  *     observe("statusAtDeath" ~ expression("ego.status"))
  *   }
  * }
  * }}}
  *
  * @author Tom Warnke
  */
trait Observation extends SimpleObservation {
  this: Experiment =>

  private[this] val observations = scala.collection.mutable.Map.empty[EventType, Set[Observable]]
  observations(Times) = Set.empty

  // observers

  private[this] var currentEventType: EventType = Times

  def observeAt(eventType: EventType)(obs: => Unit): Unit = {
    currentEventType = eventType
    observations(eventType) = Set.empty
    obs
    currentEventType = Times
  }

  // listeners

  def agentCount(agentType: String, filter: String = "true"): String = {
    require(currentEventType == Times, "agentCount can only be used with time-based observation")
    observations(Times) += AgentCount(agentType, filter)
    agentType+filter
  }

  def expression(expression: String): String = {
    require(currentEventType != Times, "expression can only be used with event-based observation")
    observations(currentEventType) += Expression(expression)
    expression
  }

  def expressionDistribution(agentType: String, expression: String, filter: String = "true"): String = {
    require(currentEventType == Times, "expressionDistribution can only be used with time-based observation")
    observations(currentEventType) += ExpressionDistribution(agentType, expression, filter)
    agentType+expression+filter
  }

  /** Setting the flags that control a proper call hierarchy, calling event handlers if installed.*/
  override def configure(): Unit = {
    super.configure()

    observations.map {case (event, observables) =>

      // a function to generate the observer for this kind of event
      val observerGenerator = event match {
        case Times => (_: Job) => createTimePointListObserver
        case Creation(agentType, filter) => (job: Job) =>
          new AgentCreationObserver(agentType, filter, job.getModel, job.getParameters)
        case Death(agentType, filter) => (job: Job) =>
          new AgentDeathObserver(agentType, filter, job.getModel, job.getParameters)
        case Change(agentType, field, filter) => (job: Job) =>
          new AgentChangeObserver(agentType, field, filter, job.getModel, job.getParameters)
      }

      // a set of functions that each generate a listener
      val listenerGenerators = observables map {
        case AgentCount(agentType, filter) =>
          (id: Int, job: Job) => createAgentCountListener(id, agentType, filter, job)
        case Expression(expression) =>
          (id: Int, job: Job) => createExpressionListener(id, expression, job)
        case ExpressionDistribution(agentType, expression, filter) => (id: Int, job: Job) =>
            createExpressionDistributionListener(id, agentType, expression, filter, job)
      }

      // a function that construct the observer, the associated listeners, and attaches the whole
      // bunch to the simulator
      val generateObserver = (runId: Int, job: Job, simulator: ISimulator) => {
        val observer = observerGenerator(job)
        for (listenerGenerator <- listenerGenerators)
          observer.registerListener(listenerGenerator(runId, job))
        simulator.addObserver(observer)
      }

      // add the defined function to the observer generators
      instrumentations += generateObserver
    }
  }

  // wrap observer creation
  private[this] def createTimePointListObserver : IObserver = {
    val jTimes = observationTimes.map(d => d.asInstanceOf[java.lang.Double]).asJava
    new TimePointListObserver(jTimes)
  }

  // wrap listener creation
  private[this] def createAgentCountListener(runId: Int, agentType: String, filter: String, job: Job):
  AgentCountListener = {
    val callback = new Consumer[Pair[java.lang.Double, Integer]] {
      override def accept(t: Pair[java.lang.Double, Integer]): Unit = {
        addValueFor(
          runId,
          agentType+filter,
          (t.getFirstValue, t.getSecondValue)
                  .asInstanceOf[TimeStampedData[_]])
      }
    }
    new AgentCountListener(agentType, job.getModel, job.getParameters, filter, callback)
  }

  private[this] def createExpressionListener(runId: Int, expression: String, job: Job):
  AgentExpressionListener = {
    val callback = new Consumer[Pair[java.lang.Double, IValue]] {
      override def accept(t: Pair[java.lang.Double, IValue]): Unit = {
        addValueFor(
          runId,
          expression,
          (t.getFirstValue, ML3ValueToSesslValue(t.getSecondValue))
                  .asInstanceOf[TimeStampedData[_]])
      }
    }
    new AgentExpressionListener(job.getModel, job.getParameters, expression, callback)
  }

  private[this] def createExpressionDistributionListener(runId: Int, agentType: String, expression:
  String, filter: String, job: Job): ExpressionDistributionListener = {
    val callback = new Consumer[Pair[java.lang.Double, java.util.List[IValue]]] {
      override def accept(t: Pair[java.lang.Double, java.util.List[IValue]]): Unit = {
        addValueFor(
          runId,
          agentType+expression+filter,
          (t.getFirstValue, t.getSecondValue.asScala.map(ML3ValueToSesslValue).toVector)
                  .asInstanceOf[TimeStampedData[_]])
      }
    }
    new ExpressionDistributionListener(agentType, job.getModel, job.getParameters, expression, filter,
      callback)
  }

}

sealed trait EventType

case object Times extends EventType
case class Creation(agentType: String, filter: String = "true") extends EventType
case class Death(agentType: String, filter: String = "true") extends EventType
case class Change(agentType: String, field: String, filter: String = "true") extends EventType

sealed trait Observable

case class AgentCount(agentType: String, filter: String) extends Observable
case class Expression(expression: String) extends Observable
case class ExpressionDistribution(agentType: String, expression: String, filter:String) extends Observable