Skip to content

Instantly share code, notes, and snippets.

@ash2shukla
Created May 7, 2023 22:18
Scale websockets with Redis Pubsub in FastAPI
import asyncio
import logging
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from fastapi.websockets import WebSocket, WebSocketDisconnect
import aioredis
import async_timeout
import json
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
html = """
<!DOCTYPE html>
<html>
<head>
<title>Chat</title>
</head>
<body>
<h1>WebSocket Chat</h1>
<h3 id="userid"></h3>
<form action="" onsubmit="sendMessage(event)">
<input type="text" id="toUser" autocomplete="off"/>
<input type="text" id="messageText" autocomplete="off"/>
<button>Send</button>
</form>
<ul id='messages'>
</ul>
<script>
const userID = "user-" + Math.floor(Math.random() * 100);
document.getElementById("userid").innerHTML = "You are " + userID;
var ws = new WebSocket("ws://localhost:8000/ws?userid="+userID);
ws.onmessage = function(event) {
var messages = document.getElementById('messages')
var message = document.createElement('li')
const event_data = JSON.parse(event.data);
const formattedMessage = "From: " + event_data.from + " > " + event_data.content
var content = document.createTextNode(formattedMessage)
message.appendChild(content)
messages.appendChild(message)
};
function sendMessage(event) {
const input = document.getElementById("messageText")
const toUser = document.getElementById("toUser")
ws.send(JSON.stringify({"to": toUser.value,"content": input.value}))
input.value = ''
event.preventDefault()
}
</script>
</body>
</html>
"""
redis_uri = "redis://localhost:6379"
redis = aioredis.from_url(redis_uri, decode_responses=True)
@app.get("/")
async def get():
return HTMLResponse(html)
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, userid: str):
await websocket.accept()
websocket._userid = userid
await redis_connector(websocket)
async def redis_connector(websocket: WebSocket):
async def consumer_handler(ws: WebSocket, r):
try:
while True:
message = await ws.receive_json()
if message:
await r.publish(f"chat:{message['to']}", json.dumps({"from": ws._userid, "content": message['content']}))
except WebSocketDisconnect as exc:
logger.error(exc)
async def producer_handler(channel, ws: WebSocket):
await channel.subscribe(f"chat:{ws._userid}")
while True:
try:
async with async_timeout.timeout(1):
message = await channel.get_message(ignore_subscribe_messages=True)
if message is not None:
data = json.loads(message["data"])
await ws.send_json({"from": data["from"], "content": data["content"]})
await asyncio.sleep(0.01)
except asyncio.TimeoutError:
pass
ps = redis.pubsub()
consumer_task = consumer_handler(websocket, redis)
producer_task = producer_handler(ps, websocket)
done, pending = await asyncio.wait(
[consumer_task, producer_task], return_when=asyncio.FIRST_COMPLETED,
)
for task in pending:
task.cancel()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment