Skip to content

Instantly share code, notes, and snippets.

@dtkav
Created January 26, 2025 00:51
Show Gist options
  • Save dtkav/0c2af846213e80fba469f97bfb04c0ce to your computer and use it in GitHub Desktop.
Save dtkav/0c2af846213e80fba469f97bfb04c0ce to your computer and use it in GitHub Desktop.
// ... (previous imports and constants remain the same) ...
pub struct Server {
docs: Arc<DashMap<String, DocWithSyncKv>>,
doc_worker_tracker: TaskTracker,
store: Option<Arc<Box<dyn Store>>>,
checkpoint_freq: Duration,
authenticator: Option<Authenticator>,
url_prefix: Option<Url>,
cancellation_token: CancellationToken,
doc_gc: bool,
client_secret: Option<String>, // Now optional
}
impl Server {
pub async fn new(
store: Option<Box<dyn Store>>,
checkpoint_freq: Duration,
authenticator: Option<Authenticator>,
url_prefix: Option<Url>,
cancellation_token: CancellationToken,
doc_gc: bool,
client_secret: Option<String>, // Now optional parameter
) -> Result<Self> {
Ok(Self {
docs: Arc::new(DashMap::new()),
doc_worker_tracker: TaskTracker::new(),
store: store.map(Arc::new),
checkpoint_freq,
authenticator,
url_prefix,
cancellation_token,
doc_gc,
client_secret,
})
}
// ... (rest of Server impl remains the same) ...
}
#[derive(Deserialize)]
struct DataPlaneWebSocketQueryParams {
token: Option<String>,
client_secret: Option<String>,
}
#[derive(Deserialize)]
struct DataPlaneQueryParams {
client_secret: Option<String>,
}
async fn validate_client_secret(
server_state: &Arc<Server>,
provided_secret: Option<String>,
) -> Result<(), AppError> {
if let Some(required_secret) = &server_state.client_secret {
let provided = provided_secret.ok_or_else(|| {
AppError(
StatusCode::UNAUTHORIZED,
anyhow!("Client secret required"),
)
})?;
if provided != *required_secret {
return Err(AppError(
StatusCode::UNAUTHORIZED,
anyhow!("Invalid client secret"),
));
}
}
Ok(())
}
async fn get_doc_as_update(
State(server_state): State<Arc<Server>>,
Path(doc_id): Path<String>,
Query(params): Query<DataPlaneQueryParams>,
auth_header: Option<TypedHeader<headers::Authorization<headers::authorization::Bearer>>>,
) -> Result<Response, AppError> {
validate_client_secret(&server_state, params.client_secret).await?;
let token = get_token_from_header(auth_header);
let _ = server_state.verify_doc_token(token.as_deref(), &doc_id)?;
let dwskv = server_state
.get_or_create_doc(&doc_id)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?;
let update = dwskv.as_update();
Ok(update.into_response())
}
async fn update_doc(
Path(doc_id): Path<String>,
State(server_state): State<Arc<Server>>,
Query(params): Query<DataPlaneQueryParams>,
auth_header: Option<TypedHeader<headers::Authorization<headers::authorization::Bearer>>>,
body: Bytes,
) -> Result<Response, AppError> {
validate_client_secret(&server_state, params.client_secret).await?;
let token = get_token_from_header(auth_header);
let authorization = server_state.verify_doc_token(token.as_deref(), &doc_id)?;
update_doc_inner(doc_id, server_state, authorization, body).await
}
async fn handle_socket_upgrade_full_path(
ws: WebSocketUpgrade,
Path((doc_id, doc_id2)): Path<(String, String)>,
Query(params): Query<DataPlaneWebSocketQueryParams>,
State(server_state): State<Arc<Server>>,
) -> Result<Response, AppError> {
validate_client_secret(&server_state, params.client_secret).await?;
if doc_id != doc_id2 {
return Err(AppError(
StatusCode::BAD_REQUEST,
anyhow!("For Yjs compatibility, the doc_id appears twice in the URL. It must be the same in both places, but we got {} and {}.", doc_id, doc_id2),
));
}
let authorization = server_state.verify_doc_token(params.token.as_deref(), &doc_id)?;
handle_socket_upgrade(ws, Path(doc_id), authorization, State(server_state)).await
}
// ... (rest of the code remains the same) ...
#[cfg(test)]
mod test {
use super::*;
use y_sweet_core::api_types::Authorization;
#[tokio::test]
async fn test_with_client_secret() {
let server_state = Server::new(
None,
Duration::from_secs(60),
None,
None,
CancellationToken::new(),
true,
Some("test_secret".to_string()), // Configured with client secret
)
.await
.unwrap();
// Test should fail without client_secret
let response = get_doc_as_update(
State(Arc::new(server_state)),
Path("test".to_string()),
Query(DataPlaneQueryParams { client_secret: None }),
None,
)
.await;
assert!(response.is_err());
assert_eq!(response.unwrap_err().0, StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_without_client_secret() {
let server_state = Server::new(
None,
Duration::from_secs(60),
None,
None,
CancellationToken::new(),
true,
None, // No client secret configured
)
.await
.unwrap();
// Test should succeed without client_secret
let response = get_doc_as_update(
State(Arc::new(server_state)),
Path("test".to_string()),
Query(DataPlaneQueryParams { client_secret: None }),
None,
)
.await;
assert!(response.is_ok());
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment