package com.xebialabs.xlplatform.endpoints

import java.security.Principal

import akka.actor.ActorSystem
import akka.http.scaladsl.model.StatusCodes
import akka.http.scaladsl.server.AuthenticationFailedRejection.{CredentialsMissing, CredentialsRejected}
import akka.http.scaladsl.server._
import akka.http.scaladsl.server.directives.Credentials
import com.xebialabs.xlplatform.endpoints.routes.{ScalaExtensionRoute, ScriptExtensionRoute}
import com.xebialabs.xlplatform.endpoints.security.{PrincipalVerifier, RequestPrincipal}
import org.springframework.security.core

import scala.concurrent.Future
import scala.language.postfixOps

trait ExtendableRestApi extends ScriptExtensionRoute with ScalaExtensionRoute {
  import scala.concurrent.ExecutionContext.Implicits.global

  implicit def system: ActorSystem

  val customRejectionHandler: RejectionHandler = RejectionHandler.newBuilder()
    .handle {
      case AuthenticationFailedRejection(CredentialsMissing, _) =>
        complete(StatusCodes.Unauthorized, "Please login first")

      case AuthenticationFailedRejection(CredentialsRejected, _) =>
        complete(StatusCodes.Unauthorized, "Provided credentials are incorrect")

      case AuthorizationFailedRejection =>
        complete(StatusCodes.Forbidden, "You do not have permission to perform this action")

      case MalformedRequestContentRejection(_, _) =>
        complete(StatusCodes.BadRequest, "Malformed Content")
    }.result()

  private val springPrincipalVerifier: PrincipalVerifier[AuthenticatedData] = { (principal: Option[Principal]) =>
    principal.flatMap {
        case (auth: core.Authentication) if auth.isAuthenticated =>
          Some(AuthenticatedData(auth))
        case _ => None
      }
    }

  private val pingRoute = pathPrefix("ping") {
    get {
      complete("Pong!")
    }
  }

  private[endpoints] def defaultRoutes: Seq[AuthenticatedRoute] = {
    Seq(customScriptEndpoints, codedExtensionRoutes).flatten
  }

  def extendableRoutes: Route = extendableRoutes(defaultRoutes)

  def extendableRoutes(apiRoutes: Seq[AuthenticatedRoute]): Route = customRoutes(springPrincipalVerifier)(apiRoutes)

  private[endpoints] def customRoutes(verifier: PrincipalVerifier[AuthenticatedData])(apiRoutes: Seq[AuthenticatedRoute]) = {
    handleRejections(customRejectionHandler) {
      rawPathPrefix(segments(settings.ServerExtension.rootPath)) {
        extractRequestContext { ctx =>
          RequestPrincipal(verifier).apply(ctx) match {
            case Right(auth) =>
              apiRoutes.foldLeft(pingRoute)(_ ~ _ (auth))
            case Left(rejection) =>
              reject(rejection)
          }
        }
      }
    }
  }
}
