package com.xebialabs.xlrelease.repository.sql.persistence

import com.xebialabs.xlrelease.db.sql.LimitOffset
import com.xebialabs.xlrelease.db.sql.SqlBuilder.Dialect
import com.xebialabs.xlrelease.db.sql.transaction.{IsReadOnly, IsTransactional}
import com.xebialabs.xlrelease.repository.sql.SqlRepository
import com.xebialabs.xlrelease.repository.sql.persistence.Schema.TASK_TAGS
import com.xebialabs.xlrelease.repository.sql.persistence.TaskTagsPersistence.TaskTag
import com.xebialabs.xlrelease.repository.sql.persistence.TasksSqlBuilder.{normalizeTag, normalizeTags}
import com.xebialabs.xlrelease.repository.sql.persistence.Utils.{params, _}
import org.springframework.jdbc.core.{BatchPreparedStatementSetter, JdbcTemplate}

import java.sql.PreparedStatement
import scala.jdk.CollectionConverters._

@IsTransactional
class TaskTagsPersistence(val jdbcTemplate: JdbcTemplate, val dialect: Dialect)
  extends SqlRepository
    with PersistenceSupport
    with LimitOffset {

  @IsReadOnly
  def findAllTags(limitNumber: Int): Set[String] = {
    val stmt = s"SELECT DISTINCT ${TASK_TAGS.VALUE} FROM ${TASK_TAGS.TABLE} ORDER BY ${TASK_TAGS.VALUE}"
    sqlQuery(addLimitAndOffset(stmt, Some(limitNumber)), params(), rs => rs.getString(TASK_TAGS.VALUE)).toSet
  }

  @IsReadOnly
  def getTags(taskUid: CiUid): Set[String] = {
    sqlQuery(
      s"SELECT ${TASK_TAGS.VALUE} FROM ${TASK_TAGS.TABLE} WHERE ${TASK_TAGS.CI_UID} = :taskUid",
      params("taskUid" -> taskUid),
      (rs, _) => rs.getString(TASK_TAGS.VALUE)
    ).toSet[String]
  }

  def insertTags(taskUid: CiUid, tags: Set[String]): Unit = {
    if (tags.nonEmpty) {
      jdbcTemplate.batch(s"INSERT INTO ${TASK_TAGS.TABLE} (${TASK_TAGS.CI_UID}, ${TASK_TAGS.VALUE}) VALUES(?, ?)", taskUid, normalizeTags(tags).toList)
    }
  }

  def batchInsertTags(taskTags: Seq[TaskTag]): Unit = {
    val sql = s"INSERT INTO ${TASK_TAGS.TABLE} (${TASK_TAGS.CI_UID}, ${TASK_TAGS.VALUE}) VALUES(?, ?)"
    val pss: BatchPreparedStatementSetter = new BatchPreparedStatementSetter {
      override def getBatchSize: Int = taskTags.size

      override def setValues(ps: PreparedStatement, i: Int): Unit = {
        ps.setInt(1, taskTags(i).taskUid)
        ps.setString(2, normalizeTag(taskTags(i).tag))
      }

    }
    jdbcTemplate.batchUpdate(sql, pss)
  }

  def updateTags(taskUid: CiUid, tags: Set[String]): Unit = {
    val updatedTags = normalizeTags(tags)
    val existingTags = getTags(taskUid)

    if (updatedTags != existingTags) {
      val deletedTags = existingTags.diff(updatedTags)
      val createdTags = updatedTags.diff(existingTags)
      deleteTags(taskUid, deletedTags)
      insertTags(taskUid, createdTags)
    }
  }

  def deleteTags(taskUid: CiUid, tags: Set[String]): Unit = {
    if (tags.nonEmpty) {
      val stmt =
        s"""|DELETE FROM ${TASK_TAGS.TABLE}
            |WHERE ${TASK_TAGS.CI_UID} = :taskUid
            |AND ${TASK_TAGS.VALUE} IN (:tags)""".stripMargin
      sqlExec(stmt, params("taskUid" -> taskUid, "tags" -> normalizeTags(tags).toList.asJava), _.execute())
    }
  }
}

object TaskTagsPersistence {
  case class TaskTag(taskUid: CiUid, tag: String)
}