import asyncio import websockets import uuid from IPython.display import display, HTML from imjoy_rpc.connection.jupyter_connection import put_buffers, remove_buffers class ColabWebSocketProxy: def __init__(self, uri): self.uri = uri self.client_id = str(uuid.uuid4()) self.comm = None self.websocket = None self.connected_event = asyncio.Event() async def connect(self): """Create a WebSocket connection and start proxying messages.""" loop = asyncio.get_running_loop() async with websockets.connect(self.uri) as websocket: self.websocket = websocket self._setup_comm() async for message in websocket: self.comm.send({"type": "log", "message": message }) await loop.run_in_executor(None, self.emit, {"type": "message", "data": message}) print(f"Proxy received from server: {message}") def _setup_comm(self): """Set up Colab communication channel.""" def registered(comm, open_msg): """Handle registration.""" self.comm = comm def msg_cb(msg): """Handle a message.""" data = msg["content"]["data"] if "type" in data: if "__buffer_paths__" in data: buffer_paths = data["__buffer_paths__"] del data["__buffer_paths__"] put_buffers(data, buffer_paths, msg["buffers"]) loop = asyncio.get_running_loop() loop.create_task(self._handle_comm_message(data)) comm.on_msg(msg_cb) get_ipython().kernel.comm_manager.register_target(f"colab_ws_proxy_{self.client_id}", registered) with open('colab_websocket.js', 'r') as f: js_code = f.read() js_code = js_code.replace('<client_id>', self.client_id).replace('<ws_url>', self.uri) display(HTML(f""" <script> {js_code} </script> """)) async def _handle_comm_message(self, message): """Handle incoming messages from JavaScript.""" if message['type'] == 'message': await self.websocket.send(message['data']) print(f"Sent message to WebSocket server: {message['data']}") elif message['type'] == 'close': await self.websocket.close() print("Closed WebSocket connection") def emit(self, msg): """Emit a message.""" msg, buffer_paths, buffers = remove_buffers(msg) if len(buffers) > 0: msg["__buffer_paths__"] = buffer_paths self.comm.send(msg, buffers=buffers) else: self.comm.send(msg) # Example usage: uri = "ws://127.0.0.1:8765" # Local WebSocket server for testing proxy = ColabWebSocketProxy(uri) async def test_websocket_proxy(): await proxy.connect() loop = asyncio.get_event_loop() loop.create_task(test_websocket_proxy())