package com.xebialabs.satellite.streaming

import java.io.{FileInputStream, InputStream}
import java.security.{KeyStore, SecureRandom}
import javax.net.ssl._

import akka.stream.TLSProtocol._
import akka.stream._
import akka.stream.scaladsl._
import akka.util.ByteString
import com.xebialabs.xlplatform.settings.SecuritySettings
import com.xebialabs.xlplatform.utils.ClassLoaderUtils
import grizzled.slf4j.Logging

object SslStreamingSupport extends Logging {

  object SslConfig extends Logging {

    def apply(useSsl: Boolean, settings: SecuritySettings): SslConfig = useSsl match {
      case true if settings.enabled =>
        debug("SSL enabled")
        Enabled(settings)
      case true =>
        logger.warn("Requested ssl encryption but there is no configuration provided.")
        throw new SecurityException("Requested ssl encryption but there is no configuration provided.")
      case _ =>
        debug("SSL disabled")
        Disabled
    }

    lazy val Disabled = SslConfig(enabled = false, sslContext = null, role = null, closing = null)

    private def Enabled(settings: SecuritySettings) = SslConfig(
      enabled = true,
      sslContext = initSslContext(settings),
      enabledAlgorithms = settings.enabledAlgorithms,
      closing = null
    )
  }

  case class SslConfig(enabled: Boolean, sslContext: SSLContext, role: TLSRole = null, closing: TLSClosing, enabledAlgorithms: Seq[String] = Nil) {
    def asClient = copy(role = Client)

    def asServer = copy(role = Server)

    def ignoreCancel = copy(closing = IgnoreCancel)

    def ignoreComplete = copy(closing = IgnoreComplete)

    def ignoreBoth = copy(closing = IgnoreBoth)

    def eagerClose = copy(closing = EagerClose)
  }

  type SslFlow = BidiFlow[SslTlsOutbound, ByteString, ByteString, SslTlsInbound, _]

  type ByteStingFlow[Mat] = Flow[ByteString, ByteString, Mat]

  def wrapWithSsl[Mat](sslConfig: SslConfig, tcpConnection: ByteStingFlow[Mat]) = sslWrapper[Mat](sslFlow(sslConfig), tcpConnection, sslConfig.role)

  private def sslFlow(sslConfig: SslConfig) = sslConfig.enabled match {
    case true =>
      debug("Real ssl config")
      TLS(sslConfig.sslContext, newSessionNegotiation(sslConfig.enabledAlgorithms), sslConfig.role, sslConfig.closing)
    case false =>
      debug("Placebo ssl config")
      TLSPlacebo()
  }

  private def newSessionNegotiation(cypherSuites: Seq[String]) = NegotiateNewSession.withCipherSuites(cypherSuites: _*)

  private def sslWrapper[Mat](sslFlow: SslFlow, tcpConnection: ByteStingFlow[Mat], role: TLSRole): ByteStingFlow[Mat] = Flow.fromGraph(GraphDSL.create(sslFlow, tcpConnection)((_, c) => c) { implicit builder =>
    (sslFlow, conn) =>
      import GraphDSL.Implicits._
      val sendBytes = builder.add(Flow[ByteString].map(bs => SendBytes(bs)))
      sendBytes.outlet ~> sslFlow.in1
      sslFlow.out1 ~> conn ~> sslFlow.in2
      val inboundFlow = sslFlow.out2.collect { case SessionBytes(_, bytes) => bytes }
      FlowShape(sendBytes.in, inboundFlow.outlet)

  })

  def initSslContext(settings: SecuritySettings): SSLContext = {
    val context = SSLContext.getInstance(settings.protocol)
    context.init(createKeyManagersIfPossible(settings), createTrustManagers(settings), new SecureRandom)
    context
  }

  private def createKeyManagersIfPossible(settings: SecuritySettings): Array[KeyManager] = {
    def createKeyManager(keyStoreResource: String, keyStorePassword: Array[Char], keyPassword: Array[Char]) = {
      val keyStore = KeyStore.getInstance(KeyStore.getDefaultType)
      keyStore.load(loadResource(keyStoreResource), keyStorePassword)
      val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm)
      keyManagerFactory.init(keyStore, keyPassword)
      keyManagerFactory.getKeyManagers
    }
    settings.keyStore.map(ks => createKeyManager(ks, settings.keyStorePassword.toCharArray, settings.keyPassword.toCharArray)).toArray.flatten
  }

  private def createTrustManagers(settings: SecuritySettings): Array[TrustManager] = {
    val trustStore = KeyStore.getInstance(KeyStore.getDefaultType)
    trustStore.load(loadResource(settings.trustStore), settings.trustStorePassword.toCharArray)

    val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm)
    trustManagerFactory.init(trustStore)

    trustManagerFactory.getTrustManagers
  }

  private def loadResource(resource: String): InputStream = {
    Option(ClassLoaderUtils.classLoader.getResourceAsStream(resource)).getOrElse(new FileInputStream(resource))
  }
}
