package com.xebialabs.xlplatform.jmx

import com.xebialabs.deployit.util.PasswordEncrypter

import java.io.{FileInputStream, InputStream}
import java.lang.management.ManagementFactory
import java.rmi.registry.LocateRegistry
import java.security.KeyStore
import java.util
import javax.management.remote.rmi.RMIConnectorServer
import javax.management.remote.{JMXConnectorServerFactory, JMXServiceURL}
import javax.net.ssl.{KeyManager, KeyManagerFactory, SSLContext}
import javax.rmi.ssl.{SslRMIClientSocketFactory, SslRMIServerSocketFactory}
import com.xebialabs.xlplatform.settings.SecuritySettings
import com.xebialabs.xlplatform.utils.{ClassLoaderUtils, SecureRandomHolder}
import grizzled.slf4j.Logging


class JMXAgent extends Logging {

  lazy val passwordEncrypter: PasswordEncrypter = PasswordEncrypter.getInstance()

  def start(jmxSettings: JMXSettings): Unit = {
    if (jmxSettings.enabled) {
      debug("Starting JMX server.")
      val hostname = jmxSettings.hostname
      val port = jmxSettings.port

      if (!hostname.isEmpty) {
        debug(s"Binding to $hostname:$port.")
        System.setProperty("java.rmi.server.hostname", hostname)
      } else {
        debug(s"Binding to (localhost):$port.")
      }
      LocateRegistry.createRegistry(port)
      val mBeanServer = ManagementFactory.getPlatformMBeanServer
      val env = new java.util.HashMap[String, Any]
      enableSSL(jmxSettings.ssl, env)
      val url = new JMXServiceURL(s"service:jmx:rmi:///jndi/rmi://%s:%d/jmxrmi".format(hostname, port))
      val cs = JMXConnectorServerFactory.newJMXConnectorServer(url, env, mBeanServer)

      cs.start
    }
  }


  private def enableSSL(securitySettings: SecuritySettings, env: util.HashMap[String, Any]) = {
    if (securitySettings.enabled) {
      env.put(RMIConnectorServer.RMI_CLIENT_SOCKET_FACTORY_ATTRIBUTE, new SslRMIClientSocketFactory)
      env.put(RMIConnectorServer.RMI_SERVER_SOCKET_FACTORY_ATTRIBUTE,
        new SslRMIServerSocketFactory(initSslContext(securitySettings),
          securitySettings.enabledAlgorithms.toList.toArray,
          Array(securitySettings.protocol),
          false)
      )
    }
  }

  def initSslContext(settings: SecuritySettings): SSLContext = {
    val context = SSLContext.getInstance(settings.protocol)
    context.init(createKeyManagersIfPossible(settings), null, SecureRandomHolder.get())
    context
  }

  private def createKeyManagersIfPossible(settings: SecuritySettings): Array[KeyManager] = {
    def createKeyManager(keyStoreResource: String, keyStorePassword: Array[Char], keyPassword: Array[Char]) = {
      val keyStore = KeyStore.getInstance(KeyStore.getDefaultType)
      keyStore.load(loadResource(keyStoreResource), keyStorePassword)
      val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm)
      keyManagerFactory.init(keyStore, keyPassword)
      keyManagerFactory.getKeyManagers
    }

    settings.keyStore.map(ks => createKeyManager(ks, passwordEncrypter.ensureDecrypted(settings.keyStorePassword).toCharArray,
      passwordEncrypter.ensureDecrypted(settings.keyPassword).toCharArray)).toArray.flatten
  }

  private def loadResource(resource: String): InputStream = {
    Option(ClassLoaderUtils.classLoader.getResourceAsStream(resource)).getOrElse(new FileInputStream(resource))
  }

}
