package com.xebialabs.xlrelease.security.sql

import com.xebialabs.deployit.engine.api.dto.{Ordering, Paging}
import com.xebialabs.deployit.exception.NotFoundException
import com.xebialabs.deployit.security.authentication.{AuthenticationFailureException, UserAlreadyExistsException}
import com.xebialabs.deployit.security.{RepoUser, SHA256PasswordEncoder, UserService}
import com.xebialabs.xlplatform.repository.sql.Database
import com.xebialabs.xlrelease.domain.distributed.events.DistributedXLReleaseEvent
import com.xebialabs.xlrelease.events.{AsyncSubscribe, EventListener}
import com.xebialabs.xlrelease.security.sql.db.Tables
import com.xebialabs.xlrelease.security.sql.db.Tables.{User, users}
import com.xebialabs.xlrelease.security.sql.SecurityCacheConfigurationConstants._
import com.xebialabs.xlrelease.service.BroadcastService
import grizzled.slf4j.Logging
import org.springframework.cache.annotation.{CacheConfig, CacheEvict, Cacheable}
import org.springframework.context.annotation.Conditional
import org.springframework.stereotype.Component
import slick.dbio.DBIOAction.seq
import slick.jdbc.JdbcProfile

import java.util.{List => JList}

@CacheConfig(cacheManager = SECURITY_USER_CACHE_MANAGER)
class CachingSqlUserService(securityDatabase: Database, evicter: UserCacheEvicter) extends SqlUserService(securityDatabase: Database) {

  @Cacheable(cacheNames = Array(SECURITY_USERNAMES))
  override def listUsernames(): JList[String] = {
    super.listUsernames()
  }

  override def delete(username: String): Unit = {
    super.delete(username)
    evicter.onEvictUser(EvictUser())
  }

  override def create(username: String, password: String): Unit = {
    super.create(username, password)
    evicter.onEvictUser(EvictUser())
  }
}

case class EvictUser(publish: Boolean = true) extends DistributedXLReleaseEvent

@Component
@EventListener
@Conditional(value = Array(classOf[SecurityUserCacheConfigurationCondition]))
@CacheConfig(cacheManager = SECURITY_USER_CACHE_MANAGER)
class UserCacheEvicter(broadcastService: BroadcastService) {

  @CacheEvict(cacheNames = Array(SECURITY_USERNAMES), allEntries = true)
  @AsyncSubscribe
  def onEvictUser(event: EvictUser): Unit = {
    if (event.publish) {
      broadcastService.broadcast(event.copy(publish = false), false)
    }
  }

}

class SqlUserService(securityDatabase: Database) extends UserService with Logging {

  import securityDatabase._

  val profile: JdbcProfile = config.databaseType.profile

  import profile.api._

  private val passwordEncoder = new SHA256PasswordEncoder()

  // for backwards compatibility with JCR
  private val JCR_ADMIN_USER = "admin"

  type Q = Query[Tables.Users, (String, String), Seq]

  private def checkValidUsername(username: String): Unit = {
    if (username == null || username.length == 0) throw new IllegalArgumentException("Username can neither be null nor empty.")
    readUser(username) match {
      case Some(_) => throw new UserAlreadyExistsException(username)
      case None =>
    }
  }

  private def filterUsers(query: Q, username: String) = Option(username)
    .map(username => query.filter(_.username.toLowerCase.like(s"%${username.toLowerCase}%")))
    .getOrElse(query)

  private def readUser(username: String): Option[User] =
    runAwait(
      users
        .filter(_.username.toLowerCase === username.toLowerCase)
        .map(user => (user.username, user.password)).result
    ).headOption

  override def countUsers(username: String): Long = runAwait(filterUsers(users, username).length.result).toLong

  override def create(username: String, password: String): Unit = {
    checkValidUsername(username)
    val pwd = passwordEncoder.encode(password)
    runAwait(seq(users += (username, pwd)))
  }

  override def read(username: String): RepoUser = {
    readUser(username)
      .map(user => new RepoUser(user._1, JCR_ADMIN_USER.equals(user._1)))
      .getOrElse(throw new NotFoundException(s"No such user: $username"))
  }

  override def listUsernames(username: String, paging: Paging, order: Ordering): JList[String] = {
    val filters = List(
      (query: Q) => filterUsers(query, username),

      (query: Q) =>
        if (order == null || order.isAscending)
          query.sortBy(_.username.toLowerCase.asc) else query.sortBy(_.username.toLowerCase.desc),

      (query: Q) => if (paging == null) query else
        query
          .drop((paging.page - 1) * paging.resultsPerPage)
          .take(paging.resultsPerPage)
    )

    runAwait(filters
      .foldLeft(users.to[Seq])((acc, filter) => filter(acc))
      .map(_.username).result
    ).toList.asJavaMutable()
  }

  override def listUsernames(): JList[String] =
    runAwait(users.map(_.username).result).toList.asJavaMutable()

  override def modifyPassword(username: String, newPassword: String): Unit =
    runAwait(users.filter(_.username === username).map(_.password).update(passwordEncoder.encode(newPassword)))

  /**
   * @throws IllegalArgumentException when the provided old password does not match
   */
  override def modifyPassword(username: String, newPassword: String, oldPassword: String): Unit = {
    readUser(username).foreach(user => {
      if (!passwordEncoder.matches(oldPassword, user._2)) throw new IllegalArgumentException("Failed to change password: Old password does not match.")
      modifyPassword(username, newPassword)
    })
  }

  override def delete(username: String): Unit = runAwait(users.filter(_.username === username).delete)

  override def authenticate(username: String, password: String): Unit = {
    val errorMessage = s"Cannot authenticate $username. Wrong username or password supplied."
    val user = readUser(username).getOrElse(throw new AuthenticationFailureException(errorMessage))
    if (!passwordEncoder.matches(password, user._2)) throw new AuthenticationFailureException(errorMessage)
  }
}
