/*
 * Copyright 2017 University of Rostock
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package sessl.ml3

import java.util.function.Consumer

import org.jamesii.core.util.misc.Pair
import org.jamesii.ml3.experiment.Job
import org.jamesii.ml3.model.values.IValue
import org.jamesii.ml3.observation._
import org.jamesii.ml3.simulator.simulators.ISimulator
import sessl._
import sessl.util.SimpleObservation

import scala.collection.JavaConverters._
import scala.collection.mutable

/**
  * @author Tom Warnke
  */
trait Observation extends SimpleObservation {
  this: Experiment =>

  // observers

  private[this] val observedEvents = scala.collection.mutable.Set.empty[EventType]

  def observeAt(eventType: EventType): Unit = observedEvents += eventType

  // listeners

  type ListenerGenerator = (Int, Job) => (IListener)

  val listenerGenerators: mutable.Set[ListenerGenerator] = mutable.Set()

  private[this] val agentCounts = scala.collection.mutable.Set.empty[(String, String)]

  private[this] val expressions = scala.collection.mutable.Set.empty[String]

  def agentCount(agentType: String, filter: String = "true"): String = {
    agentCounts += ((agentType, filter))
    agentType+filter
  }

  def expression(expression: String): String = {
    expressions += expression
    expression
  }

  /** Setting the flags that control a proper call hierarchy, calling event handlers if installed. */
  override def configure(): Unit = {
    super.configure()

    val countListeners = agentCounts.map {
      (agentType) => (runId: Int, job: Job) => createAgentCountListener(runId, agentType._1, agentType._2, job)
    }

    val expressionListeners = expressions.map {
      (exp) => (runId: Int,  job: Job) => createExpressionListener(runId, exp, job)
    }

    listenerGenerators ++= countListeners ++ expressionListeners

    for (observers <- createObservers) {

      // a function to create a new observer for a run id
      val generateObserver = (runId: Int, job: Job, simulator: ISimulator) => {

        val observer = observers(job)

        for (listenerGenerator <- listenerGenerators)
          observer.registerListener(listenerGenerator(runId, job))

        simulator.addObserver(observer)
      }

      // add the defined function to the observer generators
      instrumentations += generateObserver
    }
  }

  /** create observer according to configured observation */
  private[this] def createObservers: Set[Job => IObserver] = {
    if (observationTimes != Nil) {
      Set((_: Job) => createTimePointListObserver)
    } else if (observedEvents.nonEmpty) {
      // create a function to create a observer for each agent type
      observedEvents.map {
        case Creation(agentType, filter) => (job: Job) =>
          new AgentCreationObserver(agentType, filter, job.getModel, job.getParameters)
        case Death(agentType, filter) => (job: Job) => new AgentDeathObserver(agentType, filter, job.getModel, job.getParameters)
        case Change(agentType, field, filter) => (job: Job) => new AgentChangeObserver(agentType, field, filter, job.getModel, job.getParameters)
      }.toSet[Job => IObserver]
    } else {
      Set.empty[Job => IObserver]
    }
  }

  // wrap observer creation
  private[this] def createTimePointListObserver : IObserver = {
    val jTimes = observationTimes.map(d => d.asInstanceOf[java.lang.Double]).asJava
    new TimePointListObserver(jTimes)
  }

  // wrap listener creation
  private[this] def createAgentCountListener(runId: Int, agentType: String, filter: String, job: Job):
  AgentCountListener = {
    val callback = new Consumer[Pair[java.lang.Double, Integer]] {
      override def accept(t: Pair[java.lang.Double, Integer]): Unit = {
        addValueFor(runId, agentType+filter, (t.getFirstValue, t.getSecondValue).asInstanceOf[TimeStampedData])
      }
    }
    new AgentCountListener(agentType, job.getModel, job.getParameters, filter, callback)
  }

  private[this] def createExpressionListener(runId: Int, expression: String, job: Job):
  AgentExpressionListener = {
    val callback = new Consumer[Pair[java.lang.Double, IValue]] {
      override def accept(t: Pair[java.lang.Double, IValue]): Unit = {
        addValueFor(runId, expression, (t.getFirstValue, t.getSecondValue).asInstanceOf[TimeStampedData])
      }
    }
    new AgentExpressionListener(job.getModel, job.getParameters, expression, callback)
  }

}

sealed trait EventType

case class Creation(agentType: String, filter: String = "true") extends EventType
case class Death(agentType: String, filter: String = "true") extends EventType
case class Change(agentType: String, field: String, filter: String = "true") extends EventType