import asyncio
import websockets
import uuid
from IPython.display import display, HTML
from imjoy_rpc.connection.jupyter_connection import put_buffers, remove_buffers

class ColabWebSocketProxy:
    def __init__(self, uri):
        self.uri = uri
        self.client_id = str(uuid.uuid4())
        self.comm = None
        self.websocket = None
        self.connected_event = asyncio.Event()
        
    async def connect(self):
        """Create a WebSocket connection and start proxying messages."""
        loop = asyncio.get_running_loop()
        async with websockets.connect(self.uri) as websocket:
            self.websocket = websocket
            self._setup_comm()
            async for message in websocket:
                self.comm.send({"type": "log", "message": message })
                await loop.run_in_executor(None, self.emit, {"type": "message", "data": message})
                print(f"Proxy received from server: {message}")

    def _setup_comm(self):
        """Set up Colab communication channel."""
        def registered(comm, open_msg):
            """Handle registration."""
            self.comm = comm
            def msg_cb(msg):
                """Handle a message."""
                data = msg["content"]["data"]
                if "type" in data:
                    if "__buffer_paths__" in data:
                        buffer_paths = data["__buffer_paths__"]
                        del data["__buffer_paths__"]
                        put_buffers(data, buffer_paths, msg["buffers"])
                    loop = asyncio.get_running_loop()
                    loop.create_task(self._handle_comm_message(data))

            comm.on_msg(msg_cb)

        get_ipython().kernel.comm_manager.register_target(f"colab_ws_proxy_{self.client_id}", registered)
        
        with open('colab_websocket.js', 'r') as f:
            js_code = f.read()
        
        js_code = js_code.replace('<client_id>', self.client_id).replace('<ws_url>', self.uri)
        
        display(HTML(f"""
            <script>
                {js_code}
            </script>
        """))

    async def _handle_comm_message(self, message):
        """Handle incoming messages from JavaScript."""
        if message['type'] == 'message':
            await self.websocket.send(message['data'])
            print(f"Sent message to WebSocket server: {message['data']}")
        elif message['type'] == 'close':
            await self.websocket.close()
            print("Closed WebSocket connection")

    def emit(self, msg):
        """Emit a message."""
        msg, buffer_paths, buffers = remove_buffers(msg)
        if len(buffers) > 0:
            msg["__buffer_paths__"] = buffer_paths
            self.comm.send(msg, buffers=buffers)
        else:
            self.comm.send(msg)

# Example usage:
uri = "ws://127.0.0.1:8765"  # Local WebSocket server for testing
proxy = ColabWebSocketProxy(uri)

async def test_websocket_proxy():
    await proxy.connect()

loop = asyncio.get_event_loop()
loop.create_task(test_websocket_proxy())