Last active
November 8, 2024 04:06
-
-
Save pyar6329/56ddb4ef28f378fbee2102705f21cd72 to your computer and use it in GitHub Desktop.
reqwest middleware sample codes
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use super::*; | |
use std::future::Future; | |
use tokio::time::{sleep, timeout, Duration}; | |
pub async fn run_async_fn_with_timeout_and_retry<F, T, E>( | |
func: F, | |
condition: impl Fn(T) -> bool, | |
timeout_err: E, | |
ttl: Duration, | |
total_ttl: Duration, | |
retry_num: u8, | |
) -> Result<T, Error> | |
where | |
F: Future<Output = Result<T, Error>> + Copy, | |
E: Into<Error> + Copy, | |
{ | |
timeout(total_ttl, async move { | |
for i in 0..retry_num { | |
// each functions run with timeout | |
let result = timeout(ttl, func) | |
.await | |
.map_err(|_| timeout_err.into()) | |
.and_then(|i| i); | |
// it returns result if function is success | |
if result.is_ok() { | |
return result; | |
} | |
if condition(result) { | |
continue; | |
} | |
// sleep time: 2^i | |
let sleep_time = 1 << i; | |
// sleep with expotencial backoff | |
let _ = sleep(Duration::from_secs(sleep_time)).await; | |
} | |
// it returns error if retry_num is reached maximum | |
Err(timeout_err.into()) // これだとfor内でOkだとしてもエラーになってしまう | |
}) | |
.await | |
.map_err(|_| timeout_err.into()) | |
.and_then(|i| i.map_err(|e| e.into())) | |
} | |
async fn get_user() -> Result<User, Error> { | |
//reqwesetの叩くAPI関数 | |
let user = Body { | |
body: User { | |
.... | |
} | |
}; | |
Ok(user) | |
} | |
fn loop_condition(e: User) -> bool { | |
e.body.id == 1 | |
} | |
run_async_fn_with_timeout_and_retry(get_user(), loop_condition(|data| xxxx), ).await; | |
#[derive(DeserializeOwned, Eq, PartialEq, Clone)] | |
struct WyscoutError { | |
pub error: WyscoutErrorMessage, | |
} | |
#[derive(DeserializeOwned, Eq, PartialEq, Clone)] | |
struct WyscoutErrorMessage { | |
pub code: u8, | |
pub message: String, | |
} | |
#[derive(Default)] | |
struct RetryMiddleware<E> { | |
condition: impl Fn(E) -> Error, // retry処理の条件 | |
}; | |
pub fn wyscout_retry_condition(e: WyscoutError) -> bool { | |
e.error.code.contains([404, 412]) | |
} | |
// next.run(req, extensions).await -> Result<Response> | |
impl<E> Middleware for RetryMiddleware<E> | |
where | |
E: DeserializeOwned | |
{ | |
async fn handle( | |
&self, | |
req: Request, | |
extensions: &mut Extensions, | |
next: Next<'_>, | |
) -> Result<Response> { | |
let result: Response = next.run(req, extensions).await?; | |
let response: E = result.json().await; | |
if self.condition(response) { | |
retry処理をここに書く | |
成功したらbreak | |
} | |
Ok(result) | |
} | |
} | |
// ref: https://github.com/TrueLayer/reqwest-middleware/blob/main/reqwest-retry/src/middleware.rs#L110C1-L205 | |
fn build_client_with_retry<E>( | |
base_header: &BaseHeader, | |
timeout_sec: &u8, | |
max_retry: &u8, | |
condition_fn: impl Fn(E) -> Error, | |
) -> Result<ClientWithMiddleware, Error> | |
where | |
E: DeserializeOwned + Eq + PartialEq | |
{ | |
let reqwest_client = reqwest::Client::builder() | |
.default_headers(headers) | |
.use_rustls_tls() | |
.timeout(Duration::from_secs(*timeout_sec as u64)) | |
.build() | |
.context("Failed to create reqwest::Client::builder")?; | |
// set max retry number | |
let retry_policy = ExponentialBackoff::builder().build_with_max_retries(*max_retry as u32); | |
// set retry middleware and log middleware | |
let client = ClientBuilder::new(reqwest_client) | |
.with(TracingMiddleware::default()) | |
.with(RetryMiddleware::default()) // ここに作ったやつを入れる | |
.build(); | |
Ok(client) | |
} | |
impl BaseClientWithRetry { | |
pub fn new<T>( | |
base_header: &BaseHeader, | |
base_url: &Url, | |
timeout_sec: &u8, | |
max_retry: &u8, | |
condition_fn: impl Fn(E) -> Error, | |
) -> Result<BaseClientWithRetry, Error> { | |
let client = build_client_with_retry<T>(base_header, timeout_sec, max_retry, condition_fn)?; | |
Ok(Self { | |
client, | |
base_url: base_url.to_owned(), | |
base_header: base_header.to_owned(), | |
}) | |
} | |
BaseClientWithRetry::new::<WyscoutError>(xxxx, wyscout_retry_condition); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment