package com.xebialabs.satellite.streaming

import akka.stream.scaladsl.Flow
import akka.stream.stage._
import akka.stream.{Attributes, FlowShape, Inlet, Outlet}
import akka.util.ByteString

object ErrorStage {

  type Handler = Throwable => Unit

  val in: Inlet[ByteString] = Inlet.create("ErrorStage.in")

  private val out: Outlet[ByteString] = Outlet.create("ErrorStage.out")

  private val flowShape: FlowShape[ByteString, ByteString] = FlowShape.of(in, out)

  def flow(onError: Handler) = Flow[ByteString].via(handlingErrorStage(onError))

  def handlingErrorStage(onError: Handler) = new GraphStage[FlowShape[ByteString, ByteString]] {

    override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler {

      override def onPush(): Unit = push(out, grab(in))

      override def onUpstreamFailure(ex: Throwable) = {
        onError(ex)
        super.onUpstreamFailure(ex)
      }

      override def onPull(): Unit = pull(in)

      setHandlers(in, out, this)
    }

    override def shape: FlowShape[ByteString, ByteString] = flowShape
  }

}
