package com.xebialabs.xlrelease.scheduler.strategies

import com.xebialabs.xlrelease.domain.distributed.events.DistributedReleaseExecutedEvent
import com.xebialabs.xlrelease.events.AsyncSubscribe
import com.xebialabs.xlrelease.repository.Ids
import com.xebialabs.xlrelease.runner.domain.JobId
import com.xebialabs.xlrelease.scheduler.repository.JobRepository
import com.xebialabs.xlrelease.scheduler.strategies.JobSchedulerStrategy.ScheduleResult
import com.xebialabs.xlrelease.scheduler.{Job, TaskJob}
import com.xebialabs.xlrelease.support.pekko.spring.ScalaSpringAwareBean
import grizzled.slf4j.Logging

import java.util
import java.util.concurrent.{ConcurrentHashMap, Semaphore}


object LimitParallelJobSchedulerStrategy extends JobSchedulerStrategy[LimitParallelJobSchedulerStrategySettings]
  with ScalaSpringAwareBean with Logging {

  private val semaphores: util.Map[String, Semaphore] = new ConcurrentHashMap[String, Semaphore]()

  lazy val jobRepository: JobRepository = springBean[JobRepository]

  override def schedule(configuration: LimitParallelJobSchedulerStrategySettings)(job: Job): ScheduleResult = {
    job match {
      case job: TaskJob[_] =>
        handleJob(configuration, job)
      case j => Right(j)
    }
  }

  override def unlock(jobId: JobId, maybeTaskId: Option[String]): Unit = {
    maybeTaskId match {
      case Some(taskId) =>
        val releaseId = getReleaseId(taskId)
        logger.trace(s"Releasing permit for Release [$releaseId]")
        if (semaphores.containsKey(releaseId)) {
          val releaseSemaphore = semaphores.get(releaseId)
          releaseSemaphore.release()
        }
      case None => ()
    }
  }

  @AsyncSubscribe
  def onReleaseExecutedEvent(event: DistributedReleaseExecutedEvent): Unit = {
    logger.trace(s"Received DistributedReleaseExecutedEvent event for Release [${event.releaseId}]. Going to clear all permits for the release.")
    val rId = Ids.getName(event.releaseId)
    semaphores.remove(rId)
  }

  private def handleJob(configuration: LimitParallelJobSchedulerStrategySettings, job: TaskJob[_]) = {
    val jobShouldBeDelayed = NonBlockingBackpressuredJobSchedulerStrategy.shouldBeDelayed(configuration, job.taskId, System.currentTimeMillis())

    if (jobShouldBeDelayed || isMaxJobThresholdReached(configuration, job)) {
      val delay = configuration.delayDuration
      job.delay(delay)
      Left(job)
    } else {
      Right(job)
    }
  }

  private def isMaxJobThresholdReached(configuration: LimitParallelJobSchedulerStrategySettings,
                                       job: TaskJob[_]): Boolean = {
    val releaseId = getReleaseId(job.taskId)
    val maxJobThreshold = configuration.maxJobs
    val releaseSemaphore = semaphores.computeIfAbsent(releaseId, _ => new Semaphore(maxJobThreshold))
    val permitAvailable = releaseSemaphore.tryAcquire()
    logger.trace(s"PermitAvailable is $permitAvailable for JobId [${job.id}] with Release [$releaseId]")
    !permitAvailable
  }

  private def getReleaseId(taskId: String): String = Ids.getName(Ids.releaseIdFrom(taskId))

}
