Source code for djangochannelsrestframework.observer.generics

from copy import deepcopy

from django.db.models import Model
from functools import partial
from typing import Dict, Type, Optional, Set, List, Iterable

from channels.db import database_sync_to_async
from rest_framework import status

from djangochannelsrestframework.consumers import APIConsumerMetaclass
from djangochannelsrestframework.decorators import action
from djangochannelsrestframework.generics import GenericAsyncAPIConsumer
from djangochannelsrestframework.mixins import RetrieveModelMixin
from djangochannelsrestframework.observer import ModelObserver


class _GenericModelObserver:
    def __init__(self, func, **kwargs):
        self.func = func
        self._group_names = None
        self._serializer = None

    def bind_to_model(
        self, model_cls: Type[Model], name: str, many_to_many=False
    ) -> ModelObserver:
        observer = ModelObserver(
            func=self.func,
            model_cls=model_cls,
            partition=name,
            many_to_many=many_to_many,
        )
        observer.groups(self._group_names)
        observer.serializer(self._serializer)
        return observer

    def groups(self, func):
        self._group_names = func
        return self

    def serializer(self, func):
        self._serializer = func
        return self


class ObserverAPIConsumerMetaclass(APIConsumerMetaclass):
    def __new__(mcs, name, bases, body) -> Type[GenericAsyncAPIConsumer]:

        queryset = body.get("queryset", None)
        many_to_many = body.get("observer_many_to_many_relationships", False)
        if queryset is not None:
            for attr_name, attr in body.items():
                if isinstance(attr, _GenericModelObserver):
                    body[attr_name] = attr.bind_to_model(
                        model_cls=queryset.model,
                        name=f"{body['__module__']}.{name}.{attr_name}",
                        many_to_many=many_to_many,
                    )
            for base in bases:
                for attr_name in dir(base):
                    attr = getattr(base, attr_name)
                    if isinstance(attr, _GenericModelObserver):
                        body[attr_name] = attr.bind_to_model(
                            model_cls=queryset.model,
                            name=f"{body['__module__']}.{name}.{attr_name}",
                            many_to_many=many_to_many,
                        )

        return super().__new__(mcs, name, bases, body)


class ObserverConsumerMixin(metaclass=ObserverAPIConsumerMetaclass):
    observer_many_to_many_relationships = False

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.subscribed_requests = {}  # type: Dict[str, Set[str]]

    def _subscribe(self, request_id: str, groups: Set[str]):
        for group in groups:
            request_ids = self.subscribed_requests.get(group, set())
            request_ids.add(request_id)
            self.subscribed_requests[group] = request_ids

    def _unsubscribe(self, request_id: str):
        to_remove = []
        for group, request_ids in self.subscribed_requests.items():
            if request_id in request_ids:
                request_ids.remove(request_id)
            if not request_ids:
                to_remove.append(group)

        self._unsubscribe_groups(to_remove)

    def _unsubscribe_groups(self, groups: Iterable[str]):
        for group in groups:
            try:
                self.subscribed_requests.pop(group)
            except KeyError:
                continue

    def _requests_for(self, group: Optional[str]) -> Set[str]:
        all_request_ids = set()
        if not group:
            for request_ids in self.subscribed_requests.values():
                all_request_ids = all_request_ids.union(request_ids)
            return all_request_ids
        return self.subscribed_requests.get(group, set())


[docs] class ObserverModelInstanceMixin(ObserverConsumerMixin, RetrieveModelMixin): """ Use this as a mixing with :class:`~djangochannelsrestframework.generics.GenericAsyncAPIConsumer`. You can also set the ``observer_many_to_many_relationships = True`` class property to ensure many-to-many relationship changes are tracked by the subscription. .. code-block:: python # consumers.py from djangochannelsrestframework.consumers import GenericAsyncAPIConsumer from djangochannelsrestframework.observer.generics import ObserverModelInstanceMixin from .serializers import UserSerializer from .models import User class MyConsumer(ObserverModelInstanceMixin, GenericAsyncAPIConsumer): queryset = User.objects.all() serializer_class = UserSerializer observer_many_to_many_relationships = True """
[docs] @action() async def subscribe_instance(self, request_id=None, **kwargs): """ Subscribes the current consumer to updates for a specific model instance. This method retrieves the model instance based on the provided lookup parameters (`kwargs`), then subscribes the consumer to receive real-time updates related to that instance. The subscription is identified by a `request_id`, which must be provided. Args: request_id (str): A unique identifier for the subscription request. **kwargs: Lookup parameters used to retrieve the model instance. Raises: ValueError: If `request_id` is not provided. """ if request_id is None: raise ValueError("request_id must have a value set") # subscribe! instance = await database_sync_to_async(self.get_object)(**kwargs) groups = set(await self.handle_instance_change.subscribe(instance=instance)) self._subscribe(request_id, groups) return None, status.HTTP_201_CREATED
[docs] @action() async def unsubscribe_instance(self, request_id: Optional[str] = None, **kwargs): """ Unsubscribes the current consumer from updates for a specific model instance. This method removes the consumer's subscription to real-time updates for the given model instance. If a `request_id` is provided, only that specific subscription is removed. Otherwise, all subscriptions related to the instance are unsubscribed. Args: request_id (str, optional): A unique identifier for the subscription request. **kwargs: Lookup parameters used to retrieve the model instance. """ instance = await database_sync_to_async(self.get_object)(**kwargs) groups = await self.handle_instance_change.unsubscribe(instance=instance) if request_id is None: self._unsubscribe_groups(groups) else: self._unsubscribe(request_id) return None, status.HTTP_204_NO_CONTENT
@_GenericModelObserver async def handle_instance_change( self, message: Dict, group=None, action=None, **kwargs ): await self.handle_observed_action( action=action, group=group, **message, ) @handle_instance_change.groups def handle_instance_change(self: ModelObserver, instance, *args, **kwargs): # one channel for all updates. yield "{}-model-{}-pk-{}".format( self.func.__name__.replace("_", "."), self.model_label, instance.pk )
[docs] async def handle_observed_action( self, action: str, group: Optional[str] = None, **kwargs ): """ run the action. """ try: await self.check_permissions(action, **kwargs) except Exception as exc: await self.handle_exception(exc, action=action, request_id=None) for request_id in self._requests_for(group): try: reply = partial(self.reply, action=action, request_id=request_id) if action == "delete": await reply(data=kwargs, status=204) # send the delete continue # the @action decorator will wrap non-async action into async ones. response = await self.retrieve( request_id=request_id, action=action, **kwargs ) if isinstance(response, tuple): data, status = response await reply(data=data, status=status) except Exception as exc: await self.handle_exception(exc, action=action, request_id=request_id)