package com.xebialabs.xlplatform.endpoints.actors

import com.xebialabs.deployit.repository.WorkDirContext
import com.xebialabs.platform.script.jython._
import com.xebialabs.xlplatform.endpoints.json.ScriptRequest
import com.xebialabs.xlplatform.endpoints.{AuthenticatedData, JythonRequest, JythonResponse}
import com.xebialabs.xlplatform.script.jython.JythonSugarDiscovery
import com.xebialabs.xlplatform.spring.JythonBindingsHolder
import grizzled.slf4j.Logging
import org.apache.pekko.actor._
import org.python.core.{PyDictionary, PyList}
import org.springframework.security.core.context.SecurityContextHolder
import spray.json._

import java.io._
import java.util.{Map => JMap}
import javax.script.{ScriptContext, SimpleScriptContext}
import scala.jdk.CollectionConverters._
import scala.util.{Failure, Success, Try}

object JythonScriptExecutorActor {
  def props(): Props = Props(classOf[JythonScriptExecutorActor])
}

case class RunScript(source: ScriptSource, request: ScriptRequest, auth: AuthenticatedData)

case class ScriptDone(entity: Any, stdout: String, stderr: String, statusCode: Option[Integer] = None, headers: Map[String, String] = Map(), exception: Option[Throwable] = None)

object ScriptDone {
  def apply(jr: JythonResponse, stdout: String, stderr: String, exception: Option[Throwable]): ScriptDone = {
    val entity = Option(jr.getEntity).map({
      case m: JMap[_, _] => m.asScala.toMap
      case m => m
    }).getOrElse(Map())
    ScriptDone(entity, stdout, stderr, Option(jr.getStatusCode), jr.getHeaders.asScala.toMap, exception)
  }
}

class InvalidScriptOutputException(msg: String) extends RuntimeException(msg)

class JythonScriptExecutorActor extends Actor with DefaultJsonProtocol with Logging with JythonSupport {

  override def receive: Actor.Receive = {
    case RunScript(source, request, auth) =>
      WorkDirContext.initWorkdir("endpoint")
      SecurityContextHolder.getContext.setAuthentication(auth.toAuthentication)
      val contextHelper = ScriptContextHelper().withRequest(request).withClients().withEmptyResponse()

      implicit val jythonContext: JythonContext = contextHelper.jythonContext

      Try(executeScript(source)) match {
        case Success(_) => sender() ! ScriptDone(contextHelper.jythonResponse, contextHelper.stdOut, contextHelper.stdErr, None)
        case Failure(e) =>
          logger.error("Jython error occurred: ", e)
          sender() ! ScriptDone(contextHelper.jythonResponse, contextHelper.stdOut, contextHelper.stdErr, Option(e))
      }

      SecurityContextHolder.getContext.setAuthentication(null)

      self ! PoisonPill
  }
}

object ScriptContextHelper {
  def apply() = new ScriptContextHelper
}

class ScriptContextHelper {

  final val requestParamName = "request"
  final val responseParamName = "response"

  val outWriter: StringWriter = new StringWriter()
  val errWriter: StringWriter = new StringWriter()

  val context: SimpleScriptContext = createContext

  def createContext: SimpleScriptContext = {
    val c = new SimpleScriptContext
    c.setWriter(new PrintWriter(outWriter))
    c.setErrorWriter(new PrintWriter(errWriter))
    c
  }

  val jythonContext: JythonContext = {
    JythonContext.withLibrariesAndFactory(
      (Syntactic.loggerLib +: Syntactic.wrapperCodeWithLib(JythonBindingsHolder.getBindings.keySet.asScala.toSeq :+ responseParamName)) ++ JythonSugarDiscovery.getExtensionResources
    )(context)
  }

  val jythonResponse = new JythonResponse

  def stdOut: String = outWriter.getBuffer.toString

  def stdErr: String = errWriter.getBuffer.toString

  def withRequest(req: ScriptRequest): ScriptContextHelper = {
    def mapAsJava(jsv: JsValue): Object = jsv match {
      case JsString(value) => value
      case JsNumber(value) if value.isValidInt => Int.box(value.toInt)
      case JsNumber(value) if value.isValidLong => Long.box(value.toLong)
      case JsNumber(value) => Double.box(value.toDouble)
      case JsArray(elems) =>
        val pyList = new PyList()
        pyList.addAll(elems.map(mapAsJava).asJava)
        pyList
      case JsObject(fields) =>
        val pyDict = new PyDictionary()
        pyDict.putAll(fields.view.mapValues(mapAsJava).toMap.asJava)
        pyDict
      case JsTrue => Boolean.box(x = true)
      case JsFalse => Boolean.box(x = false)
      case _ => null
    }

    val simplifiedParams = req.query.map {
      case (x: String, List("")) => (x, null)
      case (x: String, h :: Nil) => (x, h)
      case (x, hh: List[_]) => (x, hh.asJava)
      case (x, hh) => (x, hh)
    }

    val queryPyDict = new PyDictionary()
    queryPyDict.putAll(simplifiedParams.asJava)

    val jreq = new JythonRequest(mapAsJava(req.entity), queryPyDict)
    context.setAttribute(requestParamName, jreq, ScriptContext.ENGINE_SCOPE)
    this
  }

  def withEmptyResponse(): ScriptContextHelper = {
    context.setAttribute(responseParamName, jythonResponse, ScriptContext.ENGINE_SCOPE)
    this
  }

  def withClients(): ScriptContextHelper = {
    JythonBindingsHolder.getBindings.asScala.foreach {
      case (bindingName, value) => context.setAttribute(bindingName, value, ScriptContext.ENGINE_SCOPE)
    }
    this
  }
}
