package com.xebialabs.xlrelease.repository.sql.persistence

import com.xebialabs.deployit.plugin.api.reflect.Type
import com.xebialabs.xlrelease.db.sql.SqlBuilder.Dialect
import com.xebialabs.xlrelease.db.sql.transaction.{IsReadOnly, IsTransactional}
import com.xebialabs.xlrelease.db.sql.{DatabaseInfo, LimitOffset, SqlWithParameters}
import com.xebialabs.xlrelease.domain.status.TaskStatus
import com.xebialabs.xlrelease.domain.{ContainerTask, CustomScriptTask, Task}
import com.xebialabs.xlrelease.repository.Ids
import com.xebialabs.xlrelease.repository.Ids.{getFolderlessId, getName}
import com.xebialabs.xlrelease.repository.sql.SqlRepository
import com.xebialabs.xlrelease.repository.sql.persistence.CiId.{CiId, _}
import com.xebialabs.xlrelease.repository.sql.persistence.Schema.{RELEASES, TASKS}
import com.xebialabs.xlrelease.repository.sql.persistence.TaskPersistence.hash
import com.xebialabs.xlrelease.repository.sql.persistence.TaskTagsPersistence.TaskTag
import com.xebialabs.xlrelease.repository.sql.persistence.Utils.{params, _}
import com.xebialabs.xlrelease.repository.sql.persistence.data.TaskRow
import com.xebialabs.xlrelease.spring.config.SqlCiReferenceCacheConfiguration.{CACHE_CI_TASK_TYPES, CACHE_MANAGER_CI_REFERENCE}
import grizzled.slf4j.Logging
import org.apache.commons.codec.digest.DigestUtils
import org.springframework.cache.annotation.CacheEvict
import org.springframework.jdbc.core.{JdbcTemplate, RowMapper}

import java.sql.ResultSet
import java.util.{Date, Map => JMap}
import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._

