Refactor EventObserver & TelegramEventObserver

This commit is contained in:
Alex Root Junior 2020-01-13 21:17:28 +02:00
parent 3b2df194a9
commit 9907eada32
3 changed files with 86 additions and 84 deletions

View file

@ -1,7 +1,7 @@
import inspect
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Awaitable, Callable, Dict, List, Tuple, Union
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Union
from aiogram.dispatcher.filters.base import BaseFilter
from aiogram.dispatcher.handler.base import BaseHandler
@ -10,7 +10,7 @@ CallbackType = Callable[[Any], Awaitable[Any]]
SyncFilter = Callable[[Any], Any]
AsyncFilter = Callable[[Any], Awaitable[Any]]
FilterType = Union[SyncFilter, AsyncFilter, BaseFilter]
HandlerType = Union[CallbackType, BaseHandler]
HandlerType = Union[FilterType, BaseHandler]
@dataclass
@ -47,7 +47,7 @@ class FilterObject(CallableMixin):
@dataclass
class HandlerObject(CallableMixin):
callback: HandlerType
filters: List[FilterObject]
filters: Optional[List[FilterObject]] = None
def __post_init__(self):
super(HandlerObject, self).__post_init__()
@ -56,6 +56,8 @@ class HandlerObject(CallableMixin):
self.awaitable = True
async def check(self, *args: Any, **kwargs: Any) -> Tuple[bool, Dict[str, Any]]:
if not self.filters:
return True, {}
for event_filter in self.filters:
check = await event_filter.call(*args, **kwargs)
if not check:

View file

@ -1,6 +1,7 @@
from __future__ import annotations
import copy
from itertools import chain
from typing import (
TYPE_CHECKING,
Any,
@ -34,15 +35,11 @@ class EventObserver:
def __init__(self) -> None:
self.handlers: List[HandlerObject] = []
def register(self, callback: HandlerType, *filters: FilterType) -> HandlerType:
def register(self, callback: HandlerType) -> HandlerType:
"""
Register callback with filters
"""
self.handlers.append(
HandlerObject(
callback=callback, filters=[FilterObject(filter_) for filter_ in filters]
)
)
self.handlers.append(HandlerObject(callback=callback))
return callback
async def trigger(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]:
@ -51,22 +48,18 @@ class EventObserver:
Handler will be called when all its filters is pass.
"""
for handler in self.handlers:
kwargs_copy = copy.copy(kwargs)
result, data = await handler.check(*args, **kwargs)
if result:
kwargs_copy.update(data)
try:
yield await handler.call(*args, **kwargs_copy)
except SkipHandler:
continue
try:
yield await handler.call(*args, **kwargs)
except SkipHandler:
continue
def __call__(self, *args: FilterType) -> Callable[[CallbackType], CallbackType]:
def __call__(self) -> Callable[[CallbackType], CallbackType]:
"""
Decorator for registering event handlers
"""
def wrapper(callback: CallbackType) -> CallbackType:
self.register(callback, *args)
self.register(callback)
return callback
return wrapper
@ -148,16 +141,29 @@ class TelegramEventObserver(EventObserver):
Register event handler
"""
resolved_filters = self.resolve_filters(bound_filters)
return super().register(callback, *filters, *resolved_filters)
self.handlers.append(
HandlerObject(
callback=callback,
filters=[FilterObject(filter_) for filter_ in chain(resolved_filters, filters)],
)
)
return callback
async def trigger(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]:
"""
Propagate event to handlers and stops propagation on first match.
Handler will be called when all its filters is pass.
"""
async for result in super(TelegramEventObserver, self).trigger(*args, **kwargs):
yield result
break
for handler in self.handlers:
kwargs_copy = copy.copy(kwargs)
result, data = await handler.check(*args, **kwargs)
if result:
kwargs_copy.update(data)
try:
yield await handler.call(*args, **kwargs_copy)
except SkipHandler:
continue
break
def __call__(
self, *args: FilterType, **bound_filters: BaseFilter