/*
 * Copyright (c) 2010 XebiaLabs B.V. All rights reserved.
 *
 * Your use of XebiaLabs Software and Documentation is subject to the Personal
 * License Agreement.
 *
 * http://www.xebialabs.com/deployit-personal-edition-license-agreement
 *
 * You are granted a personal license (i) to use the Software for your own
 * personal purposes which may be used in a production environment and/or (ii)
 * to use the Documentation to develop your own plugins to the Software.
 * "Documentation" means the how to's and instructions (instruction videos)
 * provided with the Software and/or available on the XebiaLabs website or other
 * websites as well as the provided API documentation, tutorial and access to
 * the source code of the XebiaLabs plugins. You agree not to (i) lease, rent
 * or sublicense the Software or Documentation to any third party, or otherwise
 * use it except as permitted in this agreement; (ii) reverse engineer,
 * decompile, disassemble, or otherwise attempt to determine source code or
 * protocols from the Software, and/or to  (iii) copy the Software or
 * Documentation (which includes the source code of the XebiaLabs plugins). You
 * shall not create or attempt to create any derivative works from the Software
 * except and only to the extent permitted by law. You will preserve XebiaLabs'
 * copyright and legal notices on the Software and Documentation. XebiaLabs
 * retains all rights not expressly granted to You in the Personal License
 * Agreement.
 */

package com.xebialabs.deployit.hostsession.ssh;

import static com.google.common.base.Preconditions.checkNotNull;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Collections;
import java.util.Map;
import java.util.Random;

import org.apache.commons.lang.StringUtils;
import org.apache.log4j.Logger;

import com.jcraft.jsch.ChannelExec;
import com.jcraft.jsch.JSch;
import com.jcraft.jsch.JSchException;
import com.jcraft.jsch.Session;
import com.jcraft.jsch.UserInfo;
import com.xebialabs.deployit.ci.OperatingSystemFamily;
import com.xebialabs.deployit.exception.RuntimeIOException;
import com.xebialabs.deployit.hostsession.CommandExecution;
import com.xebialabs.deployit.hostsession.CommandExecutionCallbackHandler;
import com.xebialabs.deployit.hostsession.HostFile;
import com.xebialabs.deployit.hostsession.HostSession;
import com.xebialabs.deployit.hostsession.common.AbstractHostSession;
import com.xebialabs.deployit.hostsession.common.ErrorStreamToCallbackHandler;
import com.xebialabs.deployit.hostsession.common.InputResponseHandler;
import com.xebialabs.deployit.hostsession.common.OutputStreamToCallbackHandler;

/**
 * A host session over SSH.
 */
abstract class SshHostSession extends AbstractHostSession implements HostSession {

	protected String host;

	protected int port;

	protected String username;

	protected String password;

	protected Session sharedSession;

	private static final String CHANNEL_PURPOSE = "";

	/**
	 * Constructs an SshHostSession
	 * 
	 * @param os
	 *            the operating system of the host
	 * @param temporaryDirectoryPath
	 *            the path of the directory in which to store temporary files
	 * @param host
	 *            the hostname or IP adress of the host
	 * @param port
	 *            the port to connect to
	 * @param username
	 *            the username to connect with
	 * @param password
	 *            the password to connect with
	 */
	public SshHostSession(OperatingSystemFamily os, String temporaryDirectoryPath, String host, int port, String username, String password) {
		super(os, temporaryDirectoryPath);
		this.host = host;
		this.port = port;
		this.username = username;
		this.password = password;
	}

	void open() throws RuntimeIOException {
		if (sharedSession == null) {
			try {
				sharedSession = openSession(CHANNEL_PURPOSE);
			} catch (JSchException exc) {
				throw new RuntimeIOException("Cannot connect to " + this, exc);
			}
		}
	}

	@Override
	public void close() {
		super.close();
		disconnectSharedSession();
	}

	protected Session getSharedSession() {
		if (sharedSession == null) {
			throw new IllegalStateException("Not connected");
		}
		return sharedSession;
	}

	public void disconnectSharedSession() {
		disconnectSession(sharedSession, CHANNEL_PURPOSE);
		sharedSession = null;
	}

	protected Session openSession(String purpose) throws JSchException {
		JSch jsch = new JSch();
		Session session = jsch.getSession(username, host, port);
		session.setUserInfo(getUserInfo());
		session.connect();
		logger.info("Connected to " + this + purpose);
		return session;
	}

	protected void disconnectSession(Session session, String purpose) {
		if (session != null) {
			session.disconnect();
			logger.info("Disconnected from " + this + purpose);
		}
	}

	public HostFile getFile(String hostPath) throws RuntimeIOException {
		return getFile(hostPath, false);
	}

	protected abstract HostFile getFile(String hostPath, boolean isTempFile) throws RuntimeIOException;

	public HostFile getFile(HostFile parent, String child) throws RuntimeIOException {
		return getFile(parent, child, false);
	}

	protected HostFile getFile(HostFile parent, String child, boolean isTempFile) throws RuntimeIOException {
		if (!(parent instanceof SshHostFile)) {
			throw new IllegalStateException("parent is not a file on an SSH host");
		}
		if (parent.getSession() != this) {
			throw new IllegalStateException("parent is not a file in this session");
		}
		return getFile(parent.getPath() + getHostOperatingSystem().getFileSeparator() + child, isTempFile);
	}

	public HostFile getTempFile(String prefix, String suffix) throws RuntimeIOException {
		checkNotNull(prefix);
		if (suffix == null) {
			suffix = ".tmp";
		}

		Random r = new Random();
		String infix = "";
		for (int i = 0; i < AbstractHostSession.MAX_TEMP_RETRIES; i++) {
			HostFile f = getFile(getTemporaryDirectory().getPath() + getHostOperatingSystem().getFileSeparator() + prefix + infix + suffix, true);
			if (!f.exists()) {
				if (logger.isDebugEnabled())
					logger.debug("Created temporary file " + f);

				return f;
			}
			infix = "-" + Long.toString(Math.abs(r.nextLong()));
		}
		throw new RuntimeIOException("Cannot generate a unique temporary file name on " + this);
	}

	@SuppressWarnings("unchecked")
	public int execute(CommandExecutionCallbackHandler handler, String... commandLine) throws RuntimeIOException {
		return execute(handler, Collections.EMPTY_MAP, commandLine);
	}

	public int execute(CommandExecutionCallbackHandler handler, Map<String, String> inputResponse, String... commandLine) throws RuntimeIOException {
		String command = encodeCommandLine(false, commandLine);
		String commandWithHiddenPassword = encodeCommandLine(true, commandLine);
		try {
			ChannelExec channel = (ChannelExec) getSharedSession().openChannel("exec");
			Thread outputCopierThread = null;
			Thread errorCopierThread = null;
			try {
				// set up command
				channel.setPty(true);
				channel.setCommand(command);

				// set up streams
				InputStream remoteStdout = channel.getInputStream();
				InputStream remoteStderr = channel.getErrStream();
				OutputStream remoteStdin = channel.getOutputStream();

				// prepare to capture output
				CommandExecutionCallbackHandler responseHandler = getInputResponseHandler(handler, remoteStdin, inputResponse);
				outputCopierThread = new Thread(new OutputStreamToCallbackHandler(remoteStdout, responseHandler));
				outputCopierThread.start();
				errorCopierThread = new Thread(new ErrorStreamToCallbackHandler(remoteStderr, responseHandler));
				errorCopierThread.start();

				// execute the command
				channel.connect();
				logger.info("Executing remote command \"" + commandWithHiddenPassword + "\" on " + this);

				int exitValue = waitForExitStatus(channel);
				if (logger.isDebugEnabled())
					logger.debug("Finished executing remote command \"" + commandWithHiddenPassword + "\" on " + this + " with exit value " + exitValue);
				return exitValue;
			} finally {
				channel.disconnect();
				if (outputCopierThread != null) {
					try {
						outputCopierThread.join();
					} catch (InterruptedException ignored) {
					}
				}
				if (errorCopierThread != null) {
					try {
						errorCopierThread.join();
					} catch (InterruptedException ignored) {
					}
				}
			}
		} catch (IOException exc) {
			throw new RuntimeIOException("Cannot execute remote command \"" + commandWithHiddenPassword + "\" on " + this, exc);
		} catch (JSchException exc) {
			throw new RuntimeIOException("Cannot execute remote command \"" + commandWithHiddenPassword + "\" on " + this, exc);
		}
	}

	protected CommandExecutionCallbackHandler getInputResponseHandler(CommandExecutionCallbackHandler originalHandler, OutputStream remoteStdin,
			Map<String, String> inputResponse) {
		return new InputResponseHandler(originalHandler, remoteStdin, inputResponse);
	}

	public CommandExecution startExecute(String... commandLine) {
		final String command = encodeCommandLine(false, commandLine);
		final String commandWithHiddenPassword = encodeCommandLine(true, commandLine);
		try {
			final ChannelExec channel = (ChannelExec) getSharedSession().openChannel("exec");
			// set up command
			channel.setPty(true);
			channel.setCommand(command);

			channel.connect();

			logger.info("Executing remote command \"" + commandWithHiddenPassword + "\" on " + this + " and passing control to caller");

			return new CommandExecution() {
				public OutputStream getStdin() {
					try {
						return channel.getOutputStream();
					} catch (IOException exc) {
						throw new RuntimeIOException("Cannot open output stream to remote stdin");
					}
				}

				public InputStream getStdout() {
					try {
						return channel.getInputStream();
					} catch (IOException exc) {
						throw new RuntimeIOException("Cannot open input stream from remote stdout");
					}
				}

				public InputStream getStderr() {
					try {
						return channel.getErrStream();
					} catch (IOException exc) {
						throw new RuntimeIOException("Cannot open input stream from remote stderr");
					}
				}

				public int waitFor() {
					try {
						int exitValue = waitForExitStatus(channel);
						logger.info("Finished executing remote command \"" + commandWithHiddenPassword + "\" on " + this + " with exit value " + exitValue
								+ " (control was passed to caller)");
						return exitValue;
					} finally {
						channel.disconnect();
					}
				}
			};
		} catch (JSchException exc) {
			throw new RuntimeIOException("Cannot execute remote command \"" + commandWithHiddenPassword + "\" on " + this, exc);
		}

	}

	static int waitForExitStatus(ChannelExec channel) {
		while (true) {
			if (channel.isClosed()) {
				return channel.getExitStatus();
			}
			try {
				Thread.sleep(1000);
			} catch (Exception ee) {
			}
		}
	}

	protected UserInfo getUserInfo() {
		return new UserInfo() {
			public boolean promptPassword(String prompt) {
				return true;
			}

			public String getPassword() {
				return password;
			}

			public boolean promptPassphrase(String prompt) {
				return false;
			}

			public String getPassphrase() {
				return null;
			}

			public boolean promptYesNo(String prompt) {
				return true;
			}

			public void showMessage(String msg) {
				logger.info("Message recieved while connecting to " + username + "@" + host + ":" + port + ": " + msg);
			}
		};
	}

	public String getHost() {
		return host;
	}

	public void setHost(String host) {
		this.host = host;
	}

	public int getPort() {
		return port;
	}

	public void setPort(int port) {
		this.port = port;
	}

	public String getUsername() {
		return username;
	}

	public void setUsername(String username) {
		this.username = username;
	}

	public String getPassword() {
		return password;
	}

	public void setPassword(String password) {
		this.password = password;
	}

	public static String encodeCommandLine(boolean hidePassword, String... commandLine) {
		if (commandLine == null || commandLine.length == 0) {
			throw new IllegalStateException("Cannot execute an empty command line");
		}

		StringBuilder sb = new StringBuilder();
		boolean passwordKeywordSeen = false;
		for (int i = 0; i < commandLine.length; i++) {
			if (i != 0) {
				sb.append(' ');
			}
			if (commandLine[i] == null) {
				sb.append("null");
			} else if (commandLine[i].length() == 0) {
				sb.append("\" \"");
			} else {
				if (passwordKeywordSeen && hidePassword) {
					for (int j = 0; j < commandLine[i].length(); j++) {
						sb.append("*");
					}
				} else {
					for (int j = 0; j < commandLine[i].length(); j++) {
						char c = commandLine[i].charAt(j);
						if (" '\"\\;()${}".indexOf(c) != -1) {
							sb.append('\\');
						}
						sb.append(c);
					}
				}
				// catch 'password' or '-password'
				passwordKeywordSeen = StringUtils.endsWithIgnoreCase(commandLine[i], "password");
			}
		}
		return sb.toString();
	}

	public String toString() {
		return username + "@" + host + ":" + port;
	}

	private static Logger logger = Logger.getLogger(SshHostSession.class);

}
