Last active
December 1, 2023 11:19
-
-
Save Tobi-De/437717c792f90814d10eae31eb8d12a5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
sse-starlette | |
starlette | |
redis | |
psycopg[c] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from contextlib import asynccontextmanager | |
from typing import AsyncGenerator, Generator | |
import redis.asyncio as async_redis | |
from sse_starlette import EventSourceResponse, ServerSentEvent | |
import json | |
from starlette.applications import Starlette | |
from starlette.requests import Request | |
from starlette.responses import JSONResponse | |
import redis | |
import time | |
import psycopg | |
import logging | |
from logging import getLogger | |
from starlette.routing import Route | |
logging.basicConfig(level=logging.DEBUG) | |
logger = getLogger(__name__) | |
COUNT_KEY = "REDIS_COUNT_KEY" | |
class PostgresBroker: | |
def __init__(self, dbname, user, password, host=None) -> None: | |
self.db_params = { | |
"client_encoding": "UTF8", | |
"dbname": dbname, | |
"user": user, | |
"password": password, | |
"host": host or "127.0.0.1", | |
} | |
async def listen(self, channel: str) -> AsyncGenerator[ServerSentEvent, None]: | |
connection = await psycopg.AsyncConnection.connect( | |
**self.db_params, | |
autocommit=True, | |
) | |
async with connection.cursor() as cursor: | |
logger.debug(f"Listening to {channel}") | |
await cursor.execute(f"LISTEN {channel}") | |
generator = connection.notifies() | |
async for notify_message in generator: | |
payload = json.loads(notify_message.payload) | |
logger.debug(f"Data received from {channel}") | |
yield ServerSentEvent(**payload) | |
def notify(self, channel: str, sse_payload: dict) -> None: | |
connection = psycopg.Connection.connect( | |
**self.db_params, | |
autocommit=True, | |
) | |
logger.debug(f"Publishing to {channel}: {sse_payload}") | |
with connection.cursor() as cursor: | |
cursor.execute(f"NOTIFY {channel}, '{json.dumps(sse_payload)}'") | |
class RedisBroker: | |
def __init__(self, redis_url: str) -> None: | |
self._client = async_redis.from_url(redis_url) | |
self._sync_client = redis.from_url(redis_url) | |
self._pubsub = self._client.pubsub() | |
@asynccontextmanager | |
async def increment(self, channel: str) -> Generator[None, None, None]: | |
logger.debug(f"Incrementing {channel}") | |
await self._client.hincrby(COUNT_KEY, channel, 1) | |
try: | |
yield | |
finally: | |
logger.debug(f"Decrementing {channel}") | |
await self._client.hincrby(COUNT_KEY, channel, -1) | |
async def listen(self, channel: str) -> AsyncGenerator[ServerSentEvent, None]: | |
async with self.increment(channel): | |
logger.debug(f"Listening to {channel}") | |
await self._pubsub.subscribe(channel) | |
while True: | |
message = await self._pubsub.get_message(ignore_subscribe_messages=True) | |
if message is not None: | |
payload = json.loads(message["data"].decode()) | |
logger.debug(f"Data received from {channel}") | |
yield ServerSentEvent(**payload) | |
async def value(self) -> dict[str, int]: | |
return { | |
k.decode(): int(v) | |
for k, v in (await self._client.hgetall(COUNT_KEY)).items() | |
} | |
def notify(self, channel: str, sse_payload: dict) -> None: | |
logger.debug(f"Publishing to {channel}: {sse_payload}") | |
self._sync_client.publish(channel=channel, message=json.dumps(sse_payload)) | |
broker = RedisBroker("redis://localhost:6379") | |
#broker = PostgresBroker(dbname="estate_sh", user="postgres", password="blumenkranz") | |
async def sse(request: Request): | |
channel = request.path_params.get("channel") | |
logger.info(f"New SSE connection to {channel}") | |
return EventSourceResponse(broker.listen(channel)) | |
async def count(_: Request): | |
return JSONResponse(await broker.value()) | |
routes = [Route("/count", endpoint=count), Route("/{channel}", endpoint=sse)] | |
app = Starlette(routes=routes) | |
if __name__ == "__main__": | |
counter = 0 | |
while True: | |
print("Sending message") | |
broker.notify("test_channel", {"data": counter}) | |
counter += 1 | |
time.sleep(4) | |
# stream event with curl | |
# curl -N http://localhost:8001/test_channel | |
# send event with python | |
# python test.py |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment