package com.xebialabs.xlrelease.security.sql

import com.xebialabs.deployit.engine.api.dto.Paging
import com.xebialabs.deployit.security.Permissions.authenticationToPrincipals
import com.xebialabs.deployit.security.{PermissionLister, Role}
import com.xebialabs.xlplatform.repository.sql.Database
import com.xebialabs.xlrelease.config.CacheManagementConstants._
import com.xebialabs.xlrelease.security.sql.SecurityCacheConfigurationConstants._
import com.xebialabs.xlrelease.security.sql.db.Ids.{fromDbId, isGlobalId}
import com.xebialabs.xlrelease.security.sql.db.Tables
import org.springframework.cache.annotation.{CacheConfig, Cacheable}
import org.springframework.security.core.Authentication
import slick.jdbc.JdbcProfile

import java.util
import scala.jdk.CollectionConverters._


@CacheConfig(cacheManager = SECURITY_CACHE_MANAGER)
class CachingPermissionLister(securityDatabase: Database) extends SqlPermissionLister(securityDatabase) {
  @Cacheable(
    cacheNames = Array(SECURITY_USER_PERMISSIONS),
    key = "#onConfigurationItem + '-' + #auth?.name + '-' + #globalRoles?.hashCode()",
    condition = "#auth != null and #auth.name != null and #auth.name != '' and #onConfigurationItem != null and #onConfigurationItem != ''"
  )
  override def listUserPermissionsFor(onConfigurationItem: String, auth: Authentication, globalRoles: Seq[Role]): util.List[String] =
    super.listUserPermissionsFor(onConfigurationItem, auth, globalRoles)
}

class SqlPermissionLister(securityDatabase: Database) extends PermissionLister {

  import securityDatabase._

  val profile: JdbcProfile = config.databaseType.profile

  import profile.api._

  type Q = Query[Tables.RolePermissions, Tables.RolePermission, Seq]

  def listUserPermissionsFor(onConfigurationItem: String, auth: Authentication, globalRoles: Seq[Role]): util.List[String] = {
    val configurationItem = Some(onConfigurationItem)
    val allRoles = if (isGlobalId(configurationItem)) globalRoles else globalRoles ++ getLocalRolesFor(configurationItem, auth, globalRoles)

    runAwait {
      Tables.roles
        .join(Tables.rolePermissions).on(_.id === _.roleId)
        .filter { case (role, rolePermissions) =>
          role.id.inLarge(allRoles.map(_.getId)) &&
            rolePermissions.isOnConfigurationItem(configurationItem)
        }.map(_._2.permissionName).distinct.result
    }.asJavaMutable()
  }

  override def listPermissions(role: Role, paging: Paging): util.Map[String, util.List[String]] =
    listPermissions(Seq(role), includeCiPermissions = true, paging)

  override def listPermissions(roles: util.List[Role], paging: Paging): util.Map[String, util.List[String]] =
    listPermissions(roles.asScala.toSeq, includeCiPermissions = true, paging)

  override def listGlobalPermissions(roles: util.List[Role], paging: Paging): util.Map[String, util.List[String]] =
    listPermissions(roles.asScala.toSeq, includeCiPermissions = false, paging)

  private def listPermissions(roles: Seq[Role], includeCiPermissions: Boolean, paging: Paging): util.Map[String, util.List[String]] = {
    var query = Tables.roles
      .join(Tables.rolePermissions).on(_.id === _.roleId)
      .filter(_._1.id.inLarge(roles.map(_.getId)))
      .map(_._2)

    if (!includeCiPermissions) {
      query = query.filter(_.isGlobal)
    }

    val groupedPermissions: Map[String, Seq[(String, String)]] = runAwait(query.result)
      .map(permission => fromDbId(permission.ciId) -> permission.permissionName)
      .groupBy(_._1)

    (Option(paging) match {
      case Some(p) if p.resultsPerPage != -1 => groupedPermissions.slice(
        (p.page - 1) * p.resultsPerPage,
        (p.page - 1) * p.resultsPerPage + p.resultsPerPage)
      case _ => groupedPermissions
    }).view
      .mapValues(_.toList.map(_._2).asJavaMutable())
      .toMap
      .asJavaMutable()
  }

  def getLocalRolesFor(onConfigurationItem: Option[String], auth: Authentication, globalRoles: Seq[Role]): Seq[Role] = {
    val localRoles = runAwait {
      Tables.roles
        .joinLeft(Tables.rolePrincipals).on { case (role, rolePrincipal) => role.id === rolePrincipal.roleId }
        .joinLeft(Tables.roleRoles).on { case ((role, _), roleRole) => role.id === roleRole.roleId }
        .filter { case ((role, rolePrincipals), roleRoles) =>
          role.isOnConfigurationItem(onConfigurationItem) &&
            (rolePrincipals.map(_.principalName.toLowerCase).in(authenticationToPrincipals(auth).asScala.map(_.toLowerCase)) ||
              roleRoles.map(_.memberRoleId).inLarge(globalRoles.map(_.getId)))
        }.map(_._1._1).result
    }.map(role => new Role(role.id, role.name))

    localRoles
  }
}
