Last active
April 26, 2025 10:08
-
-
Save afpro/7c76ab62e68d73150d9edcc9344c4ece to your computer and use it in GitHub Desktop.
rust http server
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::{ | |
borrow::Cow, | |
fmt::Debug, | |
future::Future, | |
net::SocketAddr, | |
pin::Pin, | |
task::{ready, Context, Poll}, | |
time::Instant, | |
}; | |
use axum::{ | |
extract::{ConnectInfo, Request}, | |
response::Response, | |
}; | |
use pin_project::{pin_project, pinned_drop}; | |
use tower::Service; | |
use tower_layer::Layer; | |
use tracing::{error, info, span, warn, Level, Span}; | |
use uuid::Uuid; | |
#[derive(Copy, Clone)] | |
pub struct AccessLog; | |
impl<S> Layer<S> for AccessLog { | |
type Service = AccessLogService<S>; | |
fn layer(&self, inner: S) -> Self::Service { | |
AccessLogService { inner } | |
} | |
} | |
#[derive(Clone)] | |
pub struct AccessLogService<S> { | |
inner: S, | |
} | |
impl<S> AccessLogService<S> { | |
fn extract_remote<B>(req: &Request<B>) -> Cow<'static, str> { | |
match req.extensions().get::<ConnectInfo<SocketAddr>>() { | |
Some(ConnectInfo(addr)) => addr.to_string().into(), | |
None => "unknown".into(), | |
} | |
} | |
} | |
impl<S, Req, Resp> Service<Request<Req>> for AccessLogService<S> | |
where | |
S: Service<Request<Req>, Response = Response<Resp>>, | |
S::Future: Future<Output = Result<Response<Resp>, S::Error>>, | |
S::Error: Debug, | |
{ | |
type Response = Response<Resp>; | |
type Error = S::Error; | |
type Future = AccessLogServiceFuture<S::Future>; | |
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { | |
self.inner.poll_ready(cx) | |
} | |
fn call(&mut self, req: Request<Req>) -> Self::Future { | |
let span = span!(Level::INFO, "request", id=%Uuid::new_v4().simple()); | |
{ | |
let _guard = span.enter(); | |
info!( | |
target: "request", | |
remote = %Self::extract_remote(&req), | |
method = %req.method(), | |
uri = %req.uri(), | |
headers = ?req.headers(), | |
"begin", | |
); | |
} | |
AccessLogServiceFuture::new(span, self.inner.call(req)) | |
} | |
} | |
#[pin_project(PinnedDrop)] | |
pub struct AccessLogServiceFuture<F> { | |
span: Span, | |
done: bool, | |
start: Instant, | |
#[pin] | |
inner: F, | |
} | |
impl<F> AccessLogServiceFuture<F> { | |
fn new(span: Span, inner: F) -> Self { | |
Self { | |
span, | |
done: false, | |
start: Instant::now(), | |
inner, | |
} | |
} | |
} | |
impl<F, B, E> Future for AccessLogServiceFuture<F> | |
where | |
F: Future<Output = Result<Response<B>, E>>, | |
E: Debug, | |
{ | |
type Output = Result<Response<B>, E>; | |
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { | |
let this = self.project(); | |
let _guard = this.span.enter(); | |
let result = ready!(this.inner.poll(cx)); | |
if !*this.done { | |
*this.done = true; | |
let cost = Instant::now().duration_since(*this.start); | |
match &result { | |
Ok(response) => { | |
if (400..=599).contains(&response.status().as_u16()) { | |
error!( | |
target: "request", | |
status = response.status().as_u16(), | |
cost = cost.as_millis(), | |
"end with error status" | |
); | |
} else { | |
info!( | |
target: "request", | |
status = response.status().as_u16(), | |
cost = cost.as_millis(), | |
"end ok" | |
); | |
} | |
} | |
Err(err) => { | |
error!( | |
target: "request", | |
cost = cost.as_millis(), | |
"end with uncached error {:?}", err | |
); | |
} | |
} | |
} | |
Poll::Ready(result) | |
} | |
} | |
#[pinned_drop] | |
impl<F> PinnedDrop for AccessLogServiceFuture<F> { | |
fn drop(self: Pin<&mut Self>) { | |
if !self.done { | |
let _guard = self.span.enter(); | |
let cost = Instant::now().duration_since(self.start); | |
warn!( | |
target: "request", | |
cost = cost.as_millis(), | |
"request connection dropped before finish", | |
); | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::cmp::max; | |
#[cfg(feature = "config")] | |
use std::convert::Infallible; | |
use anyhow::{Context, Result}; | |
use mysql_async::{Conn, Pool, PoolConstraints, PoolOpts}; | |
#[cfg(feature = "config")] | |
use mysql_async::{Opts, OptsBuilder}; | |
use tracing::info; | |
#[cfg(feature = "config")] | |
#[derive(clap::Args, Clone)] | |
pub struct DataMysqlOpts { | |
#[clap( | |
name = "mysql-host", | |
long = "mysql-host", | |
default_value = "127.0.0.1", | |
help = "mysql ip or host" | |
)] | |
pub host: String, | |
#[clap( | |
name = "mysql-port", | |
long = "mysql-port", | |
default_value = "3306", | |
help = "mysql port" | |
)] | |
pub port: u16, | |
#[clap( | |
name = "mysql-username", | |
long = "mysql-user", | |
default_value = "root", | |
help = "mysql username" | |
)] | |
pub user: String, | |
#[clap( | |
name = "mysql-password", | |
long = "mysql-pass", | |
default_value = "", | |
help = "mysql password" | |
)] | |
pub pass: String, | |
#[clap( | |
name = "mysql-db-name", | |
long = "mysql-db-name", | |
default_value = "bot", | |
help = "mysql database name" | |
)] | |
pub name: String, | |
#[clap( | |
name = "mysql-max-connection", | |
long = "mysql-max-connection", | |
default_value = "128", | |
help = "mysql database name" | |
)] | |
pub max_connection: usize, | |
} | |
#[derive(Clone, derive_more::From)] | |
pub struct DataMysql { | |
#[from] | |
pool: Pool, | |
} | |
impl DataMysql { | |
#[cfg(feature = "config")] | |
pub async fn create_by_opts(opts: &DataMysqlOpts) -> Result<Self> { | |
Ok(Self { | |
pool: Pool::new(opts), | |
}) | |
} | |
pub async fn get_conn(&self) -> Result<Conn> { | |
self.pool | |
.get_conn() | |
.await | |
.context("obtain mysql connection") | |
} | |
pub async fn dump_info(&self) -> Result<()> { | |
let conn = self.get_conn().await?; | |
let (major, minor, patch) = conn.server_version(); | |
info!("connected to mysql {}.{}.{}", major, minor, patch); | |
Ok(()) | |
} | |
} | |
#[cfg(feature = "config")] | |
impl TryFrom<&DataMysqlOpts> for Opts { | |
type Error = Infallible; | |
fn try_from(value: &DataMysqlOpts) -> Result<Self, Self::Error> { | |
Ok(Opts::from( | |
OptsBuilder::default() | |
.ip_or_hostname(&value.host) | |
.tcp_port(value.port) | |
.user(Some(&value.user)) | |
.pass(Some(&value.pass)) | |
.db_name(Some(&value.name)) | |
.pool_opts( | |
PoolOpts::default().with_constraints( | |
PoolConstraints::new(0, max(value.max_connection, 1)) | |
.expect("create constraint"), | |
), | |
), | |
)) | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::{ | |
sync::Arc, | |
time::{Duration, Instant}, | |
}; | |
use anyhow::{Context, Result}; | |
use redis::{ | |
aio::{MultiplexedConnection, PubSub}, | |
cmd, Client, | |
}; | |
#[cfg(feature = "config")] | |
use redis::{ | |
ConnectionAddr, ConnectionInfo, IntoConnectionInfo, ProtocolVersion, RedisConnectionInfo, | |
RedisResult, | |
}; | |
use tokio::sync::RwLock; | |
use tracing::info; | |
#[cfg(feature = "config")] | |
#[derive(clap::Args, Clone)] | |
pub struct DataRedisOpts { | |
#[clap( | |
name = "redis-host", | |
long = "redis-host", | |
default_value = "127.0.0.1", | |
help = "redis ip or host" | |
)] | |
pub host: String, | |
#[clap( | |
name = "redis-port", | |
long = "redis-port", | |
default_value = "6379", | |
help = "redis port" | |
)] | |
pub port: u16, | |
#[clap(name = "redis-username", long = "redis-user", help = "redis username")] | |
pub user: Option<String>, | |
#[clap(name = "redis-password", long = "redis-pass", help = "redis password")] | |
pub pass: Option<String>, | |
#[clap( | |
name = "redis-index", | |
long = "redis-index", | |
default_value = "0", | |
help = "redis database index" | |
)] | |
pub index: i64, | |
} | |
#[derive(Clone)] | |
pub struct DataRedis { | |
client: Client, | |
cache_io: Arc<RwLock<CacheIo>>, | |
} | |
struct CacheIo { | |
io: Option<MultiplexedConnection>, | |
expire_at: Instant, | |
} | |
impl DataRedis { | |
#[cfg(feature = "config")] | |
pub async fn create_by_opts(opts: &DataRedisOpts) -> Result<Self> { | |
let client = Client::open(opts).context("open redis client")?; | |
Ok(Self { | |
client, | |
cache_io: Arc::new(RwLock::new(CacheIo { | |
io: None, | |
expire_at: Instant::now(), | |
})), | |
}) | |
} | |
pub async fn get_conn(&self) -> Result<MultiplexedConnection> { | |
let now = Instant::now(); | |
// check cache | |
{ | |
let cache_io = self.cache_io.read().await; | |
if now < cache_io.expire_at { | |
if let Some(io) = cache_io.io.clone() { | |
return Ok(io); | |
} | |
} | |
} | |
// update cache | |
let io = { | |
let mut cache_io = self.cache_io.write().await; | |
if now < cache_io.expire_at { | |
if let Some(io) = cache_io.io.clone() { | |
return Ok(io); | |
} | |
} | |
let io = self | |
.client | |
.get_multiplexed_async_connection() | |
.await | |
.context("obtain redis connection")?; | |
cache_io.io = Some(io.clone()); | |
cache_io.expire_at = now + Duration::from_secs(300); // re-connect every 300s (5min) | |
io | |
}; | |
Ok(io) | |
} | |
pub async fn get_pub_sub(&self) -> Result<PubSub> { | |
self.client | |
.get_async_pubsub() | |
.await | |
.context("obtain redis pub sub") | |
} | |
pub async fn dump_info(&self) -> Result<()> { | |
let mut conn = self.get_conn().await?; | |
let info = cmd("info") | |
.arg("server") | |
.query_async::<String>(&mut conn) | |
.await | |
.context("dump redis server info")?; | |
let version = info | |
.lines() | |
.map(|v| v.trim()) | |
.filter_map(|v| v.strip_prefix("redis_version:")) | |
.next() | |
.context("can't extract redis_version from info dump")?; | |
info!("connected to redis {}", version); | |
Ok(()) | |
} | |
} | |
#[cfg(feature = "config")] | |
impl IntoConnectionInfo for &DataRedisOpts { | |
fn into_connection_info(self) -> RedisResult<ConnectionInfo> { | |
Ok(ConnectionInfo { | |
addr: ConnectionAddr::Tcp(self.host.clone(), self.port), | |
redis: RedisConnectionInfo { | |
protocol: ProtocolVersion::RESP3, | |
db: self.index, | |
username: self.user.clone(), | |
password: self.pass.clone(), | |
}, | |
}) | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::io::Read; | |
use deku::prelude::*; | |
use tokio::io::{AsyncBufReadExt, AsyncWriteExt}; | |
#[derive(Debug, thiserror::Error)] | |
pub enum DekuIOError { | |
#[error("IOError({})", _0)] | |
IOError( | |
#[source] | |
#[from] | |
std::io::Error, | |
), | |
#[error("DekuError({})", _0)] | |
DekuError( | |
#[source] | |
#[from] | |
DekuError, | |
), | |
} | |
pub async fn deku_read<T, R, C>(stream: &mut R, context: C) -> Result<T, DekuIOError> | |
where | |
R: AsyncBufReadExt + Unpin + ?Sized, | |
T: DekuReader<'static, C>, | |
C: Clone, | |
{ | |
let mut buf = Vec::<u8>::new(); | |
loop { | |
let io_buf = stream.fill_buf().await?; | |
let mut input = buf.as_slice().chain(io_buf); | |
let mut reader = Reader::new(&mut input); | |
match T::from_reader_with_ctx(&mut reader, context.clone()) { | |
Ok(v) => { | |
let consumed = reader.bits_read.div_ceil(8) - buf.len(); | |
stream.consume(consumed); | |
return Ok(v); | |
} | |
Err(DekuError::Incomplete(_)) => { | |
let consumed = io_buf.len(); | |
buf.extend_from_slice(io_buf); | |
stream.consume(consumed); | |
continue; | |
} | |
Err(err) => return Err(err.into()), | |
} | |
} | |
} | |
pub async fn deku_write<T, W, C>( | |
stream: &mut W, | |
value: &T, | |
context: C, | |
) -> Result<usize, DekuIOError> | |
where | |
W: AsyncWriteExt + Unpin + ?Sized, | |
T: DekuWriter<C>, | |
{ | |
let mut buf = Vec::<u8>::new(); | |
let mut writer = Writer::new(&mut buf); | |
value.to_writer(&mut writer, context)?; | |
stream.write_all(&buf).await?; | |
Ok(buf.len()) | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::{ | |
future::Future, | |
time::{Duration, Instant}, | |
}; | |
use anyhow::{Context, Error, Result}; | |
use async_trait::async_trait; | |
use indoc::indoc; | |
use lazy_static::lazy_static; | |
use rand::{Rng, rng}; | |
use redis::{AsyncCommands, RedisError, Script}; | |
use tokio::{select, time::sleep}; | |
use uuid::Uuid; | |
lazy_static! { | |
static ref SCRIPT_INIT_LOCK: Script = Script::new(indoc! {r#" | |
if (redis.call('SET', KEYS[1], ARGV[1], 'NX', 'EX', '2')) | |
then | |
return true; | |
else | |
return false; | |
end | |
"#}); | |
static ref SCRIPT_UPDATE: Script = Script::new(indoc! {r#" | |
local current = redis.call('GET', KEYS[1]); | |
if (current ~= ARGV[1]) | |
then | |
return false; | |
end | |
if (redis.call('EXPIRE', KEYS[1], '2')) | |
then | |
return true; | |
else | |
return false; | |
end | |
"#}); | |
} | |
#[derive(Debug, thiserror::Error)] | |
pub enum DistributionLockError { | |
#[error("redis error {0}")] | |
RedisError(#[source] Error), | |
#[error("task error {0}")] | |
TaskError(#[source] Error), | |
#[error("lock current already acquired by others")] | |
LockCurrentAcquired, | |
#[error("lock update during task failed")] | |
LockUpdateFailed, | |
#[error("lock update during task failed by error {0}")] | |
LockUpdateError(#[source] RedisError), | |
#[error("lock release failed by error {0}")] | |
LockReleaseError(#[source] RedisError), | |
} | |
#[async_trait] | |
pub trait DistributionLock: AsyncCommands + Sized { | |
async fn run_task_with_lock<R: Send, F: Future<Output = Result<R>> + Send>( | |
&mut self, | |
key: &str, | |
task: F, | |
timeout: Option<Duration>, | |
) -> Result<R, DistributionLockError> { | |
let value = Uuid::new_v4(); | |
let timeout = timeout.map(|v| Instant::now() + v); | |
loop { | |
let init_lock: bool = SCRIPT_INIT_LOCK | |
.prepare_invoke() | |
.key(key) | |
.arg(value) | |
.invoke_async(self) | |
.await | |
.context("init lock") | |
.map_err(DistributionLockError::RedisError)?; | |
if init_lock { | |
break; | |
} | |
if matches!(timeout, Some(timeout) if timeout < Instant::now()) { | |
return Err(DistributionLockError::LockCurrentAcquired); | |
} | |
let delay = Duration::from_millis(rng().random_range(10..50)); | |
sleep(delay).await; | |
} | |
let update_lock = async { | |
loop { | |
match SCRIPT_UPDATE | |
.prepare_invoke() | |
.key(key) | |
.arg(value) | |
.invoke_async::<bool>(self) | |
.await | |
{ | |
Ok(true) => { | |
sleep(Duration::from_secs(1)).await; | |
} | |
Ok(false) => { | |
return DistributionLockError::LockUpdateFailed; | |
} | |
Err(err) => { | |
return DistributionLockError::LockUpdateError(err); | |
} | |
} | |
} | |
}; | |
// run task & periodically update redis | |
let ret = select! { | |
task_ret = task => { | |
task_ret.map_err(DistributionLockError::TaskError)? | |
} | |
update_ret = update_lock => { | |
return Err(update_ret); | |
} | |
}; | |
// remove redis key | |
self.del::<_, ()>(&key) | |
.await | |
.map_err(DistributionLockError::LockReleaseError)?; | |
Ok(ret) | |
} | |
} | |
impl<C> DistributionLock for C where C: AsyncCommands + Send {} | |
#[cfg(test)] | |
mod test { | |
use std::time::Duration; | |
use redis::{Client, cmd}; | |
use tokio::time::sleep; | |
use uuid::Uuid; | |
use super::DistributionLock; | |
#[tokio::test] | |
async fn run_lock() { | |
let test_name = "TEST_LOCK:test_lock"; | |
let cache = Client::open("redis://127.0.0.1").expect("connect redis"); | |
let mut conn = cache | |
.get_multiplexed_async_connection() | |
.await | |
.expect("obtain redis connection"); | |
conn.clone() | |
.run_task_with_lock( | |
test_name, | |
async { | |
let value: Uuid = cmd("get").arg(test_name).query_async(&mut conn).await?; | |
sleep(Duration::from_secs(10)).await; | |
let post_value: Uuid = cmd("get").arg(test_name).query_async(&mut conn).await?; | |
assert_eq!(value, post_value); | |
Ok(()) | |
}, | |
None, | |
) | |
.await | |
.expect("run with lock"); | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use anyhow::{Context, Result}; | |
use async_trait::async_trait; | |
use indoc::indoc; | |
use lazy_static::lazy_static; | |
use redis::{aio::ConnectionLike, Script}; | |
lazy_static! { | |
static ref MODIFY_PERMIT_SCRIPT: Script = Script::new(indoc! {r#" | |
-- keys | |
local usage_key = KEYS[1]; | |
local timestamp_key = KEYS[2]; | |
-- arguments | |
local duration_in_sec = tonumber(ARGV[1]); | |
local permit_per_sec = tonumber(ARGV[2]); | |
local now_in_sec = tonumber(ARGV[3]); | |
local submit = tonumber(ARGV[4]); | |
local acquire = tonumber(ARGV[5]); | |
-- check current usage | |
local usage = redis.call('GET', usage_key) | |
local timestamp = redis.call('GET', timestamp_key) | |
-- avoid key missing | |
if usage ~= nil and usage ~= false | |
then | |
usage = tonumber(usage) + submit | |
else | |
usage = submit | |
end | |
if timestamp ~= nil and timestamp ~= false | |
then | |
timestamp = tonumber(timestamp) | |
if timestamp < now_in_sec | |
then | |
usage = math.max(usage - (now_in_sec - timestamp) * permit_per_sec, 0) | |
end | |
end | |
-- check permit can be acquired | |
local quota = math.max(duration_in_sec * permit_per_sec - usage, 0) | |
local acquired = math.min(quota, acquire) | |
-- process acquirement | |
usage = usage + acquired | |
local expire = math.ceil(usage / permit_per_sec) | |
redis.call('SET', usage_key, tostring(usage)) | |
redis.call('SET', timestamp_key, tostring(now_in_sec)) | |
redis.call('EXPIRE', usage_key, tostring(expire)) | |
redis.call('EXPIRE', timestamp_key, tostring(expire)) | |
-- finish | |
return { acquired, quota - acquired } | |
"#}); | |
} | |
#[derive(Copy, Clone)] | |
pub struct RedisPermitConfig<'a> { | |
pub name: &'a str, | |
pub duration_in_secs: u64, | |
pub permit_per_sec: u64, | |
} | |
#[derive(Copy, Clone, Eq, PartialEq, Debug)] | |
pub struct RedisPermitStatus { | |
pub acquired: u64, | |
pub quota: u64, | |
} | |
#[async_trait] | |
pub trait RedisPermit: ConnectionLike + Sized { | |
async fn modify_permit( | |
&mut self, | |
config: RedisPermitConfig<'_>, | |
submit: u64, | |
acquire: u64, | |
) -> Result<RedisPermitStatus> { | |
let mut invoke = MODIFY_PERMIT_SCRIPT.prepare_invoke(); | |
invoke.key(format!("PERMIT:U:{}", config.name)); | |
invoke.key(format!("PERMIT:T:{}", config.name)); | |
invoke.arg(config.duration_in_secs); | |
invoke.arg(config.permit_per_sec); | |
invoke.arg(now_in_sec()); | |
invoke.arg(submit); | |
invoke.arg(acquire); | |
invoke | |
.invoke_async::<(u64, u64)>(self) | |
.await | |
.map(|(acquired, quota)| RedisPermitStatus { acquired, quota }) | |
.context("invoke script") | |
} | |
async fn acquire_permit( | |
&mut self, | |
config: RedisPermitConfig<'_>, | |
acquire: u64, | |
) -> Result<RedisPermitStatus> { | |
self.modify_permit(config, 0, acquire).await | |
} | |
async fn submit_permit( | |
&mut self, | |
config: RedisPermitConfig<'_>, | |
submit: u64, | |
) -> Result<RedisPermitStatus> { | |
self.modify_permit(config, submit, 0).await | |
} | |
} | |
impl<T> RedisPermit for T where T: ConnectionLike + Sized {} | |
#[cfg(test)] | |
fn now_in_sec() -> i64 { | |
test::now_in_sec() | |
} | |
#[cfg(not(test))] | |
fn now_in_sec() -> i64 { | |
use chrono::Utc; | |
Utc::now().timestamp() | |
} | |
#[cfg(test)] | |
mod test { | |
use std::sync::atomic::{AtomicI64, Ordering::Relaxed}; | |
use redis::Client; | |
use super::{RedisPermit, RedisPermitConfig, RedisPermitStatus}; | |
#[allow(dead_code)] | |
static TIME_IN_SEC: AtomicI64 = AtomicI64::new(0); | |
pub fn now_in_sec() -> i64 { | |
TIME_IN_SEC.load(Relaxed) | |
} | |
pub fn set_now_in_sec(v: i64) { | |
TIME_IN_SEC.store(v, Relaxed) | |
} | |
#[tokio::test] | |
async fn acquire() { | |
let redis = Client::open("redis://127.0.0.1/").expect("open redis"); | |
let mut conn = redis | |
.get_multiplexed_async_connection() | |
.await | |
.expect("open connection"); | |
let config = RedisPermitConfig { | |
name: "test-acquire", | |
duration_in_secs: 5, | |
permit_per_sec: 1, | |
}; | |
set_now_in_sec(0); | |
assert_eq!( | |
conn.acquire_permit(config, 3).await.expect("acquire"), | |
RedisPermitStatus { | |
acquired: 3, | |
quota: 2, | |
} | |
); | |
assert_eq!( | |
conn.acquire_permit(config, 2).await.expect("acquire"), | |
RedisPermitStatus { | |
acquired: 2, | |
quota: 0, | |
} | |
); | |
set_now_in_sec(1); | |
assert_eq!( | |
conn.acquire_permit(config, 3).await.expect("acquire"), | |
RedisPermitStatus { | |
acquired: 1, | |
quota: 0, | |
} | |
); | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use async_trait::async_trait; | |
use axum::{ | |
extract::{FromRequest, Query, Request}, | |
http::{Method, StatusCode}, | |
response::{IntoResponse, Response}, | |
Json, | |
}; | |
use serde::{de::DeserializeOwned, Deserialize, Serialize}; | |
use tracing::warn; | |
#[macro_export] | |
macro_rules! api_response_combine { | |
($enum_name:ident { $($name:ident($type:ty),)+ }) => { | |
pub enum $enum_name { | |
Status($crate::rest::ApiStatus), | |
Raw(axum::response::Response), | |
$($name($type),)+ | |
} | |
impl From<$crate::rest::ApiStatus> for $enum_name { | |
fn from(value: $crate::rest::ApiStatus) -> Self { | |
Self::Status(value) | |
} | |
} | |
impl From<axum::response::Response> for $enum_name { | |
fn from(value: axum::response::Response) -> Self { | |
Self::Raw(value) | |
} | |
} | |
$( | |
impl From<$type> for $enum_name { | |
fn from(value: $type) -> Self { | |
Self::$name(value) | |
} | |
} | |
)+ | |
impl axum::response::IntoResponse for $enum_name { | |
fn into_response(self) -> axum::response::Response { | |
match self { | |
Self::Status(status) => $crate::rest::ApiResponse::<()>::from(status).into_response(), | |
Self::Raw(response) => response, | |
$(Self::$name(v) => v.into_response(),)+ | |
} | |
} | |
} | |
}; | |
} | |
#[macro_export] | |
macro_rules! api_require_token { | |
($request:ident) => { | |
$crate::api_require_token!(_inner &$request.token) | |
}; | |
(take $request:ident) => { | |
$crate::api_require_token!(_inner $request.token) | |
}; | |
(_inner $token:expr) => { | |
match $token { | |
Some(token) => token, | |
None => { | |
tracing::warn!("token missing"); | |
return $crate::rest::ApiResponse::fail( | |
$crate::rest::ApiStatus::InvalidRequest, | |
Some("token missing".to_string()), | |
) | |
.into(); | |
} | |
} | |
} | |
} | |
#[macro_export] | |
macro_rules! api_require_data { | |
($request:ident) => { | |
match &$request.data { | |
Some(data) => data, | |
None => { | |
tracing::warn!("data missing"); | |
return $crate::rest::ApiResponse::fail( | |
$crate::rest::ApiStatus::InvalidRequest, | |
Some("data missing".to_string()), | |
) | |
.into(); | |
} | |
} | |
}; | |
(take $request:ident) => { | |
match $request.data { | |
Some(data) => data, | |
None => { | |
tracing::warn!("data missing"); | |
return $crate::rest::ApiResponse::fail( | |
$crate::rest::ApiStatus::InvalidRequest, | |
Some("data missing".to_string()), | |
) | |
.into(); | |
} | |
} | |
}; | |
} | |
#[macro_export] | |
macro_rules! api_must_success { | |
($op:expr, $msg:literal) => { | |
match $op { | |
Ok(v) => v, | |
Err(err) => { | |
tracing::warn!("{} error: {:?}", $msg, err); | |
return $crate::rest::ApiStatus::ServerError.into(); | |
} | |
} | |
}; | |
} | |
#[derive(Serialize, Deserialize)] | |
pub struct ApiRequest<T = ()> { | |
#[serde(rename = "_t", default, skip_serializing_if = "Option::is_none")] | |
pub token: Option<String>, | |
#[serde( | |
flatten, | |
default = "Option::default", | |
skip_serializing_if = "Option::is_none", | |
bound( | |
serialize = "Option<T>: Serialize", | |
deserialize = "Option<T>: Deserialize<'de>", | |
) | |
)] | |
pub data: Option<T>, | |
} | |
#[async_trait] | |
impl<T, S> FromRequest<S> for ApiRequest<T> | |
where | |
ApiRequest<T>: DeserializeOwned, | |
S: Send + Sync, | |
{ | |
type Rejection = ApiResponse; | |
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> { | |
if req.method() == Method::GET { | |
match Query::<ApiRequest<T>>::from_request(req, state).await { | |
Ok(Query(data)) => Ok(data), | |
Err(err) => { | |
warn!("decode request failed: {:?}", err); | |
Err(ApiResponse::fail(ApiStatus::InvalidRequest, None)) | |
} | |
} | |
} else { | |
match Json::<ApiRequest<T>>::from_request(req, state).await { | |
Ok(Json(data)) => Ok(data), | |
Err(err) => { | |
warn!("decode request failed: {:?}", err); | |
Err(ApiResponse::fail(ApiStatus::InvalidRequest, None)) | |
} | |
} | |
} | |
} | |
} | |
#[derive(Serialize, Deserialize)] | |
pub struct ApiResponse<T = ()> { | |
pub status: ApiStatus, | |
#[serde(default, skip_serializing_if = "Option::is_none")] | |
pub message: Option<String>, | |
#[serde( | |
default = "Option::default", | |
skip_serializing_if = "Option::is_none", | |
bound( | |
serialize = "Option<T>: Serialize", | |
deserialize = "Option<T>: Deserialize<'de>", | |
) | |
)] | |
pub data: Option<T>, | |
} | |
impl<T> From<ApiStatus> for ApiResponse<T> { | |
fn from(value: ApiStatus) -> Self { | |
Self { | |
status: value, | |
message: None, | |
data: None, | |
} | |
} | |
} | |
impl<T> ApiResponse<T> { | |
pub fn success(data: Option<T>) -> Self { | |
Self { | |
status: ApiStatus::Success, | |
message: None, | |
data, | |
} | |
} | |
pub fn fail(status: ApiStatus, message: Option<String>) -> Self { | |
assert!(status != ApiStatus::Success, "fail with success status"); | |
Self { | |
status, | |
message, | |
data: None, | |
} | |
} | |
} | |
impl<T> From<ApiResponse<T>> for Response | |
where | |
ApiResponse<T>: IntoResponse, | |
{ | |
fn from(value: ApiResponse<T>) -> Self { | |
value.into_response() | |
} | |
} | |
impl<T> IntoResponse for ApiResponse<T> | |
where | |
ApiResponse<T>: Serialize, | |
{ | |
fn into_response(self) -> Response { | |
(self.status.http_status_code(), Json(self)).into_response() | |
} | |
} | |
#[derive(Serialize, Deserialize, Copy, Clone, Eq, PartialEq)] | |
#[serde(rename_all = "snake_case")] | |
pub enum ApiStatus { | |
Success, | |
InvalidRequest, | |
ServerError, | |
InvalidToken, | |
PermissionDenied, | |
} | |
impl IntoResponse for ApiStatus { | |
fn into_response(self) -> Response { | |
ApiResponse::<()>::from(self).into_response() | |
} | |
} | |
impl From<ApiStatus> for Response { | |
fn from(value: ApiStatus) -> Self { | |
value.into_response() | |
} | |
} | |
impl ApiStatus { | |
pub fn http_status_code(&self) -> StatusCode { | |
match self { | |
ApiStatus::Success => StatusCode::OK, | |
ApiStatus::InvalidToken => StatusCode::UNAUTHORIZED, | |
ApiStatus::PermissionDenied => StatusCode::FORBIDDEN, | |
_ => StatusCode::INTERNAL_SERVER_ERROR, | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::net::SocketAddr; | |
use anyhow::{Context, Result}; | |
use axum::Router; | |
use tokio::{net::TcpListener, signal::ctrl_c}; | |
use tracing::{error, info, instrument}; | |
pub async fn run_http_server(bind: SocketAddr, router: Router<()>) -> Result<()> { | |
let tcp_listener = TcpListener::bind(bind) | |
.await | |
.with_context(|| format!("can't bind tcp socket {}", bind))?; | |
info!("bound at {}", bind); | |
axum::serve( | |
tcp_listener, | |
router.into_make_service_with_connect_info::<SocketAddr>(), | |
) | |
.with_graceful_shutdown(graceful_shutdown()) | |
.await | |
.context("run http")?; | |
Ok(()) | |
} | |
#[instrument("graceful-shutdown")] | |
async fn graceful_shutdown() { | |
match ctrl_c().await { | |
Ok(_) => { | |
info!("CTRL+C pressed, quiting"); | |
} | |
Err(err) => { | |
error!("tokio CTRL+C signal handler error {}", err); | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::borrow::Cow; | |
use clap::{Args, ValueEnum}; | |
use tracing::level_filters::LevelFilter; | |
use tracing_appender::{non_blocking::WorkerGuard, rolling::Rotation}; | |
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer}; | |
#[derive(Args)] | |
pub struct TracingSetupOpts { | |
#[clap(long = "tracing", default_value = "debug", help = "tracing env filter")] | |
tracing: Cow<'static, str>, | |
#[clap( | |
long = "tracing-output-dir", | |
default_value = "./log", | |
help = "tracing output dir" | |
)] | |
tracing_output_dir: Cow<'static, str>, | |
#[clap( | |
long = "tracing-output-prefix", | |
default_value = "host", | |
help = "tracing output filename prefix" | |
)] | |
tracing_output_prefix: Cow<'static, str>, | |
#[clap( | |
long = "tracing-output-suffix", | |
default_value = "log", | |
help = "tracing output filename suffix" | |
)] | |
tracing_output_suffix: Cow<'static, str>, | |
#[clap( | |
long = "tracing-output-rotate", | |
default_value = "day", | |
help = "tracing output filename suffix" | |
)] | |
tracing_output_rotate: TracingRotateType, | |
#[clap( | |
long = "tracing-output-files", | |
default_value = "7", | |
help = "tracing output file count" | |
)] | |
tracing_output_files: usize, | |
} | |
#[derive(ValueEnum, Clone, Copy)] | |
pub enum TracingRotateType { | |
Minute, | |
Hour, | |
Day, | |
Never, | |
} | |
impl From<TracingRotateType> for Rotation { | |
fn from(value: TracingRotateType) -> Self { | |
match value { | |
TracingRotateType::Minute => Rotation::MINUTELY, | |
TracingRotateType::Hour => Rotation::HOURLY, | |
TracingRotateType::Day => Rotation::DAILY, | |
TracingRotateType::Never => Rotation::NEVER, | |
} | |
} | |
} | |
impl TracingSetupOpts { | |
pub fn setup(&self) -> WorkerGuard { | |
let file_rolling = tracing_appender::rolling::Builder::new() | |
.filename_prefix(self.tracing_output_prefix.as_ref()) | |
.filename_suffix(self.tracing_output_suffix.as_ref()) | |
.rotation(self.tracing_output_rotate.into()) | |
.max_log_files(self.tracing_output_files) // week | |
.build(self.tracing_output_dir.as_ref()) | |
.expect("create tracing file rolling output"); | |
let (file_rolling, guard) = tracing_appender::non_blocking(file_rolling); | |
tracing_subscriber::registry() | |
.with( | |
tracing_subscriber::fmt::layer().with_filter( | |
EnvFilter::builder() | |
.with_default_directive(LevelFilter::DEBUG.into()) | |
.parse_lossy(self.tracing.as_ref()), | |
), | |
) | |
.with( | |
tracing_subscriber::fmt::layer() | |
.json() | |
.with_ansi(false) | |
.with_writer(file_rolling) | |
.with_filter( | |
EnvFilter::builder() | |
.with_default_directive(LevelFilter::DEBUG.into()) | |
.parse_lossy(self.tracing.as_ref()), | |
), | |
) | |
.try_init() | |
.expect("setup tracing output"); | |
guard | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment