package com.xebialabs.deployit.security.sql

import ai.digital.deploy.permissions.client.RoleServiceClient
import com.xebialabs.deployit.checks.Checks.checkArgument
import com.xebialabs.deployit.core.sql.spring.{MapRowMapper, Setter, XlSingleColumnRowMapper}
import com.xebialabs.deployit.core.sql.spring.Setter.setString
import com.xebialabs.deployit.core.sql.{ColumnName, OrderBy, Queries, SchemaInfo, SelectBuilder, SelectFragmentBuilder, SqlFunction, TableName, SqlCondition => cond}
import com.xebialabs.deployit.engine.api.dto.{Ordering, Paging}
import com.xebialabs.deployit.exception.NotFoundException
import com.xebialabs.deployit.security.archive.ArchiveSecurity
import com.xebialabs.deployit.security.sql.RolesSchema._
import com.xebialabs.deployit.security.{Permissions, Role, RoleService, _}
import org.springframework.beans.factory.annotation.{Autowired, Qualifier}
import org.springframework.jdbc.core._
import org.springframework.security.core.Authentication
import org.springframework.security.core.context.SecurityContextHolder
import org.springframework.stereotype.Component
import org.springframework.transaction.annotation.Transactional

import java.sql.{PreparedStatement, ResultSet}
import java.util
import java.util.{Objects, UUID}
import scala.collection.mutable
import scala.jdk.CollectionConverters._

@Component
@Transactional("mainTransactionManager")
class SqlRoleRepository(@Autowired @Qualifier("mainJdbcTemplate") val jdbcTemplate: JdbcTemplate,
                        @Autowired val ciResolver: CiResolver,
                        @Autowired val securityArchive: ArchiveSecurity,
                        @Autowired val roleServiceClient: RoleServiceClient)
                        (@Autowired @Qualifier("mainSchema") implicit val schemaInfo: SchemaInfo)
  extends RoleService with RoleQueries {

  val ADMIN_READ_ONLY_ROLE_NAME = "deploy_admin_read_only"

  private val roleMapper: RowMapper[Role] = (rs: ResultSet, _: Int) =>
    new Role(rs.getString(Roles.ID.name), rs.getString(Roles.NAME.name))

  override def getRoles(rolePattern: String, paging: Paging, order: Ordering): util.List[Role] =
    getRoles(null, rolePattern, paging, order)

  override def getRoles(onConfigurationItem: String, rolePattern: String, paging: Paging, order: Ordering): util.List[Role] = {
    val builder = new SelectBuilder(Roles.tableName)
    builder.select(Roles.ID).select(Roles.NAME).where(cond.equals(Roles.CI_ID, getPk(onConfigurationItem)))
    withOrder(builder, order)
    withRolePattern(builder, Option(rolePattern))
    withPage(builder, paging)
    jdbcTemplate.query(builder.query, Setter(builder.parameters), roleMapper)
  }


  private def withRolePattern(builder: SelectBuilder, rolePattern: Option[String]) = rolePattern match {
    case Some(pattern) if pattern.nonEmpty =>
      builder.where(cond.likeEscaped(SqlFunction.lower(Roles.NAME), s"%$pattern%".toLowerCase()))
    case _ =>
  }

  private def withPage(builder: SelectBuilder, paging: Paging): Unit =
    Option(paging).filter(_.resultsPerPage != -1).foreach(p => builder.showPage(p.page, p.resultsPerPage))

  private def withOrder(builder: SelectBuilder, order: Ordering, alias: Option[String] = None) = {
    def orderingColumn(alias: Option[String]) = {
      SqlFunction.lower(alias.map(a => Roles.NAME.tableAlias(a)).getOrElse(Roles.NAME))
    }

    Option(order) match {
      case Some(ord) =>
        builder.orderBy(
          if (ord.isAscending) OrderBy.asc(orderingColumn(alias))
          else OrderBy.desc(orderingColumn(alias))
        )
      case _ =>
        builder.orderBy(OrderBy.asc(orderingColumn(alias)))
    }
  }

  override def getRolesFor(principal: String, rolePattern: String, paging: Paging, order: Ordering): util.List[Role] = {
    val builder = new SelectBuilder(Roles.tableName)
      .select(Roles.ID)
      .select(Roles.NAME)
      .where(cond.equals(Roles.CI_ID, GLOBAL_ID))
      .where(cond.subselect(Roles.ID,
        new SelectBuilder(RolePrincipals.tableName)
          .select(RolePrincipals.ROLE_ID)
          .where(cond.equals(SqlFunction.lower(RolePrincipals.PRINCIPAL_NAME), principal.toLowerCase))
      )
      )

    withRolePattern(builder, Option(rolePattern))
    withPage(builder, paging)
    withOrder(builder, order)

    jdbcTemplate.query(builder.query, Setter(builder.parameters), roleMapper)
  }

  override def getRolesFor(auth: Authentication, rolePattern: String, paging: Paging, order: Ordering): util.List[Role] = {
    val principals = Permissions.authenticationToPrincipals(auth)
    if (principals.isEmpty) {
      new util.ArrayList()
    } else {
      val builder = new SelectBuilder(Roles.tableName)
        .select(Roles.ID)
        .select(Roles.NAME)
        .where(cond.equals(Roles.CI_ID, GLOBAL_ID))
        .where(cond.subselect(Roles.ID,
          new SelectBuilder(RolePrincipals.tableName)
            .select(RolePrincipals.ROLE_ID)
            .where(cond.in(SqlFunction.lower(RolePrincipals.PRINCIPAL_NAME), principals.asScala.map(_.toLowerCase())))
        ))
      withRolePattern(builder, Option(rolePattern))
      withPage(builder, paging)
      withOrder(builder, order)
      jdbcTemplate.query(builder.query, Setter(builder.parameters), roleMapper)
    }
  }

  override def getRoleForRoleName(roleName: String): Role = {
    val builder = new SelectBuilder(Roles.tableName)
      .select(Roles.ID)
      .select(Roles.NAME)
      .where(cond.equals(Roles.NAME, roleName))
    val role: Role = jdbcTemplate.query(builder.query, Setter(builder.parameters), roleMapper).asScala
      .headOption.getOrElse(throw new NotFoundException("Could not find the role [%s]", roleName))

    role.setPrincipals(getRolePrincipalsForRoleId(role.getId))
    role.setRoles(getRoleRolesForRoleId(role.getId))
    role
  }

  override def writeRoleAssignments(roles: util.List[Role]): Unit = writeRoleAssignments(null, roles)

  override def writeRoleAssignments(onConfigurationItem: String, roles: util.List[Role]): Unit = {
    val convertedRoles = roles.asScala
    checkDuplicates(convertedRoles.toSeq)

    convertedRoles.filter(_.getId == null).foreach(generateId)

    val onConfigurationItemId: Number = getPk(onConfigurationItem)
    val originalRoles = readAllRoleAssignments(onConfigurationItemId).asScala

    val (toCreate, toUpdate, toDelete) = diff3(convertedRoles.toSet, originalRoles.toSet)

    doListDelete(toDelete.toList.map(_.getId))
    toCreate.foreach(doCreate(_, onConfigurationItemId))
    toUpdate.foreach(r => doUpdate(r._1, r._2, onConfigurationItemId))
  }

  private def checkDuplicates(updatedRoles: Seq[Role]): Unit = {
    val duplicateRoles = updatedRoles.groupBy(_.getName).filter(_._2.size > 1)
    checkArgument(duplicateRoles.isEmpty, s"Roles with duplicate names [${duplicateRoles.keys.mkString(", ")}] are not allowed")
  }

  private def generateId(role: Role): Unit = role.setId(UUID.randomUUID().toString)

  private def doListDelete(ids: List[String]): Unit = if (ids.nonEmpty) {
    val setter = new BatchPreparedStatementSetter {
      override def getBatchSize: Int = ids.size

      override def setValues(ps: PreparedStatement, i: Int): Unit = setString(ps, 1, ids(i))
    }
    jdbcTemplate.batchUpdate(DELETE_ROLE_PERMISSIONS, setter)
    jdbcTemplate.batchUpdate(DELETE_CHILD_ROLE, setter)
    jdbcTemplate.batchUpdate(DELETE_ROLE_PRINCIPAL, setter)
    jdbcTemplate.batchUpdate(DELETE_ROLE, setter)

    securityArchive.doListDelete(ids)
  }

  private def doCreate(role: Role, onConfigurationItem: Number): String = {
    jdbcTemplate.update(INSERT_ROLE, role.getId, role.getName, onConfigurationItem)

    val childRoleIds: Seq[String] = role.getRoles.asScala.map(getRoleIdByName(GLOBAL_ID)(_)).toSeq
    batchInsert(INSERT_CHILD_ROLE, role.getId, childRoleIds)
    val principals: Seq[String] = role.getPrincipals.asScala.distinct.toSeq
    batchInsert(INSERT_ROLE_PRINCIPAL, role.getId, principals)

    securityArchive.doCreate(role.getId, role.getName, childRoleIds, principals,  onConfigurationItem)

    roleServiceClient.create(role.getName, role.getPrincipals.asScala.toList)
    role.getId
  }

  private def doUpdate(original: Role, updated: Role, onConfigurationItem: Number): String = {
    if (original.getName != updated.getName) {
      jdbcTemplate.update(UPDATE_ROLE, updated.getName, updated.getId)
    }
    val originalPrincipals = original.getPrincipals.asScala.toSet
    val updatedPrincipals = updated.getPrincipals.asScala.toSet
    updateChildEntities(INSERT_ROLE_PRINCIPAL, DELETE_ROLE_PRINCIPAL_BY_PRINCIPAL, original.getId, originalPrincipals, updatedPrincipals)

    val originalRoleIds = original.getRoles.asScala.toSet.map(getRoleIdByName(onConfigurationItem))
    val updateRoleIds = updated.getRoles.asScala.toSet.map(getRoleIdByName(onConfigurationItem))
    updateChildEntities(INSERT_CHILD_ROLE, DELETE_CHILD_ROLE_BY_CHILD, original.getId, originalRoleIds, updateRoleIds)

    securityArchive.doUpdate(original, updated, originalPrincipals, updatedPrincipals, originalRoleIds, updateRoleIds)
    val (principalsToCreate, principalsToDelete) = diff2(originalPrincipals, updatedPrincipals)
    roleServiceClient.update(original.getName, updated.getName, principalsToCreate, principalsToDelete)

    original.getId
  }

  private def getRoleIdByName(onConfigurationItem: Number)(roleName: String): String = {
    val builder = new SelectBuilder(Roles.tableName)
      .select(Roles.ID)
      .select(Roles.NAME)
      .where(cond.equals(Roles.NAME, roleName))
      .where(cond.in(Roles.CI_ID, Set(onConfigurationItem, GLOBAL_ID)))
    jdbcTemplate.query(builder.query, Setter(builder.parameters), (rs: ResultSet, rowNum: Int) => rs.getString(1)).asScala
      .headOption.getOrElse(throw new NotFoundException("Role [%s] not found", roleName))
  }

  private def diff2(original: Set[String], updated: Set[String]) = (
    updated.diff(original),
    original.diff(updated)
  )

  private def diff3(updated: Set[Role], original: Set[Role]) = (
    updated.diff(original),
    for {
      u <- updated
      o <- original
      if u == o
    } yield (o, u),
    original.diff(updated)
  )

  override def readRoleAssignments(rolePattern: String, paging: Paging, order: Ordering): util.List[Role] =
    readRoleAssignments(null, rolePattern, paging, order)

  override def readRoleAssignments(onConfigurationItem: String, rolePattern: String, paging: Paging, order: Ordering): util.List[Role] = {
    val onConfigurationItemId: Number = getPk(onConfigurationItem)
    if (rolePattern == null && (paging == null || paging.resultsPerPage == -1) && order == null) {
      return readAllRoleAssignments(onConfigurationItemId)
    }

    val roleSelection = new SelectBuilder(Roles.tableName).select(Roles.ID)

    withRolePattern(roleSelection, Option(rolePattern))

    val fullSelection = new SelectFragmentBuilder(SELECT_FULL_ROLES_FRAGMENT)
      .where(cond.subselect(Roles.ID.tableAlias("role"), roleSelection))
      .where(cond.equals(Roles.CI_ID.tableAlias("role"), onConfigurationItemId))

    withOrder(fullSelection, order, Option("role"))
    withPage(fullSelection, paging)

    groupRolesById(jdbcTemplate.query(fullSelection.query, Setter(fullSelection.parameters), MapRowMapper).asScala)
  }

  private def readAllRoleAssignments(onConfigurationItem: Number): util.List[Role] = {
    val selection = new SelectFragmentBuilder(SELECT_FULL_ROLES)
    withOrder(selection, new Ordering("ASC"), Option("role"))
    groupRolesById(jdbcTemplate.query(selection.query, MapRowMapper, onConfigurationItem).asScala)
  }

  private def groupRolesById(results: mutable.Buffer[util.Map[String, Object]]) = {
    val principals = new mutable.HashMap[String, mutable.Set[String]]
    val children = new mutable.HashMap[String, mutable.Set[String]]
    val roleMap = new mutable.LinkedHashMap[String, Role]()
    results.foreach({ row =>
      val id = row.get("roleId").asInstanceOf[String]
      val name = row.get("roleName").asInstanceOf[String]
      val principal = row.get("principalName").asInstanceOf[String]
      val child = row.get("childName").asInstanceOf[String]

      if (!roleMap.contains(id)) {
        principals.put(id, new mutable.HashSet[String])
        children.put(id, new mutable.HashSet[String])
        roleMap.put(id, new Role(id, name, List().asJava, List().asJava))
      }

      principals(id).add(principal)
      children(id).add(child)
    })

    roleMap.foreach { case (id, role) =>
      role.getPrincipals.addAll(principals(id).filter(Objects.nonNull).asJava)
      role.getRoles.addAll(children(id).filter(Objects.nonNull).asJava)
    }
    roleMap.values.toList.asJava
  }

  private def getPk(onConfigurationItem: String): Number = onConfigurationItem match {
    case GLOBAL_SECURITY_ALIAS | "" | null => GLOBAL_ID
    case _ => ciResolver.getPkFromId(onConfigurationItem)
  }

  override def countRoles(onConfigurationItem: Number, rolePattern: String): Long = {
    val builder = new SelectBuilder(Roles.tableName)
    builder.select(SqlFunction.countAll).where(cond.equals(Roles.CI_ID, onConfigurationItem))
    withRolePattern(builder, Option(rolePattern))

    jdbcTemplate
      .query(builder.query, Setter(builder.parameters), new SingleColumnRowMapper(classOf[Long])).asScala
      .headOption
      .getOrElse(0L)
  }

  override def roleExists(roleName: String): Boolean = {
    val countSelection = new SelectBuilder(Roles.tableName).select(SqlFunction.countAll)
    countSelection.where(cond.equals(Roles.NAME, roleName))

    jdbcTemplate
      .query(countSelection.query, Setter(countSelection.parameters), new SingleColumnRowMapper(classOf[Long]))
      .asScala
      .headOption
      .getOrElse(0L) > 0
  }

  override def countRoles(onConfigurationItem: String, rolePattern: String): Long =
    this.countRoles(getPk(onConfigurationItem), rolePattern)

  override def create(name: String, onConfigurationItem: String): String = {
    doCreate(createRole(name), getPk(onConfigurationItem))
  }

  override def rename(name: String, newName: String, onConfigurationItem: String): String = {
    val original: Role = getRoleForRoleName(name)
    val newRole: Role = new Role(original.getId, newName, original.getPrincipals, original.getRoles)
    doUpdate(original, newRole, getPk(onConfigurationItem))
  }

  override def createOrUpdateRole(role: Role, onConfigurationItem: String): String = {
    updatePrincipals(role.getId, role.getName, role.getPrincipals, role.getRoles, getPk(onConfigurationItem))
  }

  private def updatePrincipals(id: String, name: String, principals: util.List[String], roles: util.List[String], onConfigurationItemPk: Number): String = {
    val original: Role =
      if (id != null) {
        getRoleForRoleId(id)
      }
      else {
        getExistingOrCreateRole(id, name, onConfigurationItemPk)
      }
    val newRole: Role = new Role(original.getId, name, principals, roles)
    doUpdate(original, newRole, onConfigurationItemPk)
  }

  override def deleteByName(name: String): Unit = {
    val role: Role = getRoleForRoleName(name)
    doDelete(role.getId, Option(name))
  }

  override def deleteById(roleId: String): Unit = doDelete(roleId)

  override def getRoleForRoleId(roleId: String): Role = {
    val builder = new SelectBuilder(Roles.tableName)
      .select(Roles.ID)
      .select(Roles.NAME)
      .where(cond.equals(Roles.ID, roleId))
    val role: Role = jdbcTemplate.query(builder.query, Setter(builder.parameters), roleMapper).asScala
      .headOption.getOrElse(throw new NotFoundException("Could not find the role with id [%s]", roleId))

    role.setPrincipals(getRolePrincipalsForRoleId(roleId))
    role.setRoles(getRoleRolesForRoleId(roleId))
    role
  }

  private def getExistingOrCreateRole(id: String, name: String, onConfigurationItem: Number): Role = {
    if (!roleExists(name)) {
      doCreate(createRole(name), onConfigurationItem)
    }
    getRoleForRoleName(name)
  }

  private def getRolePrincipalsForRoleId(roleId: String): util.List[String] = {
    val builder = new SelectBuilder(RolePrincipals.tableName)
      .select(RolePrincipals.PRINCIPAL_NAME)
      .where(cond.equals(RolePrincipals.ROLE_ID, roleId))

    jdbcTemplate.query(builder.query, Setter(builder.parameters), new XlSingleColumnRowMapper(classOf[String]))
  }

  private def getRoleRolesForRoleId(roleId: String): util.List[String] = {
    jdbcTemplate.queryForList(SELECT_ROLE_ROLES_NAMES, classOf[String], roleId)
  }


  private def batchInsert(query: String, roleId: String, childEntityList: Seq[String]): Unit = {
    jdbcTemplate.batchUpdate(query, new BatchPreparedStatementSetter {
      override def getBatchSize: Int = childEntityList.length

      override def setValues(ps: PreparedStatement, i: Int): Unit = {
        ps.setString(1, roleId)
        setString(ps,2, childEntityList(i))
      }
    })
  }

  private def doDelete(id: String, roleNameMaybe: Option[String] = None): Unit = {
    val roleName = roleNameMaybe.getOrElse(getRoleForRoleId(id).getName)
    jdbcTemplate.update(DELETE_ROLE_PERMISSIONS, id)
    jdbcTemplate.update(DELETE_CHILD_ROLE, id)
    jdbcTemplate.update(DELETE_ROLE_PRINCIPAL, id)
    jdbcTemplate.update(DELETE_ROLE, id)
    securityArchive.doDelete(id)

    roleServiceClient.deleteAllReferences(roleName)
  }


  private def createRole(name: String): Role = {
    val role = new Role(name)
    generateId(role)
    role
  }

  private def updateChildEntities(insertQuery: String, deleteQuery: String,
                                  roleId: String, originals: Set[String], updates: Set[String]): Unit = {
    val (toCreate, toDelete) = diff2(originals, updates)
    toCreate.foreach(jdbcTemplate.update(insertQuery, roleId, _))
    toDelete.foreach(jdbcTemplate.update(deleteQuery, roleId, _))
  }

  override def isReadOnlyAdmin(): Boolean = {
    val authentication = SecurityContextHolder.getContext.getAuthentication
    val allRoles = getRolesFor(authentication)
    allRoles.asScala.map(_.getName).contains(ADMIN_READ_ONLY_ROLE_NAME)
  }
}

