package com.xebialabs.xlrelease.script.jython

import com.xebialabs.xlrelease.script.Jsr223EngineFactory
import grizzled.slf4j.Logging
import org.apache.commons.lang.reflect.FieldUtils.readField
import org.python.core._
import org.python.jsr223.{PyScriptEngine, PyScriptEngineFactory}
import org.python.util.PythonInterpreter

import javax.script.ScriptEngine
import scala.util.Try

object JythonEngineInstance extends Jsr223EngineFactory with Logging {
  private val engineFactory: PyScriptEngineFactory = new PyScriptEngineFactory()

  {
    PySystemState.initialize()
    PySystemState.packageManager = SandboxAwarePackageManager.getInstance
  }

  private val restrictedEngine = createEngine(true)
  private val unrestrictedEngine = createEngine(false)

  def getScriptEngine(restricted: Boolean): ScriptEngine = {
    logger.debug(s"Jython engine [restricted: $restricted] instance started.")
    if (restricted) {
      restrictedEngine
    } else {
      unrestrictedEngine
    }
  }

  private def createEngine(restricted: Boolean): ScriptEngine = {
    try {
      SandboxAwarePackageManager.setSandboxed(restricted)
      if (restricted) {
        val engine = engineFactory.getScriptEngine.asInstanceOf[PyScriptEngine]
        val interpreter = readField(engine, "interp", true).asInstanceOf[PythonInterpreter]
        val parentLoader = Thread.currentThread().getContextClassLoader
        interpreter.getSystemState.setClassLoader(new JythonScriptClassLoader(parentLoader))
        engine
      } else {
        engineFactory.getScriptEngine
      }
    } finally {
      SandboxAwarePackageManager.setSandboxed(false)
    }
  }

  // reload all modules
  def reload(): Unit = {
    logger.info("Reloading jython engine modules")
    reload(restrictedEngine.asInstanceOf[PyScriptEngine])
    reload(unrestrictedEngine.asInstanceOf[PyScriptEngine])
  }

  private def reload(engine: PyScriptEngine): Unit = {
    import scala.jdk.CollectionConverters._
    val interpreter = readField(engine, "interp", true).asInstanceOf[PythonInterpreter]
    interpreter.exec("import imp") // make sure imp module is loaded
    // maybe system can be fetched via Py.getSystemState
    val pySystemState = Py.getSystemState
    val systemState: PySystemState = interpreter.getSystemState
    val importLock = systemState.getImportLock
    importLock.lock()
    try {
      val moduleNames = systemState.modules.asInstanceOf[PyStringMap].keys().listIterator().asScala.toSeq.map(_.toString)
      val builtinModuleNames = PySystemState.builtin_module_names.listIterator().asScala.toSeq.map(_.toString)
      // now reload everything already loaded except builtinNames: __builtin__, sys
      val modulesToReload = moduleNames.toList.diff(builtinModuleNames.toList)
      val impModule = systemState.modules.__getitem__(new PyString("imp")).asInstanceOf[PyModule]
      for (m <- modulesToReload) {
        logger.info(s"Reloading module: $m")
        val maybeModule = systemState.modules.__getitem__(new PyString(m))
        maybeModule match {
          case pyModule: PyModule =>
            Try(impModule.invoke("reload", pyModule)).recover {
              case t: Throwable => logger.warn(s"Unable to log python module $m", t)
            }
          case pyPackage: PyJavaPackage =>
            logger.info(s"Unable to reload $m as it is pyJavaPackage")
          case pyModule =>
            logger.info(s"Unable to reload $m as it is ${pyModule.getType}")
        }
      }
    } finally {
      importLock.unlock()
    }
  }
}
