package com.xebialabs.xlplatform.webhooks.endpoint

import com.xebialabs.deployit.exception.NotFoundException
import com.xebialabs.xlplatform.webhooks.authentication.RequestAuthenticationMethod
import com.xebialabs.xlplatform.webhooks.domain.{Endpoint, HttpRequestEvent}
import com.xebialabs.xlplatform.webhooks.endpoint.WebhooksEndpointController.exceptions.{EndpointNotFound, EndpointDisabled}
import com.xebialabs.xlplatform.webhooks.endpoint.WebhooksEndpointController.{exceptions, requestPrefix}
import com.xebialabs.xlplatform.webhooks.events.handlers.EventSourceHandler
import grizzled.slf4j.Logging
import javax.servlet.http.HttpServletRequest
import javax.ws.rs.core.{Context, Response}
import javax.ws.rs.{GET, POST, Path, PathParam}
import org.apache.commons.io.IOUtils

import scala.jdk.CollectionConverters._
import scala.util.{Failure, Success, Try}


abstract class WebhooksEndpointController(endpointProvider: EndpointProvider)
  extends EventSourceHandler[HttpRequestEvent, Endpoint]
    with Logging {

  @GET
  @Path("{endpointPath}")
  def acceptGET(@PathParam("endpointPath") path: String,
               @Context request: HttpServletRequest): Response = {
    accept(path, request)
  }

  @POST
  @Path("{endpointPath}")
  def acceptPOST(@PathParam("endpointPath") path: String,
                 @Context request: HttpServletRequest): Response = {
    accept(path, request)
  }

  def accept(path: String, request: HttpServletRequest): Response = {
    processRequest(path, request).recover {
      case e: exceptions.WebhookEndpointControllerException =>
        logger.warn(e.getMessage, if (e.logStackTrace) e else null)
        Response.status(e.status).entity(e.getMessage).build()

      case e: NotFoundException =>
        logger.warn(e.getMessage)
        Response.status(Response.Status.NOT_FOUND).entity(e.getMessage).build()

      case e: Throwable =>
        logger.warn(e.getMessage, e)
        Response.status(Response.Status.INTERNAL_SERVER_ERROR).entity("Exception happened when trying to process your request, check XLRelease log for details").build()
    }.get
  }

  protected def processRequest(path: String, request: HttpServletRequest): Try[Response] = {
    def getHeaders = request.getHeaderNames.asScala.map(name => name -> request.getHeader(name)).toMap

    def getParams = request.getParameterMap.asScala.toMap

    for {
      endpoint <- getEndpoint(request, path)
      _ <- checkEnabled(request, endpoint)
      _ <- checkMethod(request, endpoint)
      payload <- Try(IOUtils.toString(request.getReader))
      reqHeaders = getHeaders
      reqParams = getParams
      authenticationMethod <- getRequestAuthenticationMethod(request, endpoint)
      _ <- authenticateRequest(request, endpoint, reqHeaders, reqParams, payload)(authenticationMethod)
      event = HttpRequestEvent(endpoint, reqHeaders.asJava, reqParams.asJava, payload)
      published <- Try(publish(endpoint, event))
      response = if (published) Response.ok() else Response.notAcceptable(List.empty.asJava)
    } yield {
      logger.trace(s"${requestPrefix(request)}: Published event for $endpoint, authenticated with ${authenticationMethod.getClass.getName}")
      response.build()
    }
  }

  protected def getEndpoint(request: HttpServletRequest, path: String): Try[Endpoint] = {
    endpointProvider.findEndpointByPath(path)
      .recoverWith {
        case _: NotFoundException => Failure(EndpointNotFound(request, path))
      }
  }

  protected def checkEnabled(request: HttpServletRequest, endpoint: Endpoint): Try[Unit] = {
    if (!endpoint.sourceEnabled) {
      Failure(EndpointDisabled(request, endpoint))
    } else {
      Success(())
    }
  }

  protected def checkMethod(request: HttpServletRequest, endpoint: Endpoint): Try[Unit] = {
    if (endpoint.method.name() != request.getMethod) Failure(exceptions.WrongMethod(request, endpoint)) else Success(())
  }

  protected def getRequestAuthenticationMethod(request: HttpServletRequest, endpoint: Endpoint): Try[RequestAuthenticationMethod] = {
    val authenticationMethod = for {
      auth <- Option(endpoint.authentication)
      method <- Option(auth.requestAuthentication)
    } yield method
    authenticationMethod
      .toRight(exceptions.EndpointAuthenticationMethodNotFound(request, endpoint))
      .toTry
  }

  protected def authenticateRequest(request: HttpServletRequest,
                                    endpoint: Endpoint,
                                    headers: Map[String, String],
                                    params: Map[String, Array[String]],
                                    payload: String)
                                   (authenticationMethod: RequestAuthenticationMethod): Try[Unit] = {
    if (authenticationMethod.authenticateScala(endpoint, headers, params, payload))
      Success(())
    else
      Failure(exceptions.UnauthorizedRequest(request, endpoint, authenticationMethod))
  }
}


object WebhooksEndpointController {

  def requestPrefix(request: HttpServletRequest) = s"${request.getRequestURI} from ${request.getRemoteAddr}:${request.getRemotePort}"

  // TODO: we should really not pollute the logs with stack traces when rejecting an incoming payload.
  object exceptions {

    sealed abstract class WebhookEndpointControllerException(val request: HttpServletRequest,
                                                             val message: String,
                                                             val status: Response.Status,
                                                             val logStackTrace: Boolean = false)
      extends Throwable(message)

    case class EndpointNotFound(req: HttpServletRequest, path: String)
      extends WebhookEndpointControllerException(req,
        message = s"Endpoint not found for path '$path'",
        status = Response.Status.NOT_FOUND
      )

    case class WrongMethod(req: HttpServletRequest, endpoint: Endpoint)
      extends WebhookEndpointControllerException(req,
        message = s"Wrong HTTP method for '$endpoint': expected ${endpoint.method}, got ${req.getMethod}.",
        status = Response.Status.BAD_REQUEST
      )

    case class EndpointAuthenticationMethodNotFound(req: HttpServletRequest, endpoint: Endpoint)
      extends WebhookEndpointControllerException(req,
        message = s"Authentication method not found for '$endpoint'.",
        status = Response.Status.NOT_FOUND
      )

    case class UnauthorizedRequest(req: HttpServletRequest, endpoint: Endpoint, authenticationMethod: RequestAuthenticationMethod)
      extends WebhookEndpointControllerException(req,
        message = s"Unauthorized request for '$endpoint' (authentication method: ${authenticationMethod.getClass.getName})",
        status = Response.Status.UNAUTHORIZED
      )

    case class EndpointDisabled(req: HttpServletRequest, endpoint: Endpoint)
      extends WebhookEndpointControllerException(req,
        message = s"Endpoint '$endpoint' is disabled",
        status = Response.Status.NOT_FOUND
      )

  }

}
