package com.xebialabs.deployit.deployment.stager

import com.xebialabs.deployit.deployment.planner.PlanSugar._
import com.xebialabs.deployit.deployment.planner._
import com.xebialabs.deployit.deployment.stager.DeploymentStager._StagingContext
import com.xebialabs.deployit.engine.spi.execution.ExecutionStateListener
import com.xebialabs.deployit.plugin.api.flow._
import com.xebialabs.deployit.plugin.api.services.Repository
import com.xebialabs.deployit.plugin.api.udm.artifact.{Artifact, DerivedArtifact, SourceArtifact}
import grizzled.slf4j.Logging

import scala.collection.convert.wrapAll._
import scala.collection.mutable
import scala.language.reflectiveCalls
import java.util.{List => JList}

class DeploymentStager(wrappedPlanner: MultiDeploymentPlanner) extends MultiDeploymentPlanner with Logging with PlanSugar {
  type GetStagingTarget = { def getStagingTarget: StagingTarget }

  def plan(spec: MultiDeltaSpecification, repository: Repository): PhasedPlan = {
    val plan = wrappedPlanner.plan(spec, repository)

    logger.info(s"Staging artifacts for plan: [${plan.getDescription}]")
    val (stagedSteps, stagedTargets) = stagePhases(plan.phases.toList)

    val stagePhase = generatePhase(stepsToPlan(stagedSteps, "Staging to", plan.getListeners), "Stage artifacts", alwaysExecuted = false)
    val cleanupPhase = generatePhase(
      stepsToPlan(stagedTargets.map(new StagedFileCleaningStep(_)), "Clean up staged files on", plan.getListeners),
      "Clean up staged artifacts",
      alwaysExecuted = true
    )

    val newPhases = stagePhase ::: plan.phases.toList ::: cleanupPhase
    plan.copy(phases = newPhases)
  }

  private def stepsToPlan(steps: Iterable[Step with GetStagingTarget], descriptionPrefix: String, listeners: JList[ExecutionStateListener]) = {
      steps.groupBy(_.getStagingTarget).map { case (target, stepList) =>
        new StepPlan(s"$descriptionPrefix ${target.getName}", stepList, listeners)
      }
  }

  private def generatePhase(underlyingPlans: Iterable[ExecutablePlan], description: String, alwaysExecuted: Boolean): List[PlanPhase] = underlyingPlans match {
    case Nil =>
      Nil
    case plan :: Nil =>
      new PlanPhase(plan, description, plan.getListeners, alwaysExecuted) :: Nil
    case plans =>
      val parallelPlan = new ParallelPlan(description, plans.toList, plans.head.getListeners)
      new PlanPhase(parallelPlan, description, parallelPlan.getListeners, alwaysExecuted) :: Nil
  }

  private def stagePhases(phases: List[PlanPhase]) = {
    val stagingContext = new _StagingContext
    phases.foreach(phase => doStage(phase.plan, stagingContext))
    stagingContext.stagingSteps -> stagingContext.cleanupHosts
  }

  private def doStage(plan: Plan, stagingContext: _StagingContext) {
    logger.debug(s"Staging for [${plan.getClass.getSimpleName}(${plan.getDescription})]")
    plan match {
      case cp: CompositePlan => cp.getSubPlans.foreach(doStage(_, stagingContext))
      case sp: StepPlan => doStage(sp, stagingContext)
    }
  }

  private def doStage(stepPlan: StepPlan, stagingContext: _StagingContext) {
    stepPlan.getSteps.withFilter(_.isInstanceOf[StageableStep]).foreach(step => doStage(step.asInstanceOf[StageableStep], stagingContext))
  }

  private def doStage(step: StageableStep, stagingContext: _StagingContext) {
    logger.debug(s"Preparing stage of artifacts for step [${step.getDescription}]")
    step.requestStaging(stagingContext)
  }
}

object DeploymentStager extends Logging {
  import java.util.{Map => JMap, HashMap => JHashMap}
  import collection.mutable.{Map => MMap}
  private[stager] class _StagingContext extends StagingContext {
    case class StagingKey(checksum: String, placeholders: JMap[String, String], target: StagingTarget)

    val stagingFiles: MMap[StagingKey, StagingFile] = new mutable.LinkedHashMap[StagingKey, StagingFile]()
    val cleanupHosts: mutable.Set[StagingTarget] = new mutable.HashSet[StagingTarget]()

    def stageArtifact(artifact: Artifact, target: StagingTarget): StagedFile = {
      if (Option(target.getStagingDirectoryPath).forall(_.trim.isEmpty)) {
        new JustInTimeFile(artifact)
      } else {
        val key = artifact match {
          case bda: DerivedArtifact[_] if bda.getSourceArtifact != null => StagingKey(bda.getSourceArtifact.getChecksum, bda.getPlaceholders, target)
          case bda: DerivedArtifact[_] if bda.getSourceArtifact == null => StagingKey(null, bda.getPlaceholders, target)
          case sa: SourceArtifact => StagingKey(sa.getChecksum, new JHashMap(), target)
        }

        if (!stagingFiles.contains(key)) {
          stagingFiles.put(key, new StagingFile(artifact, target))
          cleanupHosts += target
        }

        stagingFiles(key)
      }
    }

    def stagingSteps = stagingFiles.values.map(new StagingStep(_))
  }

}
