package ai.digital.config.server.client

import java.io.IOException
import java.security.GeneralSecurityException
import java.util
import java.util.Collections

import ai.digital.config.ConfigServerProfiles
import grizzled.slf4j.Logging
import org.apache.http.impl.client.HttpClients
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.cloud.config.client.ConfigClientProperties.{AUTHORIZATION, STATE_HEADER, TOKEN_HEADER}
import org.springframework.cloud.config.client.{ConfigClientProperties, ConfigClientStateHolder, ConfigServicePropertySourceLocator}
import org.springframework.cloud.configuration.SSLContextFactory
import org.springframework.context.annotation.Profile
import org.springframework.core.env.Environment
import org.springframework.http._
import org.springframework.http.client.{ClientHttpRequestFactory, HttpComponentsClientHttpRequestFactory, SimpleClientHttpRequestFactory}
import org.springframework.stereotype.Component
import org.springframework.util.{Base64Utils, StringUtils}
import org.springframework.web.client.{HttpClientErrorException, ResourceAccessException, RestTemplate}

import scala.jdk.CollectionConverters._


object RestTemplateHelper {
  val actuatorRefreshPath = "/actuator/refresh"
}

@Profile(Array(ConfigServerProfiles.NOT_CONFIG_SERVER))
@Component
class RestTemplateHelper(@Autowired(required = false) private val defaultRestTemplate: RestTemplate,
                         private val defaultProperties: ConfigClientProperties,
                         private val environment: Environment) extends Logging {

  val actuatorRefreshPath = "/actuator/refresh"

  def buildUri(uri: String, path: String, replace: String = ""): String = {
    val slash = if (uri.endsWith("/") || path.startsWith("/")) "" else "/"
    s"$uri$slash$path".replace(replace, "")
  }

  def execute[B, T](path: String, method: HttpMethod, responseType: Class[T],
                    bodyMaybe: Option[B] = None,
                    mediaTypeMaybe: Option[String] = None,
                    uriVariables: Map[String, _] = Map.empty,
                    replacePath: String = ""): Option[T] = {
    val properties = this.defaultProperties.`override`(environment)

    val restTemplate = if (defaultRestTemplate == null)
      getSecureRestTemplate(properties)
    else
      defaultRestTemplate

    val noOfUrls = properties.getUri.length
    if (noOfUrls > 1)
      logger.info("Multiple Config Server Urls found listed.")

    val acceptHeader = mediaTypeMaybe
      .map(mediaType => Collections.singletonList(MediaType.parseMediaType(mediaType)))
      .getOrElse(Collections.singletonList(MediaType.parseMediaType(properties.getMediaType)))

    val contentType = mediaTypeMaybe
      .map(mediaType => MediaType.parseMediaType(mediaType))
      .getOrElse(MediaType.parseMediaType(properties.getMediaType))

    val state = ConfigClientStateHolder.getState

    var responseBody: Option[T] = None
    var i = 0
    while (responseBody.isEmpty || i < noOfUrls) {
      val credentials = properties.getCredentials(i)
      val uri = credentials.getUri
      val username = credentials.getUsername
      val password = credentials.getPassword
      val token = properties.getToken

      logger.debug(s"Using config from server at : $uri")

      def prepareHeaders(): HttpHeaders = {
        val headers = new HttpHeaders
        headers.setContentType(contentType)
        headers.setAccept(acceptHeader)
        addAuthorizationToken(properties, headers, username, password)
        if (StringUtils.hasText(token))
          headers.add(TOKEN_HEADER, token)
        if (StringUtils.hasText(state) && properties.isSendState)
          headers.add(STATE_HEADER, state)
        headers
      }

      try {
        val headers = prepareHeaders()

        val entity =
          bodyMaybe
            .map(body => new HttpEntity[B](body, headers))
            .getOrElse(new HttpEntity[Void](null.asInstanceOf[Void], headers))

        val response: ResponseEntity[T] = restTemplate.exchange(
          buildUri(uri, path, replacePath),
          method,
          entity,
          responseType,
          uriVariables.asJava
        )

        if (response == null || (response.getStatusCode != HttpStatus.OK))
          throw new IllegalStateException(s"Central configuration request failed with status code ${response.getStatusCode}: ${response.getBody}")
        else {
          responseBody = Some(response.getBody)
        }
      } catch {
        case e: HttpClientErrorException =>
          if (e.getStatusCode != HttpStatus.NOT_FOUND)
            throw e
        case e: ResourceAccessException =>
          logger.warn(s"Connection Timeout Exception on URL - $uri. The next URL to be checked if available")
          if (i == noOfUrls - 1)
            throw e
      }
      i += 1
    }
    responseBody
  }

  private def addAuthorizationToken(configClientProperties: ConfigClientProperties,
                                    httpHeaders: HttpHeaders,
                                    username: String,
                                    password: String): Unit = {
    val authorization = configClientProperties.getHeaders.get(AUTHORIZATION)
    if (password != null && authorization != null)
      throw new IllegalStateException("You must set either 'password' or 'authorization'")
    if (password != null) {
      val token = Base64Utils.encode((username + ":" + password).getBytes)
      httpHeaders.add("Authorization", "Basic " + new String(token))
    }
    else if (authorization != null) httpHeaders.add("Authorization", authorization)
  }

  private def getSecureRestTemplate(client: ConfigClientProperties) = {
    if (client.getRequestReadTimeout < 0)
      throw new IllegalStateException("Invalid Value for Read Timeout set.")

    if (client.getRequestConnectTimeout < 0)
      throw new IllegalStateException("Invalid Value for Connect Timeout set.")

    val requestFactory = createHttpRequestFactory(client)
    val template = new RestTemplate(requestFactory)
    val headers = new util.HashMap[String, String](client.getHeaders)
    if (headers.containsKey(AUTHORIZATION))
      headers.remove(AUTHORIZATION) // To avoid redundant addition of header
    if (!headers.isEmpty)
      template.setInterceptors(util.Arrays.asList(new ConfigServicePropertySourceLocator.GenericRequestHeaderInterceptor(headers)))
    template
  }

  private def createHttpRequestFactory(client: ConfigClientProperties): ClientHttpRequestFactory = {
    if (client.getTls.isEnabled) {
      try {
        val factory = new SSLContextFactory(client.getTls)
        val sslContext = factory.createSSLContext
        val httpClient = HttpClients.custom.setSSLContext(sslContext).build
        val result = new HttpComponentsClientHttpRequestFactory(httpClient)

        result.setReadTimeout(client.getRequestReadTimeout)
        result.setConnectTimeout(client.getRequestConnectTimeout)
        return result
      } catch {
        case ex@(_: GeneralSecurityException | _: IOException) =>
          logger.error(ex)
          throw new IllegalStateException("Failed to create config client with TLS.", ex)
      }
    }
    val result = new SimpleClientHttpRequestFactory
    result.setReadTimeout(client.getRequestReadTimeout)
    result.setConnectTimeout(client.getRequestConnectTimeout)
    result
  }
}
