package ai.digital.deploy.profiling.persistence

import ai.digital.deploy.profiling.ApplicationProfiler.PROFILING
import ai.digital.deploy.profiling.web.CorrelationalIdListener
import com.xebialabs.deployit.core.sql.spring.ExtendedArgumentPreparedStatementSetter
import com.xebialabs.deployit.core.sql.spring.Setter
import grizzled.slf4j.Logging
import org.aspectj.lang.ProceedingJoinPoint
import org.aspectj.lang.annotation.Around
import org.aspectj.lang.annotation.Aspect
import org.slf4j.MDC
import org.springframework.context.annotation.Profile
import org.springframework.jdbc.core.PreparedStatementCreator
import org.springframework.jdbc.core.PreparedStatementSetter
import org.springframework.jdbc.core.SqlProvider
import org.springframework.stereotype.Component

import java.time.Duration
import java.time.Instant
import java.util

@Aspect
@Component
@Profile(Array(PROFILING))
class DatabaseProfiler extends Logging {

  @Around("execution(* org.springframework.jdbc.core.JdbcTemplate.query*(String,Class,..)) || " +
    "execution(* org.springframework.jdbc.core.JdbcTemplate.query*(String,org.springframework.jdbc.core.RowMapper,..))")
  def processFewQueryMethods(pjp: ProceedingJoinPoint): AnyRef = processLogEntry(pjp, 2)


  @Around("execution(* org.springframework.jdbc.core.JdbcTemplate.query*(..)) && " +
    "!(execution(* org.springframework.jdbc.core.JdbcTemplate.query*(String,Class,..)) || " +
    "execution(* org.springframework.jdbc.core.JdbcTemplate.query*(String,org.springframework.jdbc.core.RowMapper,..)))")
  def processRemainingQueryMethods(pjp: ProceedingJoinPoint): AnyRef = processLogEntry(pjp, 1)

  // TODO : Add pointcuts for the execute, update as well

  def processLogEntry(pjp: ProceedingJoinPoint, positionOfValues: Int): AnyRef = {
    val start = Instant.now()
    val result = pjp.proceed()

    val logEntry: StringBuilder = new StringBuilder()

    Option(MDC.get(CorrelationalIdListener.CORRELATION_ID)).foreach(s => {
      logEntry.append(String.format("[CorrelationID : %s]", s))
    })

    getQuery(pjp.getArgs).foreach(s => {
      logEntry.append(String.format("[Execution Time : %014d nano seconds] [Query: %s] ",
        Duration.between(start, Instant.now()).toNanos(), s))
    })
    getValues(pjp.getArgs, positionOfValues).foreach(s => {
      logEntry.append(String.format("[Values : %s]", s))
    })
    logger.info(logEntry.toString())
    result
  }

  private def getQuery(args: Array[AnyRef]): Option[String] = {
    if ((args == null) || args.length == 0) return Option.empty

    args(0) match {
      case psc: PreparedStatementCreator => Option(psc.asInstanceOf[SqlProvider].getSql)
      case s: String => Option(s)
      case _ => Option.empty
    }
  }

  private def getValues(args: Array[AnyRef], positionOfValues: Int): Option[String] = {
    if ((args == null) || args.length <= positionOfValues) return Option.empty

    args(positionOfValues) match {
      case values: Array[Object] => Option(util.Arrays.toString(values))
      case s @ (_:ExtendedArgumentPreparedStatementSetter | _:Setter) => Option(s.toString)
      case _: PreparedStatementSetter => Option("<Could not get parameter values. Since trace level for " +
        "'org.springframework.jdbc.core' is enabled, please check the log statements from org.springframework.jdbc.core" +
        " package>")
      case _ => Option.empty
    }
  }
}
