package com.xebialabs.xlrelease.plugin.classloading


import com.xebialabs.overthere.util.OverthereUtils
import com.xebialabs.plugin.protocol.xlp.JarURL
import com.xebialabs.plugin.zip.PluginScanner
import com.xebialabs.xlplatform.utils.PerformanceLogging
import com.xebialabs.xlrelease.plugin.classloading.XlrPluginClassLoader.{JarPlugin, Plugin}
import org.slf4j.LoggerFactory

import java.io.{ByteArrayOutputStream, File, InputStream}
import java.net.URL
import java.util
import java.util.jar.JarFile
import scala.collection.mutable
import scala.jdk.CollectionConverters._

object XlrPluginClassLoader {
  private[XlrPluginClassLoader] val hotfixLogger = LoggerFactory.getLogger("hotfix")

  implicit class NormalizePath(val path: String) extends AnyVal {
    def toUnixPath: String = path.replace(File.separatorChar, '/')
  }

  def apply(pluginDirectory: File, parentClassLoader: ClassLoader) =
    new XlrPluginClassLoader(Seq(pluginDirectory), parentClassLoader)

  private[XlrPluginClassLoader] trait Plugin {
    def getResources(name: String): Seq[URL]
    def close(): Unit
  }

  private[XlrPluginClassLoader] case class JarPlugin(file: File) extends Plugin {
    val jarFile = new JarFile(file)

    override def getResources(name: String): Seq[URL] = {
      Option(jarFile.getEntry(name)).map(e => Seq(JarURL(file.getAbsolutePath, e.getName))).getOrElse(Seq())
    }

    override def close(): Unit = jarFile.close()
  }

  private[this] case class XLPluginNode(name: String, subNodes: mutable.Map[String, XLPluginNode], url: URL)

  private[XlrPluginClassLoader] case class ExplodedPlugin(explodedDir: File) extends Plugin {
    override def getResources(name: String): Seq[URL] = {
      val f = new File(explodedDir, name)
      if (f.exists()) {
        Seq(f.toURI.toURL)
      } else {
        Seq()
      }
    }

    // no need to close plugins in ext folder cause there's no materialization happening in that folder
    override def close(): Unit = ()
  }

}


class XlrPluginClassLoader(pluginDirectories: Iterable[File], parentClassLoader: ClassLoader)
  extends ClassLoader(parentClassLoader) with PluginScanner with PerformanceLogging {

  override def findClass(name: String): Class[_] = logWithTime(s"Loading class $name") {
    val str: Option[URL] = findResourceUrl(convertClassName(name))
    val classOption = str.map(loadClassFromUrl(name, _))
    classOption.getOrElse(
      throw new ClassNotFoundException(
        s"""A plugin could not be loaded due to a missing class ($name). Please remove the offending plugin to successfully start the server.
           |Classes related to JCR were removed from the server because of the migration from JCR to SQL.
           |If the plugin depends on these classes and its functionality is required, please contact support to fix your configuration.
           |$name not found""".stripMargin.replaceAll("\n", ""))
    )
  }

  private def loadClassFromUrl(className: String, resourceUrl: URL): Class[_] = {
    logger.trace(s"Loading class from url $resourceUrl")
    import com.xebialabs.xlplatform.utils.ResourceManagement._
    using(resourceUrl.openStream()) { classInputStream =>
      val bytes: Array[Byte] = readFully(classInputStream)
      if (bytes.isEmpty) {
        throw new ClassFormatError("Could not load class. Empty stream returned")
      }
      definePackageIfNeeded(className)
      val clazz = defineClass(className, bytes, 0, bytes.length)
      resolveClass(clazz)
      clazz
    }
  }

  private def definePackageIfNeeded(className: String): Unit = {
    val packageName: String = className.split('.').init.mkString(".")
    Option(getDefinedPackage(packageName)).getOrElse(definePackage(packageName, null, null, null, null, null, null, null))
  }

  def logHotfix(url: URL): URL = {
    if (url != null && url.toString.contains("hotfix")) {
      XlrPluginClassLoader.hotfixLogger.warn(s"Loading class/resource from hotfix: $url")
    }
    url
  }

  override def findResource(name: String): URL = logWithTime(s"Loading resource $name")(logHotfix(findResourceUrl(name).orNull))

  override def findResources(name: String): util.Enumeration[URL] = logWithTime(s"Loading resources $name")({
    resourcesByName(name).map(u => {
      logger.trace(s"Found $u for $name")
      u
    }).iterator.asJavaEnumeration
  })

  private def findResourceUrl(name: String): Option[URL] = resourceByName(name)

  private def convertClassName(className: String) = className.replace('.', '/').concat(".class")

  private def resourceByName(resourcePath: String): Option[URL] = {
    resourcesByName(resourcePath).headOption
  }

  private def resourcesByName(resourcePath: String): Seq[URL] = {
    classPathRoots.flatMap(_.getResources(resourcePath)).toSeq
  }

  private def readFully(is: InputStream) = {
    val os = new ByteArrayOutputStream()
    OverthereUtils.write(is, os)
    os.toByteArray
  }

  private var classPathRoots: Iterable[Plugin] = readDirs

  def clearClasspathRoots(): Unit = {
    classPathRoots.foreach(_.close)
    classPathRoots = classPathRoots.empty
  }

  def refreshDirs(): Unit = classPathRoots = readDirs

  private def readDirs: Iterable[Plugin] = {
    pluginDirectories.flatMap { pluginDirectory =>
      findAllPluginFiles(pluginDirectory, "jar").map(JarPlugin)
    }
  }
}
