package com.xebialabs.platform.script.jython

import java.io.Writer
import javax.script.{ScriptContext, ScriptEngine, ScriptException}

import com.xebialabs.platform.script.jython.JythonSupport._
import grizzled.slf4j.Logging

object JythonSupport {
  type PreprocessExpression = String => String
  type ResultProcessor = (ScriptEngine, String, AnyRef) => AnyRef
  val outWriterDecorator = new ThreadLocalWriterDecorator
  val errorWriterDecorator = new ThreadLocalWriterDecorator
}

trait JythonSupport extends Logging {

  import com.xebialabs.platform.script.jython.EngineInstance._

  private val doNotPreprocess: PreprocessExpression = expr => expr
  private val identityResultProcessor: ResultProcessor = (engine, key, value) => value

  def evaluateExpression[T](expression: String, preprocess: PreprocessExpression = doNotPreprocess, resultKey: String = "None", resultProcessor: ResultProcessor = identityResultProcessor)(implicit jythonContext: JythonContext): T = {
    require(Option(expression).filter(_.nonEmpty).isDefined, "Expression must be defined")
    val result = executeScript(ScriptSource.byContent(preprocess(expression)), resultKey, resultProcessor)
    result.asInstanceOf[T]
  }

  def executeScript(scriptSource: ScriptSource, resultKey: String = "None", resultProcessor: ResultProcessor = identityResultProcessor)(implicit jythonContext: JythonContext): AnyRef = {
    val scriptContext = jythonContext.buildScriptContext
    jython.setContext(scriptContext)

    jythonContext.libraries.foreach(runtimeScript =>
      executeScript(runtimeScript, scriptContext, resultKey, resultProcessor)
    )
    executeScript(scriptSource, scriptContext, resultKey, resultProcessor)
  }

  def executeScript(scriptSource: ScriptSource, scriptContext: ScriptContext, resultKey: String, resultProcessor: ResultProcessor): AnyRef = {
    trace(s"Evaluating script\n${scriptSource.scriptContent}")
    withThreadLocalWriter(scriptContext) {
      try {
        val result = jython.eval(scriptSource.scriptContent, scriptContext)
        resultProcessor.apply(jython, resultKey, result)
      } catch {
        case ex: ScriptException => throw JythonException(scriptSource, ex)
      }
    }
  }

  private def withThreadLocalWriter(scriptContext: ScriptContext)(fn: => AnyRef) = {
    addLoggerDecoration(scriptContext)
    val result = fn
    removeLoggerDecoration(scriptContext)
    result
  }

  private def addLoggerDecoration(scriptContext: ScriptContext): Unit = {
    add(scriptContext, scriptContext.getWriter, outWriterDecorator, decorator => scriptContext.setWriter(decorator))
    add(scriptContext, scriptContext.getErrorWriter, errorWriterDecorator, decorator => scriptContext.setErrorWriter(decorator))

    def add(scriptContext: ScriptContext, writer: Writer, decorator: ThreadLocalWriterDecorator, setWriter: (ThreadLocalWriterDecorator) => Unit): Unit = {
      writer match {
        case _: ThreadLocalWriterDecorator =>
        case w if w != null =>
          decorator.registerWriter(w)
          setWriter(decorator)
        case _ =>
      }
    }
  }

  private def removeLoggerDecoration(scriptContext: ScriptContext) = {
    remove(scriptContext, scriptContext.getWriter, outWriterDecorator, decorator => scriptContext.setWriter(decorator.getWriter))
    remove(scriptContext, scriptContext.getErrorWriter, errorWriterDecorator, decorator => scriptContext.setErrorWriter(decorator.getWriter))

    def remove(scriptContext: ScriptContext, writer: Writer, decorator: ThreadLocalWriterDecorator, restoreWriter: (ThreadLocalWriterDecorator) => Unit) = {
      writer match {
        case x: ThreadLocalWriterDecorator =>
          restoreWriter(decorator)
          decorator.removeWriter()
        case _ =>
      }
    }
  }
}
