Created
February 3, 2020 11:08
-
-
Save frederikbosch/de51d7d329c5c2097e72361f1c4a0764 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::collections::HashMap; | |
use std::sync::Arc; | |
use std::sync::Mutex; | |
use tokio::sync::{mpsc, watch}; | |
use tokio::sync::RwLock; | |
use tonic::{Code, Request, Response, Status}; | |
use def::health_check_response::ServingStatus; | |
use def::health_server; | |
use def::{HealthCheckRequest, HealthCheckResponse}; | |
pub mod def { | |
tonic::include_proto!("grpc.health.v1"); | |
} | |
type HealthResult<T> = Result<Response<T>, Status>; | |
type ResponseStream<T> = mpsc::Receiver<Result<T, Status>>; | |
pub struct ServiceRegister { | |
services: Mutex<HashMap<String, ServingStatus>>, | |
signaller: Mutex<watch::Sender<Option<(String, ServingStatus)>>>, | |
subscriber: watch::Receiver<Option<(String, ServingStatus)>>, | |
} | |
impl ServiceRegister { | |
pub fn new(initial: ServingStatus) -> Self { | |
let mut services = HashMap::new(); | |
services.insert(String::new(), initial); | |
let (tx, rx) = watch::channel(None); | |
ServiceRegister { | |
services: Mutex::new(services), | |
signaller: Mutex::new(tx), | |
subscriber: rx, | |
} | |
} | |
pub fn get_status(&self, service_name: &str) -> Option<ServingStatus> { | |
if let Ok(services) = self.services.lock() { | |
services.get(service_name).copied() | |
} else { | |
None | |
} | |
} | |
pub fn set_status(&mut self, service_name: &str, status: ServingStatus) -> Result<(), ()> { | |
if let Ok(writer) = self.services.get_mut() { | |
let last = writer.insert(service_name.to_string(), status); | |
if last.is_none() || last.unwrap() != status { | |
if let Ok(sig) = self.signaller.lock() { | |
sig.broadcast(Some((service_name.to_string(), status))) | |
.map_err(|_| ())?; | |
} | |
} | |
Ok(()) | |
} else { | |
Err(()) | |
} | |
} | |
pub fn subscribe (&self) -> watch::Receiver<Option<(String, ServingStatus)>> { | |
self.subscriber.clone() | |
} | |
pub fn shutdown(&mut self) -> Result<(), ()> { | |
if let Ok(writer) = self.services.get_mut() { | |
for (name, status) in writer.iter_mut() { | |
if *status != ServingStatus::NotServing { | |
*status = ServingStatus::NotServing; | |
if let Ok(sig) = self.signaller.lock() { | |
sig.broadcast(Some((name.to_string(), *status))).map_err(|_| ())?; | |
} | |
} | |
} | |
Ok(()) | |
} else { | |
Err(()) | |
} | |
} | |
} | |
pub struct HealthCheckService { | |
register: Arc<RwLock<ServiceRegister>>, | |
} | |
impl HealthCheckService { | |
pub fn new(register: Arc<RwLock<ServiceRegister>>) -> Self { | |
HealthCheckService { register } | |
} | |
} | |
#[tonic::async_trait] | |
impl health_server::Health for HealthCheckService { | |
async fn check( | |
&self, | |
request: Request<HealthCheckRequest>, | |
) -> HealthResult<HealthCheckResponse> { | |
match self.register.read().await.get_status(&request.get_ref().service) { | |
Some(status) => { | |
let response = Response::new(HealthCheckResponse { | |
status: status.clone() as i32, | |
}); | |
Ok(response) | |
} | |
None => Err(Status::new(Code::NotFound, "")), | |
} | |
} | |
type WatchStream = ResponseStream<HealthCheckResponse>; | |
async fn watch(&self, request: Request<HealthCheckRequest>) -> HealthResult<Self::WatchStream> { | |
let name = &request.get_ref().service; | |
let (mut tx, res_rx) = mpsc::channel(10); | |
if let Some(status) = self.register.read().await.get_status(name) { | |
let _ = tx.send(Ok(HealthCheckResponse { | |
status: status as i32, | |
})) | |
.await; | |
let mut rx = self.register.read().await.subscribe(); | |
while let Some(value) = rx.recv().await { | |
match value { | |
Some((changed_name, status)) => { | |
if *name == changed_name { | |
let _ = tx.send(Ok(HealthCheckResponse { | |
status: status as i32, | |
})) | |
.await; | |
} | |
}, | |
_ => {}, | |
} | |
} | |
Ok(Response::new(res_rx)) | |
} else { | |
Err(Status::new(Code::NotFound, "")) | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment