Last active
August 3, 2018 21:08
-
-
Save jmcardon/a69d3966d0b4b96e3a9b9c4bb40d8480 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.nio.charset.StandardCharsets.UTF_8 | |
import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} | |
import cats.effect.{Concurrent, Sync} | |
import cats.syntax.all._ | |
import fs2.async.mutable.Queue | |
import fs2.{Sink, Stream} | |
import org.http4s.websocket.WebsocketBits._ | |
import WebsocketMsg._ | |
import cats.Monad | |
/** A simplified websocket message ADT | |
* | |
*/ | |
sealed trait WebsocketMsg { | |
def toFrame: WebSocketFrame | |
} | |
final case class TextMsg(content: String) extends WebsocketMsg { | |
def toFrame: WebSocketFrame = Text(content) | |
} | |
final case class BinaryMsg(content: Array[Byte]) extends WebsocketMsg { | |
def toFrame: WebSocketFrame = Binary(content) | |
} | |
object WebsocketMsg { | |
sealed trait State | |
case object BufferingText extends State | |
case object BufferingBinary extends State | |
case object Empty extends State | |
} | |
abstract class FSMAlgebra[F[_]] { | |
def getState: F[State] | |
def clearState(): F[Unit] | |
def lastText(content: Array[Byte]): F[WebsocketMsg] | |
def lastBinary(content: Array[Byte]): F[WebsocketMsg] | |
def fragmentedBinary(content: Array[Byte]): F[Unit] | |
def fragmentedText(content: String): F[Unit] | |
def enqueueAll(w: Stream[F, WebsocketMsg]): F[Unit] | |
def out: Stream[F, WebSocketFrame] | |
} | |
object FSMAlgebra { | |
def apply[F[_]](implicit F: Concurrent[F]): F[FSMAlgebra[F]] = | |
for { | |
msgQueue <- Queue.unbounded[F, WebSocketFrame] | |
stateRef <- F.delay(new AtomicReference[State](Empty)) | |
len <- F.delay(new AtomicInteger(0)) | |
i <- F.delay(new Impl[F](msgQueue, stateRef, len)) | |
} yield i | |
private class Impl[F[_]]( | |
msgQueue: Queue[F, WebSocketFrame], | |
state: AtomicReference[State], | |
msgLen: AtomicInteger, | |
)(implicit F: Sync[F]) | |
extends FSMAlgebra[F] { | |
@volatile private[this] var internalList: List[Array[Byte]] = Nil | |
private[this] def foldBytesToArray(lastBytes: Array[Byte]): F[Array[Byte]] = | |
F.delay(msgLen.get()).flatMap { len => | |
F.delay { | |
val aggregator = new ReverseByteArrayAggregator(len + lastBytes.length) | |
aggregator.aggregate(lastBytes) | |
while (internalList.nonEmpty) { | |
aggregator.aggregate(internalList.head) | |
internalList = internalList.tail | |
} | |
aggregator.emit | |
} | |
} | |
private[this] def compareAndSetState(old: State, ns: State): F[Boolean] = | |
F.delay(state.compareAndSet(old, ns)) | |
private[this] def clearMsgLen(): F[Unit] = | |
F.delay(msgLen.set(0)) | |
private[this] def incrementMsgLen(i: Int) = | |
F.delay({ msgLen.getAndAdd(i); () }) | |
def getState: F[State] = F.delay(state.get()) | |
def clearState(): F[Unit] = F.delay(state.set(Empty)) | |
def lastText(content: Array[Byte]): F[WebsocketMsg] = | |
for { | |
bytes <- foldBytesToArray(content) | |
_ <- clearMsgLen() | |
} yield TextMsg(new String(bytes, UTF_8)) | |
def lastBinary(content: Array[Byte]): F[WebsocketMsg] = | |
for { | |
bytes <- foldBytesToArray(content) | |
_ <- clearMsgLen() | |
} yield BinaryMsg(bytes) | |
def fragmentedBinary(content: Array[Byte]): F[Unit] = | |
for { | |
_ <- compareAndSetState(Empty, BufferingBinary) | |
_ <- incrementMsgLen(content.length) | |
_ <- F.delay(internalList = content::internalList) | |
} yield () | |
def fragmentedText(content: String): F[Unit] = { | |
val bytes = content.getBytes(UTF_8) | |
for { | |
_ <- compareAndSetState(Empty, BufferingText) | |
_ <- incrementMsgLen(bytes.length) | |
_ <- F.delay(internalList = bytes::internalList) | |
} yield () | |
} | |
def enqueueAll(w: Stream[F, WebsocketMsg]): F[Unit] = | |
w.map(_.toFrame).through(msgQueue.enqueue).compile.drain | |
def out: Stream[F, WebSocketFrame] = msgQueue.dequeue | |
} | |
private class ReverseByteArrayAggregator(size: Int) { | |
require(size > 0) | |
private[this] val internal = new Array[Byte](size) | |
private[this] var nextIx: Int = size | |
def aggregate(arr: Array[Byte]): ReverseByteArrayAggregator = { | |
nextIx -= arr.length | |
if (nextIx < 0) | |
throw new ArrayIndexOutOfBoundsException("Size will exceed append size") | |
else { | |
System.arraycopy(arr, 0, internal, nextIx, arr.length) | |
this | |
} | |
} | |
def emit: Array[Byte] = internal | |
} | |
} | |
final class WSFSM[F[_]](f: WebsocketMsg => Stream[F, WebsocketMsg], | |
alg: FSMAlgebra[F])(implicit F: Monad[F]) { | |
def handleText(content: String, last: Boolean): F[Unit] = | |
if (last) { | |
for { | |
st <- alg.getState | |
_ <- alg.clearState() | |
msg <- st match { | |
case BufferingBinary => | |
alg.lastBinary(content.getBytes(UTF_8)) | |
case BufferingText => | |
alg.lastText(content.getBytes(UTF_8)) | |
case Empty => | |
F.pure[WebsocketMsg](TextMsg(content)) | |
} | |
_ <- alg.enqueueAll(f(msg)) | |
} yield () | |
} else alg.fragmentedText(content) | |
def handleBinary(content: Array[Byte], last: Boolean): F[Unit] = | |
if (last) { | |
for { | |
st <- alg.getState | |
_ <- alg.clearState() | |
msg <- st match { | |
case BufferingBinary => | |
alg.lastBinary(content) | |
case BufferingText => | |
alg.lastText(content) | |
case Empty => | |
F.pure[WebsocketMsg](BinaryMsg(content)) | |
} | |
_ <- alg.enqueueAll(f(msg)) | |
} yield () | |
} else alg.fragmentedBinary(content) | |
def send: Stream[F, WebSocketFrame] = alg.out | |
def recv: Sink[F, WebSocketFrame] = _.evalMap { | |
case Text(content, last) => | |
handleText(content, last) | |
case Binary(content, last) => | |
handleBinary(content, last) | |
case Continuation(content, last) => | |
handleBinary(content, last) | |
case _ => | |
F.unit //Do not worry about handling other messages | |
} | |
} | |
object WSFSM { | |
def apply[F[_]: Concurrent](f: WebsocketMsg => Stream[F, WebsocketMsg]): F[WSFSM[F]] = | |
FSMAlgebra[F].map(new WSFSM[F](f, _)) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment