package com.xebialabs.xlrelease.authentication

import com.xebialabs.deployit.ServerConfiguration
import com.xebialabs.xlplatform.utils.SecureRandomHolder
import com.xebialabs.xlrelease.repository.CustomPersistentTokenRepository
import org.springframework.security.core.Authentication
import org.springframework.security.core.userdetails.{UserDetails, UserDetailsService}
import org.springframework.security.web.authentication.rememberme._
import org.springframework.util.Assert

import java.util
import java.util._
import jakarta.servlet.http.{Cookie, HttpServletRequest, HttpServletResponse}

class XlPersistentTokenRememberMeServices(key: String,
                                          userDetailsService: UserDetailsService,
                                          tokenRepository: CustomPersistentTokenRepository,
                                          serverConfiguration: ServerConfiguration)
  extends AbstractRememberMeServices(key: String,
    userDetailsService: UserDetailsService) {

  private val DEFAULT_SERIES_LENGTH = 16
  private val DEFAULT_TOKEN_LENGTH = 16

  private var seriesLength = DEFAULT_SERIES_LENGTH
  private var tokenLength = DEFAULT_TOKEN_LENGTH

  private def goPast(givenDate: Date, timeInMs: Long) = {
    val givenTimeInMs = givenDate.getTime
    val pastTime = givenTimeInMs - timeInMs
    new Date(pastTime)
  }

  private def goFuture(givenDate: Date, timeInMs: Long) = {
    val givenTimeInMs = givenDate.getTime
    val futureTime = givenTimeInMs + timeInMs
    new Date(futureTime)
  }

  override def onLoginSuccess(request: HttpServletRequest, response: HttpServletResponse, successfulAuthentication: Authentication): Unit = {
    if (serverConfiguration.isClientSessionRememberEnabled) {
      val username = successfulAuthentication.getName

      logger.debug("Creating new persistent login for user " + username)

      val persistentToken = new PersistentRememberMeToken(username.toLowerCase(), generateSeriesData, generateTokenData, new Date)
      try {
        tokenRepository.createNewToken(persistentToken)
        addCookie(persistentToken, request, response)
      } catch {
        case e: Exception =>
          logger.error("Failed to save persistent token ", e)
      }
    }
  }

  /**
    * Checks for a token match, if present grants access
    * If not present checks for previous tokens
    * If present gives access and removes few old tokens else throws cookie theft exception.
    */
  override def processAutoLoginCookie(cookieTokens: Array[String], request: HttpServletRequest, response: HttpServletResponse): UserDetails = {
    if (cookieTokens.length != 2) {
      throw new InvalidCookieException("Cookie token did not contain " + 2 + " tokens, but contained '" + util.Arrays.asList(cookieTokens) + "'")
    }

    val presentedSeries: String = cookieTokens(0)
    val presentedToken: String = cookieTokens(1)

    val token: PersistentRememberMeToken = tokenRepository.getTokenForSeries(presentedSeries)

    if (token == null) { // No series match, so we can't authenticate using this cookie
      throw new RememberMeAuthenticationException("No persistent token found for series id: " + presentedSeries)
    }


    //check whether the token from the client and server matches
    if (!(presentedToken == token.getTokenValue)) { // Get the old tokens for the same series
      val previousTokenList: util.List[PersistentRememberMeToken] = tokenRepository.getPreviousTokenForSeries(presentedSeries, token.getTokenValue)
      // Check for the matching token in that old tokens
      val previousOldToken: Optional[PersistentRememberMeToken] = previousTokenList.stream.filter((previousToken: PersistentRememberMeToken) => previousToken.getTokenValue == presentedToken).findFirst
      if (!previousOldToken.isPresent) { // Token not present. Delete all logins for this user and throw
        // an exception to warn them.
        tokenRepository.removeUserTokens(token.getUsername)
        throw new CookieTheftException(
          messages.getMessage("PersistentTokenBasedRememberMeServices.cookieStolen",
            "Invalid remember-me token (Series/token) mismatch. Implies previous cookie theft attack."))
      } else { // Remove all the tokens that are present before 60s of last authenticated token
        tokenRepository.removeTokenInSeriesBeforeGivenDate(token.getSeries, goPast(previousOldToken.get.getDate, 60000L))
      }
    }

    if (goFuture(token.getDate, getTokenValiditySeconds * 1000L).getTime < System.currentTimeMillis) {
      throw new RememberMeAuthenticationException("Remember-me login has expired")
    }

    // Token also matches, so login is valid. Update the token value, keeping the
    // *same* series number.
    if (logger.isDebugEnabled) logger.debug("Refreshing persistent login token for user '" + token.getUsername + "', series '" + token.getSeries + "'")

    if (!serverConfiguration.isClientSessionRememberEnabled) {
      throw new RememberMeAuthenticationException("Remember-me login is disabled")
    }

    try // Create new token only when the last token is more than 30s
    if (goFuture(token.getDate, 30000L).getTime < System.currentTimeMillis) {
      val newToken: PersistentRememberMeToken = new PersistentRememberMeToken(token.getUsername.toLowerCase(), token.getSeries, generateTokenData, new Date)
      tokenRepository.createNewToken(newToken)
      addCookie(newToken, request, response)
    }
    catch {
      case e: Exception =>
        logger.error("Failed to update token: ", e)
        throw new RememberMeAuthenticationException("Autologin failed due to data access problem")
    }

    getUserDetailsService.loadUserByUsername(token.getUsername)
  }

  override def logout(request: HttpServletRequest, response: HttpServletResponse, authentication: Authentication): Unit = {
    super.logout(request, response, authentication)
    if (authentication != null) {
      tokenRepository.removeUserTokens(authentication.getName.toLowerCase())
    } else {
      Option(request.getCookies).foreach { cookies =>
        cookies.find(_.getName == this.getCookieName).foreach { c =>
          val values = decodeCookie(c.getValue)
          values.headOption.foreach(tokenRepository.removeUserTokensBasedOnSeries)
        }
      }
    }

    import org.springframework.security.web.authentication.logout.SecurityContextLogoutHandler
    val securityContextLogoutHandler = new SecurityContextLogoutHandler
    securityContextLogoutHandler.logout(request, response, null)
  }

  protected def generateSeriesData: String =
    new String(Base64.getEncoder.encode(SecureRandomHolder.getRandomByteArray(seriesLength)))

  protected def generateTokenData: String =
    new String(Base64.getEncoder.encode(SecureRandomHolder.getRandomByteArray(tokenLength)))

  private def addCookie(token: PersistentRememberMeToken, request: HttpServletRequest, response: HttpServletResponse): Unit = {
    val cookieValue = encodeCookie(Array[String](token.getSeries, token.getTokenValue))
    val cookie = new Cookie(getCookieName, cookieValue)
    val maxAge = getTokenValiditySeconds
    cookie.setMaxAge(maxAge)
    cookie.setPath(getCookiePath(request))
    if (maxAge < 1) {
      cookie.setVersion(1)
    }
    if (!serverConfiguration.isSsl) {
      cookie.setSecure(serverConfiguration.isSecureCookieEnabled)
    } else {
      cookie.setSecure(request.isSecure)
    }
    cookie.setHttpOnly(true)
    cookie.setAttribute("SameSite", "Strict")
    response.addCookie(cookie)
  }

  private def getCookiePath(request: HttpServletRequest): String = {
    val contextPath = request.getContextPath
    if (contextPath.nonEmpty) contextPath else "/"
  }

  def setSeriesLength(seriesLength: Int): Unit = {
    this.seriesLength = seriesLength
  }

  def setTokenLength(tokenLength: Int): Unit = {
    this.tokenLength = tokenLength
  }

  override def setTokenValiditySeconds(tokenValiditySeconds: Int): Unit = {
    Assert.isTrue(tokenValiditySeconds > 0, "tokenValiditySeconds must be positive for this implementation")
    super.setTokenValiditySeconds(tokenValiditySeconds)
  }
}
