Skip to content

Instantly share code, notes, and snippets.

@sergiocampama
Last active January 25, 2025 14:55
Show Gist options
  • Save sergiocampama/bddf5a92c56cf6204660c1263593d5ea to your computer and use it in GitHub Desktop.
Save sergiocampama/bddf5a92c56cf6204660c1263593d5ea to your computer and use it in GitHub Desktop.
GRPCWebSupportVapor
/*
* Copyright 2025, Sergio Campamá All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import GRPCCore
import Vapor
import ServiceContextModule
struct GRPCWebMiddleware: AsyncMiddleware {
private let router: RPCRouter
init(services: [GRPCCore.RegistrableRPCService]) {
var router = RPCRouter()
for service in services {
service.registerMethods(with: &router)
}
self.router = router
}
func respond(to request: Request, chainingTo next: AsyncResponder) async throws -> Response {
// TODO(kaipi): Support other types of incoming formats, like +json, +thrift and -text
if request.method == .POST, request.headers.contentType?.subType == "grpc-web+proto" {
return try await router.handle(request: request)
} else {
return try await next.respond(to: request)
}
}
}
enum RequestKey: ServiceContextKey {
typealias Value = Request
}
struct StaticSequence<T: Sendable>: AsyncSequence {
typealias Element = T
let elements: [T]
init(_ elements: [T]) {
self.elements = elements
}
struct AsyncIterator: AsyncIteratorProtocol {
var elements: [T]
mutating func next() async throws -> T? {
guard elements.isEmpty else {
return elements.removeFirst()
}
return nil
}
}
func makeAsyncIterator() -> AsyncIterator {
AsyncIterator(elements: elements)
}
}
final class OutboundRPCResponsePartWriter: ClosableRPCWriterProtocol {
private let elementsStream: AsyncThrowingStream<RPCResponsePart, any Error>
private let elementsStreamContinuation: AsyncThrowingStream<RPCResponsePart, any Error>.Continuation
init() {
(self.elementsStream, self.elementsStreamContinuation) = AsyncThrowingStream.makeStream()
}
func write(contentsOf elements: some Sequence<RPCResponsePart>) async throws {
for element in elements {
elementsStreamContinuation.yield(element)
}
}
func write(_ element: RPCResponsePart) async throws {
elementsStreamContinuation.yield(element)
}
func finish() async {
elementsStreamContinuation.finish()
}
func finish(throwing error: any Error) async {
elementsStreamContinuation.finish(throwing: error)
}
func responseParts() async throws -> [RPCResponsePart] {
try await elementsStream.reduce(into: [RPCResponsePart]()) { $0.append($1) }
}
}
extension RPCRouter {
func handle(request: Request) async throws -> Response {
guard var bodyBuffer = request.body.data else {
throw Abort(.badRequest)
}
// TODO(kaipi): Support incoming metadata.
let dataMarker = bodyBuffer.readInteger(endianness: .big, as: UInt8.self)
guard dataMarker == 0 else {
throw Abort(.badRequest)
}
guard let dataLength = bodyBuffer.readInteger(endianness: .big, as: UInt32.self) else {
throw Abort(.badRequest)
}
guard let dataBytes = bodyBuffer.readBytes(length: Int(dataLength)) else {
throw Abort(.badRequest)
}
let requestParts = StaticSequence([RPCRequestPart.metadata(Metadata()), .message(dataBytes)])
let pathComponents = request.url.path.pathComponents
guard pathComponents.count == 2 else {
throw Abort(.badRequest)
}
let methodDescriptor = MethodDescriptor(
service: ServiceDescriptor(fullyQualifiedService: pathComponents[0].description),
method: pathComponents[1].description
)
let outboundWriter = OutboundRPCResponsePartWriter()
let requestStream = RPCStream(
descriptor: methodDescriptor,
inbound: RPCAsyncSequence(wrapping: requestParts),
outbound: RPCWriter<RPCResponsePart>.Closable(wrapping: outboundWriter)
)
await withServerContextRPCCancellationHandle { cancellationHandle in
var context = ServerContext(
descriptor: requestStream.descriptor,
peer: request.peerAddress?.description ?? "unknown",
cancellation: cancellationHandle
)
context.serviceContext[RequestKey.self] = request
await handle(stream: requestStream, context: context)
}
let responseParts = try await outboundWriter.responseParts()
var responseBytes = ByteBuffer()
for responsePart in responseParts {
switch responsePart {
case .metadata:
// TODO(kaipi): Support outgoing header metadata.
break
case .message(let responseData):
responseBytes.writeInteger(UInt8(0), endianness: .big)
responseBytes.writeInteger(UInt32(responseData.count), endianness: .big)
responseBytes.writeBytes(responseData)
case .status(let status, _):
// TODO(kaipi): Support outgoing trailer metadata.
let statusData = Data("grpc-status: \(status.code.rawValue)\r\n".utf8)
responseBytes.writeInteger(UInt8(0x80), endianness: .big)
responseBytes.writeInteger(UInt32(statusData.count), endianness: .big)
responseBytes.writeData(statusData)
}
}
let response = Response(status: .ok, body: .init(buffer: responseBytes))
response.headers.contentType = .init(type: "application", subType: "grpc-web+proto")
return response
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment