package com.xebialabs.deployit.security.sql

import com.xebialabs.deployit.core.sql.{ColumnName, Queries, SchemaInfo, TableName}
import com.xebialabs.deployit.security.repository.XldUserGroupRepository
import com.xebialabs.deployit.security.sql.XldUserGroupsSchema.{ID, NAME, tableName}
import grizzled.slf4j.Logging
import org.springframework.beans.factory.annotation.{Autowired, Qualifier}
import org.springframework.context.annotation.{Scope, ScopedProxyMode}
import org.springframework.dao.{DuplicateKeyException, EmptyResultDataAccessException}
import org.springframework.jdbc.core.{BatchPreparedStatementSetter, JdbcTemplate}
import org.springframework.stereotype.Component
import org.springframework.transaction.annotation.Transactional
import org.springframework.util.StringUtils.hasText

import java.sql.{PreparedStatement, ResultSet}
import java.util.UUID
import scala.jdk.CollectionConverters._
import scala.util.{Failure, Success, Try}

@Component
@Scope(proxyMode = ScopedProxyMode.TARGET_CLASS)
@Transactional("mainTransactionManager")
class SqlUserGroupRepository(@Autowired @Qualifier("mainJdbcTemplate") val jdbcTemplate: JdbcTemplate)
                            (@Autowired @Qualifier("mainSchema") override implicit val schemaInfo: SchemaInfo
                             )
  extends XldUserGroupRepository with SqlXldUserGroupsRepositoryQueries with Logging {

  /**
   * Returns list of groups assigned to the user
   */
  override def findGroupsForUser(username: String): Set[String] = {
    if (username.isEmpty) {
      throw new IllegalArgumentException("Group names cannot be empty")
    }
    logger.trace(s"Finding groups for user '$username'")
    try {
      val results = jdbcTemplate.query(
        FIND_GROUPS_FOR_USER,
        (rs: ResultSet, _: Int) => rs.getString(NAME.name),
        username.toLowerCase
      )
      results.asScala.toSet
    } catch {
      case e: Exception =>
        logger.error(s"Error finding groups for user '$username'", e)
        Set.empty
    }
  }

  /**
   * Creates a new group with given name
   * @param groupName
   * @return
   */

  override def createGroup(groupName: String): String = {
    if (groupName.isEmpty) {
      throw new IllegalArgumentException("Group names cannot be empty")
    }
    logger.trace(s"Creating new group '$groupName'")
    val groupId: String = UUID.randomUUID().toString
    logger.debug(s"Generated group ID: $groupId")
    jdbcTemplate.update(INSERT, groupId, groupName)
    groupId
  }

  /**
   * Creates new groups with given name
   */
  override def createGroups(groupNames: Set[String]): Set[String] = {
    if (groupNames.isEmpty) {
      throw new IllegalArgumentException("Group names cannot be empty")
    }
    val nonEmptyGroupNames = groupNames.filter(hasText)
    // Generate UUIDs for each group name
    logger.debug(s"Created groups: ${nonEmptyGroupNames.mkString(", ")}")
    val groupIds = nonEmptyGroupNames.filter(_.trim.nonEmpty).map { groupName =>
      val groupId = UUID.randomUUID().toString
      logger.debug(s"Generated group ID for groupName: $groupName: $groupId")
      (groupId, groupName)
    }
    groupIds.grouped(100).flatMap {
      chunkedGroupIds =>
        jdbcTemplate.batchUpdate(INSERT, new BatchPreparedStatementSetter {
          override def getBatchSize: Int = chunkedGroupIds.size

          override def setValues(ps: PreparedStatement, i: Int): Unit = {
            ps.setString(1, chunkedGroupIds.toSeq(i)._1) // Set the UUID
            ps.setString(2, chunkedGroupIds.toSeq(i)._2) // Set the group name
          }
        })
    }
    groupIds.map(_._1)
  }

  private def findGroupId(groupName: String): Option[String] = {
    logger.debug(s"Get Group Id for groupName: $groupName")
    try {
      val groupId = jdbcTemplate.queryForObject(FIND_GROUP_ID,
        classOf[String],
        groupName.toLowerCase)
      Option(groupId)
    } catch {
      case _: EmptyResultDataAccessException => None
      case e: Exception => throw new RuntimeException(s"Error while retrieving groupid for groupName $groupName", e)
    }
  }

  /**
   * Deletes a group with given name if exists
   */
  override def deleteGroup(groupName: String): Unit = {
    if (groupName.isEmpty || groupName.trim.isEmpty) {
      throw new IllegalArgumentException("Group names cannot be empty")
    }
    logger.debug(s"Deleting group '$groupName'")
    findGroupId(groupName) match {
      case Some(groupId) =>
        jdbcTemplate.update(DELETE_USER_GROUP_MEMBERSHIP_BY_GROUP,
          groupId) // Delete all user memberships for this group)
        jdbcTemplate.update(DELETE_GROUP,
          groupName.toLowerCase) // Delete the group itself)
      case None =>
        logger.trace(s"Group '$groupName' not found")
    }
  }

  /**
   * Creates a new group if does not exist and assigns group membership to the supplied user
   */
  override def addUserToGroup(username: String, groupName: String): Unit = {
    if (groupName.isEmpty || username.isEmpty || groupName.trim.isEmpty || username.trim.isEmpty) {
      throw new IllegalArgumentException("Group name/user name cannot be empty")
    }
    logger.trace(s"Adding user '$username' to group '$groupName'")
    val groupId: String = findGroupId(groupName) match {
      case Some(id) => id
      case None => createGroup(groupName)
    }
    jdbcTemplate.update(INSERT_USER_GROUP_MEMBERSHIP,
      groupId,
      username.toLowerCase
    )
  }


  private def findGroups(groupNames: Set[String]): Set[Group] = {
    if (groupNames.isEmpty) {
      return Set.empty
    }

    groupNames.filter(hasText).grouped(100).flatMap { chunkedGroupNames =>
      logger.debug(s"Finding groups for names: ${chunkedGroupNames.mkString(", ")}")

      // Create sql placeholder string with the correct number of placeholders
      val placeholders = "?" + ", ?".repeat(chunkedGroupNames.size - 1)
      val sql = FIND_GROUPS.replace("(?)", s"($placeholders)")

      val results = jdbcTemplate.query(
        sql,
        (rs: ResultSet, _: Int) => Group(rs.getString(ID.name), rs.getString(NAME.name)),
        chunkedGroupNames.map(_.toLowerCase).toArray: _*
      )
      results.asScala
    }.toSet
  }

  /**
   * Removes group membership for the supplied user
   */
  override def removeUserFromGroup(username: String, groupName: String): Unit = {
    if (groupName.isEmpty || username.isEmpty || username.trim.isEmpty || groupName.trim.isEmpty) {
      throw new IllegalArgumentException("Group name/user name cannot be empty")
    }
    logger.trace(s"Removing user '$username' from group '$groupName'")
    findGroupId(groupName) match {
      case Some(groupId) =>
        jdbcTemplate.update(DELETE_USER_GROUPS_MEMBERSHIP,
          username.toLowerCase,
          groupId)
      case None =>
        logger.trace(s"Group '$groupName' not found")
    }
  }

  /**
   * Removes groups membership for the supplied user
   */
  override def removeUserFromGroups(username: String, groupNames: Set[String]): Unit = {
    if (groupNames.isEmpty || username.isEmpty || username.trim.isEmpty) {
      throw new IllegalArgumentException("Group name/user name cannot be empty")
    }
    logger.trace(s"Removing user '$username' from groups '$groupNames'")
    val groupIds = findGroups(groupNames).map(_.id)
    groupIds.grouped(100).foreach { chunkedGroupIds =>
      jdbcTemplate.batchUpdate(DELETE_USER_GROUPS_MEMBERSHIP, new BatchPreparedStatementSetter {
        override def getBatchSize: Int = chunkedGroupIds.size

        override def setValues(ps: PreparedStatement, i: Int): Unit = {
          ps.setString(1, username.toLowerCase) // Set the username
          ps.setString(2, chunkedGroupIds.toSeq(i)) // Set the Group ID
        }
      })
    }
  }

  private def tryAddUserToGroup(username: String, groupName: String): Unit = {
    Try(addUserToGroup(username, groupName)) match {
      case Success(_) =>
        logger.trace(s"Group $groupName inserted successfully for user $username")
      case Failure(e: DuplicateKeyException) =>
        logger.trace(s"Group $groupName insertion failed for user $username: ${e.getMessage}")
      case Failure(e) =>
        logger.warn(s"Unexpected failure accrued during group $groupName insert for user $username: ${e.getMessage}")
        throw e
    }
  }

  /**
   * Assigns or removes group membership for the user based on supplied groupNames
   */
  override def updateGroupsMembershipForUser(username: String, groupNames: Set[String]): Unit = {
    val existingGroupsMembership = findGroupsForUser(username)
    val diff = Diff(existingGroupsMembership, groupNames)
    diff.newEntries.foreach(tryAddUserToGroup(username, _))
    diff.deletedEntries.foreach(removeUserFromGroup(username, _))
  }
}
  object XldUserGroupsSchema {
    val tableName: TableName = TableName("XLD_GROUPS")
    val ID: ColumnName = ColumnName("ID")
    val NAME: ColumnName = ColumnName("NAME")
  }

  object XldUserGroupsPrincipalsSchema {
    val tableName: TableName = TableName("XLD_GROUP_PRINCIPALS")
    val GROUP_ID: ColumnName = ColumnName("GROUP_ID")
    val PRINCIPAL_NAME: ColumnName = ColumnName("PRINCIPAL_NAME")
  }


  trait SqlXldUserGroupsRepositoryQueries extends Queries {
    val INSERT = sqlb"insert into $tableName ($ID, $NAME) values (?, ?)"
    val FIND_GROUP_ID = sqlb"select $ID from $tableName where lower($NAME) = ?"
    val FIND_GROUPS: String = sqlb"select $ID, $NAME from $tableName where lower($NAME) in (?)"
    val DELETE_GROUP = sqlb"delete from $tableName where lower($NAME) = ?"
    val DELETE_USER_GROUP_MEMBERSHIP_BY_GROUP = sqlb"delete from ${XldUserGroupsPrincipalsSchema.tableName} where ${XldUserGroupsPrincipalsSchema.GROUP_ID} = ?"
    val INSERT_USER_GROUP_MEMBERSHIP: String =
      sqlb"insert into ${XldUserGroupsPrincipalsSchema.tableName} (${XldUserGroupsPrincipalsSchema.GROUP_ID}, ${XldUserGroupsPrincipalsSchema.PRINCIPAL_NAME}) values (?, ?)"
    val DELETE_USER_GROUPS_MEMBERSHIP: String = sqlb"delete from ${XldUserGroupsPrincipalsSchema.tableName} where ${XldUserGroupsPrincipalsSchema.PRINCIPAL_NAME} = ? and ${XldUserGroupsPrincipalsSchema.GROUP_ID} = ?"
    val FIND_GROUPS_FOR_USER: String =
      sqlb"SELECT g.${XldUserGroupsSchema.NAME} FROM ${XldUserGroupsSchema.tableName} g join ${XldUserGroupsPrincipalsSchema.tableName} gp ON g.${XldUserGroupsSchema.ID} = gp.${XldUserGroupsPrincipalsSchema.GROUP_ID} where gp.${XldUserGroupsPrincipalsSchema.PRINCIPAL_NAME} = ?"
  }
  private case class Group(id: String, name: String)
  case class Diff[T](original: Set[T], updated: Set[T]) {
    lazy val newEntries: Set[T] = updated -- original
    lazy val updatedEntries: Set[(T, T)] = for {
      u <- updated
      o <- original
      if u == o
    } yield (o, u)
    lazy val deletedEntries: Set[T] = original -- updated
  }


