package com.xebialabs.deployit.service.deployment

import java.util.UUID
import ai.digital.deploy.tasker.common.TaskMetadata._
import ai.digital.deploy.tasker.common.TaskType
import com.xebialabs.deployit.deployment.planner.{CheckPointManagerListener, RollbackCompletedListener}
import com.xebialabs.deployit.engine.api.distribution.TaskExecutionWorkerRepository
import com.xebialabs.deployit.engine.api.execution.TaskExecutionState.{ABORTED, EXECUTED, FAILED, STOPPED}
import com.xebialabs.deployit.engine.tasker._
import com.xebialabs.deployit.engine.tasker.repository.PendingTaskRepository
import com.xebialabs.deployit.repository.{WorkDir, WorkDirContext}
import com.xebialabs.deployit.spring.BeanWrapper
import com.xebialabs.deployit.task.TaskMetadataModifier._
import com.xebialabs.deployit.task.WorkdirCleanerTrigger
import grizzled.slf4j.Logging
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.stereotype.Component

import scala.jdk.CollectionConverters._

@Component
class RollbackService @Autowired()(engine: BeanWrapper[TaskExecutionEngine],
                                   deploymentService: DeploymentService,
                                   workerRepository: TaskExecutionWorkerRepository,
                                   pendingTaskRepository: PendingTaskRepository) extends Logging {
  def rollback(task: Task): String = {
    logger.debug(s"Going to rollback task ${task.getId()}")
    val rollbackSpecification: TaskSpecification = createRollbackSpecification(task)
    workerRepository.getWorker(task.getWorkerId).foreach { worker =>
      pendingTaskRepository.store(rollbackSpecification, Some(worker.address))
    }
    doPrepareRollback(task, rollbackSpecification)
  }

  private[deployment] def createRollbackSpecification(task: Task): TaskSpecification = {
    val context: TaskExecutionContext = task.getContext
    val workdirCleanerTrigger: WorkdirCleanerTrigger = getWorkdirCleaner(context)
    val checkPointManagerListener: CheckPointManagerListener = getCheckPointManagerListener(context)
    val rollbackSpecification: TaskSpecification = createRollbackSpecification(task, checkPointManagerListener, workdirCleanerTrigger.getWorkDirs.asScala.toList)
    rollbackSpecification.getListeners.add(new RollbackCompletedListener(task.getId()))
    putMetadata(rollbackSpecification, TASK_TYPE, TaskType.ROLLBACK.name)
    putMetadata(rollbackSpecification, ROLLBACK_TASK, task.getId())
    logger.info(s"Create rollback task ${rollbackSpecification.getId} for task ${task.getId}")
    rollbackSpecification
  }

  private def createRollbackSpecification(task: Task, partialCommitTrigger: CheckPointManagerListener, workDirs: List[WorkDir]): TaskSpecification = {
    val rollbackSpec = partialCommitTrigger.checkpointManager.prepareRollback()
    WorkDirContext.setWorkDir(task.getWorkDir)
    val rollbackTaskId = UUID.randomUUID.toString
    deploymentService.getTaskFullSpecification(rollbackTaskId, rollbackSpec, task.getWorkDir, workDirs: _*)
  }

  private def getWorkdirCleaner(context: TaskExecutionContext): WorkdirCleanerTrigger = {
    val name: String = classOf[WorkdirCleanerTrigger].getName
    context.getAttribute(name).asInstanceOf[WorkdirCleanerTrigger]
  }

  private def getCheckPointManagerListener(context: TaskExecutionContext): CheckPointManagerListener = {
    val name: String = classOf[CheckPointManagerListener].getName
    context.getAttribute(name).asInstanceOf[CheckPointManagerListener]
  }

  private[deployment] def doPrepareRollback(task: Task, rollbackSpecification: TaskSpecification): TaskId = {
    engine.get().prepareRollback(task.getId(), rollbackSpecification)
    archive(task)
    rollbackSpecification.getId
  }

  private def archive(task: Task): Unit = {
    if (task.getState == EXECUTED) {
      engine.get().archive(task.getId())
    } else if (Set(STOPPED, ABORTED, FAILED).contains(task.getState())) {
      engine.get().cancel(task.getId())
    } else {
      throw new IllegalStateException(s"Can only rollback a STOPPED, FAILED, ABORTED or EXECUTED task [${task.getId()} (${task.getState()})]")
    }
  }
}
