Last active
July 24, 2025 01:02
-
-
Save niw/e845b803ecd9a6e82519f38aa405b65f to your computer and use it in GitHub Desktop.
A simple Sever-Sent Event Client implementation on `URLSession`
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Foundation | |
public enum ChunkedDataTaskError: Swift.Error { | |
case notSuccess(_ data: Data, _ httpResponse: HTTPURLResponse) | |
} | |
private extension Int { | |
var isSuccess: Bool { | |
(200 ..< 300).contains(self) | |
} | |
} | |
private extension Swift.Error { | |
var isURLCancelledError: Bool { | |
let nsError = self as NSError | |
return nsError.domain == NSURLErrorDomain && nsError.code == NSURLErrorCancelled | |
} | |
} | |
private final class ChunkedDataTaskDelegate: NSObject, URLSessionDataDelegate { | |
let onResponse: (@Sendable (HTTPURLResponse) throws -> Void)? | |
let onUpdate: @Sendable (Data) throws -> Void | |
let onComplete: (@Sendable () -> Void)? | |
let onError: (@Sendable (Swift.Error) -> Void)? | |
init( | |
onResponse: (@Sendable (HTTPURLResponse) throws -> Void)? = nil, | |
onUpdate: @escaping @Sendable (Data) throws -> Void, | |
onComplete: (@Sendable () -> Void)? = nil, | |
onError: (@Sendable (Swift.Error) -> Void)? = nil | |
) { | |
self.onResponse = onResponse | |
self.onUpdate = onUpdate | |
self.onComplete = onComplete | |
self.onError = onError | |
} | |
private struct NotSuccess { | |
var data: Data = Data() | |
var httpResponse: HTTPURLResponse | |
} | |
// We know that the following `URLSession` callbacks access these vars | |
// in expected order safely. | |
private nonisolated(unsafe) var notSuccess: NotSuccess? | |
private nonisolated(unsafe) var callbackError: Swift.Error? | |
func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive response: URLResponse, completionHandler: @escaping (URLSession.ResponseDisposition) -> Void) { | |
guard let httpResponse = response as? HTTPURLResponse else { | |
// Must not reach here. | |
completionHandler(.cancel) | |
return | |
} | |
// If response status code is not success, callbacks a single data instead of chunks. | |
if !httpResponse.statusCode.isSuccess { | |
notSuccess = NotSuccess(httpResponse: httpResponse) | |
} | |
do { | |
try onResponse?(httpResponse) | |
} catch { | |
callbackError = error | |
completionHandler(.cancel) | |
return | |
} | |
completionHandler(.allow) | |
} | |
func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive data: Data) { | |
if notSuccess != nil { | |
notSuccess?.data.append(data) | |
return | |
} | |
do { | |
try onUpdate(data) | |
} catch { | |
callbackError = error | |
dataTask.cancel() | |
} | |
} | |
func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: (any Swift.Error)?) { | |
if let error { | |
if error.isURLCancelledError, let callbackError { | |
onError?(callbackError) | |
} else { | |
onError?(error) | |
} | |
return | |
} | |
if let notSuccess { | |
onError?(ChunkedDataTaskError.notSuccess(notSuccess.data, notSuccess.httpResponse)) | |
return | |
} | |
onComplete?() | |
} | |
} | |
// See <https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation> | |
public enum ServerSentEventError: Swift.Error { | |
case invalidData(Data) | |
case invalidContentType(String) | |
case remainingBuffer(String) | |
} | |
public struct ServerSentEvent: Sendable { | |
public var lastEventID: String | |
public var type: String | |
public var data: String | |
public var reconnectionTime: Int? | |
public init( | |
lastEventID: String = "", | |
type: String = "", | |
data: String = "", | |
reconnectionTime: Int? = nil | |
) { | |
self.lastEventID = lastEventID | |
self.type = type | |
self.data = data | |
self.reconnectionTime = reconnectionTime | |
} | |
} | |
private extension String { | |
func extractLines() -> ([Substring], String) { | |
var lines: [Substring] = [] | |
var index = unicodeScalars.startIndex | |
var startIndex = index | |
while index < unicodeScalars.endIndex { | |
let scalar = unicodeScalars[index] | |
switch scalar { | |
case UnicodeScalar(0x000D): | |
lines.append(self[startIndex..<index]) | |
let next = unicodeScalars.index(after: index) | |
if next < unicodeScalars.endIndex, unicodeScalars[next] == UnicodeScalar(0x000A) { | |
index = unicodeScalars.index(after: next) | |
} else { | |
index = next | |
} | |
startIndex = index | |
case UnicodeScalar(0x000A): | |
lines.append(self[startIndex..<index]) | |
index = unicodeScalars.index(after: index) | |
startIndex = index | |
default: | |
index = unicodeScalars.index(after: index) | |
} | |
} | |
let remaining = String(unicodeScalars[startIndex..<endIndex]) | |
return (lines, remaining) | |
} | |
} | |
private extension StringProtocol { | |
func fieldValue() -> (SubSequence, SubSequence) { | |
let fieldValue = split(separator: ":", maxSplits: 1) | |
if fieldValue.count == 2 { | |
let field = fieldValue[0] | |
let value = fieldValue[1].drop(while: \.isWhitespace) | |
return (field, value) | |
} else if fieldValue.count == 1 { | |
return (fieldValue[0], "") | |
} else { | |
// Must not reach here. | |
return ("", "") | |
} | |
} | |
} | |
private extension UnicodeScalar { | |
var isASCIIDigits: Bool { | |
UnicodeScalar(0x0030) <= self && self <= UnicodeScalar(0x0039) | |
} | |
} | |
private final class ServerSentEventsParser { | |
func validateContentType(_ contentType: String) throws { | |
if let contentType = contentType | |
.components(separatedBy: ";") | |
.first? | |
.trimmingCharacters(in: .whitespaces) | |
{ | |
if contentType != "text/event-stream" { | |
throw ServerSentEventError.invalidContentType(contentType) | |
} | |
} | |
} | |
private var buffer: String = "" | |
private var eventTypeBuffer: String = "" | |
private var dataBuffer: String = "" | |
private var lastEventIDBuffer: String = "" | |
private var reconnectionTime: Int? = nil | |
func parse(_ data: Data, onEvent: (ServerSentEvent) -> Void) throws { | |
guard let string = String(data: data, encoding: .utf8) else { | |
throw ServerSentEventError.invalidData(data) | |
} | |
buffer.append(string) | |
let (lines, remaining) = buffer.extractLines() | |
buffer = remaining | |
for line in lines { | |
if line.isEmpty { | |
var event = ServerSentEvent( | |
lastEventID: lastEventIDBuffer, | |
reconnectionTime: reconnectionTime | |
) | |
if dataBuffer.isEmpty { | |
eventTypeBuffer = "" | |
dataBuffer = "" | |
} | |
if dataBuffer.hasSuffix("\n") { | |
dataBuffer.removeLast() | |
} | |
if eventTypeBuffer.isEmpty { | |
event.type = "message" | |
} else { | |
event.type = eventTypeBuffer | |
} | |
event.data = dataBuffer | |
onEvent(event) | |
eventTypeBuffer = "" | |
dataBuffer = "" | |
reconnectionTime = nil | |
} | |
let (field, value) = line.fieldValue() | |
if field.isEmpty { | |
continue | |
} | |
switch String(field) { | |
case "event": | |
eventTypeBuffer = String(value) | |
case "data": | |
dataBuffer = value.appending("\n") | |
case "id": | |
if !value.contains("\0") { | |
lastEventIDBuffer = String(value) | |
} | |
case "retry": | |
if value.unicodeScalars.allSatisfy(\.isASCIIDigits) { | |
if let retryValue = Int(value) { | |
reconnectionTime = retryValue | |
} | |
} | |
default: | |
break | |
} | |
} | |
} | |
func validateCompletion() throws { | |
if !buffer.isEmpty { | |
throw ServerSentEventError.remainingBuffer(buffer) | |
} | |
} | |
} | |
extension URLSession { | |
func serverSentEvents(for request: URLRequest) -> some AsyncSequence<ServerSentEvent, Swift.Error> & Sendable { | |
AsyncThrowingStream { continuation in | |
let task = dataTask(with: request) | |
nonisolated(unsafe) let parser = ServerSentEventsParser() | |
task.delegate = ChunkedDataTaskDelegate( | |
onResponse: { httpResponse in | |
// If there is no content-type, let it go and try parse the body. | |
if let contentType = httpResponse.value(forHTTPHeaderField: "Content-Type") { | |
try parser.validateContentType(contentType) | |
} | |
}, | |
onUpdate: { data in | |
try parser.parse(data) { event in | |
continuation.yield(event) | |
} | |
}, | |
onComplete: { | |
do { | |
try parser.validateCompletion() | |
continuation.finish() | |
} catch { | |
continuation.finish(throwing: error) | |
} | |
}, | |
onError: { error in | |
continuation.finish(throwing: error) | |
} | |
) | |
continuation.onTermination = { termination in | |
switch termination { | |
case .cancelled: | |
task.cancel() | |
default: | |
break | |
} | |
} | |
task.resume() | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment