package com.xebialabs.satellite.streaming

import java.util.zip.{Deflater, Inflater}

import akka.stream.scaladsl.Flow
import akka.util.{ByteString, ByteStringBuilder}

import scala.collection.mutable.ArrayBuilder

object Compression {

  def decompress(implicit streamConfig: StreamConfig) = ifEnabled(decompressFunction)

  def compress(implicit streamConfig: StreamConfig) = ifEnabled(compressFunction)

  private def ifEnabled(function: (ByteString) => ByteString)(implicit streamConfig: StreamConfig) = Flow[ByteString].map(streamConfig.compression match {
    case true => function
    case false => identity[ByteString]
  })

  private val zipCompressor = new ZipCompressor

  val compressFunction: (ByteString) => ByteString = { bytes: ByteString =>
    val b = new ByteStringBuilder

    b ++= zipCompressor.compress(bytes.toArray)

    b.result()
  }

  val decompressFunction = { bytes: ByteString =>
    val b = new ByteStringBuilder

    b ++= zipCompressor.decompress(bytes.toArray)

    b.result()
  }

  class ZipCompressor {

    lazy val deflater = new Deflater(Deflater.BEST_SPEED)
    lazy val inflater = new Inflater()

    def compress(inputBuff: Array[Byte]): Array[Byte] = {
      val compressedDataBuffer = new ArrayBuilder.ofByte

      deflater.setInput(inputBuff)
      deflater.finish()
      val buff = new Array[Byte](4096)

      while (!deflater.finished) {
        val n = deflater.deflate(buff)
        compressedDataBuffer ++= buff.take(n)
      }
      deflater.reset()
      buildOutputBuffer(compressedDataBuffer)
    }

    def buildOutputBuffer(tempBuff: ArrayBuilder.ofByte): Array[Byte] = {
      val outputBuff = new ArrayBuilder.ofByte
      val outputSize: Int = tempBuff.result().size
      outputBuff += (outputSize & 0xff).toByte
      outputBuff += (outputSize >> 8 & 0xff).toByte
      outputBuff += (outputSize >> 16 & 0xff).toByte
      outputBuff += (outputSize >> 24 & 0xff).toByte
      outputBuff ++= tempBuff.result()
      outputBuff.result()
    }

    def decompress(inputBuff: Array[Byte]): Array[Byte] = {
      readAndDecompressChunks(inputBuff, 0, new ArrayBuilder.ofByte)
    }

    def readAndDecompressChunks(inputBuff: Array[Byte], index: Int, outputBuff: ArrayBuilder.ofByte): Array[Byte] = {
      if (index < inputBuff.size) {
        val computedSize = computeSize(inputBuff.slice(index, inputBuff.length))
        val chunk: Array[Byte] = inputBuff.slice(index + 4, index + computedSize + 4)
        outputBuff ++= inflateChunk(chunk)
        readAndDecompressChunks(inputBuff, index + computedSize + 4, outputBuff)
      }
      outputBuff.result()
    }

    def computeSize(inputBuff: Array[Byte]): Int = {
      (inputBuff(0).asInstanceOf[Int] & 0xff) |
        (inputBuff(1).asInstanceOf[Int] & 0xff) << 8 |
        (inputBuff(2).asInstanceOf[Int] & 0xff) << 16 |
        (inputBuff(3).asInstanceOf[Int] & 0xff) << 24
    }

    def inflateChunk(inputBuff:Array[Byte]):Array[Byte]={
      val outputBuff = new ArrayBuilder.ofByte
      inflater.setInput(inputBuff)
      val buff = new Array[Byte](4096)

      while (!inflater.finished) {
        val n = inflater.inflate(buff)
        outputBuff ++= buff.take(n)
      }
      inflater.reset()
      outputBuff.result()
    }

  }
}
