Created
July 19, 2023 01:26
-
-
Save kyle-mccarthy/73ab6c78e6d3bf0819fc7c00b90161f4 to your computer and use it in GitHub Desktop.
Tonic bidirectional stream client utility
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[package] | |
name = "tonic-stream-wrapper" | |
version = "0.1.0" | |
edition = "2021" | |
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | |
[dependencies] | |
tokio = { version = "1", features = ["full"] } | |
tokio-stream = { version = "0.1", features = ["sync"] } | |
tonic = { version = "0.9" } | |
futures = { version = "0.3" } | |
tracing = { version = "0.1" } | |
thiserror = "1.0" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::future::Future; | |
use futures::{future, Stream, StreamExt, TryStreamExt}; | |
use tokio::{ | |
spawn, | |
sync::{broadcast, mpsc, oneshot}, | |
task::{JoinError, JoinHandle}, | |
}; | |
use tokio_stream::wrappers::{BroadcastStream, ReceiverStream}; | |
use tonic::{Response, Status, Streaming}; | |
use tracing::{error, warn}; | |
#[derive(Debug, thiserror::Error, Clone)] | |
pub enum Error { | |
#[error("gRPC error: {0}")] | |
Status(#[from] tonic::Status), | |
#[error("Stream closed")] | |
StreamClosed, | |
} | |
struct RequestWrapper<T> { | |
inner: T, | |
request_id_sender: oneshot::Sender<usize>, | |
} | |
#[derive(Clone)] | |
struct ResponseWrapper<T> { | |
inner: T, | |
request_id: usize, | |
} | |
pub struct OrderedStream<Req, Res> { | |
request_sender: mpsc::Sender<RequestWrapper<Req>>, | |
response_sender: broadcast::Sender<ResponseWrapper<Result<Res, Status>>>, | |
background_pipeline: JoinHandle<()>, | |
} | |
impl<Req, Res> OrderedStream<Req, Res> | |
where | |
Res: Send + Clone + 'static, | |
Req: Send + 'static, | |
{ | |
pub fn new<F, O>(request_fn: F) -> Self | |
where | |
F: FnOnce(Box<dyn Stream<Item = Req> + Send>) -> O, | |
F: Send + 'static, | |
O: Future<Output = Result<Response<Streaming<Res>>, Status>> + Send, | |
{ | |
let (request_sender, request_receiver) = mpsc::channel::<RequestWrapper<Req>>(32); | |
let (response_sender, _response_receiver) = | |
broadcast::channel::<ResponseWrapper<Result<Res, Status>>>(32); | |
let background_pipeline = spawn({ | |
let response_sender = response_sender.clone(); | |
async move { | |
let outgoing = ReceiverStream::new(request_receiver).enumerate().map( | |
|( | |
request_id, | |
RequestWrapper { | |
inner, | |
request_id_sender, | |
}, | |
)| { | |
// We don't care about the result here, an error just means that the receiver side | |
// of the oneshot is gone. | |
let _ = request_id_sender.send(request_id); | |
inner | |
}, | |
); | |
let incoming = match request_fn(Box::new(outgoing)).await { | |
Err(status) => { | |
error!( | |
"Stream initialization returned an error, exiting early. (status = {:?})", | |
status | |
); | |
return; | |
} | |
Ok(response) => response.into_inner(), | |
}; | |
incoming | |
.enumerate() | |
.map(|(request_id, inner)| ResponseWrapper { request_id, inner }) | |
.for_each(|response| { | |
if response_sender.send(response).is_err() { | |
warn!("Response received but no receivers on the response channel"); | |
} | |
future::ready(()) | |
}) | |
.await; | |
} | |
}); | |
Self { | |
request_sender, | |
response_sender, | |
background_pipeline, | |
} | |
} | |
pub async fn send(&self, message: Req) -> Result<Res, Error> { | |
let (sender, receiver) = oneshot::channel::<usize>(); | |
let response_receiver = self.response_sender.subscribe(); | |
self.request_sender | |
.send(RequestWrapper { | |
inner: message, | |
request_id_sender: sender, | |
}) | |
.await | |
.map_err(|_| Error::StreamClosed)?; | |
let request_id = receiver.await.map_err(|_| Error::StreamClosed)?; | |
let mut response_stream = BroadcastStream::new(response_receiver) | |
.filter_map(|result| { | |
let value = match result { | |
Ok(wrapper) if wrapper.request_id == request_id => Some(wrapper.inner), | |
_ => None, | |
}; | |
future::ready(value) | |
}) | |
.map_err(Error::Status) | |
.take(1) | |
.fuse(); | |
response_stream | |
.next() | |
.await | |
.expect("Every request has a response") | |
} | |
pub async fn close(self) -> Result<(), JoinError> { | |
let Self { | |
background_pipeline, | |
request_sender, | |
.. | |
} = self; | |
// need to drop the request sender or awaiting the join handle will hang | |
drop(request_sender); | |
background_pipeline.await | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment