package com.xebialabs.deployit.jetty;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
import java.io.IOException;

/**
 */
public class ResponseCodeManglingFilter implements Filter {
	protected static final String X_ORIGINAL_STATUS = "X-Original-Status";
	protected static final String X_WANT_OK_STATUS = "X-Want-Ok-Status";

	private static final Logger logger = LoggerFactory.getLogger(ResponseCodeManglingFilter.class);

	@Override
	public void init(final FilterConfig filterConfig) throws ServletException {
		logger.debug("Added Hide responsecode Filter");
	}

	@Override
	public void doFilter(final ServletRequest request, final ServletResponse response, final FilterChain chain) throws IOException, ServletException {
		final ServletResponse httpResponse;
		if (request instanceof HttpServletRequest && response instanceof HttpServletResponse) {
			final String s = ((HttpServletRequest) request).getHeader(X_WANT_OK_STATUS);
			if (s != null && !"".equals(s.trim())) {
				logger.debug("Hiding Response Code if necessary");
				httpResponse = new HttpReponseOkWrapper((HttpServletResponse) response);
			} else {
				httpResponse = response;
			}
		} else {
			httpResponse = response;
		}
		
		chain.doFilter(request, httpResponse);
	}

	@Override
	public void destroy() {

	}

	private class HttpReponseOkWrapper extends HttpServletResponseWrapper {

		private HttpReponseOkWrapper(final HttpServletResponse response) {
			super(response);
		}

		@Override
		public void setStatus(final int sc) {
			int code = _determineCorrectStatusCode(sc);
			super.setStatus(code);
		}

		private int _determineCorrectStatusCode(final int sc) {
			addHeader(X_ORIGINAL_STATUS, Integer.toString(sc));
			if (sc >= 400) {
				logger.info("Hiding status code {} with {}", sc, 207);
				return 207;
			}
			return sc;
		}

		@Override
		public void setStatus(final int sc, final String sm) {
			int code = _determineCorrectStatusCode(sc);
			super.setStatus(code, sm);
		}

		@Override
		public void sendError(final int sc) throws IOException {
			int code = _determineCorrectStatusCode(sc);
			super.sendError(code);
		}

		@Override
		public void sendError(final int sc, final String msg) throws IOException {
			int code = _determineCorrectStatusCode(sc);
			super.sendError(code, msg);
		}
	}
}
