package com.xebialabs.deployit.engine.tasker

import java.io._
import java.util
import javax.annotation.{PostConstruct, PreDestroy}

import akka.actor.ActorDSL._
import akka.actor._
import akka.pattern._
import akka.util.Timeout
import com.xebialabs.deployit.engine.api.execution.TaskExecutionState
import com.xebialabs.deployit.engine.spi.services.RepositoryFactory
import com.xebialabs.deployit.engine.tasker.ArchivedListeningActor.Forward
import com.xebialabs.deployit.engine.tasker.TaskManagingActor.messages.{ArchiveTask, Cancel, Recovered, Register, Schedule}
import com.xebialabs.deployit.engine.tasker.messages.{Abort, Archived, Cancelled, Stop, _}
import grizzled.slf4j.Logging

import scala.collection.convert.wrapAll._
import scala.concurrent.duration.{Duration, _}
import scala.concurrent.{Await, Promise}
import scala.language.postfixOps
import scala.util.{Failure, Success, Try}

class TaskExecutionEngine(archive: Archive, repositoryFactory: RepositoryFactory, implicit val system: ActorSystem) extends IEngine with Logging {
  def this(archive: Archive, repositoryFactory: RepositoryFactory) = this(archive, repositoryFactory: RepositoryFactory, ActorSystem("TaskExecutionEngine"))

  // Register the ArchiveActor
  val archiver = system.actorOf(ArchiveActor.props(archive), ArchiveActor.name)

  @PreDestroy
  def shutdownTasks() {
    import scala.concurrent.duration._

    TaskRegistryExtension(system).getTasks.filter(t => Set(TaskExecutionState.ABORTING, TaskExecutionState.EXECUTING, TaskExecutionState.FAILING, TaskExecutionState.STOPPING).contains(t.getState)).foreach {
      t =>
        info(s"Stopping task [${t.getId}] due to shutdown")
        lookupTaskActor(t.getId) ! Stop
        val promise = Promise[String]()
        actor(new Act {
          become {
            case TaskDone(doneTask) if doneTask.getId == t.getId => promise.success(doneTask.getId)
          }
        })
        Try(Await.result(promise.future, 10 seconds)) match {
          case Failure(exception) => warn(s"Failed to stop task [${t.getId}]", exception)
          case Success(value) => info(s"Successfully stopped task [$value]")
        }
    }

    system.shutdown()
  }

  @PostConstruct
  def recoverTasks() {
    val taskFiles: Array[File] = TaskerSettings(system).recoveryDir.listFiles(new FileFilter {
      def accept(p1: File): Boolean = p1.getName.endsWith(".task")
    })
    Option(taskFiles).getOrElse(Array[File]()).foreach(recover)
  }

  def recover(file: File) {
    Task(file) match {
      case Some(t) =>
        t.context.repository = repositoryFactory.create(t.getTempWorkDir)
        implicit val timeout = new Timeout(1 second)
        Await.ready(createTaskActor(t) ? Recovered(t), timeout.duration)
      case None =>
    }
  }

  def archive(taskid: String) {
    archiveWith(taskid, ArchiveTask.apply)
  }

  def archiveWith(taskid: TaskId, msg: (TaskId, ActorRef, ActorRef) => AnyRef): Unit = {
    val p = Promise[TaskId]()
    val listener: ActorRef = system.actorOf(ArchivedListeningActor.props(taskid, p))
    val message: AnyRef = msg(taskid, archiver, listener)

    val taskActor: ActorSelection = lookupTaskActor(taskid)
    listener ! Forward(taskActor, message)
    Await.ready(p.future, Duration.Inf)
    system.stop(listener)
    p.future.value.get match {
      case Failure(ex) => throw ex
      case _ =>
    }
  }

  def addPauseStep(taskid: String, position: BlockPath) {
    retrieve(taskid).addPause(position)
  }

  def addPauseStep(taskid: String, position: Integer) {
    retrieve(taskid).addPause(position)
  }

  def moveStep(taskid: String, stepNr: Int, newPosition: Int) {
    retrieve(taskid).moveStep(stepNr, newPosition)
  }

  def unskipSteps(taskid: String, stepNrs: util.List[Integer]) {
    retrieve(taskid).unskip(stepNrs.map(_.intValue()).toList)
  }

  def skipSteps(taskid: String, stepNrs: util.List[Integer]) {
    retrieve(taskid).skip(stepNrs.map(_.intValue()).toList)
  }

  def unskipStepPaths(taskid: String, stepNrs: util.List[BlockPath]) {
    retrieve(taskid).unskipPaths(stepNrs.toList)
  }

  def skipStepPaths(taskid: String, stepNrs: util.List[BlockPath]) {
    retrieve(taskid).skipPaths(stepNrs.toList)
  }

  def cancel(taskid: String) {
    archiveWith(taskid, Cancel.apply)
  }

  def stop(taskid: String) {
    lookupTaskActor(taskid) ! Stop(taskid)
  }

  def abort(taskid: String) {
    lookupTaskActor(taskid) ! Abort(taskid)
  }

  def execute(taskid: String) {
    lookupTaskActor(taskid) ! Enqueue(taskid)
  }

  def schedule(taskid: String, scheduleAt: com.github.nscala_time.time.Imports.DateTime): Unit = {
    import com.github.nscala_time.time.Imports._
    def p(d: DateTime) = d.toString("yyyy-MM-dd HH:mm:ss Z")
    if (scheduleAt.isBeforeNow) {
      throw new TaskerException(s"Cannot schedule a task for the past, date entered was [${p(scheduleAt)}, now is [${p(DateTime.now)}]")
    }
    val delayMillis: Long = (DateTime.now to scheduleAt).millis
    val tickMillis: Long = TaskerSettings(system).tickDuration
    if (delayMillis > Int.MaxValue.toLong * tickMillis) {
      throw new TaskerException(s"Cannot schedule task [$taskid] at [${p(scheduleAt)}], because it is too far into the future. Can only schedule to [${p(new DateTime(DateTime.now.millis + (tickMillis * Int.MaxValue)))}]")
    }
    lookupTaskActor(taskid) ! Schedule(taskid, scheduleAt)
  }

  def register(spec: TaskSpecification): String = {
    val task: Task = new Task(spec.getId, spec)
    task.context.repository = repositoryFactory.create(task.getTempWorkDir)
    implicit val timeout = new Timeout(1 second)
    Await.ready(createTaskActor(task) ? Register(task), timeout.duration)
    task.getId
  }

  def retrieve(taskid: String): Task = TaskRegistryExtension(system).getTask(taskid).getOrElse(throw new TaskNotFoundException("registry", taskid))

  def getAllIncompleteTasks: util.List[Task] = TaskRegistryExtension(system).getTasks

  protected[tasker] def lookupTaskActor(s: String): ActorSelection = system.actorSelection(system.child(s))

  protected[tasker] def createTaskActor(task: Task) = system.actorOf(TaskManagingActor.props, task.getId)

  def getSystem: ActorSystem = system

}

object ArchivedListeningActor {
  def props(taskId: TaskId, promise: Promise[TaskId]) = Props(classOf[ArchivedListeningActor], taskId, promise)
  case class Forward(actor: ActorSelection, message: AnyRef)
}

class ArchivedListeningActor(taskId: TaskId, promise: Promise[TaskId]) extends Actor {
  import context._

  def receive: Actor.Receive = {
    case Forward(actorSelection, message) =>
      become(identifyActor(message))
      actorSelection ! Identify("")
  }

  def identifyActor(originalMessage: AnyRef): Actor.Receive = {
    case ActorIdentity(_, Some(actorRef)) =>
      watch(actorRef)
      system.eventStream.subscribe(self, classOf[DeadLetter])
      become(await(actorRef))
      actorRef ! originalMessage
    case _ =>
      promise.failure(new TaskNotFoundException("akka system", taskId))
  }

  def await(actorRef: ActorRef): Actor.Receive = {
    case Terminated(`actorRef`) | DeadLetter(_, `self`, `actorRef`) =>
      promise.tryFailure(new TaskerException(s"Task $taskId was terminated by a different process"))
    case FailedToArchive(`taskId`, exception) =>
      forget(actorRef)
      promise.tryFailure(new TaskerException(exception, s"Task [$taskId] failed to archive"))
    case akka.actor.Status.Failure(exception) =>
      forget(actorRef)
      promise.tryFailure(new TaskerException(exception, s"Task [$taskId] failed to archive"))
    case Archived(`taskId`) | Cancelled(`taskId`) =>
      forget(actorRef)
      promise.trySuccess(taskId)
  }

  private def forget(actorRef: ActorRef): Unit = {
    unwatch(actorRef)
    system.eventStream.unsubscribe(self)
  }
}
