Created
October 2, 2022 16:28
-
-
Save Swoorup/41dfcba8d712fb5fa03a38119a1d7088 to your computer and use it in GitHub Desktop.
Restartable stream on complete or error using a connect function
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::task::Poll; | |
use futures::{Future, FutureExt, Stream}; | |
use pin_project::pin_project; | |
use tracing::debug; | |
#[pin_project(project = StreamStateProj)] | |
#[derive(Debug, Clone)] | |
enum StreamState<F, S> | |
where | |
F: Future, | |
{ | |
NotConnected, | |
Connected(#[pin] S), | |
Connecting { | |
#[pin] | |
reconnect_attempt: F, | |
}, | |
} | |
impl<F, S> Default for StreamState<F, S> | |
where | |
F: Future, | |
{ | |
fn default() -> Self { | |
Self::NotConnected | |
} | |
} | |
/// Restartable stream on either construction error or when finished | |
#[pin_project] | |
#[derive(Debug, Clone)] | |
pub struct Restartable<Param, F, S, E, Cons> | |
where | |
Param: Clone, | |
F: Future<Output = Result<S, E>>, | |
S: Stream, | |
Cons: Fn(Param) -> F, | |
{ | |
param: Param, | |
create: Cons, | |
#[pin] | |
stream: StreamState<F, S>, | |
} | |
impl<Param, F, S, E, Cons> Restartable<Param, F, S, E, Cons> | |
where | |
Param: Clone, | |
F: Future<Output = Result<S, E>>, | |
S: Stream, | |
Cons: Fn(Param) -> F, | |
{ | |
pub fn of(establish: Cons, param: Param) -> Self { | |
Self { | |
param: param, | |
create: establish, | |
stream: Default::default(), | |
} | |
} | |
} | |
impl<Param, F, S, E, Cons> Stream for Restartable<Param, F, S, E, Cons> | |
where | |
Param: Clone, | |
F: Future<Output = Result<S, E>>, | |
S: Stream, | |
Cons: Fn(Param) -> F, | |
{ | |
type Item = S::Item; | |
fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> { | |
let mut me = self.as_mut().project(); | |
match me.stream.as_mut().project() { | |
StreamStateProj::NotConnected => { | |
debug!("Establishing connection..."); | |
let reconnect_attempt = (me.create)(me.param.clone()); | |
me.stream.set(StreamState::Connecting { reconnect_attempt }); | |
cx.waker().wake_by_ref(); | |
Poll::Pending | |
} | |
StreamStateProj::Connected(stream) => { | |
match stream.poll_next(cx) { | |
Poll::Ready(Some(t)) => Poll::Ready(Some(t)), | |
// completed, reset to NotConnected | |
Poll::Ready(None) => { | |
me.stream.set(StreamState::NotConnected); | |
cx.waker().wake_by_ref(); | |
Poll::Pending | |
} | |
Poll::Pending => Poll::Pending, | |
} | |
} | |
StreamStateProj::Connecting { ref mut reconnect_attempt } => { | |
debug!("Connecting"); | |
match reconnect_attempt.poll_unpin(cx) { | |
Poll::Ready(Ok(stream)) => { | |
me.stream.set(StreamState::Connected(stream)); | |
cx.waker().wake_by_ref(); | |
Poll::Pending | |
} | |
// on error mark it as completed | |
Poll::Ready(Err(_)) => Poll::Ready(None), | |
Poll::Pending => Poll::Pending, | |
} | |
} | |
} | |
} | |
fn size_hint(&self) -> (usize, Option<usize>) { | |
match self.stream { | |
StreamState::NotConnected | StreamState::Connecting { reconnect_attempt: _ } => (0, Some(0)), | |
StreamState::Connected(ref stream) => stream.size_hint(), | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Example usage: