Created
November 28, 2023 17:43
-
-
Save Tobi-De/c059c878d6b0a51ce7ad207e7b6e4658 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import TYPE_CHECKING | |
import structlog | |
from .config import get_redis_url, get_client_tracking_enabled | |
from functools import wraps | |
import threading | |
import redis | |
import asyncio | |
if TYPE_CHECKING: | |
from .brokers import Broker | |
logger = structlog.stdlib.get_logger("client_tracking") | |
class RedisCounterMap: | |
hash_key = "sse_relay_server:channels" | |
def __init__(self, redis_url: str) -> None: | |
self._redis = redis.Redis.from_url(redis_url) | |
def increment(self, channel: str) -> None: | |
self._redis.hincrby(self.hash_key, channel, 1) | |
def decrement(self, channel: str) -> None: | |
count = self._redis.hincrby(self.hash_key, channel, -1) | |
if count <= 0: | |
self._redis.hdel(self.hash_key, channel) | |
def value(self) -> dict[str, int]: | |
return { | |
k.decode(): int(v) for k, v in self._redis.hgetall(self.hash_key).items() | |
} | |
def reset(self) -> None: | |
self._redis.delete(self.hash_key) | |
class DictCounterMap: | |
def __init__(self) -> None: | |
self._map = {} | |
self.lock = threading.Lock() | |
def increment(self, channel: str): | |
with self.lock: | |
if channel in self._map: | |
self._map[channel] += 1 | |
else: | |
self._map[channel] = 1 | |
def decrement(self, channel: str): | |
with self.lock: | |
if channel not in self._map: | |
return | |
self._map[channel] -= 1 | |
if self._map[channel] == 0: | |
del self._map[channel] | |
def value(self): | |
return self._map.copy() | |
def reset(self): | |
with self.lock: | |
self._map = {} | |
def _get_counter_map(): | |
if redis_url := get_redis_url(): | |
return RedisCounterMap(redis_url) | |
return DictCounterMap() | |
_counter = _get_counter_map() | |
client_tracking_enabled = get_client_tracking_enabled() | |
def count_clients(func: "Broker.listen"): | |
if not client_tracking_enabled: | |
return func | |
@wraps(func) | |
async def wrapper(instance: "Broker", channel: str): | |
try: | |
logger.debug(f"Incrementing counter using {_counter.__class__.__name__}") | |
_counter.increment(channel) | |
async for event in func(instance, channel): | |
yield event | |
except asyncio.CancelledError: | |
_counter.decrement(channel) | |
logger.debug(f"Decrementing counter using {_counter.__class__.__name__}") | |
raise | |
return wrapper | |
def get_count_value(): | |
if not client_tracking_enabled: | |
return {} | |
return _counter.value() | |
def reset_count_value(): | |
if not client_tracking_enabled: | |
return {} | |
return _counter.reset() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment