package com.xebialabs.xlrelease.repository.sql.persistence

import com.xebialabs.deployit.exception.NotFoundException
import com.xebialabs.deployit.repository.ItemConflictException
import com.xebialabs.xlrelease.db.sql.SqlBuilder.{CommonDialect, Dialect}
import com.xebialabs.xlrelease.domain.id.CiUid
import com.xebialabs.xlrelease.repository.sql.persistence.PersistenceConstants.BLOB_TYPE
import com.xebialabs.xlrelease.repository.sql.persistence.Utils._
import org.springframework.jdbc.core.namedparam.{MapSqlParameterSource, NamedParameterJdbcTemplate, SqlParameterSource}
import org.springframework.jdbc.core.support.SqlBinaryValue
import org.springframework.jdbc.core.{JdbcTemplate, PreparedStatementCallback, ResultSetExtractor, RowMapper}

import java.sql._
import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.language.implicitConversions
import scala.{Array => ScalaArray}

trait PersistenceSupport extends Utils with CompressionSupport {
  def jdbcTemplate: JdbcTemplate

  implicit val dialect: Dialect

  implicit lazy val namedTemplate: NamedParameterJdbcTemplate = new NamedParameterJdbcTemplate(jdbcTemplate)

  implicit def paramMap2MapSqlParameterSource(params: Map[String, Any]): MapSqlParameterSource = new MapSqlParameterSource(params.asJava)

  def sqlInsert(sqlStatement: String, params: Map[String, Any]): Unit = {
    val parameters = new MapSqlParameterSource()
    params.foreach {
      case (key, value) => parameters.addValue(key, value)
    }
    sqlInsert(sqlStatement, parameters)
  }

  def sqlInsertTask(sqlStatement: String, params: Map[String, Any]): Unit = {
    val parameters = new MapSqlParameterSource()
    params.foreach {
      case (key, value) => parameters.addValue(key, value)
    }
    sqlInsert(sqlStatement, parameters)
  }

  def sqlInsert(sqlStatement: String, params: MapSqlParameterSource): Unit = {
    namedTemplate.update(sqlStatement, params)
  }

  def sqlExec[R](sqlStatement: String, params: Map[String, Any], callback: PreparedStatementCallback[R]): R = {
    namedTemplate.execute[R](sqlStatement, params.asJava, callback)
  }

  def sqlExecWithContent(sqlStatement: String, params: Map[String, Any], contentParam: (String, String)): Unit = {
    sqlExecWithContent(sqlStatement, params, contentParam, identity)
  }

  def sqlExecWithContent[R](sqlStatement: String, params: Map[String, Any], contentParam: (String, String), mapper: Int => R): R = {
    val parameters = new MapSqlParameterSource()
    params.foreach {
      case (key, value) => parameters.addValue(key, value)
    }
    parameters.addValue(contentParam._1, new SqlBinaryValue(compress(contentParam._2)), BLOB_TYPE)
    mapper(namedTemplate.update(sqlStatement, parameters))
  }

  def sqlUpdate[R](sqlStatement: String, params: MapSqlParameterSource, mapper: Int => R): R = {
    mapper(namedTemplate.update(sqlStatement, params))
  }

  def sqlSet[R](sqlStatement: String, setup: PreparedStatement => PreparedStatement, callback: Int => R): R = {
    val numRowsUpdated = jdbcTemplate.update(
      (conn: Connection) =>
        setup(conn.prepareStatement(sqlStatement))
    )
    callback(numRowsUpdated)
  }

  def checkCiUpdated(entityId: String): Int => Unit = {
    case 0 => throw new NotFoundException(s"Repository entity [$entityId] could not be found.")
    case _ =>
  }

  def checkCiUpdated(uid: CiUid, token: Option[Token], freshToken: Token): Int => Token = {
    case 1 => freshToken
    case 0 => token match {
      case None =>
        throw new NotFoundException(s"Repository entity [$uid] could not be found.")
      case Some(_) =>
        throw new ItemConflictException(s"Repository entity [$uid] has been updated since you read it. Please reload the CI from the repository again.")
    }
  }

  def checkCiDeleted(uid: CiUid): Int => Unit = {
    case 1 => ()
    case _ => throw new NotFoundException(s"Repository entity [$uid] could not be found.")
  }

  def sqlQuery[R](sqlStatement: String, params: Map[String, Any], mapper: RowMapper[R]): mutable.Buffer[R] = {
    namedTemplate.query[R](
      sqlStatement,
      params.asJava,
      mapper
    ).asScala
  }

  def sqlQuery[R](sqlStatement: String, params: Map[String, Any], mapper: ResultSet => R): mutable.Buffer[R] =
    sqlQuery(sqlStatement, params, rowMapper(mapper))

  def sqlQuery[R](sqlStatement: String, params: Map[String, Any], extractor: ResultSetExtractor[R]): R = {
    namedTemplate.query[R](
      sqlStatement,
      params.asJava,
      extractor
    )
  }

  def sqlBatch(sqlStatement: String, parameters: Set[Map[String, Any]]): Seq[Int] = {
    val res = namedTemplate.batchUpdate(sqlStatement, parameters.map(ps => ps.asJava: java.util.Map[String, _]).toArray)
    // Oracle does not return number of updated rows in a batch before version 12.1.0.2
    res.collect {
      case Statement.SUCCESS_NO_INFO => 1
      case r => r
    }
  }

  def sqlBatchWithContent(sqlStatement: String, params: Seq[(Map[String, Any], (String, String))]): Seq[Int] = {
    val paramsWithContents: ScalaArray[SqlParameterSource] = params.map {
      case (ps, (contentKey, content)) =>
        val parameters = new MapSqlParameterSource()
        ps.foreach {
          case (key, value) =>
            parameters.addValue(key, value)
        }
        parameters.addValue(contentKey, new SqlBinaryValue(compress(content)), BLOB_TYPE)
        parameters
    }.toArray
    namedTemplate.batchUpdate(sqlStatement, paramsWithContents)
  }

  def pkName(pkColumn: String): String = dialect match {
    case CommonDialect(_) => pkColumn.toLowerCase()
    case _ => pkColumn
  }

  def paramSource(pairs: (String, Any)*): MapSqlParameterSource = {
    val params = new MapSqlParameterSource()
    pairs.collect { case (k, v) => params.addValue(k, v) }
    params
  }

}
