package com.xebialabs.xlplatform.endpoints.servlet

import akka.NotUsed
import akka.actor.ActorSystem
import akka.http.scaladsl.model.HttpHeader.ParsingResult.{Error, Ok}
import akka.http.scaladsl.model.HttpMethods._
import akka.http.scaladsl.model.HttpProtocols._
import akka.http.scaladsl.model.StatusCodes._
import akka.http.scaladsl.model._
import akka.http.scaladsl.server.Route
import akka.stream.scaladsl.{Broadcast, Flow, GraphDSL, RunnableGraph, Sink, Source, Zip}
import akka.stream.{ClosedShape, Materializer, SystemMaterializer}
import com.xebialabs.xlplatform.endpoints.security.CustomHttpHeaders.`X-User-Principal`
import grizzled.slf4j.Logging

import java.io.IOException
import java.security.Principal
import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}
import javax.servlet.{AsyncContext, ServletInputStream}
import scala.annotation.tailrec
import scala.jdk.CollectionConverters._
import scala.util.Try

class AkkaStreamServlet() extends HttpServlet with Logging {

  protected val unhandledHeaders = Set("content-type", "content-length")

  protected implicit var system: ActorSystem = _
  protected implicit var materializer: Materializer = _

  protected var routes: Route = _
  protected var flow: Flow[HttpRequest, HttpResponse, NotUsed] = _

  protected var prefix: String = _
  protected var prefixLength: Int = 0

  protected var timeout: Long = _

  override def init(): Unit = {
    logger.debug("Initializing AkkaStream Servlet")
    val ctx = getServletContext

    prefix = ctx.getContextPath
    prefixLength = prefix.length

    system = ctx.getAttribute(AkkaStreamServletInitializer.SYSTEM_KEY).asInstanceOf[ActorSystem]
    materializer = SystemMaterializer.get(system).materializer

    routes = ctx.getAttribute(AkkaStreamServletInitializer.ROUTES_KEY).asInstanceOf[Route]
    flow = Route.toFlow(routes)

    timeout = ctx.getAttribute(AkkaStreamServletInitializer.TIMEOUT_KEY).asInstanceOf[Long]

    logger.debug(s"AkkaStream Servlet initialized: prefix = $prefix, timeout = $timeout")
  }

  protected def pipeline(method: HttpMethod, ac: AsyncContext)(in: HttpServletRequest, out: HttpServletResponse): RunnableGraph[NotUsed] =
    RunnableGraph.fromGraph(GraphDSL.create() { implicit builder =>
      import GraphDSL.Implicits._
      val source = Source.single(toHttpRequest(method)(in))
      val broadcast = builder.add(Broadcast[HttpRequest](2)) // pass along the HttpRequest
      val merge = builder.add(Zip[HttpRequest, HttpResponse]()) // combine HttpRequest and HttpResponse
      source ~> broadcast.in
      broadcast ~> merge.in0
      broadcast ~> flow ~> merge.in1
      merge.out ~> Sink.foreach[(HttpRequest, HttpResponse)] { case (req: HttpRequest, resp: HttpResponse) =>
        writeServletHttpResponse(ac)(req, resp, out)
      }
      ClosedShape
    })

  protected def handle(method: HttpMethod)(in: HttpServletRequest, out: HttpServletResponse): Unit = {
    logger.debug(s"$method ${in.getRequestURI}")
    val ac = in.startAsync()
    ac.setTimeout(timeout)
    ac.start(new Runnable() {
      def run(): Unit = pipeline(method, ac)(in, out).run()
    })
  }

  override def doGet(in: HttpServletRequest, out: HttpServletResponse): Unit = handle(GET)(in, out)

  override def doPost(in: HttpServletRequest, out: HttpServletResponse): Unit = handle(POST)(in, out)

  override def doDelete(in: HttpServletRequest, out: HttpServletResponse): Unit = handle(DELETE)(in, out)

  override def doHead(in: HttpServletRequest, out: HttpServletResponse): Unit = handle(HEAD)(in, out)

  override def doPut(in: HttpServletRequest, out: HttpServletResponse): Unit = handle(PUT)(in, out)

  protected def toHttpRequest(method: HttpMethod)(in: HttpServletRequest): HttpRequest =
    HttpRequest(
      method = method,
      uri = parseUri(in),
      headers = userPrincipalHeader(in.getUserPrincipal).toList ++ parseHeaders(in),
      entity = toHttpEntity(in, parseContentType(in.getContentType), in.getContentLength),
      protocol = parseProtocol(in.getProtocol)
    )

  protected def parseUri(in: HttpServletRequest): Uri = {
    val path = in.getRequestURI match {
      case uri if uri.startsWith(prefix) => uri.drop(prefixLength)
      case uri => uri
    }
    Uri.from(
      userinfo = in.getRemoteUser,
      host = removeZoneIndexFromHost(in.getRemoteHost),
      port = in.getRemotePort,
      path = path,
      queryString = Option(in.getQueryString)
    )
  }

  private def removeZoneIndexFromHost(host: String) = "(%.*)".r.replaceAllIn(host, "")

  protected def userPrincipalHeader(userPrincipal: Principal): Option[`X-User-Principal`] =
    Option(userPrincipal).map(`X-User-Principal`(_))

  protected def parseHeaders(in: HttpServletRequest): List[HttpHeader] = {
    in.getHeaderNames.asScala.toList.flatMap { n =>
      HttpHeader.parse(n, in.getHeader(n)) match {
        case Error(_) => List.empty
        case Ok(header, _) => List(header)
      }
    }
  }

  protected def parseProtocol(protocol: String): HttpProtocol =
    protocol match {
      case "HTTP/1.0" => `HTTP/1.0`
      case "HTTP/1.1" => `HTTP/1.1`
      case "HTTP/2.0" => `HTTP/2.0`
      case _ => `HTTP/1.1`
    }

  protected def toHttpEntity(req: HttpServletRequest, contentType: Option[ContentType], contentLength: Int): RequestEntity = {
    @tailrec
    def drainRequestInputStream(buf: Array[Byte], inputStream: ServletInputStream, bytesRead: Int = 0): Array[Byte] =
      if (bytesRead < contentLength) {
        val count = inputStream.read(buf, bytesRead, contentLength - bytesRead)
        if (count >= 0) drainRequestInputStream(buf, inputStream, bytesRead + count)
        else {
          throw RequestProcessingException(InternalServerError,
            s"Illegal Servlet request entity, expected length $contentLength but only has length $bytesRead"
          )
        }
      } else buf

    val body =
      if (contentLength > 0) {
        try {
          drainRequestInputStream(new Array[Byte](contentLength), req.getInputStream)
        } catch {
          case e: IOException =>
            logger.warn(s"Could not read request entity: ${e.getMessage}")
            throw RequestProcessingException(InternalServerError, s"Could not read request entity: ${e.getMessage}")
        }
      } else Array.empty[Byte]

    contentType
      .fold(HttpEntity(body))(HttpEntity(_, body))
  }

  protected def parseContentType(ctString: String): Option[ContentType] =
    for {
      ctOpt <- Option(ctString)
      ct <- ContentType.parse(ctOpt).fold(_ => None, Some(_))
    } yield ct

  protected def writeServletHttpResponse(ac: AsyncContext)(req: HttpRequest, resp: HttpResponse, out: HttpServletResponse): Unit = {
    out.setStatus(resp.status.intValue)
    resp.headers
      .filterNot(h => unhandledHeaders contains h.lowercaseName)
      .foreach { header =>
        out.addHeader(header.name, header.value)
      }
    resp.entity match {
      case HttpEntity.Empty =>
      case HttpEntity.Strict(ct, data) =>
        out.addHeader("Content-Type", ct.value)
        out.addHeader("Content-Length", data.length.toString)
        Try {
          out.getOutputStream.write(data.toArray)
          out.getOutputStream.flush()
        } recover {
          case e: IOException =>
            logger.error(s"Exception while writing HttpResponse body: ${e.getMessage}")
            out.setStatus(InternalServerError.intValue)
        }
      case HttpEntity.Chunked(_, _) =>
        logger.debug("Cannot handle 'Chunked HttpEntity'")
        throw RequestProcessingException(InternalServerError, "Chunked HttpEntity")
      case HttpEntity.CloseDelimited(_, _) =>
        logger.debug("Cannot handle 'CloseDelimited HttpEntity'")
        throw RequestProcessingException(InternalServerError, "CloseDelimited HttpEntity")
      case HttpEntity.Default(_, _, _) =>
        logger.debug("Cannot handle 'Default HttpEntity'")
        throw RequestProcessingException(InternalServerError, "Default HttpEntity")
    }
    ac.complete()
  }

  case class RequestProcessingException(code: StatusCode, msg: String) extends RuntimeException

}
