Last active
March 6, 2023 03:07
-
-
Save Ce11an/6775b001bf3bbe65d1f06c9d6f1768ba to your computer and use it in GitHub Desktop.
SurrealDB WebSocket Client
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""SurrealDB websocket client library.""" | |
import enum | |
import json | |
from types import TracebackType | |
from typing import Any, Dict, List, Optional, Tuple, Type, Union | |
import pydantic | |
import websockets | |
ID = 0 | |
def guid() -> str: | |
"""Generate a GUID. | |
Returns: | |
A GUID. | |
""" | |
global ID | |
ID = (ID + 1) % (2**53 - 1) | |
return str(ID) | |
class SurrealException(Exception): | |
"""Base exception for SurrealDB client library.""" | |
class SurrealAuthenticationException(SurrealException): | |
"""Exception raised for errors with the SurrealDB authentication.""" | |
class SurrealPermissionException(SurrealException): | |
"""Exception raised for errors with the SurrealDB permissions.""" | |
class WebSocketState(enum.Enum): | |
"""Represents the state of a WebSocket connection. | |
Attributes: | |
CONNECTING: The WebSocket is connecting. | |
CONNECTED: The WebSocket is connected. | |
DISCONNECTED: The WebSocket is disconnected. | |
""" | |
CONNECTING = 0 | |
CONNECTED = 1 | |
DISCONNECTED = 2 | |
class Request(pydantic.BaseModel): | |
"""Represents an RPC request to a Surreal server. | |
Attributes: | |
id: The ID of the request. | |
method: The method of the request. | |
params: The parameters of the request. | |
""" | |
id: str | |
method: str | |
params: Optional[Tuple] = None | |
@pydantic.validator("params", pre=True, always=True) | |
def validate_params(cls, value): # pylint: disable=no-self-argument | |
"""Validate the parameters of the request.""" | |
if value is None: | |
return tuple() | |
return value | |
class Config: | |
"""Represents the configuration of the RPC request.""" | |
allow_mutation = False | |
class ResponseSuccess(pydantic.BaseModel): | |
"""Represents a successful RPC response from a Surreal server. | |
Attributes: | |
id: The ID of the request. | |
result: The result of the request. | |
""" | |
id: str | |
result: Any | |
class Config: | |
"""Represents the configuration of the RPC request. | |
Attributes: | |
allow_mutation: Whether to allow mutation. | |
""" | |
allow_mutation = False | |
class ResponseError(pydantic.BaseModel): | |
"""Represents an RPC error. | |
Attributes: | |
code: The code of the error. | |
message: The message of the error. | |
""" | |
code: int | |
message: str | |
class Config: | |
"""Represents the configuration of the RPC request. | |
Attributes: | |
allow_mutation: Whether to allow mutation. | |
""" | |
allow_mutation = False | |
def _validate_response( | |
response: Union[ResponseSuccess, ResponseError], | |
exception: Type[Exception] = SurrealException, | |
) -> ResponseSuccess: | |
"""Validate the response. | |
The response is validated by checking if it is an error. If it is an error, | |
the exception is raised. Otherwise, the response is returned. | |
Args: | |
response: The response to validate. | |
exception: The exception to raise if the response is an error. | |
Returns: | |
The original response. | |
Raises: | |
SurrealDBException: If the response is an error. | |
""" | |
if isinstance(response, ResponseError): | |
raise exception(response.message) | |
return response | |
class Surreal: | |
"""Surreal is a class that represents a Surreal server. | |
Attributes: | |
url: The URL of the Surreal server. | |
""" | |
def __init__(self, url: str, token: Optional[str] = None) -> None: | |
self.url = url | |
self.token = token | |
self.client_state = WebSocketState.CONNECTING | |
self.ws: Optional[websockets.WebSocketClientProtocol] = None # type: ignore | |
async def __aenter__(self) -> "Surreal": | |
"""Connect to the Surreal server. | |
Returns: | |
The Surreal client. | |
""" | |
await self.connect() | |
return self | |
async def __aexit__( | |
self, | |
exc_type: Optional[Type[BaseException]] = None, | |
exc_value: Optional[Type[BaseException]] = None, | |
traceback: Optional[Type[TracebackType]] = None, | |
) -> None: | |
"""Disconnects from the Surreal server. | |
Args: | |
exc_type: The type of the exception. | |
exc_value: The value of the exception. | |
traceback: The traceback of the exception. | |
""" | |
await self.disconnect() | |
async def connect(self) -> None: | |
"""Connect to the Surreal server.""" | |
self.ws = await websockets.connect(self.url) # type: ignore | |
self.client_state = WebSocketState.CONNECTED | |
async def disconnect(self) -> None: | |
"""Disconnects from the Surreal server.""" | |
await self.ws.close() # type: ignore | |
self.client_state = WebSocketState.DISCONNECTED | |
async def ping(self) -> bool: | |
"""Pings the Surreal server.""" | |
response = await self._send_receive( | |
Request( | |
id=guid(), | |
method="ping", | |
), | |
) | |
success: ResponseSuccess = _validate_response(response) | |
return success.result | |
async def use(self, namespace: str, database: str) -> None: | |
"""Use a namespace and database. | |
Args: | |
namespace: The namespace to use. | |
database: The database to use. | |
""" | |
response = await self._send_receive( | |
Request(id=guid(), method="use", params=(namespace, database)), | |
) | |
_validate_response(response) | |
async def signin(self, auth: Dict[str, Any]) -> str: | |
"""Signs into the Surreal server. | |
Args: | |
auth: The authentication parameters. | |
""" | |
response = await self._send_receive( | |
Request(id=guid(), method="signin", params=(auth,)), | |
) | |
success: ResponseSuccess = _validate_response( | |
response, SurrealAuthenticationException | |
) | |
token: str = success.result | |
self.token = token | |
return self.token | |
async def info(self) -> Optional[Dict[str, Any]]: | |
"""Get the information of the Surreal server. | |
Returns: | |
The information of the Surreal server. | |
""" | |
response = await self._send_receive( | |
Request( | |
id=guid(), | |
method="info", | |
), | |
) | |
success: ResponseSuccess = _validate_response(response) | |
return success.result | |
async def signup(self, auth: Dict[str, Any]) -> None: | |
"""Signs up to the Surreal server. | |
Args: | |
auth: The authentication parameters. | |
""" | |
response = await self._send_receive( | |
Request(id=guid(), method="signup", params=(auth,)), | |
) | |
_validate_response(response, SurrealAuthenticationException) | |
async def invalidate(self) -> None: | |
"""Invalidates the token.""" | |
response = await self._send_receive( | |
Request( | |
id=guid(), | |
method="invalidate", | |
), | |
) | |
_validate_response(response, SurrealAuthenticationException) | |
self.token = None | |
async def authenticate(self) -> None: | |
"""Authenticate the token.""" | |
response = await self._send_receive( | |
Request(id=guid(), method="authenticate", params=(self.token,)), | |
) | |
_validate_response(response, SurrealAuthenticationException) | |
async def create(self, thing: str, data: Optional[Dict[str, Any]] = None) -> str: | |
"""Create a record in the database. | |
Args: | |
thing: The table or record ID. | |
data: The document / record data to insert. | |
""" | |
response = await self._send_receive( | |
Request( | |
id=guid(), | |
method="create", | |
params=(thing,) if data is None else (thing, data), | |
), | |
) | |
success: ResponseSuccess = _validate_response( | |
response, SurrealPermissionException | |
) | |
return success.result | |
async def delete(self, thing: str) -> None: | |
"""Delete all records in a table or a specific record from the database. | |
Args: | |
thing: The table name or a record ID to select. | |
""" | |
response = await self._send_receive( | |
Request(id=guid(), method="delete", params=(thing,)), | |
) | |
_validate_response(response, SurrealPermissionException) | |
async def update(self, thing, data: Dict[str, Any]) -> None: | |
"""Update all records in a table or a specific record in the database. | |
This function replaces the current document / record data with the | |
specified data. | |
Args: | |
thing: The table or record ID. | |
data: The document / record data to insert. | |
""" | |
response = await self._send_receive( | |
Request(id=guid(), method="update", params=(thing, data)), | |
) | |
_validate_response(response, SurrealPermissionException) | |
async def kill(self) -> None: | |
"""Kills the Surreal server.""" | |
response = await self._send_receive( | |
Request( | |
id=guid(), | |
method="kill", | |
), | |
) | |
_validate_response(response) | |
async def select(self, thing: str) -> List[Dict[str, Any]]: | |
"""Select all records in a table or a specific record from the database. | |
Args: | |
thing: The table or record ID to select. | |
Returns: | |
The records. | |
""" | |
response = await self._send_receive( | |
Request(id=guid(), method="select", params=(thing,)), | |
) | |
success: ResponseSuccess = _validate_response(response) | |
return success.result | |
async def modify(self, thing: str, data: Dict[str, Any]) -> None: | |
"""Modify all records or a specific record in the database. | |
Applies JSON Patch changes to all records, or a specific record, in the | |
database. This function patches the current document / record data with | |
the specified JSON Patch data. | |
Args: | |
thing: The table or record ID. | |
data: The data to modify the record with. | |
""" | |
response = await self._send_receive( | |
Request(id=guid(), method="modify", params=(thing, data)), | |
) | |
_validate_response(response, SurrealPermissionException) | |
async def change(self, thing: str, data: Dict[str, Any]) -> None: | |
"""Modify all records in a table or a specific record in the database. | |
This function merges the current document / record data with the | |
specified data. | |
Args: | |
thing: The table name or the specific record ID to change. | |
data: The document / record data to insert. | |
""" | |
response = await self._send_receive( | |
Request(id=guid(), method="change", params=(thing, data)), | |
) | |
_validate_response(response, SurrealPermissionException) | |
async def query( | |
self, query: str, params: Optional[Dict[str, Any]] = None | |
) -> List[Dict[str, Any]]: | |
"""Query the database. | |
Args: | |
query: The query to execute. | |
params: The query parameters. | |
Returns: | |
The records. | |
""" | |
response = await self._send_receive( | |
Request( | |
id=guid(), | |
method="query", | |
params=(query,) if params is None else (query, params), | |
), | |
) | |
success: ResponseSuccess = _validate_response(response) | |
return success.result | |
async def live(self, table: str) -> str: | |
"""Get a live stream of changes to a table. | |
Args: | |
table: The table name. | |
Returns: | |
The records. | |
""" | |
response = await self._send_receive( | |
Request(id=guid(), method="live", params=(table,)), | |
) | |
success: ResponseSuccess = _validate_response(response) | |
return success.result | |
async def _send_receive( | |
self, request: Request | |
) -> Union[ResponseSuccess, ResponseError]: | |
"""Send a request to the Surreal server and receive a response. | |
Args: | |
request: The request to send. | |
Returns: | |
The response from the Surreal server. | |
Raises: | |
Exception: If the client is not connected to the Surreal server. | |
""" | |
await self._send(request) | |
return await self._recv() | |
async def _send(self, request: Request) -> None: | |
"""Send a request to the Surreal server. | |
Args: | |
request: The request to send. | |
Raises: | |
Exception: If the client is not connected to the Surreal server. | |
""" | |
self._validate_connection() | |
await self.ws.send(json.dumps(request.dict())) # type: ignore | |
def _validate_connection(self) -> None: | |
"""Validate the connection to the Surreal server.""" | |
if self.client_state != WebSocketState.CONNECTED: | |
raise SurrealException("Not connected to Surreal server.") | |
async def _recv(self) -> Union[ResponseSuccess, ResponseError]: | |
"""Receives a response from the Surreal server. | |
Returns: | |
The response from the Surreal server. | |
Raises: | |
Exception: If the client is not connected to the Surreal server. | |
Exception: If the response contains an error. | |
""" | |
self._validate_connection() | |
response = json.loads(await self.ws.recv()) # type: ignore | |
if response.get("error"): | |
return ResponseError(**response["error"]) | |
return ResponseSuccess(**response) | |
async def main(): | |
"""Example of how to use the SurrealDB client.""" | |
async with Surreal("ws://127.0.0.1:8000/rpc") as db: | |
await db.signin({"user": "root", "pass": "root"}) | |
await db.use("test", "test") | |
await db.create( | |
"user", | |
{ | |
"user": "cellan", | |
"pass": "password", | |
"DB": "test", | |
"NS": "test", | |
"SC": "allusers", | |
"marketing": True, | |
"tags": ["python", "javascript"], | |
}, | |
) | |
await db.live("user") | |
return await db.query("SELECT * FROM type::table($tb)", {"tb": "user"}) | |
if __name__ == "__main__": | |
import asyncio | |
print(asyncio.run(main())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
An example of how to connect to SurrealDB with Websocket in Python.
Packages:
You can set up the table for
users
by following this tutorial. This Python implementation was greatly influenced by the JavaScript equivalent.