Add prototype of class-based handlers

This commit is contained in:
Alex Root Junior 2019-12-03 00:03:15 +02:00
parent 2a731f7ce2
commit b82a1a6fb0
11 changed files with 178 additions and 8 deletions

View file

@ -4,16 +4,18 @@ from functools import partial
from typing import Any, Awaitable, Callable, Dict, List, Tuple, Union
from aiogram.dispatcher.filters.base import BaseFilter
from aiogram.dispatcher.handler.base import BaseHandler
CallbackType = Callable[[Any], Awaitable[Any]]
SyncFilter = Callable[[Any], Any]
AsyncFilter = Callable[[Any], Awaitable[Any]]
FilterType = Union[SyncFilter, AsyncFilter, BaseFilter]
HandlerType = Union[CallbackType, BaseHandler]
@dataclass
class CallableMixin:
callback: Callable
callback: HandlerType
awaitable: bool = field(init=False)
spec: inspect.FullArgSpec = field(init=False)
@ -44,9 +46,19 @@ class FilterObject(CallableMixin):
@dataclass
class HandlerObject(CallableMixin):
callback: CallbackType
callback: HandlerType
filters: List[FilterObject]
def __post_init__(self):
super(HandlerObject, self).__post_init__()
if inspect.isclass(self.callback) and issubclass(self.callback, BaseHandler):
self.awaitable = True
if hasattr(self.callback, "filters"):
self.filters.extend(
FilterObject(event_filter) for event_filter in self.callback.filters
)
async def check(self, *args: Any, **kwargs: Any) -> Tuple[bool, Dict[str, Any]]:
for event_filter in self.filters:
check = await event_filter.call(*args, **kwargs)

View file

@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Type
from pydantic import ValidationError
from ..filters.base import BaseFilter
from .handler import CallbackType, FilterObject, FilterType, HandlerObject
from .handler import CallbackType, FilterObject, FilterType, HandlerObject, HandlerType
if TYPE_CHECKING: # pragma: no cover
from aiogram.dispatcher.router import Router
@ -24,7 +24,7 @@ class EventObserver:
def __init__(self):
self.handlers: List[HandlerObject] = []
def register(self, callback: CallbackType, *filters: FilterType):
def register(self, callback: HandlerType, *filters: FilterType):
"""
Register callback with filters
@ -91,7 +91,7 @@ class TelegramEventObserver(EventObserver):
yield filter_
registry.append(filter_)
def register(self, callback: CallbackType, *filters: FilterType, **bound_filters: Any):
def register(self, callback: HandlerType, *filters: FilterType, **bound_filters: Any):
resolved_filters = self.resolve_filters(bound_filters)
return super().register(callback, *filters, *resolved_filters)

View file

View file

@ -0,0 +1,37 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
from aiogram import Bot
from aiogram.api.types import TelegramObject
if TYPE_CHECKING: # pragma: no cover
from aiogram.dispatcher.event.handler import FilterType # NOQA: F401
class BaseHandlerMixin:
event: TelegramObject
data: Dict[str, Any]
class HandlerBotMixin(BaseHandlerMixin):
@property
def bot(self) -> Bot:
if "bot" in self.data:
return self.data["bot"]
return Bot.get_current()
class BaseHandler(HandlerBotMixin, ABC):
event: TelegramObject
filters: Union[List["FilterType"], Tuple["FilterType"]]
def __init__(self, event: TelegramObject, **kwargs: Any) -> None:
self.event = event
self.data = kwargs
@abstractmethod
async def handle(self) -> Any: # pragma: no cover
pass
def __await__(self):
return self.handle().__await__()

View file

@ -0,0 +1,16 @@
from abc import ABC
from aiogram.api.types import Message
from aiogram.dispatcher.handler.base import BaseHandler
class MessageHandler(BaseHandler, ABC):
event: Message
@property
def from_user(self):
return self.event.from_user
@property
def chat(self):
return self.event.chat

View file

@ -45,9 +45,13 @@ class ContextInstanceMixin:
return cls.__context_instance.get()
@classmethod
def set_current(cls: Type[T], value: T):
def set_current(cls: Type[T], value: T) -> contextvars.Token:
if not isinstance(value, cls):
raise TypeError(
f"Value should be instance of {cls.__name__!r} not {type(value).__name__!r}"
)
cls.__context_instance.set(value)
return cls.__context_instance.set(value)
@classmethod
def reset_current(cls: Type[T], token: contextvars.Token):
cls.__context_instance.reset(token)