package com.xebialabs.xlrelease.scheduler.sql

import com.xebialabs.xlrelease.db.sql.SqlBuilder.Dialect
import com.xebialabs.xlrelease.db.sql.transaction.{IsReadOnly, IsTransactional}
import com.xebialabs.xlrelease.repository.sql.persistence.Schema._
import com.xebialabs.xlrelease.repository.sql.persistence.TaskPersistence.hash
import com.xebialabs.xlrelease.repository.sql.persistence.Utils.params
import com.xebialabs.xlrelease.repository.sql.persistence.{PersistenceSupport, Utils}
import com.xebialabs.xlrelease.runner.domain.JobId
import com.xebialabs.xlrelease.scheduler._
import com.xebialabs.xlrelease.scheduler.filters.JobFilters
import com.xebialabs.xlrelease.scheduler.repository._
import com.xebialabs.xlrelease.scheduler.sql.SqlJobRepository.InstantHelper
import grizzled.slf4j.Logging
import org.springframework.dao.OptimisticLockingFailureException
import org.springframework.data.domain.{Page, Pageable}
import org.springframework.jdbc.core.JdbcTemplate
import org.springframework.jdbc.support.GeneratedKeyHolder

import java.sql.Timestamp
import java.time.Instant
import java.time.temporal.ChronoUnit
import scala.util.Try

class SqlJobRepository(val jdbcTemplate: JdbcTemplate, val dialect: Dialect)
  extends JobRepository
    with PersistenceSupport
    with JobRowMapper
    with Utils
    with Logging {

  override def create(job: JobRow): JobRow = {
    logger.trace(s"creating job $job")
    try {
      createTaskJob(job)
    } catch {
      case _: Exception =>
        replace(job)
    }
  }

  @IsTransactional
  override def replace(job: JobRow): JobRow = {
    logger.trace(s"replacing job $job")
    delete(DeleteByTaskId(job.taskId))
    createTaskJob(job)
  }

  private val STMT_INSERT_JOB =
    s"""INSERT INTO ${TASK_JOBS.TABLE} (
       |   ${TASK_JOBS.TASK_ID},
       |   ${TASK_JOBS.TASK_ID_HASH},
       |   ${TASK_JOBS.RELEASE_UID},
       |   ${TASK_JOBS.JOB_TYPE},
       |   ${TASK_JOBS.STATUS},
       |   ${TASK_JOBS.VERSION},
       |   ${TASK_JOBS.SUBMIT_TIME},
       |   ${TASK_JOBS.RESERVATION_TIME},
       |   ${TASK_JOBS.SCHEDULED_START_TIME},
       |   ${TASK_JOBS.EXECUTION_ID},
       |   ${TASK_JOBS.NODE},
       |   ${TASK_JOBS.RUNNER_ID}
       | ) VALUES (
       | :${TASK_JOBS.TASK_ID},
       | :${TASK_JOBS.TASK_ID_HASH},
       | :${TASK_JOBS.RELEASE_UID},
       | :${TASK_JOBS.JOB_TYPE},
       | :${TASK_JOBS.STATUS},
       | :${TASK_JOBS.VERSION},
       | :${TASK_JOBS.SUBMIT_TIME},
       | :${TASK_JOBS.RESERVATION_TIME},
       | :${TASK_JOBS.SCHEDULED_START_TIME},
       | :${TASK_JOBS.EXECUTION_ID},
       | :${TASK_JOBS.NODE},
       | :${TASK_JOBS.RUNNER_ID}
       | )
       |""".stripMargin

  private def createTaskJob(job: JobRow): JobRow = {
    val params = Map(
      TASK_JOBS.TASK_ID -> job.taskId,
      TASK_JOBS.TASK_ID_HASH -> hash(job.taskId),
      TASK_JOBS.RELEASE_UID -> job.releaseUid,
      TASK_JOBS.JOB_TYPE -> job.jobType.name(),
      TASK_JOBS.STATUS -> job.status.name(),
      TASK_JOBS.VERSION -> job.version,
      TASK_JOBS.SUBMIT_TIME -> job.submitTime.toTimestamp(),
      TASK_JOBS.RESERVATION_TIME -> job.reservationTime.toTimestamp(),
      TASK_JOBS.SCHEDULED_START_TIME -> job.scheduledStartTime.toTimestamp(),
      TASK_JOBS.EXECUTION_ID -> job.executionId,
      TASK_JOBS.NODE -> job.node,
      TASK_JOBS.RUNNER_ID -> job.runnerId
    )
    val holder = new GeneratedKeyHolder
    namedTemplate.update(STMT_INSERT_JOB, params, holder, Array(pkName(TASK_JOBS.ID)))
    job.copy(id = holder.getKey.longValue())
  }

  private val PREVIOUS_VERSION_PARAMETER_NAME = "PREVIOUS_VERSION"
  private val STMT_LOCK_JOB =
    s"""UPDATE ${TASK_JOBS.TABLE}
       | SET
       |   ${TASK_JOBS.STATUS} = :${TASK_JOBS.STATUS},
       |   ${TASK_JOBS.VERSION} = :${TASK_JOBS.VERSION},
       |   ${TASK_JOBS.NODE} = :${TASK_JOBS.NODE},
       |   ${TASK_JOBS.RUNNER_ID} = :${TASK_JOBS.RUNNER_ID},
       |   ${TASK_JOBS.RESERVATION_TIME} = :${TASK_JOBS.RESERVATION_TIME},
       |   ${TASK_JOBS.START_TIME} = :${TASK_JOBS.START_TIME}
       | WHERE ${TASK_JOBS.ID} = :${TASK_JOBS.ID} AND ${TASK_JOBS.VERSION} = :$PREVIOUS_VERSION_PARAMETER_NAME
       |""".stripMargin

  private val STMT_FIND_TASK_IDS_BY_RELEASE_UID =
    s"""
       |SELECT
       | ${TASK_JOBS.TASK_ID}
       |FROM ${TASK_JOBS.TABLE}
       |WHERE
       | ${TASK_JOBS.RELEASE_UID} = :${TASK_JOBS.RELEASE_UID} AND
       | ${TASK_JOBS.STATUS} = :${TASK_JOBS.STATUS}
       |""".stripMargin

  @IsReadOnly
  override def findQueuedTaskIdsByReleaseUid(releaseUid: Integer): Seq[String] = {
    findMany(sqlQuery(STMT_FIND_TASK_IDS_BY_RELEASE_UID,
      params(TASK_JOBS.RELEASE_UID -> releaseUid, TASK_JOBS.STATUS -> JobStatus.QUEUED.name()),
      rs => rs.getString(TASK_JOBS.TASK_ID)))
  }

  private val STMT_FIND_VERSION_BY_ID =
    s"""
       |SELECT
       | ${TASK_JOBS.VERSION}
       |FROM ${TASK_JOBS.TABLE}
       |WHERE
       | ${TASK_JOBS.ID} = :${TASK_JOBS.ID}
       |""".stripMargin

  private def findVersionById(id: Long): Option[Integer] = {
    findOne(sqlQuery(STMT_FIND_VERSION_BY_ID,
      params(TASK_JOBS.ID -> id),
      rs => rs.getInt(TASK_JOBS.VERSION).asInstanceOf[Integer]))
  }

  private val SELECT_JOB_FIELDS =
    s"""
       |SELECT
       | ${TASK_JOBS.ID},
       | ${TASK_JOBS.TASK_ID},
       | ${TASK_JOBS.RELEASE_UID},
       | ${TASK_JOBS.EXECUTION_ID},
       | ${TASK_JOBS.JOB_TYPE},
       | ${TASK_JOBS.STATUS},
       | ${TASK_JOBS.VERSION},
       | ${TASK_JOBS.SUBMIT_TIME},
       | ${TASK_JOBS.RESERVATION_TIME},
       | ${TASK_JOBS.SCHEDULED_START_TIME},
       | ${TASK_JOBS.START_TIME},
       | ${TASK_JOBS.NODE},
       | ${TASK_JOBS.RUNNER_ID}
       |""".stripMargin


  private val STMT_FIND_JOB_BY_TASK_ID =
    s"""
       |$SELECT_JOB_FIELDS
       |FROM ${TASK_JOBS.TABLE}
       |WHERE
       | ${TASK_JOBS.TASK_ID_HASH} = :${TASK_JOBS.TASK_ID_HASH}
       |""".stripMargin

  @IsReadOnly
  override def findByTaskId(taskId: String): Option[JobRow] = {
    findOne(sqlQuery(STMT_FIND_JOB_BY_TASK_ID,
      params(TASK_JOBS.TASK_ID_HASH -> hash(taskId)),
      jobMapper
    ))
  }

  private val STMT_FIND_JOB_BY_JOB_ID =
    s"""
       |$SELECT_JOB_FIELDS
       |FROM ${TASK_JOBS.TABLE}
       |WHERE
       | ${TASK_JOBS.ID} = :${TASK_JOBS.ID}
       |""".stripMargin

  override def read(jobId: JobId): Option[JobRow] = {
    findOne(sqlQuery(STMT_FIND_JOB_BY_JOB_ID,
      params(TASK_JOBS.ID -> jobId),
      jobMapper
    ))
  }


  @IsReadOnly
  override def findAll(jobFilters: JobFilters, pageable: Pageable): Page[JobRow] = {
    SqlJobFiltersQueryBuilder(dialect, namedTemplate)
      .from(jobFilters)
      .withPageable(pageable)
      .build()
      .execute()
  }

  @IsReadOnly
  override def findAllJobOverview(jobFilters: JobFilters, pageable: Pageable): Page[JobOverview] = {
    SqlJobOverviewQueryBuilder(dialect, namedTemplate)
      .from(jobFilters)
      .withPageable(pageable)
      .build()
      .execute()
  }

  private val STMT_FIND_DISTINCT_NODE_IDS =
    s"""
       |SELECT DISTINCT
       | ${TASK_JOBS.NODE}
       |FROM ${TASK_JOBS.TABLE}
       |""".stripMargin

  @IsReadOnly
  override def findDistinctNodeIds(): Set[String] = {
    findMany(sqlQuery(STMT_FIND_DISTINCT_NODE_IDS,
      params(),
      rs => rs.getString(TASK_JOBS.NODE)))
      .toSet
  }

  private val STMT_DELETE_JOB =
    s"""DELETE FROM ${TASK_JOBS.TABLE}
       | WHERE ${TASK_JOBS.ID} = :${TASK_JOBS.ID}
       |""".stripMargin

  private val STMT_DELETE_BY_TASK_ID_HASH =
    s"""DELETE FROM ${TASK_JOBS.TABLE}
       | WHERE ${TASK_JOBS.TASK_ID_HASH} = :${TASK_JOBS.TASK_ID_HASH}
       |""".stripMargin

  private val STMT_DELETE_BY_TASK_ID_HASH_AND_EXECUTION_ID =
    s"""DELETE FROM ${TASK_JOBS.TABLE}
       | WHERE ${TASK_JOBS.TASK_ID_HASH} = :${TASK_JOBS.TASK_ID_HASH}
       | AND ${TASK_JOBS.EXECUTION_ID} = :${TASK_JOBS.EXECUTION_ID}
       |""".stripMargin

  @IsTransactional
  override def delete(deleteJob: DeleteJob): Unit = {
    deleteJob match {
      case DeleteById(id) =>
        logger.trace(s"deleting job with jobId $id")
        val deletedRows = sqlUpdate(STMT_DELETE_JOB, params(TASK_JOBS.ID -> id), rows => rows)
        if (deletedRows != 1) {
          // this is normal situation, when job was superseeded by one of:
          // * abort script
          // * task.schedule
          // In these cases new job with same task ID is created BEFORE old one is completed,
          // so, new job, when being created causes key collision and old job is deleted.
          // This situation should not cause consistency problems, because replacement happens inside
          // a transaction and task is completed anyway, results are already saved in the release.
          logger.trace(s"$deletedRows rows deleted when trying to delete scheduled job with id $id")
        }
      case DeleteByTaskId(taskId) =>
        logger.trace(s"deleting job with taskId $taskId")
        sqlExec(STMT_DELETE_BY_TASK_ID_HASH,
          params(TASK_JOBS.TASK_ID_HASH -> hash(taskId)),
          _.execute()
        )
      case DeleteByTaskIdAndExecutionId(taskId, executionId) =>
        logger.trace(s"deleting job with taskId $taskId and executionId $executionId")
        sqlExec(STMT_DELETE_BY_TASK_ID_HASH_AND_EXECUTION_ID,
          params(
            TASK_JOBS.TASK_ID_HASH -> hash(taskId),
            TASK_JOBS.EXECUTION_ID -> executionId
          ),
          _.execute()
        )
      case _ =>
        throw new UnsupportedOperationException(s"$deleteJob operation is not supported")
    }
  }

  @IsTransactional
  override def update(updateJob: UpdateJob): Try[JobRow] = {
    Try {
      val newJobRow = updateJob match {
        case ReserveJob(job, nodeId, runnerId) =>
          logger.trace(s"reserving job $job")
          val reservationTime = Instant.now().truncatedTo(ChronoUnit.MILLIS)
          job.copy(version = job.version + 1, status = JobStatus.RESERVED, node = nodeId, runnerId = runnerId, reservationTime = reservationTime)
        case ConfirmJobExecution(job, runnerId) =>
          logger.trace(s"confirming job $job execution on $runnerId")
          val startTime = Instant.now().truncatedTo(ChronoUnit.MILLIS)
          val allowedStates = Seq(JobStatus.RESERVED, JobStatus.RUNNING)
          if (runnerId != job.runnerId || !allowedStates.contains(job.status)) {
            throw new IllegalStateException(s"Job [$job] does not have runner id set to [$runnerId] or is not in [$allowedStates] states")
          }
          job.copy(version = job.version + 1, status = JobStatus.RUNNING, runnerId = runnerId, startTime = startTime)
        case UpdateNode(job, nodeId) =>
          logger.trace(s"changing job $job node to $nodeId")
          job.copy(version = job.version + 1, node = nodeId)
        case UpdateJobStatus(job, jobStatus) =>
          logger.trace(s"changing status of job $job to $jobStatus")
          job.copy(version = job.version + 1, status = jobStatus)
        case UpdateNodeAndStatus(job, nodeId, jobStatus) =>
          logger.trace(s"changing job $job node to $nodeId and status to $jobStatus")
          job.copy(version = job.version + 1, node = nodeId, status = jobStatus)
      }
      updateWithVersionCheck(newJobRow)
    }
  }

  private def updateWithVersionCheck(job: JobRow): JobRow = {
    logger.trace(s"going to update with version check to $job")
    val previousVersion = job.version - 1
    val params = Map(
      TASK_JOBS.STATUS -> job.status.name(),
      TASK_JOBS.VERSION -> job.version,
      TASK_JOBS.NODE -> job.node,
      TASK_JOBS.RUNNER_ID -> job.runnerId,
      TASK_JOBS.RESERVATION_TIME -> Option(job.reservationTime).map(Timestamp.from).orNull,
      TASK_JOBS.START_TIME -> Option(job.startTime).map(Timestamp.from).orNull,
      TASK_JOBS.ID -> job.id,
      PREVIOUS_VERSION_PARAMETER_NAME -> previousVersion
    )
    sqlUpdate(STMT_LOCK_JOB, params, handleOptimisticUpdate(job, Some(() => {
      findVersionById(job.id).map(version => s"current version is $version, expected version $previousVersion").getOrElse("record not found")
    })))
    job
  }

  private def handleOptimisticUpdate(job: JobRow, logRecoveryFn: Option[() => String]): PartialFunction[Int, Unit] = {
    case 0 =>
      val msg = s"Unable to update Job row $job ${logRecoveryFn.map(fn => s", ${fn.apply()}").getOrElse("")}"
      logger.trace(msg)
      throw new OptimisticLockingFailureException(msg)
    case 1 =>
      logger.trace(s"Successfully updated job to $job")
    case _ =>
      val msg = s"More than one row updated for job ${job.id} ${logRecoveryFn.map(fn => s", ${fn.apply()}").getOrElse("")}"
      logger.trace(msg)
      throw new IllegalStateException(msg)
  }
}

object SqlJobRepository {
  implicit class InstantHelper(instant: Instant) {
    def toTimestamp(): Timestamp = {
      if (instant == null) {
        null
      } else {
        Timestamp.from(instant)
      }
    }
  }
}