object RolesSchema {

  object Roles {
    val tableName: TableName = TableName("XL_ROLES")

    val ID: ColumnName = ColumnName("ID")
    val NAME: ColumnName = ColumnName("NAME")
    val CI_ID: ColumnName = ColumnName("CI_ID")
  }

  object RoleRoles {
    val tableName: TableName = TableName("XL_ROLE_ROLES")

    val ROLE_ID: ColumnName = ColumnName("ROLE_ID")
    val MEMBER_ROLE_ID: ColumnName = ColumnName("MEMBER_ROLE_ID")
  }

  object RolePrincipals {
    val tableName: TableName = TableName("XL_ROLE_PRINCIPALS")

    val ROLE_ID: ColumnName = ColumnName("ROLE_ID")
    val PRINCIPAL_NAME: ColumnName = ColumnName("PRINCIPAL_NAME")
  }

}

trait RoleQueries extends Queries {

  import RolesSchema._

  lazy val SELECT_FULL_ROLES_FRAGMENT: String =
    sqlb"""role.${Roles.ID} roleId, role.${Roles.NAME} roleName, child.${Roles.NAME} childName, principal.${RolePrincipals.PRINCIPAL_NAME} principalName
      from ${Roles.tableName} role
      left join ${RoleRoles.tableName} childRoles on role.${Roles.ID} = childRoles.${RoleRoles.ROLE_ID}
      left join ${Roles.tableName} child on childRoles.${RoleRoles.MEMBER_ROLE_ID} = child.${Roles.ID}
      left join ${RolePrincipals.tableName} principal on role.${Roles.ID} = principal.${RolePrincipals.ROLE_ID}"""

  lazy val SELECT_FULL_ROLES: String =
    sqlb"""$SELECT_FULL_ROLES_FRAGMENT where role.${Roles.CI_ID} = ?"""

  lazy val SELECT_ROLE_ROLES_NAMES: String =
    sqlb"""select role.${Roles.NAME} from ${Roles.tableName} role
          left join ${RoleRoles.tableName} childRoles on role.${Roles.ID} = childRoles.${RoleRoles.MEMBER_ROLE_ID}
          where childRoles.${RoleRoles.ROLE_ID} = ?"""

  lazy val INSERT_ROLE: String = {
    import Roles._
    sqlb"insert into $tableName ($ID, $NAME, $CI_ID) values (?, ?, ?)"
  }

  lazy val UPDATE_ROLE: String = {
    import Roles._
    sqlb"update $tableName set $NAME = ? where $ID = ?"
  }

  lazy val INSERT_CHILD_ROLE: String = {
    import RoleRoles._
    sqlb"insert into $tableName ($ROLE_ID, $MEMBER_ROLE_ID) values (?, ?)"
  }

  lazy val INSERT_ROLE_PRINCIPAL: String = {
    import RolePrincipals._
    sqlb"insert into $tableName ($ROLE_ID, $PRINCIPAL_NAME) values (?, ?)"
  }

  lazy val DELETE_ROLE: String = {
    import Roles._
    sqlb"delete from $tableName where $ID = ?"
  }

  lazy val DELETE_CHILD_ROLE: String = {
    import RoleRoles._
    sqlb"delete from $tableName where $ROLE_ID = ?"
  }

  lazy val DELETE_CHILD_ROLE_BY_CHILD: String = {
    import RoleRoles._
    sqlb"delete from $tableName where $ROLE_ID = ? and $MEMBER_ROLE_ID = ?"
  }

  lazy val DELETE_ROLE_PRINCIPAL: String = {
    import RolePrincipals._
    sqlb"delete from $tableName where $ROLE_ID = ?"
  }

  lazy val DELETE_ROLE_PRINCIPAL_BY_PRINCIPAL: String = {
    import RolePrincipals._
    sqlb"delete from $tableName where $ROLE_ID = ? and $PRINCIPAL_NAME = ?"
  }

  lazy val DELETE_ROLE_PERMISSIONS: String = {
    import PermissionsSchema._
    sqlb"delete from $tableName where $ROLE_ID = ?"
  }
}
