package com.xebialabs.license

import com.xebialabs.deployit.plugin.api.reflect.Type
import com.xebialabs.deployit.plugin.api.udm.ConfigurationItem
import com.xebialabs.license.LicenseCiCounter.LicensedCiUse
import com.xebialabs.license.service.LicenseTransaction

import java.util.concurrent.atomic.AtomicInteger
import scala.collection.concurrent.TrieMap
import scala.collection.mutable.{Map => MMap}

object LicenseCiCounter {

  case class LicensedCiUse(`type`: Type, allowedAmount: Int, actualAmount: Int) {
    override def toString: String = s"Your license is limited to $allowedAmount ${`type`} CIs and you currently have $actualAmount."
  }

}

trait LicenseCiCounter {
  def allowedCiAmounts: Map[Type, Int]

  def getCiCount(ciType: Type): Int

  def licensedCisInUse(): Array[LicensedCiUse]

  def restrictedTypes: Set[Type]

  def registerCisCreation(cis: Seq[ConfigurationItem], transaction: LicenseTransaction): Unit

  def registerTypesCreation(tt: Seq[Type], transaction: LicenseTransaction): Unit

  def registerTypeRemoval(t: Type, transaction: LicenseTransaction): Unit

  def registerTypesRemoval(tt: Seq[Type], transaction: LicenseTransaction): Unit

  def rollbackTransaction(tt: Type, transaction: LicenseTransaction): Unit

  /**
    * Validates current counters against the license limitations, throws [[AmountOfCisExceededException]] if actual amount of CIs exceeds allowed values.
    */
  def validate(): Unit = {
    findViolations() match {
      case violations: Seq[LicensedCiUse] if violations.nonEmpty =>
        throw new AmountOfCisExceededException(
          s"The system reached the maximum allowed number of Configuration Items: ${violations.mkString(", ")}. Please check your license.")
      case _ =>
    }
  }

  protected def findViolations(): Seq[LicensedCiUse]

  protected def countRestrictedTypes(types: Seq[Type]): Map[Type, Int] = {
    allowedCiAmounts.keys.collect {
      case t => (t, types.count(_.instanceOf(t)))
    }.toMap.filterNot(_._2 == 0)
  }
}

class InMemoryLicenseCiCounter(val allowedCiAmounts: Map[Type, Int]) extends LicenseCiCounter {

  // This will be used from java
  def this() = this(Map.empty)

  val typeCounter: MMap[Type, AtomicInteger] = TrieMap.empty.withDefaultValue(new AtomicInteger(0)) // used in JcrLicenseCiCounterFactory

  private def getAtomicCiCount(t: Type): AtomicInteger = typeCounter.get(t) match {
    case Some(value: AtomicInteger) => value
    case None => new AtomicInteger(0)
  }

  def getCiCount(t: Type): Int = getAtomicCiCount(t).get()

  def licensedCisInUse(): Array[LicensedCiUse] = allowedCiAmounts.map {
    case (ciType, licensed) => LicensedCiUse(ciType, licensed, getCiCount(ciType))
  }.toArray // for Java interop

  /**
    * Increases the counter according to given amount of configuration items and performs the validation.
    * If the validation failed, reverts counter changes and throws a validation exception.
    */
  def registerCisCreation(cis: Seq[ConfigurationItem], transaction: LicenseTransaction): Unit = {
    registerTypesCreation(cis.map(_.getType), transaction)
  }

  /**
    * Increases the counter according to given amount of types and performs the validation.
    * If the validation failed, reverts counter changes and throws a validation exception.
    */
  def registerTypesCreation(tt: Seq[Type], transaction: LicenseTransaction): Unit = {
    countRestrictedTypes(tt).foreach {
      case (t, cnt) if allowedCiAmounts(t) - getCiCount(t) >= cnt =>
        updateGeneralCounter(t, cnt)
        transaction.registerCreate(t, cnt)
      case (t, _) =>
        throw new AmountOfCisExceededException(s"Unable to create ${t}. Your license is limited to ${allowedCiAmounts(t)} $t CIs and you currently have ${getCiCount(t)}.")
    }

    validate()
  }

  /**
    * Decrements the counter according to the passed types.
    */
  def registerTypesRemoval(tt: Seq[Type], transaction: LicenseTransaction): Unit = {
    countRestrictedTypes(tt).foreach {
      case (t, cnt) =>
        updateGeneralCounter(t, -cnt)
        transaction.registerDelete(t, cnt)
    }
  }

  def updateGeneralCounter(t: Type, cnt: Int): Unit = {
    typeCounter.get(t) match {
      case Some(value) => value.addAndGet(cnt)
      case None => typeCounter.put(t, new AtomicInteger(cnt))
    }
  }

  /**
    * Decrements the counter for the passed type.
    */
  def registerTypeRemoval(t: Type, transaction: LicenseTransaction): Unit = registerTypesRemoval(Seq(t), transaction)

  override protected def findViolations(): Seq[LicensedCiUse] = {
    allowedCiAmounts.collect {
      case (t, allowedAmount) if typeCounter.get(t).exists(_.get() > allowedAmount) => LicensedCiUse(t, allowedAmount, getCiCount(t))
    }.toSeq
  }

  override def rollbackTransaction(tt: Type, transaction: LicenseTransaction): Unit = getAtomicCiCount(tt).addAndGet(-transaction.getCiCount(tt))

  override def restrictedTypes: Set[Type] = allowedCiAmounts.keys.toSet
}

