package com.xebialabs.xlplatform.cluster.membership.storage

import java.sql.{Connection, PreparedStatement, SQLException}
import javax.sql.DataSource

import akka.actor.Address
import akka.cluster.Cluster
import com.xebialabs.xlplatform.cluster.DataSourceConfig
import com.xebialabs.xlplatform.cluster.membership.storage.ClusterMembershipManagement._
import com.zaxxer.hikari.{HikariConfig, HikariDataSource}
import grizzled.slf4j.Logging

import scala.concurrent.duration.FiniteDuration
import scala.concurrent.{ExecutionContext, Future}

abstract class ClusterMembershipSQLManagement(dsConfig: DataSourceConfig, ttl: FiniteDuration, failFast: Boolean = true) extends ClusterMembershipManagement with Logging {
  def driverName: String

  val datasource: DataSource = initializeDataSource()

  private[this] def initializeDataSource(): DataSource = {
    logger.info(s"Starting DataSource for cluster management: $dsConfig (driverName = $driverName)")
    val cfg = new HikariConfig()
    cfg.setInitializationFailTimeout(if (failFast) 1 else -1)
    cfg.setDriverClassName(dsConfig.driver.getOrElse(driverName))
    cfg.setJdbcUrl(dsConfig.url)
    cfg.setUsername(dsConfig.username)
    cfg.setPassword(dsConfig.password)
    cfg.setPoolName(dsConfig.poolName)
    cfg.setMaximumPoolSize(dsConfig.maxPoolSize)
    cfg.setConnectionTimeout(dsConfig.connectionTimeout)
    cfg.setMinimumIdle(dsConfig.minimumIdle)
    cfg.setIdleTimeout(dsConfig.idleTimeout)
    cfg.setLeakDetectionThreshold(dsConfig.leakConnectionThreshold)
    new HikariDataSource(cfg)
  }

  override def registerSelf(self: Address)(implicit executionContext: ExecutionContext): Future[Result] = {
    Future {
      withConnection { con =>
        val ps: PreparedStatement = con.prepareStatement(registerSelfSql)
        ps.setString(1, self.protocol)
        ps.setString(2, self.system)
        ps.setString(3, self.host.get)
        ps.setInt(4, self.port.get)
        ps.setLong(5, ttl.toSeconds)
        ps.executeUpdate()
        con.commit()
        Success
      }
    }(executionContext)
  }


  def registerSelfSql: String

  override def heartbeat(self: Address)(implicit executionContext: ExecutionContext): Future[Result] = {
    Future {
      withConnection { con =>
        val ps: PreparedStatement = con.prepareStatement(heartbeatSql)
        ps.setLong(1, ttl.toSeconds)
        ps.setString(2, self.host.get)
        ps.setInt(3, self.port.get)
        ps.executeUpdate()
        con.commit()
        Success
      }
    }(executionContext)
  }

  def heartbeatSql: String

  override def listActiveSeeds(cluster: Cluster)(implicit executionContext: ExecutionContext): Future[Result] = Future {
    withConnection { con =>
      val call: PreparedStatement = con.prepareStatement(listActiveSeedsSql)
      call.setString(1, cluster.selfAddress.protocol)
      call.setString(2, cluster.selfAddress.system)
      val set = call.executeQuery()
      var l: List[Seed] = Nil
      while (set.next()) {
        val seed = Seed(
          Address(cluster.selfAddress.protocol, cluster.selfAddress.system, set.getString("host"), set.getInt("port")),
          set.getTimestamp("ttl").toInstant
        )
        logger.info(s"Found Seed $seed")
        l = seed :: l
      }
      Data(l)
    }
  }(executionContext)

  def listActiveSeedsSql: String = {
    """
      |SELECT host, port, ttl
      |FROM cluster_members
      |WHERE protocol = ? AND system = ? AND ttl > CURRENT_TIMESTAMP ORDER BY ttl""".stripMargin
  }

  override def deregisterSeed(seed: Address)(implicit executionContext: ExecutionContext): Future[Result] = Future {
    withConnection { con =>
      val call: PreparedStatement = con.prepareStatement("DELETE FROM cluster_members WHERE host = ? AND port = ?")
      call.setString(1, seed.host.get)
      call.setInt(2, seed.port.get)
      call.execute()
      con.commit()
      Success
    }
  }(executionContext)


  def withConnection[A >: Result](b: Connection => A): A = {
    val c: Connection = datasource.getConnection
    c.setAutoCommit(false)
    try {
      b(c)
    } catch {
      case e: SQLException =>
        c.rollback()
        logger.error("Error", e)
        Failure(e.getMessage)
    } finally {
      c.close()
    }
  }
}
