package com.xebialabs.license.service.aws

import java.io.{BufferedReader, StringReader}
import java.nio.charset.StandardCharsets
import java.util.Base64

import software.amazon.awssdk.services.marketplacemetering.MarketplaceMeteringClient
import software.amazon.awssdk.services.marketplacemetering.model.RegisterUsageRequest
import com.hierynomus.asn1.ASN1InputStream
import com.hierynomus.asn1.encodingrules.der.DERDecoder
import com.hierynomus.asn1.types.constructed.ASN1Sequence
import com.hierynomus.asn1.types.string.ASN1BitString
import com.hierynomus.asn1.types.{ASN1Object, ASN1Tag}
import com.xebialabs.license.{ASN1Utils, LicenseVerificationException}
import com.xebialabs.license.service.LicenseConfig
import com.xebialabs.xlplatform.utils.ResourceManagement
import grizzled.slf4j.Logging
import org.bouncycastle.asn1.pkcs.RSAPublicKey
import org.bouncycastle.crypto.digests.SHA256Digest
import org.bouncycastle.crypto.engines.RSABlindedEngine
import org.bouncycastle.crypto.params.RSAKeyParameters
import org.bouncycastle.crypto.signers.PSSSigner
import scala.jdk.CollectionConverters._

trait AWSMetering extends Logging {
  val licenseConfig: LicenseConfig
  // Upon every instantiation of the product (metered license service), we generate a new UUID.
  // The UUID will be used to sign/verify the license to prevent replay attacks.
  val nonce: String

  private[this] val aws = licenseConfig.license.aws
  private[this] lazy val meteredService = MarketplaceMeteringClient.builder().build()

  def registerUsage(): Unit = {
    logger.info(s"Calling AWS RegisterUsage (product code: ${aws.productCode}, nonce: $nonce)")
    val request = RegisterUsageRequest.builder()
      .nonce(nonce)
      .productCode(aws.productCode)
      .publicKeyVersion(aws.publicKeyVersion)
      .build()

    val response = meteredService.registerUsage(request)
    val signature = response.signature()
    SignatureVerifier(aws.productCode, nonce, aws.publicKeyVersion, aws.publicKey).verify(signature)
  }

}

case class SignatureVerifier(productCode: String, nonce: String, publicKeyVersion: Int, publicKey: String) {
  def verify(signature: String): Unit = {
    val signer = newSigner
    val (header, payload, payloadSignature) = jwtComponents(signature)
    val bytes = s"$header.$payload".getBytes(StandardCharsets.UTF_8)
    signer.update(bytes, 0, bytes.length)
    val decodedSignatureBytes = Base64.getUrlDecoder.decode(payloadSignature.getBytes(StandardCharsets.UTF_8))
    if (!signer.verifySignature(decodedSignatureBytes)) {
      throw new LicenseVerificationException(s"Failed to verify the RegisterUsage signature from AWS (product cˀode: $productCode)")
    }
  }

  private def newSigner = {
    val signer = new PSSSigner(new RSABlindedEngine(), new SHA256Digest(), 32, PSSSigner.TRAILER_IMPLICIT)
    val key = readPublicKey
    signer.init(false, new RSAKeyParameters(false, key.getModulus, key.getPublicExponent))
    signer
  }

  def jwtComponents(signature: String): (String, String, String) = {
    val parts = signature.split("\\.")
    (parts(0), parts(1), parts(2))
  }

  def readPublicKey: RSAPublicKey = {
    val beginMarker = "-----BEGIN"
    val endMarker = "-----END"
    ResourceManagement.using(new BufferedReader(new StringReader(publicKey))) { reader =>
      val lines = reader.lines().iterator().asScala.filterNot(l => l.startsWith(beginMarker) || l.startsWith(endMarker)).map(_.trim)
      val derParser = new ASN1InputStream(new DERDecoder, Base64.getDecoder.decode(lines.mkString))
      val root: ASN1Object[_] = derParser.readObject()
      if (root.getTag.equals(ASN1Tag.SEQUENCE)) {
        val pk = root.asInstanceOf[ASN1Sequence].asScala.collectFirst({ case bitString: ASN1BitString => ASN1Utils.constructBytes(bitString.getValue) }) match {
          case Some(bytes) => RSAPublicKey.getInstance(bytes)
          case None => throw new IllegalStateException(s"Could not find public key data in $publicKey")
        }
        pk
      } else {
        throw new IllegalStateException(s"Could not decode the ASN.1 public key data, got $root")
      }
    }
  }

}
