package com.xebialabs.deployit.engine.tasker;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.FutureTask;
import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.task.TaskExecutor;
import com.google.common.collect.Lists;
import com.google.common.io.Closeables;

import com.xebialabs.deployit.engine.api.execution.StepExecutionState;
import com.xebialabs.deployit.engine.spi.services.RepositoryFactory;

import javassist.util.proxy.ProxyObjectInputStream;

import static com.google.common.base.Preconditions.checkState;
import static com.xebialabs.deployit.engine.api.execution.TaskExecutionState.*;
import static java.lang.String.format;
import static java.util.EnumSet.of;

public class Engine {
    static final AtomicReference<Registry> REGISTRY_REF = new AtomicReference<Registry>();
    static final AtomicReference<Archive> ARCHIVE_REF = new AtomicReference<Archive>();
    private TaskExecutor taskExecutor;
    private File recoveryDir;
    private RepositoryFactory repository;

    public Engine(Registry registry, Archive archive, TaskExecutor taskExecutor, String recoveryDir, RepositoryFactory repository) {
        REGISTRY_REF.set(registry);
        ARCHIVE_REF.set(archive);
        this.taskExecutor = taskExecutor;
        this.recoveryDir = new File(recoveryDir);
        this.repository = repository;
    }

    public String register(TaskSpecification spec) {
        registerDefaultTriggers(spec);
        Task task = new Task(spec.getSteps(), spec);
        registerTask(task);
        // Do a no-op so that the task recovery is written.
        getTaskRunner(task.getId()).pending();
        return task.getId();
    }

    private void registerTask(Task task) {
        REGISTRY_REF.get().register(new TaskRunner(task, repository));
    }

    private void registerDefaultTriggers(TaskSpecification spec) {
        spec.getListeners().add(new ArchiveTaskTrigger());
        if (spec.isRecoverable()) {
            spec.getListeners().add(new TaskRecoveryTrigger(recoveryDir));
        }
        spec.getListeners().add(new OldExecutionContextListenerCleanupTrigger());
    }

    public Task retrieve(String taskid) {
        TaskRunner retrieve = getTaskRunner(taskid);
        return retrieve == null ? null : retrieve.getTask();
    }

    public void execute(String taskid) {
        TaskRunner runner = getTaskRunner(taskid);
        checkState(runner != null, "No task registered with id [%s]", taskid);
        runner.queue();
        final FutureTask<Object> future = new FutureTask<Object>(runner, null);
        runner.setThreadHandle(future);
        taskExecutor.execute(future);
    }

    public void abort(String taskid) {
        getTaskRunner(taskid).abort();
    }

    public void stop(String taskid) {
        getTaskRunner(taskid).stop();
    }

    public void cancel(String taskid) {
        getTaskRunner(taskid).cancel();
    }

    public void skipSteps(String taskid, List<Integer> stepNrs) {
        TaskRunner runner = getTaskRunner(taskid);
        for (Integer stepNr : stepNrs) {
            runner.skip(stepNr);
        }
    }

    public void unskipSteps(String taskid, List<Integer> stepNrs) {
        TaskRunner runner = getTaskRunner(taskid);
        for (Integer stepNr : stepNrs) {
            runner.unskip(stepNr);
        }
    }

    public void moveStep(String taskid, int stepNr, int newPosition) {
        TaskRunner runner = getTaskRunner(taskid);
        runner.moveStep(stepNr, newPosition);
    }

    public void addPauseStep(String taskid, int position) {
        getTaskRunner(taskid).addPause(position);
    }

    public List<Task> getAllIncompleteTasks() {
        logger.debug("Finding all incomplete tasks");
        List<Task> tasks = Lists.newArrayList();
        for (TaskRunner eachRunner : REGISTRY_REF.get().tasks()) {
            Task eachTask = eachRunner.getTask();
            logger.debug("Considering task {}", eachTask);

            // although the registry only contains unfinished tasks, there can
            // be a time-window when the task is finished and is yet to be moved
            // to the task archive, so better check the status
            if (of(PENDING, QUEUED, STOPPED, EXECUTING, EXECUTED).contains(eachTask.getState())) {
                logger.debug("Returning task [{}] because it's PENDING, STOPPED or EXECUTING", eachTask);
                tasks.add(eachTask);
            }
        }
        logger.debug("Returning [{}] tasks", tasks.size());
        return tasks;
    }

    public void archive(final String taskid) {
        getTaskRunner(taskid).archive();
    }

    private TaskRunner getTaskRunner(String taskid) {
        return REGISTRY_REF.get().retrieve(taskid);
    }

    @PostConstruct
    public void recoverTasks() {
        createRecoveryDir();
        for (File file : recoveryDir.listFiles()) {
            if (!file.getName().endsWith(".task")) {
                continue;
            }
            logger.info("Recovering task [{}]", file);
            ObjectInputStream is = null;
            try {
                is = new ProxyObjectInputStream(new FileInputStream(file));
                Task t = (Task) is.readObject();
                setRecoveryState(t);
                registerTask(t);
            } catch (ClassNotFoundException e) {
                logger.error(format("Could not find serialized class in recovery file [%s]", file), e);
            } catch (IOException e) {
                logger.error(format("Could not read recovery file [%s]", file), e);
            } catch (RuntimeException e) {
                logger.error(format("Could not read recovery file [%s]", file), e);
            } finally {
                Closeables.closeQuietly(is);
            }
        }
    }

    private void setRecoveryState(final Task t) {
        if (!of(PENDING, EXECUTED).contains(t.getState())) {
            t.setState(STOPPED);
        }

        for (TaskStep step : t.getTaskSteps()) {
            if (step.getState() == StepExecutionState.EXECUTING) {
                step.setState(StepExecutionState.FAILED);
            }
        }
    }

    private void createRecoveryDir() {
        if (!recoveryDir.exists() && !recoveryDir.mkdir()) {
            throw new IllegalStateException("Could not create the recovery dir: " + recoveryDir);
        } else if (recoveryDir.exists() && !recoveryDir.isDirectory()) {
            throw new IllegalStateException("A file exists with the name of the recovery dir, please delete it (" + recoveryDir + ")");
        }
    }


    @PreDestroy
    public void shutdownTasks() {
        requestStopOnTasks();
        waitForTasksToBeStopped();
    }

    private void waitForTasksToBeStopped() {
        for (TaskRunner task : REGISTRY_REF.get().tasks()) {
            if (task.isExecuting()) {
                try {
                    logger.info("Waiting for task {} to stop", task.getTask().getId());
                    final FutureTask<Object> wrappingTask = task.getThreadHandle();
                    if (wrappingTask != null) {
                        wrappingTask.get();
                    }
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                } catch (ExecutionException e) {
                    logger.error("Could not wait for task {} to be stopped...", task.getTask().getId());
                }
            }
        }
    }

    private void requestStopOnTasks() {
        for (TaskRunner task : REGISTRY_REF.get().tasks()) {
            if (task.isExecuting()) {
                task.stop();
            }
        }
    }

    private static final Logger logger = LoggerFactory.getLogger(Engine.class);
}
