package com.xebialabs.xlrelease.service

import com.xebialabs.deployit.plugin.api.udm.ConfigurationItem
import com.xebialabs.deployit.repository.RepositoryAdapter
import com.xebialabs.xlrelease.domain._
import com.xebialabs.xlrelease.domain.utils.syntax._
import com.xebialabs.xlrelease.json.JsonUtils.TaskBackupJsonOps
import com.xebialabs.xlrelease.repository.DependencyTargetResolver
import com.xebialabs.xlrelease.repository.Ids._
import com.xebialabs.xlrelease.repository.sql.SqlRepositoryAdapter
import com.xebialabs.xlrelease.repository.sql.persistence.{ReleasePersistence, TaskBackupPersistence}
import com.xebialabs.xlrelease.serialization.json.repository.ResolverRepository
import com.xebialabs.xlrelease.serialization.json.utils.CiSerializerHelper
import com.xebialabs.xlrelease.utils.CiHelper.{getNestedCis, rewriteWithNewId}
import com.xebialabs.xlrelease.variable.VariableHelper.containsVariables
import com.xebialabs.xltype.serialization.CiReference
import grizzled.slf4j.Logging
import io.micrometer.core.annotation.Timed
import org.springframework.beans.factory.annotation.Autowired

import java.util
import scala.jdk.CollectionConverters._

class SqlTaskBackup @Autowired()(taskBackupPersistence: TaskBackupPersistence,
                                 releasePersistence: ReleasePersistence,
                                 val repositoryAdapter: SqlRepositoryAdapter,
                                 dependencyTargetResolver: DependencyTargetResolver)
  extends Logging
    with TaskBackup {

  @Timed
  override def backupTasks(tasks: util.List[Task], releaseToBackup: Release): Unit = {
    tasks.asScala.groupBy(_.getRelease).foreach { case (r, tasks) =>
      val release = if (releaseToBackup == null) {
        repositoryAdapter.read[Release](r.getId)
      } else {
        releaseToBackup
      }
      tasks.foreach(task => backupTask(task, release))
    }
  }

  @Timed
  override def backupTask(task: Task, releaseToBackup: Release): Unit = {
    if (task.getContainer != task.getPhase) {
      throw new IllegalStateException(s"Task ${task.getId} should be a top level task in order to be backed up.")
    }
    val taskId = task.getId
    val release = if (releaseToBackup != null) {
      releaseToBackup
    } else {
      repositoryAdapter.read[Release](releaseIdFrom(taskId))
    }
    taskBackupPersistence.upsertTaskBackup(release.getTask(taskId))
  }

  /**
   * see REL-6704 - the idea is that we should copy up to date input properties
   * of a CustomScriptTask from failed/complete task rather than from the original backup
   */
  private def copyPythonScriptInputPropertiesIfNecessary(release: Release, taskId: String, dst: Task): Unit = {
    dst match {
      case dstCustomScriptTask: CustomScriptTask if dstCustomScriptTask.getPythonScript != null =>
        Option(release.getTask(taskId)).foreach {
          case srcCustomScriptTask: CustomScriptTask if srcCustomScriptTask.getPythonScript != null =>
            for (property <- srcCustomScriptTask.getPythonScript.getInputProperties.asScala) {
              logger.debug(s"Copying property [${property.getName}] on phase restore")
              val currentValue = dstCustomScriptTask.getPythonScript.getProperty[Object](property.getName)
              currentValue match {
                case s: String if containsVariables(s) => ()
                case _ =>
                  dstCustomScriptTask.getPythonScript.setProperty(
                    property.getName,
                    srcCustomScriptTask.getPythonScript.getProperty[Object](property.getName))
              }
            }
          case _ => ()
        }
      case _ => ()
    }
  }

  @Timed
  override def restoreTask[T <: Task](task: T, originalTaskId: String = null, inMemory: Boolean = false): T = {
    val taskId: String = task.getId
    val taskToRestoreId = Option(originalTaskId).getOrElse(taskId)
    doWithBackup(taskToRestoreId) { (release: Release, backup: Task) =>
      rewriteWithNewId(backup, backup.getId.replace(phaseIdFrom(backup.getId), phaseIdFrom(taskId)))
      backup.setContainer(task.getContainer)
      backup.allDependencies.foreach(dependencyTargetResolver.resolveTarget)
      copyPythonScriptInputPropertiesIfNecessary(release, taskToRestoreId, backup)
      release.replaceTask(backup)
      if (!inMemory) {
        releasePersistence.update(release)
      }
      backup.asInstanceOf[T]
    }.getOrElse {
      logger.debug(s"Trying to restore task $taskId which doesn't have a backup.")
      task
    }
  }

  @Timed
  override def removeVariable(task: UserInputTask, variableId: String): Unit = {
    doWithBackup(task.getId) { (_, backup: UserInputTask) =>
      backup.setVariables(backup.getVariables.asScala.filter(v => v.getId != variableId).asJava)
      taskBackupPersistence.updateTaskBackupJson(backup)
      backup
    }
  }

  private def doWithBackup[T <: Task](taskId: String)(process: (Release, T) => T): Option[T] = {
    val release = repositoryAdapter.read[Release](releaseIdFrom(taskId))
    for {
      taskJson <- taskBackupPersistence.findTaskBackup(taskId, release.getCiUid)
      backup = CiSerializerHelper.deserialize(taskJson.withoutUnknownProps,
        new InMemoryReleaseSqlRepositoryAdapter(release, repositoryAdapter)).asInstanceOf[T]
    } yield process(release, backup)
  }

  class InMemoryReleaseSqlRepositoryAdapter(release: Release, sqlRepositoryAdapter: SqlRepositoryAdapter)
    extends RepositoryAdapter(null, null, null, null)
      with ResolverRepository {

    override def read[T <: ConfigurationItem](id: String): T = id match {
      case r if isAttachmentId(r) => sqlRepositoryAdapter.read(id)
      case itemId if isInRelease(itemId) && releaseIdFrom(itemId) == release.getId =>
        getNestedCis(release).asScala.find(ci => ci.getId == itemId).get.asInstanceOf[T]
      case _ => sqlRepositoryAdapter.read(id)
    }

    override def read[T <: ConfigurationItem](ids: util.List[String], depth: Integer): util.List[T] = {
      ids.asScala.map(this.read[T]).filter(_ != null).asJava
    }

    override def resolve(id: String, ciReference: CiReference): ConfigurationItem = {
      id match {
        case itemId if isInRelease(itemId) && releaseIdFrom(itemId) == release.getId =>
          getNestedCis(release).asScala.find(ci => ci.getId == itemId).get
        case _ =>
          sqlRepositoryAdapter.resolve(id, ciReference)
      }
    }

    override def read[T <: ConfigurationItem](id: String, target: ConfigurationItem): T = {
      sqlRepositoryAdapter.read(id, target)
    }
  }

}