@IsTransactional
class TaskPersistence(taskTagsPersistence: TaskTagsPersistence, dbInfo: DatabaseInfo)(implicit val jdbcTemplate: JdbcTemplate, implicit val dialect: Dialect)
  extends SqlRepository
    with PersistenceSupport
    with LimitOffset
    with Logging {

  private val STMT_UID_BY_ID: String =
    s"""|SELECT
        |   ${TASKS.CI_UID}
        |FROM ${TASKS.TABLE}
        |WHERE
        |   ${TASKS.TASK_ID_HASH} = :taskIdHash
        |   AND ${TASKS.TASK_ID} = :taskId
     """.stripMargin

  def taskUidById(taskId: CiId): Option[CiUid] = {
    findOne {
      sqlQuery(STMT_UID_BY_ID, params(
        "taskId" -> getFolderlessId(taskId),
        "taskIdHash" -> hash(taskId)
      ), rs => CiUid(rs.getInt(TASKS.CI_UID)))
    }
  }

  def insert(task: Task, releaseUid: Int): CiUid = {
    val ciUid = insertTask(releaseUid, task)
    taskTagsPersistence.insertTags(ciUid, task.getTags.asScala.toSet)
    task.setCiUid(ciUid)
    ciUid
  }

  def batchInsert(tasks: Set[Task], releaseUid: Int): Unit = {
    sqlBatch(STMT_INSERT_TASK, tasks.collect {
      case task: Task => params("releaseUid" -> releaseUid) ++ commonTaskParams(task)
    })
  }

  def batchInsertTags(taskTags: List[TaskTag]): Unit = {
    taskTagsPersistence.batchInsertTags(taskTags)
  }

  private val STMT_RELEASE_TASK_UIDS = s"SELECT ${TASKS.TASK_ID}, ${TASKS.CI_UID} FROM ${TASKS.TABLE} WHERE ${TASKS.RELEASE_UID} = :${TASKS.RELEASE_UID}"

  def releaseTaskUids(releaseUid: Int): Map[String, CiUid] = {
    // returns taskId -> taskCiUid map
    val mapper: RowMapper[(String, CiUid)] = (ps, _) => ps.getString(1) -> ps.getInt(2)
    val results = sqlQuery(STMT_RELEASE_TASK_UIDS, params(TASKS.RELEASE_UID -> releaseUid), mapper)
    results.toMap
  }

  private val STMT_INSERT_TASK =
    s"""| INSERT INTO ${TASKS.TABLE}
        |   ( ${TASKS.RELEASE_UID}
        |   , ${TASKS.TASK_ID}
        |   , ${TASKS.TASK_ID_HASH}
        |   , ${TASKS.TASK_TYPE}
        |   , ${TASKS.TITLE}
        |   , ${TASKS.STATUS}
        |   , ${TASKS.STATUS_LINE}
        |   , ${TASKS.OWNER}
        |   , ${TASKS.TEAM}
        |   , ${TASKS.START_DATE}
        |   , ${TASKS.END_DATE}
        |   , ${TASKS.IS_AUTOMATED}
        |   , ${TASKS.LOCKED}
        |   , ${TASKS.PLANNED_DURATION}
        |   , ${TASKS.IS_OVERDUE_NOTIFIED}
        |   , ${TASKS.IS_DUE_SOON_NOTIFIED}
        |   )
        | VALUES
        |   ( :releaseUid
        |   , :taskId
        |   , :taskIdHash
        |   , :taskType
        |   , :title
        |   , :status
        |   , :statusLine
        |   , :owner
        |   , :team
        |   , :startDate
        |   , :endDate
        |   , :isAutomated
        |   , :locked
        |   , :plannedDuration
        |   , :overdueNotified
        |   , :dueSoonNotified
        |   )
      """.stripMargin

  private def insertTask(releaseUid: Int, task: Task): CiUid = {
    sqlInsert(TASKS.CI_UID, STMT_INSERT_TASK, params("releaseUid" -> releaseUid) ++ commonTaskParams(task))
  }

  private val STMT_UPDATE_TASK =
    s"""| UPDATE ${TASKS.TABLE}
        |   SET
        |   ${TASKS.TITLE} = :title,
        |   ${TASKS.STATUS} = :status,
        |   ${TASKS.STATUS_LINE} = :statusLine,
        |   ${TASKS.OWNER} = :owner,
        |   ${TASKS.TEAM} = :team,
        |   ${TASKS.START_DATE} = :startDate,
        |   ${TASKS.END_DATE} = :endDate,
        |   ${TASKS.IS_AUTOMATED} = :isAutomated,
        |   ${TASKS.TASK_TYPE} = :taskType,
        |   ${TASKS.LOCKED} = :locked,
        |   ${TASKS.PLANNED_DURATION} = :plannedDuration,
        |   ${TASKS.IS_OVERDUE_NOTIFIED} = :overdueNotified,
        |   ${TASKS.IS_DUE_SOON_NOTIFIED} = :dueSoonNotified
        | WHERE
        |   ${TASKS.CI_UID} = :taskUid
      """.stripMargin

  def updateProperties(task: Task): Unit = {
    val taskUid = taskUidById(task.getId).get
    taskTagsPersistence.updateTags(taskUid, task.getTags.asScala.toSet)
    sqlUpdate(STMT_UPDATE_TASK, params("taskUid" -> taskUid) ++ commonTaskParams(task), _ => ())
  }

  private val STMT_UPDATE_TASK_BY_ID_HASH =
    s"""| UPDATE ${TASKS.TABLE}
        |   SET
        |   ${TASKS.TITLE} = :title,
        |   ${TASKS.STATUS} = :status,
        |   ${TASKS.STATUS_LINE} = :statusLine,
        |   ${TASKS.OWNER} = :owner,
        |   ${TASKS.TEAM} = :team,
        |   ${TASKS.START_DATE} = :startDate,
        |   ${TASKS.END_DATE} = :endDate,
        |   ${TASKS.IS_AUTOMATED} = :isAutomated,
        |   ${TASKS.TASK_TYPE} = :taskType,
        |   ${TASKS.LOCKED} = :locked,
        |   ${TASKS.PLANNED_DURATION} = :plannedDuration,
        |   ${TASKS.IS_OVERDUE_NOTIFIED} = :overdueNotified,
        |   ${TASKS.IS_DUE_SOON_NOTIFIED} = :dueSoonNotified
        | WHERE
        |   ${TASKS.TASK_ID_HASH} = :${TASKS.TASK_ID_HASH}
        |   AND ${TASKS.TASK_ID} = :${TASKS.TASK_ID}
      """.stripMargin

  def batchUpdateTaskProperties(tasks: Set[Task]): Unit = {
    sqlBatch(STMT_UPDATE_TASK_BY_ID_HASH, tasks.collect {
      case task: Task => params(TASKS.TASK_ID_HASH -> hash(task.getId), TASKS.TASK_ID -> Ids.getFolderlessId(task.getId)) ++ commonTaskParams(task)
    })
  }

  private val QUERY_TASK_HASH_AND_CIUID_BY_RELEASE_UID =
    s""" SELECT ${TASKS.TASK_ID_HASH}, ${TASKS.CI_UID}
       | FROM ${TASKS.TABLE}
       | WHERE ${TASKS.RELEASE_UID} = :${TASKS.RELEASE_UID}
       |""".stripMargin

  def findReleaseTaskCiUids(releaseUid: CiUid): Map[Hash, CiUid] = {
    val mapper: RowMapper[(Hash, CiUid)] = (rs, _) => rs.getString(1) -> rs.getInt(2)
    sqlQuery(QUERY_TASK_HASH_AND_CIUID_BY_RELEASE_UID, params(TASKS.RELEASE_UID -> releaseUid), mapper).toMap
  }

  def updateType(task: Task): Unit = {
    updateProperties(task)
  }

  private val STMT_MOVE_TASK =
    s"""| UPDATE ${TASKS.TABLE}
        |   SET
        |   ${TASKS.TASK_ID_HASH} = :newIdHash,
        |   ${TASKS.TASK_ID} = :newId
        | WHERE
        |   ${TASKS.TASK_ID_HASH} = :oldIdHash
        |   AND ${TASKS.TASK_ID} = :oldId
     """.stripMargin

  def move(oldId: String, newId: String): Unit = {
    sqlUpdate(STMT_MOVE_TASK, params(
      "oldId" -> getFolderlessId(oldId),
      "oldIdHash" -> hash(oldId),
      "newId" -> getFolderlessId(newId),
      "newIdHash" -> hash(newId)
    ), _ => ())
  }

  private val STMT_DELETE_TASK =
    s"""| DELETE FROM ${TASKS.TABLE}
        | WHERE
        |   ${TASKS.TASK_ID_HASH} = :taskIdHash
        |   AND ${TASKS.TASK_ID} = :taskId
      """.stripMargin

  private val STMT_DELETE_TASK_BY_UID =
    s"""| DELETE FROM ${TASKS.TABLE}
        | WHERE
        |   ${TASKS.CI_UID} = :${TASKS.CI_UID}
      """.stripMargin

  @CacheEvict(cacheNames = Array(CACHE_CI_TASK_TYPES), cacheManager = CACHE_MANAGER_CI_REFERENCE, key = "#root.args[0].id")
  def delete(task: Task): Unit = {
    if (task.getCiUid != null) {
      sqlUpdate(STMT_DELETE_TASK_BY_UID, params(TASKS.CI_UID -> task.getCiUid), _ => ())
    } else {
      sqlUpdate(STMT_DELETE_TASK, commonTaskParams(task), _ => ())
    }
  }

  def deleteByUids(taskUids: Seq[CiUid]): Unit = {
    // S-91304 no need to evict task type from cache as this method is called when phase is deleted
    //  and it should not be possible to delete a phase that is referenced by a dependency
    val query =
      s"""| DELETE FROM ${TASKS.TABLE}
          | WHERE ${TASKS.CI_UID} IN (${taskUids.map(_ => "?").mkString(",")})
      """.stripMargin
    jdbcTemplate.update(query, taskUids: _*)
  }

  private val STMT_DELETE_TASKS_BY_RELEASE_UID =
    s"""| DELETE FROM ${TASKS.TABLE}
        | WHERE
        |   ${TASKS.RELEASE_UID} = :releaseUid
      """.stripMargin

  def deleteTasksByReleaseUid(releaseUid: Int): Unit = {
    sqlUpdate(STMT_DELETE_TASKS_BY_RELEASE_UID, params("releaseUid" -> releaseUid), _ => ())
  }

  private val TASK_ROWS =
    s"""   task.${TASKS.CI_UID},
       |   task.${TASKS.RELEASE_UID},
       |   task.${TASKS.TASK_ID},
       |   task.${TASKS.TASK_TYPE},
       |   task.${TASKS.TITLE},
       |   task.${TASKS.STATUS},
       |   task.${TASKS.STATUS_LINE},
       |   task.${TASKS.OWNER},
       |   task.${TASKS.TEAM},
       |   task.${TASKS.START_DATE},
       |   task.${TASKS.END_DATE},
       |   task.${TASKS.IS_AUTOMATED},
       |   task.${TASKS.LOCKED}"""
  private val STMT_FIND_BY_IDS =
    s"""| SELECT
        |  $TASK_ROWS
        | FROM ${TASKS.TABLE} task
        | WHERE
        |   task.${TASKS.TASK_ID_HASH} IN (:taskIdHashes)
        |   AND task.${TASKS.TASK_ID} IN (:taskIds)
     """.stripMargin

  def findByIds(taskIds: Iterable[String]): Seq[TaskRow] = {
    sqlQuery(STMT_FIND_BY_IDS, params(
      "taskIds" -> taskIds.map(getFolderlessId).asJava,
      "taskIdHashes" -> taskIds.map(hash).asJava
    ), taskRowMapper).toSeq
  }

  private val STMT_FIND_BY_RELEASE_UID =
    s"""| SELECT
        |  $TASK_ROWS
        | FROM ${TASKS.TABLE} task
        | WHERE
        |   task.${TASKS.RELEASE_UID} = :${TASKS.RELEASE_UID}
     """.stripMargin

  def findByReleaseUid(releaseUid: CiUid): Seq[TaskRow] = {
    sqlQuery(STMT_FIND_BY_RELEASE_UID, params(
      TASKS.RELEASE_UID -> releaseUid
    ), taskRowMapper).toSeq
  }

  def findById(taskId: String): Option[TaskRow] = {
    findByIds(Seq(taskId)).headOption
  }

  private val STMT_EXISTS =
    s"""SELECT COUNT(*)
       |FROM ${TASKS.TABLE}
       |WHERE
       |  ${TASKS.TASK_ID_HASH} = :taskIdHash
       |  AND ${TASKS.TASK_ID} = :taskId
       |""".stripMargin

  def exists(taskId: String): Boolean = {
    sqlQuery(STMT_EXISTS, params(
      "taskId" -> getFolderlessId(taskId),
      "taskIdHash" -> hash(taskId)
    ), _.getInt(1) > 0).head
  }

  val STMT_GET_TASK_UID_BY_ID = s"SELECT ${TASKS.CI_UID} FROM ${TASKS.TABLE} WHERE ${TASKS.TASK_ID_HASH} = :taskIdHash AND ${TASKS.TASK_ID} = :taskId"

  @IsReadOnly
  def getTaskUidById(taskId: String): Option[CiUid] = {
    findOne {
      sqlQuery(STMT_GET_TASK_UID_BY_ID, params(
        "taskId" -> getFolderlessId(taskId),
        "taskIdHash" -> hash(taskId)
      ), rs => CiUid(rs.getInt(TASKS.CI_UID)))
    }
  }

  @IsReadOnly
  def findTaskIdsByQuery(sqlWithParameters: SqlWithParameters): Seq[CiId] = {
    val (sql, params) = sqlWithParameters
    val rowSet = jdbcTemplate.queryForRowSet(sql, params: _*)
    var seq = ArrayBuffer[CiId]()
    while (rowSet.next()) {
      seq = seq :+ rowSet.getString(1)
    }
    seq.toSeq
  }

  private val STMT_GET_TASK_IDS_BY_TASK_TYPE_STATUS_AND_START_DATE =
    s"""SELECT ${TASKS.TASK_ID}
       |FROM ${TASKS.TABLE}
       |WHERE
       |  ${TASKS.TASK_TYPE} = ?
       |  AND ${TASKS.STATUS} = ?
       |  AND ${TASKS.START_DATE} <= ?
       |""".stripMargin

  @IsReadOnly
  def findTaskIdsByTaskTypeStatusAndStartDate(taskType: Type, taskStatus: TaskStatus, startedBefore: Date): Seq[String] = {
    findTaskIdsByQuery((
      STMT_GET_TASK_IDS_BY_TASK_TYPE_STATUS_AND_START_DATE,
      Seq(
        taskType.toString,
        taskStatus.value(),
        startedBefore
      )
    ))
  }

  private val STMT_GET_TITLE_BY_ID =
    s"""SELECT ${TASKS.TITLE}
       |FROM ${TASKS.TABLE}
       |WHERE
       |  ${TASKS.TASK_ID_HASH} = :taskIdHash
       |  AND ${TASKS.TASK_ID} = :taskId
       |""".stripMargin

  @IsReadOnly
  def getTitle(taskId: CiId): Option[String] = {
    sqlQuery(STMT_GET_TITLE_BY_ID, params(
      "taskId" -> getFolderlessId(taskId),
      "taskIdHash" -> hash(taskId)
    ), _.getString(TASKS.TITLE)).headOption
  }

  @IsReadOnly
  def findAllTags(limitNumber: Int): Set[String] = {
    taskTagsPersistence.findAllTags(limitNumber)
  }

  private val STMT_GET_STATUS =
    s"""SELECT t.${TASKS.STATUS} FROM ${TASKS.TABLE} t
       |WHERE
       |  t.${TASKS.TASK_ID_HASH} = :taskIdHash AND
       |  t.${TASKS.TASK_ID} = :taskId
       |""".stripMargin

  @IsReadOnly
  def getStatus(taskId: String): Option[TaskStatus] = {
    sqlQuery(STMT_GET_STATUS, params(
      "taskId" -> Ids.getFolderlessId(taskId),
      "taskIdHash" -> hash(taskId)
    ), readStatus).headOption
  }

  private val STMT_GET_STATUSES =
    s"""SELECT t.${TASKS.TASK_ID}, t.${TASKS.STATUS}
       | FROM ${TASKS.TABLE} t
       | JOIN ${RELEASES.TABLE} r ON t.${TASKS.RELEASE_UID} = r.${RELEASES.CI_UID}
       | WHERE
       |  r.${RELEASES.RELEASE_ID} = :${RELEASES.RELEASE_ID}
       |""".stripMargin

  @IsReadOnly
  def getTaskStatuses(releaseId: String): JMap[String, TaskStatus] = {
    sqlQuery(STMT_GET_STATUSES, params(RELEASES.RELEASE_ID -> getName(releaseId.normalized)), readStatuses).toMap.asJava
  }

  private def readStatuses: RowMapper[(String, TaskStatus)] = (rs: ResultSet, _: Int) =>
    rs.getString(TASKS.TASK_ID) -> TaskStatus.valueOf(rs.getString(TASKS.STATUS).toUpperCase)


  private def readStatus: RowMapper[TaskStatus] = (rs: ResultSet, _: Int) =>
    TaskStatus.valueOf(rs.getString(TASKS.STATUS).toUpperCase)

  private[persistence] def commonTaskParams(task: Task): Map[String, Any] = {
    val commonParams = params(
      "taskId" -> getFolderlessId(task.getId),
      "taskIdHash" -> hash(task.getId),
      "taskType" -> task.getTaskType.toString,
      "title" -> task.getTitle.truncate(Schema.COLUMN_LENGTH_TITLE),
      "status" -> task.getStatus.value(),
      "statusLine" -> null,
      "owner" -> task.getOwner,
      "team" -> task.getTeam,
      "startDate" -> task.getStartOrScheduledDate,
      "endDate" -> task.getEndOrDueDate,
      "isAutomated" -> task.isAutomated.asInteger,
      "locked" -> task.isLocked.asInteger,
      "plannedDuration" -> task.getPlannedDuration,
      "overdueNotified" -> task.isOverdueNotified.asInteger,
      "dueSoonNotified" -> task.isDueSoonNotified.asInteger
    )
    val taskSpecificParams = task match {
      case cst: CustomScriptTask => Map("statusLine" -> cst.getStatusLine)
      case ct: ContainerTask => Map("statusLine" -> ct.getStatusLine)
      case _ => Map()
    }
    commonParams ++ taskSpecificParams
  }

  private val taskRowMapper: RowMapper[TaskRow] = (rs: ResultSet, _: Int) =>
    TaskRow(
      ciUid = rs.getInt(TASKS.CI_UID),
      releaseUid = rs.getInt(TASKS.RELEASE_UID),
      taskId = rs.getString(TASKS.TASK_ID),
      taskType = rs.getString(TASKS.TASK_TYPE),
      title = rs.getString(TASKS.TITLE),
      status = rs.getString(TASKS.STATUS),
      statusLine = rs.getString(TASKS.STATUS_LINE),
      owner = rs.getString(TASKS.OWNER),
      team = rs.getString(TASKS.TEAM),
      startDate = rs.getDate(TASKS.START_DATE),
      endDate = rs.getDate(TASKS.END_DATE),
      isAutomated = rs.getInt(TASKS.IS_AUTOMATED).asBoolean,
      locked = rs.getInt(TASKS.LOCKED).asBoolean
    )

  val overdueTasksQuery = new NotificationQueries.OverdueTasksQuery(dbInfo, namedTemplate)

  def findOverdueTaskIds(): Seq[String] = {
    overdueTasksQuery.execute
  }

  val dueSoonTasksQuery = new NotificationQueries.DueSoonTasksQuery(dbInfo, namedTemplate)

  def findDueSoonTaskIds(): Seq[String] = {
    dueSoonTasksQuery.execute
  }

  private val FIND_TASK_CI_UIDS_BY_RELEASE_CI_UID: String = s"""SELECT ${TASKS.CI_UID} FROM ${TASKS.TABLE} WHERE ${TASKS.RELEASE_UID} = :releaseUid"""

  def findTaskCiUidsByReleaseCiUid(releaseCiUid: CiUid): Seq[CiUid] = {
    val params = Map("releaseUid" -> releaseCiUid)
    sqlQuery(FIND_TASK_CI_UIDS_BY_RELEASE_CI_UID, params, row => {
      row.getInt(TASKS.CI_UID).asInstanceOf[Integer]
    }).toSeq
  }

  //noinspection DuplicatedCode

  val STMT_UPDATE_TASK_STATUSLINE =
    s"""
       | UPDATE ${TASKS.TABLE}
       | SET ${TASKS.STATUS_LINE} = :${TASKS.STATUS_LINE}
       | WHERE
       |   ${TASKS.TASK_ID_HASH} = :${TASKS.TASK_ID_HASH}
       |   AND ${TASKS.TASK_ID} = :${TASKS.TASK_ID}
       |""".stripMargin

  def updateStatusLine(taskId: String, statusLine: String): Unit = {
    val folderLessTaskId = getFolderlessId(taskId)
    sqlUpdate(STMT_UPDATE_TASK_STATUSLINE,
      params(TASKS.TASK_ID_HASH -> hash(taskId), TASKS.TASK_ID -> folderLessTaskId, TASKS.STATUS_LINE -> statusLine),
      _ => ()
    )
  }

}

object TaskPersistence {
  def hash(taskId: String): String =
    DigestUtils.sha256Hex(getFolderlessId(taskId))
}
