package com.xebialabs.xlrelease.script.jython

import com.xebialabs.xlrelease.script.jython.SandboxAwarePackageManager.doWithPackageManager
import grizzled.slf4j.Logging
import org.python.core.packagecache.{PackageManager, SysPackageManager}
import org.python.core.{PyJavaPackage, PyList, PyObject, PySystemState}

import java.io.File

/**
  * This class proxies calls to PackageManager to one of 2 instances of SysPackageManager, depending on
  * whether current script is executed sandboxed or not
  *
  * Originally Jython has only one global package manager which lives in
  * org.python.core.PySystemState.packageManager and is set only once during engine initialization
  * in org.python.core.PySystemState#initPackages(java.util.Properties)
  *
  * The problem is, that all code in Jython is static and then it simply
  * calls PySystemState.packageManager every now and then to load classes
  * See for example
  * * org.python.core.PySystemState#add_package(java.lang.String, java.lang.String)
  * * org.python.core.JavaImporter#find_module(java.lang.String, org.python.core.PyObject)
  *
  * PackageManager effectively caches loaded modules. So, when unsandboxed script loads
  * a class, then this class is cached by PackageManager and then available to sandboxed
  * script.
  *
  * It is possible to prevent this behavior if we use separate PackageManagers for
  * sandbox and unsandboxed scripts
  */
class SandboxAwarePackageManager extends PackageManager {
  override def findClass(s: String, s1: String, s2: String): Class[_] = doWithPackageManager(_.findClass(s, s1, s2))

  override def packageExists(s: String, s1: String): Boolean = doWithPackageManager(_.packageExists(s, s1))

  override def doDir(pyJavaPackage: PyJavaPackage, b: Boolean, b1: Boolean): PyList = doWithPackageManager(_.doDir(pyJavaPackage, b, b1))

  override def addDirectory(file: File): Unit = doWithPackageManager(_.addDirectory(file))

  override def addJarDir(s: String, b: Boolean): Unit = doWithPackageManager(_.addJarDir(s, b))

  override def addJar(s: String, b: Boolean): Unit = doWithPackageManager(_.addJar(s, b))

  override def findClass(pkg: String, name: String): Class[_] = doWithPackageManager(_.findClass(pkg, name))

  override def notifyPackageImport(pkg: String, name: String): Unit = doWithPackageManager(_.notifyPackageImport(pkg, name))

  override def lookupName(name: String): PyObject = doWithPackageManager(_.lookupName(name))

  override def makeJavaPackage(name: String, classes: String, jarfile: String): PyJavaPackage = doWithPackageManager(_.makeJavaPackage(name, classes, jarfile))
}

object SandboxAwarePackageManager extends Logging {
  private var sandboxed: ThreadLocal[Boolean] = ThreadLocal.withInitial(() => false);
  private var sandboxedPackageManager: PackageManager = new SysPackageManager(null, PySystemState.registry)
  private var nonSandboxedPackageManager: PackageManager = new SysPackageManager(null, PySystemState.registry)
  private val instance: SandboxAwarePackageManager = new SandboxAwarePackageManager

  def setSandboxed(sandboxed: Boolean): Unit = {
    this.sandboxed.set(sandboxed)
  }

  def doWithPackageManager[T](fn: PackageManager => T): T = {
    fn(getPackageManager)
  }

  def getPackageManager: PackageManager = {
    if (sandboxed.get) {
      sandboxedPackageManager
    } else {
      nonSandboxedPackageManager
    }
  }

  def getInstance: SandboxAwarePackageManager = instance
}