package com.xebialabs.xlplatform.cluster.full.downing

import akka.actor.{Actor, ActorLogging, Address, Cancellable}
import akka.cluster.ClusterEvent._
import akka.cluster.MemberStatus.{Down, Exiting}
import akka.cluster._
import com.xebialabs.xlplatform.cluster.XlCluster
import com.xebialabs.xlplatform.cluster.membership.storage.ClusterMembershipManagement
import com.xebialabs.xlplatform.cluster.membership.storage.ClusterMembershipManagement.{Data, Seed}

import scala.collection.immutable
import scala.collection.immutable.SortedSet
import scala.concurrent.Await
import scala.concurrent.duration.{Deadline, DurationInt, FiniteDuration}
import scala.language.postfixOps
import scala.util.{Success, Try}

object LeaderAutoDowningActor {

  val NODE_DOWNED_EXIT_CODE: Int = 42

  val ignoredStatus: Set[MemberStatus] = Set[MemberStatus](Down, Exiting)

  sealed trait DownAction

  case object DownReachable extends DownAction

  case object DownUnreachable extends DownAction

  case object Tick

}

abstract class LeaderAutoDowningActor(val stableAfter: FiniteDuration, val downRemovalMargin: FiniteDuration) extends Actor with ActorLogging {

  import LeaderAutoDowningActor._
  import context.dispatcher

  val cluster: Cluster = Cluster(context.system)

  val selfAddress: Address = cluster.selfAddress

  val tickTask: Cancellable = {
    val interval = (stableAfter / 2).max(500 millis)
    context.system.scheduler.scheduleWithFixedDelay(interval, interval / 2, self, Tick)
  }

  var isLeader: Boolean = false

  var selfAdded: Boolean = false

  def resetStableDeadline(): Deadline = Deadline.now + stableAfter

  var stableDeadline: Deadline = resetStableDeadline()

  var unreachable: Set[Address] = Set.empty[Address]

  var members: immutable.SortedSet[Member] = immutable.SortedSet.empty(Member.ageOrdering)

  lazy val membershipManagement: ClusterMembershipManagement = XlCluster(context.system).membershipManagement

  override def preStart(): Unit = {
    cluster.subscribe(self, ClusterEvent.InitialStateAsEvents, classOf[ClusterDomainEvent])
    cluster.registerOnMemberRemoved {
      shutdownMember()
    }
    super.preStart()
  }

  override def postStop(): Unit = {
    cluster.unsubscribe(self)
    tickTask.cancel()
    super.postStop()
  }

  def receive: Receive = handleMemberUpdates orElse handleTick

  def handleMemberUpdates: Receive = {
    case UnreachableMember(m) => unreachableMember(m)
    case ReachableMember(m) => reachableMember(m)
    case MemberUp(m) => memberUp(m)
    case MemberRemoved(m, _) => memberRemoved(m)

    case LeaderChanged(leaderOption) =>
      isLeader = leaderOption.contains(selfAddress)
  }

  def handleTick: Receive = {
    case Tick =>
      val shouldAct = isLeader && selfAdded && unreachable.nonEmpty && stableDeadline.isOverdue()

      if (shouldAct) {
        decideAndDownMembers()
      }

    case _: ClusterDomainEvent => // ignore
  }

  def unreachableMember(m: Member): Unit = {
    require(m.address != selfAddress, "selfAddress cannot be unreachable")

    if (!ignoredStatus(m.status)) {
      unreachable += m.address

      if (m.status != MemberStatus.Joining) add(m)
    }

    stableDeadline = resetStableDeadline()
  }

  def reachableMember(m: Member): Unit = {
    unreachable -= m.address

    stableDeadline = resetStableDeadline()
  }

  def memberUp(m: Member): Unit = {
    add(m)
    if (m.address == selfAddress) selfAdded = true

    stableDeadline = resetStableDeadline()
  }

  def memberRemoved(m: Member): Unit = {
    if (m.address == selfAddress) {
      context.stop(self)
    } else {
      unreachable -= m.address
      members -= m
      stableDeadline = resetStableDeadline()
    }
  }

  def decideAndDownMembers(): Unit = {
    val downAction = decide()

    val membersToDown = downAction match {
      case DownUnreachable => unreachable
      case DownReachable => reachable ++ joining.filterNot(unreachable)
    }

    if (membersToDown.nonEmpty) {
      membersToDown.foreach(member => if (member != selfAddress) down(member))
      if (membersToDown.contains(selfAddress)) {
        downSelf()
      }
      stableDeadline = resetStableDeadline()
    }
  }

  def decide(): DownAction

  def down(member: Address): Unit

  def downSelf(): Unit

  def shutdownMember(): Unit = {
    log.info(s"Shutting down XL Release node to avoid inconsistent cluster state.")
    // exit JVM when ActorSystem has been terminated
    context.system.registerOnTermination(System.exit(NODE_DOWNED_EXIT_CODE))
    // shut down ActorSystem
    context.system.terminate()

    // In case ActorSystem shutdown takes longer than n seconds,
    // exit the JVM forcefully anyway.
    // We must spawn a separate thread to not block current thread,
    // since that would have blocked the shutdown of the ActorSystem.
    new Thread {
      override def run(): Unit = {
        if (Try(Await.ready(context.system.whenTerminated, (downRemovalMargin / 2).max(10 seconds))).isFailure) {
          System.exit(NODE_DOWNED_EXIT_CODE)
        }
      }
    }.start()
  }


  def add(m: Member): Unit = {
    members = members - m + m
  }

  def unreachableMembers: SortedSet[Member] = members.filter(m => unreachable(m.address))

  def reachableMembers: SortedSet[Member] = members.filter(m => !unreachable(m.address))

  def reachable: SortedSet[Address] = reachableMembers.map(_.address)

  def joining: Set[Address] = cluster.state.members.collect {
    case m if m.status == MemberStatus.Joining => m.address
  }

  def isActiveMember(memberAddress: Address): Boolean = getActiveMembers.contains(memberAddress)

  def getActiveMembers: Seq[Address] = {
    Try(Await.result(membershipManagement.listActiveSeeds(cluster), stableAfter)) match {
      case Success(Data(seeds: Seq[_])) =>
        seeds.asInstanceOf[Seq[Seed]].map(_.address)
      case _ => Seq.empty[Address]
    }
  }

}
