Skip to content

Instantly share code, notes, and snippets.

@afpro
Last active April 26, 2025 10:08
Show Gist options
  • Save afpro/7c76ab62e68d73150d9edcc9344c4ece to your computer and use it in GitHub Desktop.
Save afpro/7c76ab62e68d73150d9edcc9344c4ece to your computer and use it in GitHub Desktop.
rust http server
use std::{
borrow::Cow,
fmt::Debug,
future::Future,
net::SocketAddr,
pin::Pin,
task::{ready, Context, Poll},
time::Instant,
};
use axum::{
extract::{ConnectInfo, Request},
response::Response,
};
use pin_project::{pin_project, pinned_drop};
use tower::Service;
use tower_layer::Layer;
use tracing::{error, info, span, warn, Level, Span};
use uuid::Uuid;
#[derive(Copy, Clone)]
pub struct AccessLog;
impl<S> Layer<S> for AccessLog {
type Service = AccessLogService<S>;
fn layer(&self, inner: S) -> Self::Service {
AccessLogService { inner }
}
}
#[derive(Clone)]
pub struct AccessLogService<S> {
inner: S,
}
impl<S> AccessLogService<S> {
fn extract_remote<B>(req: &Request<B>) -> Cow<'static, str> {
match req.extensions().get::<ConnectInfo<SocketAddr>>() {
Some(ConnectInfo(addr)) => addr.to_string().into(),
None => "unknown".into(),
}
}
}
impl<S, Req, Resp> Service<Request<Req>> for AccessLogService<S>
where
S: Service<Request<Req>, Response = Response<Resp>>,
S::Future: Future<Output = Result<Response<Resp>, S::Error>>,
S::Error: Debug,
{
type Response = Response<Resp>;
type Error = S::Error;
type Future = AccessLogServiceFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Req>) -> Self::Future {
let span = span!(Level::INFO, "request", id=%Uuid::new_v4().simple());
{
let _guard = span.enter();
info!(
target: "request",
remote = %Self::extract_remote(&req),
method = %req.method(),
uri = %req.uri(),
headers = ?req.headers(),
"begin",
);
}
AccessLogServiceFuture::new(span, self.inner.call(req))
}
}
#[pin_project(PinnedDrop)]
pub struct AccessLogServiceFuture<F> {
span: Span,
done: bool,
start: Instant,
#[pin]
inner: F,
}
impl<F> AccessLogServiceFuture<F> {
fn new(span: Span, inner: F) -> Self {
Self {
span,
done: false,
start: Instant::now(),
inner,
}
}
}
impl<F, B, E> Future for AccessLogServiceFuture<F>
where
F: Future<Output = Result<Response<B>, E>>,
E: Debug,
{
type Output = Result<Response<B>, E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let _guard = this.span.enter();
let result = ready!(this.inner.poll(cx));
if !*this.done {
*this.done = true;
let cost = Instant::now().duration_since(*this.start);
match &result {
Ok(response) => {
if (400..=599).contains(&response.status().as_u16()) {
error!(
target: "request",
status = response.status().as_u16(),
cost = cost.as_millis(),
"end with error status"
);
} else {
info!(
target: "request",
status = response.status().as_u16(),
cost = cost.as_millis(),
"end ok"
);
}
}
Err(err) => {
error!(
target: "request",
cost = cost.as_millis(),
"end with uncached error {:?}", err
);
}
}
}
Poll::Ready(result)
}
}
#[pinned_drop]
impl<F> PinnedDrop for AccessLogServiceFuture<F> {
fn drop(self: Pin<&mut Self>) {
if !self.done {
let _guard = self.span.enter();
let cost = Instant::now().duration_since(self.start);
warn!(
target: "request",
cost = cost.as_millis(),
"request connection dropped before finish",
);
}
}
}
use std::cmp::max;
#[cfg(feature = "config")]
use std::convert::Infallible;
use anyhow::{Context, Result};
use mysql_async::{Conn, Pool, PoolConstraints, PoolOpts};
#[cfg(feature = "config")]
use mysql_async::{Opts, OptsBuilder};
use tracing::info;
#[cfg(feature = "config")]
#[derive(clap::Args, Clone)]
pub struct DataMysqlOpts {
#[clap(
name = "mysql-host",
long = "mysql-host",
default_value = "127.0.0.1",
help = "mysql ip or host"
)]
pub host: String,
#[clap(
name = "mysql-port",
long = "mysql-port",
default_value = "3306",
help = "mysql port"
)]
pub port: u16,
#[clap(
name = "mysql-username",
long = "mysql-user",
default_value = "root",
help = "mysql username"
)]
pub user: String,
#[clap(
name = "mysql-password",
long = "mysql-pass",
default_value = "",
help = "mysql password"
)]
pub pass: String,
#[clap(
name = "mysql-db-name",
long = "mysql-db-name",
default_value = "bot",
help = "mysql database name"
)]
pub name: String,
#[clap(
name = "mysql-max-connection",
long = "mysql-max-connection",
default_value = "128",
help = "mysql database name"
)]
pub max_connection: usize,
}
#[derive(Clone, derive_more::From)]
pub struct DataMysql {
#[from]
pool: Pool,
}
impl DataMysql {
#[cfg(feature = "config")]
pub async fn create_by_opts(opts: &DataMysqlOpts) -> Result<Self> {
Ok(Self {
pool: Pool::new(opts),
})
}
pub async fn get_conn(&self) -> Result<Conn> {
self.pool
.get_conn()
.await
.context("obtain mysql connection")
}
pub async fn dump_info(&self) -> Result<()> {
let conn = self.get_conn().await?;
let (major, minor, patch) = conn.server_version();
info!("connected to mysql {}.{}.{}", major, minor, patch);
Ok(())
}
}
#[cfg(feature = "config")]
impl TryFrom<&DataMysqlOpts> for Opts {
type Error = Infallible;
fn try_from(value: &DataMysqlOpts) -> Result<Self, Self::Error> {
Ok(Opts::from(
OptsBuilder::default()
.ip_or_hostname(&value.host)
.tcp_port(value.port)
.user(Some(&value.user))
.pass(Some(&value.pass))
.db_name(Some(&value.name))
.pool_opts(
PoolOpts::default().with_constraints(
PoolConstraints::new(0, max(value.max_connection, 1))
.expect("create constraint"),
),
),
))
}
}
use std::{
sync::Arc,
time::{Duration, Instant},
};
use anyhow::{Context, Result};
use redis::{
aio::{MultiplexedConnection, PubSub},
cmd, Client,
};
#[cfg(feature = "config")]
use redis::{
ConnectionAddr, ConnectionInfo, IntoConnectionInfo, ProtocolVersion, RedisConnectionInfo,
RedisResult,
};
use tokio::sync::RwLock;
use tracing::info;
#[cfg(feature = "config")]
#[derive(clap::Args, Clone)]
pub struct DataRedisOpts {
#[clap(
name = "redis-host",
long = "redis-host",
default_value = "127.0.0.1",
help = "redis ip or host"
)]
pub host: String,
#[clap(
name = "redis-port",
long = "redis-port",
default_value = "6379",
help = "redis port"
)]
pub port: u16,
#[clap(name = "redis-username", long = "redis-user", help = "redis username")]
pub user: Option<String>,
#[clap(name = "redis-password", long = "redis-pass", help = "redis password")]
pub pass: Option<String>,
#[clap(
name = "redis-index",
long = "redis-index",
default_value = "0",
help = "redis database index"
)]
pub index: i64,
}
#[derive(Clone)]
pub struct DataRedis {
client: Client,
cache_io: Arc<RwLock<CacheIo>>,
}
struct CacheIo {
io: Option<MultiplexedConnection>,
expire_at: Instant,
}
impl DataRedis {
#[cfg(feature = "config")]
pub async fn create_by_opts(opts: &DataRedisOpts) -> Result<Self> {
let client = Client::open(opts).context("open redis client")?;
Ok(Self {
client,
cache_io: Arc::new(RwLock::new(CacheIo {
io: None,
expire_at: Instant::now(),
})),
})
}
pub async fn get_conn(&self) -> Result<MultiplexedConnection> {
let now = Instant::now();
// check cache
{
let cache_io = self.cache_io.read().await;
if now < cache_io.expire_at {
if let Some(io) = cache_io.io.clone() {
return Ok(io);
}
}
}
// update cache
let io = {
let mut cache_io = self.cache_io.write().await;
if now < cache_io.expire_at {
if let Some(io) = cache_io.io.clone() {
return Ok(io);
}
}
let io = self
.client
.get_multiplexed_async_connection()
.await
.context("obtain redis connection")?;
cache_io.io = Some(io.clone());
cache_io.expire_at = now + Duration::from_secs(300); // re-connect every 300s (5min)
io
};
Ok(io)
}
pub async fn get_pub_sub(&self) -> Result<PubSub> {
self.client
.get_async_pubsub()
.await
.context("obtain redis pub sub")
}
pub async fn dump_info(&self) -> Result<()> {
let mut conn = self.get_conn().await?;
let info = cmd("info")
.arg("server")
.query_async::<String>(&mut conn)
.await
.context("dump redis server info")?;
let version = info
.lines()
.map(|v| v.trim())
.filter_map(|v| v.strip_prefix("redis_version:"))
.next()
.context("can't extract redis_version from info dump")?;
info!("connected to redis {}", version);
Ok(())
}
}
#[cfg(feature = "config")]
impl IntoConnectionInfo for &DataRedisOpts {
fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
Ok(ConnectionInfo {
addr: ConnectionAddr::Tcp(self.host.clone(), self.port),
redis: RedisConnectionInfo {
protocol: ProtocolVersion::RESP3,
db: self.index,
username: self.user.clone(),
password: self.pass.clone(),
},
})
}
}
use std::io::Read;
use deku::prelude::*;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
#[derive(Debug, thiserror::Error)]
pub enum DekuIOError {
#[error("IOError({})", _0)]
IOError(
#[source]
#[from]
std::io::Error,
),
#[error("DekuError({})", _0)]
DekuError(
#[source]
#[from]
DekuError,
),
}
pub async fn deku_read<T, R, C>(stream: &mut R, context: C) -> Result<T, DekuIOError>
where
R: AsyncBufReadExt + Unpin + ?Sized,
T: DekuReader<'static, C>,
C: Clone,
{
let mut buf = Vec::<u8>::new();
loop {
let io_buf = stream.fill_buf().await?;
let mut input = buf.as_slice().chain(io_buf);
let mut reader = Reader::new(&mut input);
match T::from_reader_with_ctx(&mut reader, context.clone()) {
Ok(v) => {
let consumed = reader.bits_read.div_ceil(8) - buf.len();
stream.consume(consumed);
return Ok(v);
}
Err(DekuError::Incomplete(_)) => {
let consumed = io_buf.len();
buf.extend_from_slice(io_buf);
stream.consume(consumed);
continue;
}
Err(err) => return Err(err.into()),
}
}
}
pub async fn deku_write<T, W, C>(
stream: &mut W,
value: &T,
context: C,
) -> Result<usize, DekuIOError>
where
W: AsyncWriteExt + Unpin + ?Sized,
T: DekuWriter<C>,
{
let mut buf = Vec::<u8>::new();
let mut writer = Writer::new(&mut buf);
value.to_writer(&mut writer, context)?;
stream.write_all(&buf).await?;
Ok(buf.len())
}
use std::{
future::Future,
time::{Duration, Instant},
};
use anyhow::{Context, Error, Result};
use async_trait::async_trait;
use indoc::indoc;
use lazy_static::lazy_static;
use rand::{Rng, rng};
use redis::{AsyncCommands, RedisError, Script};
use tokio::{select, time::sleep};
use uuid::Uuid;
lazy_static! {
static ref SCRIPT_INIT_LOCK: Script = Script::new(indoc! {r#"
if (redis.call('SET', KEYS[1], ARGV[1], 'NX', 'EX', '2'))
then
return true;
else
return false;
end
"#});
static ref SCRIPT_UPDATE: Script = Script::new(indoc! {r#"
local current = redis.call('GET', KEYS[1]);
if (current ~= ARGV[1])
then
return false;
end
if (redis.call('EXPIRE', KEYS[1], '2'))
then
return true;
else
return false;
end
"#});
}
#[derive(Debug, thiserror::Error)]
pub enum DistributionLockError {
#[error("redis error {0}")]
RedisError(#[source] Error),
#[error("task error {0}")]
TaskError(#[source] Error),
#[error("lock current already acquired by others")]
LockCurrentAcquired,
#[error("lock update during task failed")]
LockUpdateFailed,
#[error("lock update during task failed by error {0}")]
LockUpdateError(#[source] RedisError),
#[error("lock release failed by error {0}")]
LockReleaseError(#[source] RedisError),
}
#[async_trait]
pub trait DistributionLock: AsyncCommands + Sized {
async fn run_task_with_lock<R: Send, F: Future<Output = Result<R>> + Send>(
&mut self,
key: &str,
task: F,
timeout: Option<Duration>,
) -> Result<R, DistributionLockError> {
let value = Uuid::new_v4();
let timeout = timeout.map(|v| Instant::now() + v);
loop {
let init_lock: bool = SCRIPT_INIT_LOCK
.prepare_invoke()
.key(key)
.arg(value)
.invoke_async(self)
.await
.context("init lock")
.map_err(DistributionLockError::RedisError)?;
if init_lock {
break;
}
if matches!(timeout, Some(timeout) if timeout < Instant::now()) {
return Err(DistributionLockError::LockCurrentAcquired);
}
let delay = Duration::from_millis(rng().random_range(10..50));
sleep(delay).await;
}
let update_lock = async {
loop {
match SCRIPT_UPDATE
.prepare_invoke()
.key(key)
.arg(value)
.invoke_async::<bool>(self)
.await
{
Ok(true) => {
sleep(Duration::from_secs(1)).await;
}
Ok(false) => {
return DistributionLockError::LockUpdateFailed;
}
Err(err) => {
return DistributionLockError::LockUpdateError(err);
}
}
}
};
// run task & periodically update redis
let ret = select! {
task_ret = task => {
task_ret.map_err(DistributionLockError::TaskError)?
}
update_ret = update_lock => {
return Err(update_ret);
}
};
// remove redis key
self.del::<_, ()>(&key)
.await
.map_err(DistributionLockError::LockReleaseError)?;
Ok(ret)
}
}
impl<C> DistributionLock for C where C: AsyncCommands + Send {}
#[cfg(test)]
mod test {
use std::time::Duration;
use redis::{Client, cmd};
use tokio::time::sleep;
use uuid::Uuid;
use super::DistributionLock;
#[tokio::test]
async fn run_lock() {
let test_name = "TEST_LOCK:test_lock";
let cache = Client::open("redis://127.0.0.1").expect("connect redis");
let mut conn = cache
.get_multiplexed_async_connection()
.await
.expect("obtain redis connection");
conn.clone()
.run_task_with_lock(
test_name,
async {
let value: Uuid = cmd("get").arg(test_name).query_async(&mut conn).await?;
sleep(Duration::from_secs(10)).await;
let post_value: Uuid = cmd("get").arg(test_name).query_async(&mut conn).await?;
assert_eq!(value, post_value);
Ok(())
},
None,
)
.await
.expect("run with lock");
}
}
use anyhow::{Context, Result};
use async_trait::async_trait;
use indoc::indoc;
use lazy_static::lazy_static;
use redis::{aio::ConnectionLike, Script};
lazy_static! {
static ref MODIFY_PERMIT_SCRIPT: Script = Script::new(indoc! {r#"
-- keys
local usage_key = KEYS[1];
local timestamp_key = KEYS[2];
-- arguments
local duration_in_sec = tonumber(ARGV[1]);
local permit_per_sec = tonumber(ARGV[2]);
local now_in_sec = tonumber(ARGV[3]);
local submit = tonumber(ARGV[4]);
local acquire = tonumber(ARGV[5]);
-- check current usage
local usage = redis.call('GET', usage_key)
local timestamp = redis.call('GET', timestamp_key)
-- avoid key missing
if usage ~= nil and usage ~= false
then
usage = tonumber(usage) + submit
else
usage = submit
end
if timestamp ~= nil and timestamp ~= false
then
timestamp = tonumber(timestamp)
if timestamp < now_in_sec
then
usage = math.max(usage - (now_in_sec - timestamp) * permit_per_sec, 0)
end
end
-- check permit can be acquired
local quota = math.max(duration_in_sec * permit_per_sec - usage, 0)
local acquired = math.min(quota, acquire)
-- process acquirement
usage = usage + acquired
local expire = math.ceil(usage / permit_per_sec)
redis.call('SET', usage_key, tostring(usage))
redis.call('SET', timestamp_key, tostring(now_in_sec))
redis.call('EXPIRE', usage_key, tostring(expire))
redis.call('EXPIRE', timestamp_key, tostring(expire))
-- finish
return { acquired, quota - acquired }
"#});
}
#[derive(Copy, Clone)]
pub struct RedisPermitConfig<'a> {
pub name: &'a str,
pub duration_in_secs: u64,
pub permit_per_sec: u64,
}
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub struct RedisPermitStatus {
pub acquired: u64,
pub quota: u64,
}
#[async_trait]
pub trait RedisPermit: ConnectionLike + Sized {
async fn modify_permit(
&mut self,
config: RedisPermitConfig<'_>,
submit: u64,
acquire: u64,
) -> Result<RedisPermitStatus> {
let mut invoke = MODIFY_PERMIT_SCRIPT.prepare_invoke();
invoke.key(format!("PERMIT:U:{}", config.name));
invoke.key(format!("PERMIT:T:{}", config.name));
invoke.arg(config.duration_in_secs);
invoke.arg(config.permit_per_sec);
invoke.arg(now_in_sec());
invoke.arg(submit);
invoke.arg(acquire);
invoke
.invoke_async::<(u64, u64)>(self)
.await
.map(|(acquired, quota)| RedisPermitStatus { acquired, quota })
.context("invoke script")
}
async fn acquire_permit(
&mut self,
config: RedisPermitConfig<'_>,
acquire: u64,
) -> Result<RedisPermitStatus> {
self.modify_permit(config, 0, acquire).await
}
async fn submit_permit(
&mut self,
config: RedisPermitConfig<'_>,
submit: u64,
) -> Result<RedisPermitStatus> {
self.modify_permit(config, submit, 0).await
}
}
impl<T> RedisPermit for T where T: ConnectionLike + Sized {}
#[cfg(test)]
fn now_in_sec() -> i64 {
test::now_in_sec()
}
#[cfg(not(test))]
fn now_in_sec() -> i64 {
use chrono::Utc;
Utc::now().timestamp()
}
#[cfg(test)]
mod test {
use std::sync::atomic::{AtomicI64, Ordering::Relaxed};
use redis::Client;
use super::{RedisPermit, RedisPermitConfig, RedisPermitStatus};
#[allow(dead_code)]
static TIME_IN_SEC: AtomicI64 = AtomicI64::new(0);
pub fn now_in_sec() -> i64 {
TIME_IN_SEC.load(Relaxed)
}
pub fn set_now_in_sec(v: i64) {
TIME_IN_SEC.store(v, Relaxed)
}
#[tokio::test]
async fn acquire() {
let redis = Client::open("redis://127.0.0.1/").expect("open redis");
let mut conn = redis
.get_multiplexed_async_connection()
.await
.expect("open connection");
let config = RedisPermitConfig {
name: "test-acquire",
duration_in_secs: 5,
permit_per_sec: 1,
};
set_now_in_sec(0);
assert_eq!(
conn.acquire_permit(config, 3).await.expect("acquire"),
RedisPermitStatus {
acquired: 3,
quota: 2,
}
);
assert_eq!(
conn.acquire_permit(config, 2).await.expect("acquire"),
RedisPermitStatus {
acquired: 2,
quota: 0,
}
);
set_now_in_sec(1);
assert_eq!(
conn.acquire_permit(config, 3).await.expect("acquire"),
RedisPermitStatus {
acquired: 1,
quota: 0,
}
);
}
}
use async_trait::async_trait;
use axum::{
extract::{FromRequest, Query, Request},
http::{Method, StatusCode},
response::{IntoResponse, Response},
Json,
};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use tracing::warn;
#[macro_export]
macro_rules! api_response_combine {
($enum_name:ident { $($name:ident($type:ty),)+ }) => {
pub enum $enum_name {
Status($crate::rest::ApiStatus),
Raw(axum::response::Response),
$($name($type),)+
}
impl From<$crate::rest::ApiStatus> for $enum_name {
fn from(value: $crate::rest::ApiStatus) -> Self {
Self::Status(value)
}
}
impl From<axum::response::Response> for $enum_name {
fn from(value: axum::response::Response) -> Self {
Self::Raw(value)
}
}
$(
impl From<$type> for $enum_name {
fn from(value: $type) -> Self {
Self::$name(value)
}
}
)+
impl axum::response::IntoResponse for $enum_name {
fn into_response(self) -> axum::response::Response {
match self {
Self::Status(status) => $crate::rest::ApiResponse::<()>::from(status).into_response(),
Self::Raw(response) => response,
$(Self::$name(v) => v.into_response(),)+
}
}
}
};
}
#[macro_export]
macro_rules! api_require_token {
($request:ident) => {
$crate::api_require_token!(_inner &$request.token)
};
(take $request:ident) => {
$crate::api_require_token!(_inner $request.token)
};
(_inner $token:expr) => {
match $token {
Some(token) => token,
None => {
tracing::warn!("token missing");
return $crate::rest::ApiResponse::fail(
$crate::rest::ApiStatus::InvalidRequest,
Some("token missing".to_string()),
)
.into();
}
}
}
}
#[macro_export]
macro_rules! api_require_data {
($request:ident) => {
match &$request.data {
Some(data) => data,
None => {
tracing::warn!("data missing");
return $crate::rest::ApiResponse::fail(
$crate::rest::ApiStatus::InvalidRequest,
Some("data missing".to_string()),
)
.into();
}
}
};
(take $request:ident) => {
match $request.data {
Some(data) => data,
None => {
tracing::warn!("data missing");
return $crate::rest::ApiResponse::fail(
$crate::rest::ApiStatus::InvalidRequest,
Some("data missing".to_string()),
)
.into();
}
}
};
}
#[macro_export]
macro_rules! api_must_success {
($op:expr, $msg:literal) => {
match $op {
Ok(v) => v,
Err(err) => {
tracing::warn!("{} error: {:?}", $msg, err);
return $crate::rest::ApiStatus::ServerError.into();
}
}
};
}
#[derive(Serialize, Deserialize)]
pub struct ApiRequest<T = ()> {
#[serde(rename = "_t", default, skip_serializing_if = "Option::is_none")]
pub token: Option<String>,
#[serde(
flatten,
default = "Option::default",
skip_serializing_if = "Option::is_none",
bound(
serialize = "Option<T>: Serialize",
deserialize = "Option<T>: Deserialize<'de>",
)
)]
pub data: Option<T>,
}
#[async_trait]
impl<T, S> FromRequest<S> for ApiRequest<T>
where
ApiRequest<T>: DeserializeOwned,
S: Send + Sync,
{
type Rejection = ApiResponse;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
if req.method() == Method::GET {
match Query::<ApiRequest<T>>::from_request(req, state).await {
Ok(Query(data)) => Ok(data),
Err(err) => {
warn!("decode request failed: {:?}", err);
Err(ApiResponse::fail(ApiStatus::InvalidRequest, None))
}
}
} else {
match Json::<ApiRequest<T>>::from_request(req, state).await {
Ok(Json(data)) => Ok(data),
Err(err) => {
warn!("decode request failed: {:?}", err);
Err(ApiResponse::fail(ApiStatus::InvalidRequest, None))
}
}
}
}
}
#[derive(Serialize, Deserialize)]
pub struct ApiResponse<T = ()> {
pub status: ApiStatus,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub message: Option<String>,
#[serde(
default = "Option::default",
skip_serializing_if = "Option::is_none",
bound(
serialize = "Option<T>: Serialize",
deserialize = "Option<T>: Deserialize<'de>",
)
)]
pub data: Option<T>,
}
impl<T> From<ApiStatus> for ApiResponse<T> {
fn from(value: ApiStatus) -> Self {
Self {
status: value,
message: None,
data: None,
}
}
}
impl<T> ApiResponse<T> {
pub fn success(data: Option<T>) -> Self {
Self {
status: ApiStatus::Success,
message: None,
data,
}
}
pub fn fail(status: ApiStatus, message: Option<String>) -> Self {
assert!(status != ApiStatus::Success, "fail with success status");
Self {
status,
message,
data: None,
}
}
}
impl<T> From<ApiResponse<T>> for Response
where
ApiResponse<T>: IntoResponse,
{
fn from(value: ApiResponse<T>) -> Self {
value.into_response()
}
}
impl<T> IntoResponse for ApiResponse<T>
where
ApiResponse<T>: Serialize,
{
fn into_response(self) -> Response {
(self.status.http_status_code(), Json(self)).into_response()
}
}
#[derive(Serialize, Deserialize, Copy, Clone, Eq, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum ApiStatus {
Success,
InvalidRequest,
ServerError,
InvalidToken,
PermissionDenied,
}
impl IntoResponse for ApiStatus {
fn into_response(self) -> Response {
ApiResponse::<()>::from(self).into_response()
}
}
impl From<ApiStatus> for Response {
fn from(value: ApiStatus) -> Self {
value.into_response()
}
}
impl ApiStatus {
pub fn http_status_code(&self) -> StatusCode {
match self {
ApiStatus::Success => StatusCode::OK,
ApiStatus::InvalidToken => StatusCode::UNAUTHORIZED,
ApiStatus::PermissionDenied => StatusCode::FORBIDDEN,
_ => StatusCode::INTERNAL_SERVER_ERROR,
}
}
}
use std::net::SocketAddr;
use anyhow::{Context, Result};
use axum::Router;
use tokio::{net::TcpListener, signal::ctrl_c};
use tracing::{error, info, instrument};
pub async fn run_http_server(bind: SocketAddr, router: Router<()>) -> Result<()> {
let tcp_listener = TcpListener::bind(bind)
.await
.with_context(|| format!("can't bind tcp socket {}", bind))?;
info!("bound at {}", bind);
axum::serve(
tcp_listener,
router.into_make_service_with_connect_info::<SocketAddr>(),
)
.with_graceful_shutdown(graceful_shutdown())
.await
.context("run http")?;
Ok(())
}
#[instrument("graceful-shutdown")]
async fn graceful_shutdown() {
match ctrl_c().await {
Ok(_) => {
info!("CTRL+C pressed, quiting");
}
Err(err) => {
error!("tokio CTRL+C signal handler error {}", err);
}
}
}
use std::borrow::Cow;
use clap::{Args, ValueEnum};
use tracing::level_filters::LevelFilter;
use tracing_appender::{non_blocking::WorkerGuard, rolling::Rotation};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer};
#[derive(Args)]
pub struct TracingSetupOpts {
#[clap(long = "tracing", default_value = "debug", help = "tracing env filter")]
tracing: Cow<'static, str>,
#[clap(
long = "tracing-output-dir",
default_value = "./log",
help = "tracing output dir"
)]
tracing_output_dir: Cow<'static, str>,
#[clap(
long = "tracing-output-prefix",
default_value = "host",
help = "tracing output filename prefix"
)]
tracing_output_prefix: Cow<'static, str>,
#[clap(
long = "tracing-output-suffix",
default_value = "log",
help = "tracing output filename suffix"
)]
tracing_output_suffix: Cow<'static, str>,
#[clap(
long = "tracing-output-rotate",
default_value = "day",
help = "tracing output filename suffix"
)]
tracing_output_rotate: TracingRotateType,
#[clap(
long = "tracing-output-files",
default_value = "7",
help = "tracing output file count"
)]
tracing_output_files: usize,
}
#[derive(ValueEnum, Clone, Copy)]
pub enum TracingRotateType {
Minute,
Hour,
Day,
Never,
}
impl From<TracingRotateType> for Rotation {
fn from(value: TracingRotateType) -> Self {
match value {
TracingRotateType::Minute => Rotation::MINUTELY,
TracingRotateType::Hour => Rotation::HOURLY,
TracingRotateType::Day => Rotation::DAILY,
TracingRotateType::Never => Rotation::NEVER,
}
}
}
impl TracingSetupOpts {
pub fn setup(&self) -> WorkerGuard {
let file_rolling = tracing_appender::rolling::Builder::new()
.filename_prefix(self.tracing_output_prefix.as_ref())
.filename_suffix(self.tracing_output_suffix.as_ref())
.rotation(self.tracing_output_rotate.into())
.max_log_files(self.tracing_output_files) // week
.build(self.tracing_output_dir.as_ref())
.expect("create tracing file rolling output");
let (file_rolling, guard) = tracing_appender::non_blocking(file_rolling);
tracing_subscriber::registry()
.with(
tracing_subscriber::fmt::layer().with_filter(
EnvFilter::builder()
.with_default_directive(LevelFilter::DEBUG.into())
.parse_lossy(self.tracing.as_ref()),
),
)
.with(
tracing_subscriber::fmt::layer()
.json()
.with_ansi(false)
.with_writer(file_rolling)
.with_filter(
EnvFilter::builder()
.with_default_directive(LevelFilter::DEBUG.into())
.parse_lossy(self.tracing.as_ref()),
),
)
.try_init()
.expect("setup tracing output");
guard
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment