import asyncio
import json
import typing
from functools import partial
from typing import Dict, List, Type
from channels.consumer import AsyncConsumer
from channels.db import database_sync_to_async
from channels.generic.websocket import AsyncJsonWebsocketConsumer
from djangochannelsrestframework.permissions import BasePermission
from djangochannelsrestframework.settings import api_settings
from django.http import HttpRequest, HttpResponse
from django.http.response import Http404
from django.template.response import SimpleTemplateResponse
from rest_framework.exceptions import PermissionDenied, MethodNotAllowed, APIException
from rest_framework.response import Response
def ensure_async(method: typing.Callable):
if asyncio.iscoroutinefunction(method):
return method
return database_sync_to_async(method)
[docs]class AsyncAPIConsumer(AsyncJsonWebsocketConsumer, metaclass=APIConsumerMetaclass):
"""
Be very inspired by django rest framework ViewSets
"""
# use django rest framework permissions
# The following policies may be set at either globally, or per-view.
# take the default values set for django rest framework!
permission_classes = (
api_settings.DEFAULT_PERMISSION_CLASSES
) # type: List[Type[BasePermission]]
groups = {}
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.groups = set(self.groups or [])
async def add_group(self, name: str):
if not isinstance(self.groups, set):
self.groups = set(self.groups)
if name not in self.groups:
await self.channel_layer.group_add(name, self.channel_name)
self.groups.add(name)
async def remove_group(self, name: str):
if not isinstance(self.groups, set):
self.groups = set(self.groups)
if name in self.groups:
await self.channel_layer.group_discard(name, self.channel_name)
self.groups.remove(name)
[docs] async def get_permissions(self, action: str, **kwargs):
"""
Instantiates and returns the list of permissions that this view requires.
"""
return [permission() for permission in self.permission_classes]
[docs] async def check_permissions(self, action: str, **kwargs):
"""
Check if the action should be permitted.
Raises an appropriate exception if the request is not permitted.
"""
for permission in await self.get_permissions(action=action, **kwargs):
if not await ensure_async(permission.has_permission)(
scope=self.scope, consumer=self, action=action, **kwargs
):
raise PermissionDenied()
[docs] async def handle_exception(self, exc: Exception, action: str, request_id):
"""
Handle any exception that occurs, by sending an appropriate message
"""
if isinstance(exc, APIException):
await self.reply(
action=action,
errors=self._format_errors(exc.detail),
status=exc.status_code,
request_id=request_id,
)
elif exc == Http404 or isinstance(exc, Http404):
await self.reply(
action=action,
errors=self._format_errors("Not found"),
status=404,
request_id=request_id,
)
else:
raise exc
def _format_errors(self, errors):
if isinstance(errors, list):
return errors
elif isinstance(errors, str):
return [errors]
elif isinstance(errors, dict):
return [errors]
[docs] async def handle_action(self, action: str, request_id: str, **kwargs):
"""
run the action.
"""
try:
await self.check_permissions(action, **kwargs)
if action not in self.available_actions:
raise MethodNotAllowed(method=action)
method_name = self.available_actions[action]
method = getattr(self, method_name)
reply = partial(self.reply, action=action, request_id=request_id)
# the @action decorator will wrap non-async action into async ones.
response = await method(request_id=request_id, action=action, **kwargs)
if isinstance(response, tuple):
data, status = response
await reply(data=data, status=status)
except Exception as exc:
await self.handle_exception(exc, action=action, request_id=request_id)
[docs] async def receive_json(self, content: typing.Dict, **kwargs):
"""
Called with decoded JSON content.
"""
# TODO assert format, if does not match return message.
request_id = content.pop("request_id")
action = content.pop("action")
await self.handle_action(action, request_id=request_id, **content)
async def reply(
self, action: str, data=None, errors=None, status=200, request_id=None
):
if errors is None:
errors = []
payload = {
"errors": errors,
"data": data,
"action": action,
"response_status": status,
"request_id": request_id,
}
await self.send_json(payload)
[docs]class DjangoViewAsConsumer(AsyncAPIConsumer):
view = None
@property
def dumpy_url_config(self):
return
# maps actions to HTTP methods
actions = {} # type: Dict[str, str]
[docs] async def receive_json(self, content: typing.Dict, **kwargs):
"""
Called with decoded JSON content.
"""
# TODO assert format, if does not match return message.
request_id = content.pop("request_id")
action = content.pop("action")
await self.handle_action(action, request_id=request_id, **content)
[docs] async def handle_action(self, action: str, request_id: str, **kwargs):
"""
run the action.
"""
try:
await self.check_permissions(action, **kwargs)
if action not in self.actions:
raise MethodNotAllowed(method=action)
content, status = await self.call_view(action=action, **kwargs)
await self.reply(
action=action, request_id=request_id, data=content, status=status
)
except Exception as exc:
await self.handle_exception(exc, action=action, request_id=request_id)
@database_sync_to_async
def call_view(self, action: str, **kwargs):
request = HttpRequest()
request.path = self.scope.get("path")
request.session = self.scope.get("session", None)
request.META["HTTP_CONTENT_TYPE"] = "application/json"
request.META["HTTP_ACCEPT"] = "application/json"
for (header_name, value) in self.scope.get("headers", []):
request.META[header_name.decode("utf-8")] = value.decode("utf-8")
args, view_kwargs = self.get_view_args(action=action, **kwargs)
request.method = self.actions[action]
request.POST = json.dumps(kwargs.get("data", {}))
if self.scope.get("cookies"):
request.COOKIES = self.scope.get("cookies")
view = getattr(self.__class__, "view")
response = view(request, *args, **view_kwargs)
status = response.status_code
if isinstance(response, Response):
data = response.data
try:
# check if we can json encode it!
# there must be a better way fo doing this?
json.dumps(data)
return data, status
except Exception as e:
pass
if isinstance(response, SimpleTemplateResponse):
response.render()
response_content = response.content
if isinstance(response_content, bytes):
try:
response_content = response_content.decode("utf-8")
except Exception as e:
response_content = response_content.hex()
return response_content, status
def get_view_args(self, action: str, **kwargs):
return [], {}
[docs]def view_as_consumer(
wrapped_view: typing.Callable[[HttpRequest], HttpResponse],
mapped_actions: typing.Optional[typing.Dict[str, str]] = None,
) -> DjangoViewAsConsumer:
"""
Wrap a django View so that it will be triggered by actions over this json
websocket consumer.
"""
if mapped_actions is None:
mapped_actions = {
"create": "PUT",
"update": "PATCH",
"list": "GET",
"retrieve": "GET",
}
class DjangoViewWrapper(DjangoViewAsConsumer):
view = wrapped_view
actions = mapped_actions
return DjangoViewWrapper()