package com.xebialabs.deployit.plugin.python;

import java.io.*;
import java.util.Timer;
import java.util.TimerTask;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;

import com.xebialabs.deployit.engine.spi.execution.ExecutionStateListener;
import com.xebialabs.deployit.engine.spi.execution.StepExecutionStateEvent;
import com.xebialabs.deployit.engine.spi.execution.TaskExecutionStateEvent;
import com.xebialabs.deployit.plugin.api.flow.ExecutionContext;
import com.xebialabs.overthere.OverthereConnection;
import com.xebialabs.overthere.OverthereFile;
import com.xebialabs.overthere.OverthereProcess;
import com.xebialabs.overthere.RuntimeIOException;

import static com.xebialabs.deployit.plugin.python.PythonManagingContainer.CONNECT_FROM_DAEMON;
import static com.xebialabs.deployit.plugin.python.PythonManagingContainer.DISCONNECT_FROM_DAEMON;
import static com.xebialabs.deployit.plugin.python.PythonManagingContainer.RUN_SCRIPT_FROM_DAEMON;
import static com.xebialabs.deployit.plugin.python.PythonStep.appendRuntimeScripts;
import static com.xebialabs.deployit.plugin.python.PythonStep.dumpPythonScript;
import static com.xebialabs.deployit.plugin.python.PythonVarsConverter.toPythonString;
import static com.xebialabs.deployit.plugin.remoting.scripts.ScriptUtils.loadScript;
import static com.xebialabs.deployit.plugin.remoting.scripts.ScriptUtils.uploadScript;
import static java.lang.String.format;
import static org.apache.commons.io.IOUtils.closeQuietly;

@SuppressWarnings("serial")
public class PythonDaemon implements ExecutionStateListener {

    private static final String DAEMON_SCRIPT_NAME = "daemon.py";
    private static final String DAEMON_SCRIPT_PATH = "python/daemon/daemon.py";
    private static final String DAEMON_CHECKPOINT = "DEPLOYIT-DAEMON-CHECKPOINT";
    private static final String DAEMON_EXIT_CODE_MARKER = "DEPLOYIT-DAEMON-EXIT-CODE:";

    private static final int FLUSH_DELAY_MS = 5000;
    private static final int FLUSH_CHECK_INTERVAL_MS = 1000;
    private static final Timer flushTimer = new Timer("Daemon-AutoFlushTimer", true);

    private final PythonManagingContainer container;

    private transient OverthereConnection connection;

    private transient OverthereProcess process;

    private transient OutputStream stdin;

    private transient Thread stdoutT;

    private transient Thread stderrT;

    private transient AtomicReference<CountDownLatch> checkpoint = new AtomicReference<CountDownLatch>(new CountDownLatch(2));
    private transient AtomicReference<ExecutionContext> currentContext = new AtomicReference<ExecutionContext>();
    private transient AtomicInteger lastExitCode = new AtomicInteger(-1);

    public PythonDaemon(PythonManagingContainer container) {
        this.container = container;
    }

    public boolean isAlive() {
        return connection != null && process != null && stdin != null && stdoutT != null && stderrT != null;
    }

    public void start(ExecutionContext context) {
        context.logOutput("Starting the daemon on [" + container.getHost() + "]");

        this.connection = container.getHost().getConnection();

        OverthereFile uploadedDaemonScript = uploadDaemonScript();

        waitForDaemonStart(context, uploadedDaemonScript);
    }

    private OverthereFile uploadDaemonScript() {
        String daemonScript = generateDaemonScript();
        dumpPythonScript(DAEMON_SCRIPT_NAME, daemonScript);

        logger.debug("Uploading the daemon script " + DAEMON_SCRIPT_NAME);
        return uploadScript(connection, DAEMON_SCRIPT_NAME, daemonScript);
    }

    private String generateDaemonScript() {
        StringBuilder b = new StringBuilder();
        appendRuntimeScripts(container, b);
        b.append("#\n" + CONNECT_FROM_DAEMON + "()\n");
        b.append(loadScript(DAEMON_SCRIPT_PATH));
        b.append("#\n" + DISCONNECT_FROM_DAEMON + "()\n");
        return b.toString();
    }

    private void waitForDaemonStart(final ExecutionContext ctx, final OverthereFile uploadedDaemonScript) {
        currentContext.set(ctx);
        try {
            logger.debug("Starting the daemon");

            process = connection.startProcess(container.getScriptCommandLine(uploadedDaemonScript));

            logger.debug("Starting the daemon stream stdout and stderr handler threads");
            stdin = process.getStdin();

            startStdoutReaderThread();
            startStderrReaderThread();

            logger.debug("Waiting for the daemon to finish starting");
            waitForCheckpoints();

            logger.debug("The daemon has started");
        } finally {
            currentContext.set(null);
        }
    }

    private void startStdoutReaderThread() {
        String stdoutTname = "Daemon stdout";
        if(logger.isDebugEnabled()) {
            stdoutTname += " for " + container.getId();
        }

        stdoutT = new Thread(new StdoutThread(process.getStdout()), stdoutTname);
        stdoutT.setDaemon(true);
        stdoutT.start();
    }

    private void startStderrReaderThread() {
        String stderrTname = "Daemon stderr";
        if(logger.isDebugEnabled()) {
            stderrTname += " for " + container.getId();
        }

        stderrT = new Thread(new StderrThread(process.getStderr()), stderrTname);
        stderrT.setDaemon(true);
        stderrT.start();
    }

    public int executePythonScript(final ExecutionContext ctx, final OverthereFile script) {
        currentContext.set(ctx);
        try {
            try {
                logger.debug("Resetting countdown latch to 2");
                checkpoint.set(new CountDownLatch(2));

                logger.debug("Resetting last exit code to -1");
                lastExitCode.set(-1);

                logger.info("Executing uploaded script [{}] on [{}] (with daemon)", script.getPath(), connection);
                sendLine(RUN_SCRIPT_FROM_DAEMON + "(" + toPythonString(script.getPath()) + ")");

                logger.debug("Waiting for the daemon to finish executing the command");
                waitForCheckpoints();

                int exitCode = lastExitCode.get();
                logger.debug("Returning last exit code [{}]", exitCode);
                return exitCode;
            } catch (IOException exc) {
                throw new RuntimeIOException(format("Cannot execute script [%s] on [%s]", script.getPath(), container.getHost()), exc);
            }
        } finally {
            currentContext.set(null);
        }
    }

    private void stop() {
        logger.info("Stopping the daemon on [{}]", container.getHost());
        try {
            sendLine("EXIT");
            process.waitFor();
            process = null;
        } catch (IOException exc) {
            logger.error("Error stopping daemon", exc);
        } catch (RuntimeException exc) {
            logger.error("Error stopping daemon", exc);
        } catch (InterruptedException exc) {
            logger.error("Error stopping daemon", exc);
        } finally {
            closeQuietly(connection);
            connection = null;
        }
    }

    private void sendLine(String line) throws IOException {
        logger.debug("Sending line [{}] to the daemon", line);
        stdin.write((line + container.getHost().getOs().getLineSeparator()).getBytes());
        stdin.flush();
    }

    private void countdownCheckpointLatch() {
        logger.debug("Counting down the checkpoint latch");
        checkpoint.get().countDown();
        logger.debug("Done counting down the checkpoint latch");
    }

    private void waitForCheckpoints() {
        try {
            logger.debug("Waiting for the checkpoint latch");
            checkpoint.get().await();
            logger.debug("Done waiting for the checkpoint latch");
        } catch (InterruptedException exc) {
            throw new RuntimeException("Interrupted waiting for checkpoints", exc);
        }
    }

    @Override
    public void stepStateChanged(StepExecutionStateEvent event) {
        // Do nothing
    }

    @Override
    public void taskStateChanged(TaskExecutionStateEvent event) {
        if (isAlive() && event.currentState().isPassiveAfterExecuting()) {
            stop();
        }
    }

    private abstract class StreamThread implements Runnable {

        private Reader out;
        private StringBuffer lineBuffer;

        StreamThread(InputStream out) {
            this.out = new InputStreamReader(out);
            this.lineBuffer = new StringBuffer();
        }

        @Override
        public void run() {
            cleanMDC();

            final long[] flushAfter = new long[1];
            final TimerTask flushTimerTask = new TimerTask() {
                @Override
                public void run() {
                    synchronized (lineBuffer) {
                        if (flushAfter[0] < System.currentTimeMillis()) {
                            if (lineBuffer.length() > 0) {
                                String line = lineBuffer.toString();
                                lineBuffer.setLength(0);
                                logger.debug("Partial line has not been updated for " + FLUSH_DELAY_MS + "ms, flushing line: [{}]", line);
                                lineReceived(line);
                            }
                            flushAfter[0] = System.currentTimeMillis() + FLUSH_DELAY_MS;
                        }
                    }
                }
            };
            flushTimer.schedule(flushTimerTask, FLUSH_DELAY_MS, FLUSH_CHECK_INTERVAL_MS);
            try {
                char prevC = '\0';
                for (;;) {
                    try {
                        int cInt = out.read();
                        if (cInt == -1) {
                            connectionLost();
                            break;
                        } else {
                            char c = (char) cInt;
                            try {
                                if (c != '\r' && c != '\n') {
                                    logger.trace("Adding char [{}] to line buffer", c);
                                    synchronized (lineBuffer) {
                                        lineBuffer.append(c);
                                    }
                                } else {
                                    if (c == '\n' && prevC == '\r') {
                                        logger.trace("Skipping LF after CR");
                                        continue;
                                    }
                                    synchronized (lineBuffer) {
                                        flushAfter[0] = System.currentTimeMillis() + FLUSH_DELAY_MS;
                                        String line = lineBuffer.toString();
                                        lineBuffer.setLength(0);
                                        logger.trace("Newline found, flushing line: [{}]", line);
                                        lineReceived(line);
                                    }
                                }
                            } finally {
                                prevC = c;
                            }
                        }
                    } catch (IOException exc) {
                        exceptionReceived(exc);
                    }
                }
            } finally {
                flushTimerTask.cancel();
            }
        }

        private void cleanMDC() {
            String username = MDC.get("username");
            String taskId = MDC.get("taskId");
            MDC.clear();
            if(username != null) {
                MDC.put("username", username);
            }
            if(taskId != null) {
                MDC.put("taskId", taskId);
            }
        }

        void lineReceived(String line) {
            if (line.startsWith(DAEMON_CHECKPOINT)) {
                logger.debug("Detected checkpoint.");
                countdownCheckpointLatch();
            } else if (line.startsWith(DAEMON_EXIT_CODE_MARKER)) {
                logger.debug("Detected exit code. Parsing it.");
                String exitCodeStr = line.substring(DAEMON_EXIT_CODE_MARKER.length());
                try {
                    int exitCode = Integer.parseInt(exitCodeStr);
                    logger.debug("Exit code parsed: {}", exitCode);
                    lastExitCode.set(exitCode);
                } catch (NumberFormatException ignored) {
                    logger.debug("Cannot parse exit code [{}]", exitCodeStr);
                }
            } else {
                log(line);
            }
        }

        abstract void log(String line);

        void exceptionReceived(IOException exc) {
            logger.debug("Error in connection to the daemon", exc);
            ExecutionContext c = currentContext.get();
            if (c != null) {
                c.logError("Error in connection to the daemon", exc);
            }
            connectionLost();
        }

        void connectionLost() {
            logger.debug("Lost connection to the daemon");
            countdownCheckpointLatch();
        }

    }

    private class StdoutThread extends StreamThread {
        StdoutThread(InputStream out) {
            super(out);
        }

        @Override
        void log(String line) {
            ExecutionContext c = currentContext.get();
            if (c != null) {
                c.logOutput(line);
            }
        }
    }

    private class StderrThread extends StreamThread {
        StderrThread(InputStream out) {
            super(out);
        }

        @Override
        void log(String line) {
            ExecutionContext c = currentContext.get();
            if (c != null) {
                c.logError(line);
            }
        }
    }

    private static Logger logger = LoggerFactory.getLogger(PythonDaemon.class);
    
}
