Source code for djangochannelsrestframework.permissions

from typing import Dict, Any, Type


from channels.consumer import AsyncConsumer
from rest_framework.permissions import BasePermission as DRFBasePermission

from djangochannelsrestframework.scope_utils import ensure_async
from djangochannelsrestframework.scope_utils import request_from_scope


class OperationHolderMixin:
    def __and__(self, other):
        return OperandHolder(AND, self, other)

    def __or__(self, other):
        return OperandHolder(OR, self, other)

    def __rand__(self, other):
        return OperandHolder(AND, other, self)

    def __ror__(self, other):
        return OperandHolder(OR, other, self)

    def __invert__(self):
        return SingleOperandHolder(NOT, self)


class SingleOperandHolder(OperationHolderMixin):
    op1_class: "Type[BasePermission | OperationHolderMixin]"

    def __init__(
        self, operator_class, op1_class: "Type[ BasePermission | OperationHolderMixin]"
    ):
        self.operator_class = operator_class
        self.op1_class = op1_class

    def __call__(self, *args, **kwargs):
        op1 = _ensure_base(self.op1_class)
        return self.operator_class(op1)


class OperandHolder(OperationHolderMixin):
    op1_class: "Type[BasePermission | OperationHolderMixin]"
    op2_class: "Type[BasePermission | OperationHolderMixin]"

    def __init__(
        self,
        operator_class,
        op1_class: "Type[BasePermission | OperationHolderMixin]",
        op2_class: "Type[BasePermission | OperationHolderMixin]",
    ):
        self.operator_class = operator_class
        self.op1_class = op1_class
        self.op2_class = op2_class

    def __call__(self, *args, **kwargs):
        op1 = _ensure_base(self.op1_class)
        op2 = _ensure_base(self.op2_class)
        return self.operator_class(op1, op2)


class AND:
    def __init__(self, op1: "BasePermission", op2: "BasePermission"):
        self.op1 = op1
        self.op2 = op2

    async def has_permission(
        self, scope: Dict[str, Any], consumer: AsyncConsumer, action: str, **kwargs
    ):
        return await self.op1.has_permission(
            scope, consumer, action, **kwargs
        ) and await self.op2.has_permission(scope, consumer, action, **kwargs)

    async def can_connect(
        self, scope: Dict[str, Any], consumer: AsyncConsumer, **kwargs
    ) -> bool:
        return await self.op1.can_connect(
            scope, consumer, **kwargs
        ) and await self.op2.can_connect(scope, consumer, **kwargs)


class OR:
    def __init__(self, op1: "BasePermission", op2: "BasePermission"):
        self.op1 = op1
        self.op2 = op2

    async def has_permission(
        self, scope: Dict[str, Any], consumer: AsyncConsumer, action: str, **kwargs
    ):
        return await self.op1.has_permission(
            scope, consumer, action, **kwargs
        ) or await self.op2.has_permission(scope, consumer, action, **kwargs)

    async def can_connect(
        self, scope: Dict[str, Any], consumer: AsyncConsumer, **kwargs
    ) -> bool:
        return await self.op1.can_connect(
            scope, consumer, **kwargs
        ) or await self.op2.can_connect(scope, consumer, **kwargs)


class NOT:
    def __init__(self, op1: "BasePermission"):
        self.op1 = op1

    async def has_permission(
        self, scope: Dict[str, Any], consumer: AsyncConsumer, action: str, **kwargs
    ):
        return not await self.op1.has_permission(scope, consumer, action, **kwargs)

    async def can_connect(
        self, scope: Dict[str, Any], consumer: AsyncConsumer, **kwargs
    ) -> bool:
        return not await self.op1.can_connect(scope, consumer, **kwargs)


class BasePermissionMetaclass(OperationHolderMixin, type):
    pass


[docs] class BasePermission(metaclass=BasePermissionMetaclass): """Base permission class Notes: You should extend this class and override the `has_permission` method to create your own permission class. You can also over override`can_connect` to determine if a websocket connection should even be permitted. Methods: async has_permission (scope, consumer, action, **kwargs) """
[docs] async def has_permission( self, scope: Dict[str, Any], consumer: AsyncConsumer, action: str, **kwargs ) -> bool: """ Called on every websocket message sent before the corresponding action handler is called. """ pass
[docs] async def can_connect( self, scope: Dict[str, Any], consumer: AsyncConsumer, message=None ) -> bool: """ Called during connection to validate if a given client can establish a websocket connection. By default, this returns True and permits all connections to be made. """ return True
[docs] class AllowAny(BasePermission): """Always allow"""
[docs] async def has_permission( self, scope: Dict[str, Any], consumer: AsyncConsumer, action: str, **kwargs ) -> bool: return True
[docs] class IsAuthenticated(BasePermission): """Allow authenticated users"""
[docs] async def has_permission( self, scope: Dict[str, Any], consumer: AsyncConsumer, action: str, **kwargs ) -> bool: user = scope.get("user") if not user: return False return user.pk and user.is_authenticated
class WrappedDRFPermission(BasePermission): permission: DRFBasePermission mapped_actions = { "create": "PUT", "update": "PATCH", "list": "GET", "retrieve": "GET", "connect": "HEAD", } def __init__(self, permission: DRFBasePermission): super().__init__() self.permission = permission async def has_permission( self, scope: Dict[str, Any], consumer: AsyncConsumer, action: str, **kwargs ) -> bool: request = request_from_scope(scope) request.method = self.mapped_actions.get(action, action.upper()) return await ensure_async(self.permission.has_permission)(request, consumer) async def can_connect( self, scope: Dict[str, Any], consumer: AsyncConsumer, message=None ) -> bool: request = request_from_scope(scope) request.method = self.mapped_actions.get("connect", "CONNECT") return await ensure_async(self.permission.has_permission)(request, consumer) def _ensure_base(cls: Type[BasePermission | DRFBasePermission]) -> BasePermission: if issubclass(cls, DRFBasePermission): return WrappedDRFPermission(permission=cls()) return cls()