package ai.digital.deploy.permissions.api.rest.pagination

import ai.digital.deploy.permissions.MessageHandler
import ai.digital.deploy.permissions.exception.PaginationParseException
import org.springframework.core.MethodParameter
import org.springframework.data.domain.{PageRequest, Pageable, Sort}
import org.springframework.stereotype.Component
import org.springframework.web.bind.support.WebDataBinderFactory
import org.springframework.web.context.request.NativeWebRequest
import org.springframework.web.method.support.{HandlerMethodArgumentResolver, ModelAndViewContainer}

import scala.util.{Failure, Success, Try}

@Component
class PaginationResolver(messageHandler: MessageHandler) extends HandlerMethodArgumentResolver {
  private val errorMessageCode: String = "pagination.parse.error"

  override def supportsParameter(parameter: MethodParameter): Boolean = parameter.getParameterType.equals(classOf[Pageable])

  override def resolveArgument(parameter: MethodParameter,
                               mavContainer: ModelAndViewContainer,
                               webRequest: NativeWebRequest,
                               binderFactory: WebDataBinderFactory
  ): AnyRef = {
    val page = getValue(webRequest, Paging.PAGE_PARAMETER, 1) { value =>
      if (value > 0)
        value - 1
      else
        value
    }
    val size = getValue(webRequest, Paging.SIZE_PARAMETER, Int.MaxValue) { value =>
      if (value < 1)
        Int.MaxValue
      else
        value
    }
    PageRequest.of(page, size, getSort(webRequest))
  }

  private def getValue(webRequest: NativeWebRequest, parameter: String, defaultValue: Int)(getValue: Int => Int): Int = {
    val param = webRequest.getParameter(parameter)

    Try(Option(param).filterNot(_.trim.isEmpty).map(_.trim.toInt).getOrElse(defaultValue)) match {
      case Success(value) =>
        getValue(value)
      case Failure(exception: NumberFormatException) =>
        throw PaginationParseException(messageHandler.getMessage(errorMessageCode, param), exception)
      case Failure(exception) =>
        throw exception
    }
  }

  private def getSort(webRequest: NativeWebRequest): Sort =
    Option(webRequest.getParameter(Order.ORDER_PARAMETER)).map(_.split(":", 2)) match {
      case Some(Array(field, order)) => Sort.by(Sort.Direction.fromString(order.toUpperCase), field)
      case None => Sort.unsorted()
    }
}
