Created
January 8, 2024 14:51
-
-
Save nicolasmelo1/db16a0c716673c187ca8f4707ba2994f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
ASGI config for reflow_server project. | |
It exposes the ASGI callable as a module-level variable named ``application``. | |
For more information on this file, see | |
https://docs.djangoproject.com/en/3.1/howto/deployment/asgi/ | |
""" | |
from django import setup | |
from django.core.asgi import get_asgi_application | |
from channels.routing import ProtocolTypeRouter, URLRouter | |
import os | |
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "reflow_server.settings") | |
setup() | |
from reflow_server.authentication.middleware import AuthWebsocketJWTMiddleware, AuthWebsocketPublicMiddleware | |
from .routing import websocket_urlpatterns | |
# Read here for reference: https://channels.readthedocs.io/en/stable/tutorial/part_2.html#write-your-first-consumer | |
application = ProtocolTypeRouter({ | |
'http': get_asgi_application(), | |
'websocket': URLRouter(websocket_urlpatterns) | |
}) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from django.conf import settings | |
from channels.layers import get_channel_layer | |
from channels.generic.websocket import AsyncWebsocketConsumer | |
from channels.exceptions import DenyConnection | |
from asgiref.sync import sync_to_async, async_to_sync | |
import json | |
# Isso aqui não é mt importante, é um "Hack" que fiz pra organizar o código | |
def get_consumers(key): | |
""" | |
Function used for retrieving consumers to your base consumer. | |
""" | |
method_list = list() | |
kls_list = list() | |
for consumer in settings.CONSUMERS[key]: | |
striped = consumer.split('.') | |
kls_name = striped.pop(-1) | |
path = '.'.join(striped) | |
module = __import__(path, fromlist=[kls_name]) | |
kls = getattr(module, kls_name) | |
kls_list.append(kls) | |
kls_methods_list = [func for func in dir(kls) if callable(getattr(kls, func)) and not func.startswith("__")] | |
if any([method in method_list for method in kls_methods_list]): | |
raise AttributeError('Your consumer methods MUST BE unique, found a duplicate method in the following consumer: {}'.format(consumer)) | |
else: | |
method_list = method_list + kls_methods_list | |
return tuple(kls_list) | |
# Isso aqui também não é mt importante, faz parte do "Hack" que fiz pra organizar o código | |
class BaseConsumer(AsyncWebsocketConsumer): | |
""" | |
This is the base consumer, ALL OF YOUR CONSUMER MUST INHERIT from this class. This adds a simple | |
change to the default django channel consumers. | |
With this your consumers become simple python classes that MUST be registered in CONSUMERS dict in your | |
`settings.py`. Look at the example below for more detais. | |
This class is not meant to be used in `routing.py`. | |
With this all the data you recieve "FROM" and send "TO" a client must contain a "type" key in the json to identify the event. | |
This `type` is used for sending to the handler method. Your handler methods must contain the `recieve_` | |
keyword in the start of the method. | |
So for example, if you have a consumer like this in your `notifications.consumers.py`: | |
>>> class NotificationConsumer: | |
async def send_notification(self, event): | |
#...your code here | |
async def recieve_notification(self, data): | |
#...your code here | |
async def recieve_notification_configuration(self, data): | |
#... your code here | |
You need first to register it in `settings.py` `CONSUMERS`: | |
>>> CONSUMERS = { | |
'LOGIN_REQUIRED': [ | |
'reflow_server.notifications.consumers.NotificationConsumer' | |
] | |
} | |
To recieve data from the client in the method `.recieve_notification()` your data must be like the following: | |
>>> { | |
'type': 'notification', | |
'data': 'foo' | |
} | |
To recieve data from the client in the method `.recieve_notification_configuration()` your data must be like the following: | |
>>> { | |
'type': 'notification_configuration', | |
'data': 'bar' | |
} | |
The method `.send_notification()` is not a handler, it is used to send data to your clients, you can refer to documentation here: | |
https://stackoverflow.com/a/50048713/13158385 | |
https://channels.readthedocs.io/en/latest/topics/channel_layers.html#what-to-send-over-the-channel-layer | |
https://channels.readthedocs.io/en/latest/topics/channel_layers.html#groups | |
https://channels.readthedocs.io/en/latest/topics/channel_layers.html#using-outside-of-consumers | |
__IMPORTANT__: | |
- To recieve data, your methods MUST contain the `recieve_` keyword and after the keyword is the action_type. | |
(if your action type is 'notification_configuration', the method name MUST BE 'recieve_notification_configuration') | |
- Your methods must always be unique. | |
- Don't forget to register your consumers in `settings.py` with the CONSUMERS list tag. | |
""" | |
async def disconnect(self, close_code): | |
# Leave room group | |
if self.__class__.group_name != self.group_name: | |
print('Closing Group name: %s' % self.group_name) | |
await self.channel_layer.group_discard( | |
self.group_name, | |
self.channel_name | |
) | |
async def receive(self, text_data): | |
data = json.loads(text_data) | |
action_type = data['type'] | |
handler = getattr(self, 'recieve_%s' % action_type, None) | |
if handler: | |
sync_to_async(handler)(data['data'] if 'data' in data else dict()) | |
else: | |
await self.send(text_data=json.dumps({ | |
'status': 'error', | |
'reason': 'no_handler_found_for_type' | |
})) | |
@classmethod | |
async def send_event(cls, event_name, group_name, **kwargs): | |
if not hasattr(event_name, cls): | |
raise KeyError('Seems like there is no handler for %s event' % event_name) | |
channel_layer = get_channel_layer() | |
await channel_layer.group_send( | |
'{}'.format(group_name), | |
{ | |
'type': event_name, | |
'data': kwargs | |
} | |
) | |
class UserConsumer(BaseConsumer, *get_consumers('LOGIN_REQUIRED')): | |
""" | |
This class is a Consumer that we use in our root routing.py, this consumer is used for | |
validating before connecting. | |
To use this consumer you must declare your consumer inside of the 'LOGIN_REQUIRED' | |
key of the CONSUMERS dict in `settings.py` | |
You can create your own custom consumers inherinting the class from | |
(BaseConsumer, *get_consumers(YOUR_CUSTOM_KEY_IN_CONSUMERS)). In `settings.py`, in the CONSUMERS | |
dict, you must add a key that will default to your custom consumer. | |
So, if you want to add a consumer that validates if the `company_id` is defined prior connecting | |
you create something like this: | |
in `settings.py`: | |
>>> CONSUMERS = { | |
# other keys | |
'COMPANY_REQUIRED': [ | |
# your consumers | |
] | |
} | |
in `core.consumers.py`: | |
>>> class CompanyConsumer(BaseConsumer, *get_consumers('COMPANY_REQUIRED')): | |
def connect(self): | |
# your custom connect logic here | |
and in root `routing.py`: | |
>>> application = ProtocolTypeRouter({ | |
'websocket': AuthWebsocketJWTMiddleware( | |
URLRouter([ | |
# other consumers | |
re_path(r'^websocket/custom_route_to_your_consumer', CompanyConsumer) | |
]) | |
) | |
}) | |
""" | |
group_name = 'user_{id}' | |
async def connect(self): | |
if 'user' in self.scope and self.scope['user'].is_authenticated: | |
# we create a custom group_name for each user, so when we send | |
# events we can send them to a specific user. | |
self.group_name = self.group_name.format(id=self.scope['user'].id) | |
await self.channel_layer.group_add( | |
self.group_name, | |
self.channel_name | |
) | |
await self.accept() | |
await super(UserConsumer, self).after_connect() | |
else: | |
raise DenyConnection('For `user` group types, your user must be authenticated') | |
@classmethod | |
async def send_event(cls, event_name, user_id, **kwargs): | |
group_name = cls.group_name.format(id=user_id) | |
await super().send_event(event_name, group_name, **kwargs) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from django.contrib.auth.models import AnonymousUser | |
from channels.db import database_sync_to_async | |
from channels.middleware import BaseMiddleware | |
# CHANNELS MIDDLEWARE BELOW | |
############################################################################################ | |
class AuthWebsocketJWTMiddleware(BaseMiddleware): | |
""" | |
Okay, so this middleware is not straight forward, and have some undocummented | |
really new features because of django 3.0. | |
Refer to this part of the documentation: https://docs.djangoproject.com/en/3.0/topics/async/ | |
There it is explained that Django 3 will raise such exception if you try to use the ORM from within an async context (which seems to be the case). | |
As Django Channels documentation explains solution would be to use sync_to_async as follows: | |
https://channels.readthedocs.io/en/latest/topics/databases.html#database-sync-to-async | |
You might want to read this stack overflow question to see why i did this way https://stackoverflow.com/a/59653335/13158385 | |
""" | |
def __init__(self, inner): | |
self.inner = inner | |
# ------------------------------------------------------------------------------------------ | |
async def __call__(self, scope, receive, send): | |
user = AnonymousUser() | |
jwt = JWT() | |
jwt.extract_jwt_from_scope(scope) | |
if jwt.is_valid(): | |
payload = jwt.data | |
user = await self.get_user(payload['id']) | |
return await self.inner(dict(scope, user=user), receive, send) | |
# ------------------------------------------------------------------------------------------ | |
@database_sync_to_async | |
def get_user(self, user_id): | |
user = UserExtended.authentication_.user_by_user_id(user_id) | |
if user: | |
return user | |
else: | |
return AnonymousUser() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from django.urls import re_path | |
from channels.routing import URLRouter | |
from reflow_server.core.consumers import UserConsumer, PublicConsumer | |
from reflow_server.authentication.middleware import AuthWebsocketJWTMiddleware, AuthWebsocketPublicMiddleware | |
websocket_urlpatterns = [ | |
re_path(r'^websocket/', URLRouter([ | |
re_path(r'user/$', AuthWebsocketJWTMiddleware(UserConsumer.as_asgi())), | |
re_path(r'public/$', AuthWebsocketPublicMiddleware(PublicConsumer.as_asgi())) | |
])) | |
] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment