package com.xebialabs.xlrelease.test

import com.xebialabs.deployit.plugin.api.reflect.Type
import com.xebialabs.xlrelease.domain.events._
import com.xebialabs.xlrelease.domain.{ParallelGroup, Release, Task}
import com.xebialabs.xlrelease.utils.ConditionBuilder
import com.xebialabs.xlrelease.utils.ConditionBuilder.DEFAULT_TIMEOUT
import com.xebialabs.xlrelease.utils.MatchingAwaiter._
import org.hamcrest.Matcher

import scala.annotation.varargs

package object async {

  def failureHandlerInProgress(task:Task): MATCHER = { case TaskRecoveryStartedEvent(t) if t == task => MatchFound()}
  def taskAbortScriptInProgress(task:Task): MATCHER = { case TaskAbortScriptStartedEvent(t) if t == task => MatchFound()}
  def taskAbortScriptCompleted(task:Task): MATCHER = { case TaskAbortScriptCompletedEvent(t) if t == task => MatchFound()}
  def taskSkipped(task:Task): MATCHER = { case TaskSkippedEvent(t, _) if t == task => MatchFound()}
  def taskSkipped(taskId:String): MATCHER = { case TaskSkippedEvent(t, _) if t.getId == taskId => MatchFound()}
  def taskRecovered(task:Task): MATCHER = { case TaskRecoveredEvent(t) if t == task => MatchFound()}
  def taskStarted(task:Task): MATCHER = { case TaskStartedEvent(t) if t == task => MatchFound()}
  def taskStarted(taskId:String): MATCHER = { case TaskStartedEvent(t) if t.getId == taskId => MatchFound()}
  def taskDelayed(taskId:String): MATCHER = { case TaskDelayedEvent(t) if t.getId == taskId => MatchFound()}
  def taskCompleted(task: Task): MATCHER = { case TaskCompletedEvent(t, _) if t == task => MatchFound()}
  def taskCompleted(taskId: String): MATCHER = { case TaskCompletedEvent(t, _) if t.getId == taskId => MatchFound()}
  def taskFailed(task: Task): MATCHER = { case TaskFailedEvent(t, _) if t == task => MatchFound()}
  def taskFailed(taskId: String): MATCHER = { case TaskFailedEvent(t, _) if t.getId == taskId => MatchFound()}
  def taskRetried(task: Task): MATCHER = { case TaskRetriedEvent(t) if t == task => MatchFound()}
  def taskRetried(taskId: String): MATCHER = { case TaskRetriedEvent(t) if t.getId == taskId => MatchFound()}
  def taskAborted(task: Task): MATCHER = { case TaskAbortedEvent(t) if t == task => MatchFound()}
  def taskAborted(taskId: String): MATCHER = { case TaskAbortedEvent(t) if t.getId == taskId => MatchFound()}
  def taskTypeChanged(task: Task, newType: Type): MATCHER = {
    case TaskUpdatedEvent(original, updated) if original.getId == task.getId && updated.getTaskType == newType => MatchFound()
  }
  def pgTaskStarted (pg: ParallelGroup): MATCHER = { case TaskStartedEvent(t) if t.getContainer == pg => MatchFound() }

  def phaseStarted(phaseId:String): MATCHER = { case PhaseStartedEvent(phase) if phase.getId == phaseId => MatchFound()}
  def phaseStarted(matcher: Matcher[_]): MATCHER = { case PhaseStartedEvent(phase) if matcher.matches(phase) => MatchFound()}
  def phaseClosed(phaseId:String): MATCHER = { case PhaseClosedEvent(phase) if phase.getId == phaseId => MatchFound()}
  def releaseStarted(release: Release): MATCHER = { case ReleaseStartedEvent(r, _) if r == release => MatchFound()}
  def releaseFailed(release: Release): MATCHER = { case ReleaseFailedEvent(r) if r == release => MatchFound()}
  def releaseStartedFailing(release: Release): MATCHER = { case ReleaseStartedFailingEvent(r) if r == release => MatchFound()}
  def releaseAborted(release: Release): MATCHER = { case ReleaseAbortedEvent(r, _) if r == release => MatchFound()}
  def releaseAborted(releaseId: String): MATCHER = { case ReleaseAbortedEvent(r, _) if r.getId == releaseId => MatchFound()}
  def releaseRetried(release: Release): MATCHER = { case ReleaseRetriedEvent(r) if r == release => MatchFound()}
  def releaseStarted(releaseId: String): MATCHER = { case ReleaseStartedEvent(r, _) if r.getId == releaseId => MatchFound()}
  def releaseCompleted(release: Release): MATCHER = { case ReleaseCompletedEvent(r) if r == release => MatchFound()}
  def releaseCompleted(releaseId: String): MATCHER = { case ReleaseCompletedEvent(r) if r.getId == releaseId => MatchFound()}
  def releaseResumed(release: Release): MATCHER = { case ReleaseResumedEvent(r) if r == release => MatchFound()}
  def releaseResumed(releaseId: String): MATCHER = { case ReleaseResumedEvent(r) if r.getId == releaseId => MatchFound()}
  def releasePaused(releaseId: String): MATCHER = { case ReleasePausedEvent(r) if r.getId == releaseId => MatchFound()}
  def commentCreatedOrUpdated(taskId: String, comment: String): MATCHER = {
    case ev: CommentCreatedEvent if ev.task.getId == taskId && ev.comment.getText.contains(comment) => MatchFound()
    case ev: CommentUpdatedEvent if ev.task.getId == taskId && ev.updated.getText.contains(comment) => MatchFound()
  }

  def ignoreEvent: MATCHER = {
    case _ => MatchIgnore()
  }

  implicit class ConditionBuilderExtension[R](builder: ConditionBuilder[R]) {

    private def builderContext = {
      ConditionBuilderContext[R](timeout = DEFAULT_TIMEOUT, builder)
    }
    @varargs
    def waitFor(matchers: MATCHER*): R = {
      builderContext.waitFor(matchers: _*)
    }

    @varargs
    def waitForAnyInAnyOrder(matchers: MATCHER*): R = {
      builderContext.waitForAnyInAnyOrder(matchers: _*)
    }

    @varargs
    def waitForAllInAnyOrder(matchers: MATCHER*): R = {
      builderContext.waitForAllInAnyOrder(matchers: _*)
    }

    def waitForStrict(matchers: MATCHER*): R = {
      builder.until(matchers: _*)
    }

    def withTimeout(timeout: Long): ConditionBuilderContext[R] = {
      ConditionBuilderContext[R](timeout, builder)
    }
  }

  case class ConditionBuilderContext[T](timeout: Long, builder: ConditionBuilder[T]) {
    @varargs
    def waitFor(matchers: MATCHER*): T = {
      builder.until(timeout, matchers.map(_.orElse(ignoreEvent)): _*)
    }

    @varargs
    def waitForAnyInAnyOrder(matchers: MATCHER*): T = {
      waitFor(anyOrder(matchers: _*))
    }

    @varargs
    def waitForAllInAnyOrder(matchers: MATCHER*): T = {
      waitFor(allOfInAnyOrder(matchers: _*): _*)
    }
  }



}
