Created
January 26, 2025 00:51
-
-
Save dtkav/0c2af846213e80fba469f97bfb04c0ce to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// ... (previous imports and constants remain the same) ... | |
pub struct Server { | |
docs: Arc<DashMap<String, DocWithSyncKv>>, | |
doc_worker_tracker: TaskTracker, | |
store: Option<Arc<Box<dyn Store>>>, | |
checkpoint_freq: Duration, | |
authenticator: Option<Authenticator>, | |
url_prefix: Option<Url>, | |
cancellation_token: CancellationToken, | |
doc_gc: bool, | |
client_secret: Option<String>, // Now optional | |
} | |
impl Server { | |
pub async fn new( | |
store: Option<Box<dyn Store>>, | |
checkpoint_freq: Duration, | |
authenticator: Option<Authenticator>, | |
url_prefix: Option<Url>, | |
cancellation_token: CancellationToken, | |
doc_gc: bool, | |
client_secret: Option<String>, // Now optional parameter | |
) -> Result<Self> { | |
Ok(Self { | |
docs: Arc::new(DashMap::new()), | |
doc_worker_tracker: TaskTracker::new(), | |
store: store.map(Arc::new), | |
checkpoint_freq, | |
authenticator, | |
url_prefix, | |
cancellation_token, | |
doc_gc, | |
client_secret, | |
}) | |
} | |
// ... (rest of Server impl remains the same) ... | |
} | |
#[derive(Deserialize)] | |
struct DataPlaneWebSocketQueryParams { | |
token: Option<String>, | |
client_secret: Option<String>, | |
} | |
#[derive(Deserialize)] | |
struct DataPlaneQueryParams { | |
client_secret: Option<String>, | |
} | |
async fn validate_client_secret( | |
server_state: &Arc<Server>, | |
provided_secret: Option<String>, | |
) -> Result<(), AppError> { | |
if let Some(required_secret) = &server_state.client_secret { | |
let provided = provided_secret.ok_or_else(|| { | |
AppError( | |
StatusCode::UNAUTHORIZED, | |
anyhow!("Client secret required"), | |
) | |
})?; | |
if provided != *required_secret { | |
return Err(AppError( | |
StatusCode::UNAUTHORIZED, | |
anyhow!("Invalid client secret"), | |
)); | |
} | |
} | |
Ok(()) | |
} | |
async fn get_doc_as_update( | |
State(server_state): State<Arc<Server>>, | |
Path(doc_id): Path<String>, | |
Query(params): Query<DataPlaneQueryParams>, | |
auth_header: Option<TypedHeader<headers::Authorization<headers::authorization::Bearer>>>, | |
) -> Result<Response, AppError> { | |
validate_client_secret(&server_state, params.client_secret).await?; | |
let token = get_token_from_header(auth_header); | |
let _ = server_state.verify_doc_token(token.as_deref(), &doc_id)?; | |
let dwskv = server_state | |
.get_or_create_doc(&doc_id) | |
.await | |
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?; | |
let update = dwskv.as_update(); | |
Ok(update.into_response()) | |
} | |
async fn update_doc( | |
Path(doc_id): Path<String>, | |
State(server_state): State<Arc<Server>>, | |
Query(params): Query<DataPlaneQueryParams>, | |
auth_header: Option<TypedHeader<headers::Authorization<headers::authorization::Bearer>>>, | |
body: Bytes, | |
) -> Result<Response, AppError> { | |
validate_client_secret(&server_state, params.client_secret).await?; | |
let token = get_token_from_header(auth_header); | |
let authorization = server_state.verify_doc_token(token.as_deref(), &doc_id)?; | |
update_doc_inner(doc_id, server_state, authorization, body).await | |
} | |
async fn handle_socket_upgrade_full_path( | |
ws: WebSocketUpgrade, | |
Path((doc_id, doc_id2)): Path<(String, String)>, | |
Query(params): Query<DataPlaneWebSocketQueryParams>, | |
State(server_state): State<Arc<Server>>, | |
) -> Result<Response, AppError> { | |
validate_client_secret(&server_state, params.client_secret).await?; | |
if doc_id != doc_id2 { | |
return Err(AppError( | |
StatusCode::BAD_REQUEST, | |
anyhow!("For Yjs compatibility, the doc_id appears twice in the URL. It must be the same in both places, but we got {} and {}.", doc_id, doc_id2), | |
)); | |
} | |
let authorization = server_state.verify_doc_token(params.token.as_deref(), &doc_id)?; | |
handle_socket_upgrade(ws, Path(doc_id), authorization, State(server_state)).await | |
} | |
// ... (rest of the code remains the same) ... | |
#[cfg(test)] | |
mod test { | |
use super::*; | |
use y_sweet_core::api_types::Authorization; | |
#[tokio::test] | |
async fn test_with_client_secret() { | |
let server_state = Server::new( | |
None, | |
Duration::from_secs(60), | |
None, | |
None, | |
CancellationToken::new(), | |
true, | |
Some("test_secret".to_string()), // Configured with client secret | |
) | |
.await | |
.unwrap(); | |
// Test should fail without client_secret | |
let response = get_doc_as_update( | |
State(Arc::new(server_state)), | |
Path("test".to_string()), | |
Query(DataPlaneQueryParams { client_secret: None }), | |
None, | |
) | |
.await; | |
assert!(response.is_err()); | |
assert_eq!(response.unwrap_err().0, StatusCode::UNAUTHORIZED); | |
} | |
#[tokio::test] | |
async fn test_without_client_secret() { | |
let server_state = Server::new( | |
None, | |
Duration::from_secs(60), | |
None, | |
None, | |
CancellationToken::new(), | |
true, | |
None, // No client secret configured | |
) | |
.await | |
.unwrap(); | |
// Test should succeed without client_secret | |
let response = get_doc_as_update( | |
State(Arc::new(server_state)), | |
Path("test".to_string()), | |
Query(DataPlaneQueryParams { client_secret: None }), | |
None, | |
) | |
.await; | |
assert!(response.is_ok()); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment