diff --git a/aiogram/dispatcher/event/handler.py b/aiogram/dispatcher/event/handler.py index 42c6202d..52e8c0da 100644 --- a/aiogram/dispatcher/event/handler.py +++ b/aiogram/dispatcher/event/handler.py @@ -1,7 +1,7 @@ import inspect from dataclasses import dataclass, field from functools import partial -from typing import Any, Awaitable, Callable, Dict, List, Tuple, Union +from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Union from aiogram.dispatcher.filters.base import BaseFilter from aiogram.dispatcher.handler.base import BaseHandler @@ -10,7 +10,7 @@ CallbackType = Callable[[Any], Awaitable[Any]] SyncFilter = Callable[[Any], Any] AsyncFilter = Callable[[Any], Awaitable[Any]] FilterType = Union[SyncFilter, AsyncFilter, BaseFilter] -HandlerType = Union[CallbackType, BaseHandler] +HandlerType = Union[FilterType, BaseHandler] @dataclass @@ -47,7 +47,7 @@ class FilterObject(CallableMixin): @dataclass class HandlerObject(CallableMixin): callback: HandlerType - filters: List[FilterObject] + filters: Optional[List[FilterObject]] = None def __post_init__(self): super(HandlerObject, self).__post_init__() @@ -56,6 +56,8 @@ class HandlerObject(CallableMixin): self.awaitable = True async def check(self, *args: Any, **kwargs: Any) -> Tuple[bool, Dict[str, Any]]: + if not self.filters: + return True, {} for event_filter in self.filters: check = await event_filter.call(*args, **kwargs) if not check: diff --git a/aiogram/dispatcher/event/observer.py b/aiogram/dispatcher/event/observer.py index 93115ab7..40afc180 100644 --- a/aiogram/dispatcher/event/observer.py +++ b/aiogram/dispatcher/event/observer.py @@ -1,6 +1,7 @@ from __future__ import annotations import copy +from itertools import chain from typing import ( TYPE_CHECKING, Any, @@ -34,15 +35,11 @@ class EventObserver: def __init__(self) -> None: self.handlers: List[HandlerObject] = [] - def register(self, callback: HandlerType, *filters: FilterType) -> HandlerType: + def register(self, callback: HandlerType) -> HandlerType: """ Register callback with filters """ - self.handlers.append( - HandlerObject( - callback=callback, filters=[FilterObject(filter_) for filter_ in filters] - ) - ) + self.handlers.append(HandlerObject(callback=callback)) return callback async def trigger(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]: @@ -51,22 +48,18 @@ class EventObserver: Handler will be called when all its filters is pass. """ for handler in self.handlers: - kwargs_copy = copy.copy(kwargs) - result, data = await handler.check(*args, **kwargs) - if result: - kwargs_copy.update(data) - try: - yield await handler.call(*args, **kwargs_copy) - except SkipHandler: - continue + try: + yield await handler.call(*args, **kwargs) + except SkipHandler: + continue - def __call__(self, *args: FilterType) -> Callable[[CallbackType], CallbackType]: + def __call__(self) -> Callable[[CallbackType], CallbackType]: """ Decorator for registering event handlers """ def wrapper(callback: CallbackType) -> CallbackType: - self.register(callback, *args) + self.register(callback) return callback return wrapper @@ -148,16 +141,29 @@ class TelegramEventObserver(EventObserver): Register event handler """ resolved_filters = self.resolve_filters(bound_filters) - return super().register(callback, *filters, *resolved_filters) + self.handlers.append( + HandlerObject( + callback=callback, + filters=[FilterObject(filter_) for filter_ in chain(resolved_filters, filters)], + ) + ) + return callback async def trigger(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]: """ Propagate event to handlers and stops propagation on first match. Handler will be called when all its filters is pass. """ - async for result in super(TelegramEventObserver, self).trigger(*args, **kwargs): - yield result - break + for handler in self.handlers: + kwargs_copy = copy.copy(kwargs) + result, data = await handler.check(*args, **kwargs) + if result: + kwargs_copy.update(data) + try: + yield await handler.call(*args, **kwargs_copy) + except SkipHandler: + continue + break def __call__( self, *args: FilterType, **bound_filters: BaseFilter diff --git a/tests/test_dispatcher/test_event/test_observer.py b/tests/test_dispatcher/test_event/test_observer.py index 2f157850..a4029197 100644 --- a/tests/test_dispatcher/test_event/test_observer.py +++ b/tests/test_dispatcher/test_event/test_observer.py @@ -39,68 +39,38 @@ class MyFilter3(MyFilter1): class TestEventObserver: - @pytest.mark.parametrize( - "count,handler,filters", - ( - pytest.param(5, my_handler, []), - pytest.param(3, my_handler, [lambda event: True]), - pytest.param( - 2, - my_handler, - [lambda event: True, lambda event: False, lambda event: {"ok": True}], - ), - ), - ) - def test_register_filters(self, count, handler, filters): + @pytest.mark.parametrize("count,handler", ([5, my_handler], [3, my_handler], [2, my_handler])) + def test_register_filters(self, count, handler): observer = EventObserver() for index in range(count): wrapped_handler = functools.partial(handler, index=index) - observer.register(wrapped_handler, *filters) + observer.register(wrapped_handler) registered_handler = observer.handlers[index] assert len(observer.handlers) == index + 1 assert isinstance(registered_handler, HandlerObject) assert registered_handler.callback == wrapped_handler - assert len(registered_handler.filters) == len(filters) + assert not registered_handler.filters - @pytest.mark.parametrize( - "count,handler,filters", - ( - pytest.param(5, my_handler, []), - pytest.param(3, my_handler, [lambda event: True]), - pytest.param( - 2, - my_handler, - [lambda event: True, lambda event: False, lambda event: {"ok": True}], - ), - ), - ) - def test_register_filters_via_decorator(self, count, handler, filters): + @pytest.mark.parametrize("count,handler", ([5, my_handler], [3, my_handler], [2, my_handler])) + def test_register_filters_via_decorator(self, count, handler): observer = EventObserver() for index in range(count): wrapped_handler = functools.partial(handler, index=index) - observer(*filters)(wrapped_handler) + observer()(wrapped_handler) registered_handler = observer.handlers[index] assert len(observer.handlers) == index + 1 assert isinstance(registered_handler, HandlerObject) assert registered_handler.callback == wrapped_handler - assert len(registered_handler.filters) == len(filters) - - @pytest.mark.asyncio - async def test_trigger_rejected(self): - observer = EventObserver() - observer.register(my_handler, lambda event: False) - - results = [result async for result in observer.trigger(42)] - assert results == [] + assert not registered_handler.filters @pytest.mark.asyncio async def test_trigger_accepted_bool(self): observer = EventObserver() - observer.register(my_handler, lambda event: True) + observer.register(my_handler) results = [result async for result in observer.trigger(42)] assert results == [42] @@ -108,23 +78,12 @@ class TestEventObserver: @pytest.mark.asyncio async def test_trigger_with_skip(self): observer = EventObserver() - observer.register(skip_my_handler, lambda event: True) - observer.register(my_handler, lambda event: False) - observer.register(my_handler, lambda event: True) + observer.register(skip_my_handler) + observer.register(my_handler) + observer.register(my_handler) results = [result async for result in observer.trigger(42)] - assert results == [42] - - @pytest.mark.asyncio - async def test_trigger_right_context_in_handlers(self): - observer = EventObserver() - observer.register( - pipe_handler, lambda event: {"a": 1}, lambda event: False - ) # {"a": 1} should not be in result - observer.register(pipe_handler, lambda event: {"b": 2}) - - results = [result async for result in observer.trigger(42)] - assert results == [((42,), {"b": 2})] + assert results == [42, 42] class TestTelegramEventObserver: @@ -144,9 +103,9 @@ class TestTelegramEventObserver: assert MyFilter in event_observer.filters def test_resolve_filters_chain(self): - router1 = Router() - router2 = Router() - router3 = Router() + router1 = Router(use_builtin_filters=False) + router2 = Router(use_builtin_filters=False) + router3 = Router(use_builtin_filters=False) router1.include_router(router2) router2.include_router(router3) @@ -168,7 +127,7 @@ class TestTelegramEventObserver: assert MyFilter3 in filters_chain3 def test_resolve_filters(self): - router = Router() + router = Router(use_builtin_filters=False) observer = router.message_handler observer.bind_filter(MyFilter1) @@ -189,7 +148,7 @@ class TestTelegramEventObserver: assert observer.resolve_filters({"test": ...}) def test_register(self): - router = Router() + router = Router(use_builtin_filters=False) observer = router.message_handler observer.bind_filter(MyFilter1) @@ -214,7 +173,7 @@ class TestTelegramEventObserver: assert MyFilter1(test="PASS") in callbacks def test_register_decorator(self): - router = Router() + router = Router(use_builtin_filters=False) observer = router.message_handler @observer() @@ -226,7 +185,7 @@ class TestTelegramEventObserver: @pytest.mark.asyncio async def test_trigger(self): - router = Router() + router = Router(use_builtin_filters=False) observer = router.message_handler observer.bind_filter(MyFilter1) observer.register(my_handler, test="ok") @@ -241,3 +200,38 @@ class TestTelegramEventObserver: results = [result async for result in observer.trigger(message)] assert results == [message] + + @pytest.mark.parametrize( + "count,handler,filters", + ( + [5, my_handler, []], + [3, my_handler, [lambda event: True]], + [2, my_handler, [lambda event: True, lambda event: False, lambda event: {"ok": True}]], + ), + ) + def test_register_filters_via_decorator(self, count, handler, filters): + router = Router(use_builtin_filters=False) + observer = router.message_handler + + for index in range(count): + wrapped_handler = functools.partial(handler, index=index) + observer(*filters)(wrapped_handler) + registered_handler = observer.handlers[index] + + assert len(observer.handlers) == index + 1 + assert isinstance(registered_handler, HandlerObject) + assert registered_handler.callback == wrapped_handler + assert len(registered_handler.filters) == len(filters) + + # + @pytest.mark.asyncio + async def test_trigger_right_context_in_handlers(self): + router = Router(use_builtin_filters=False) + observer = router.message_handler + observer.register( + pipe_handler, lambda event: {"a": 1}, lambda event: False + ) # {"a": 1} should not be in result + observer.register(pipe_handler, lambda event: {"b": 2}) + + results = [result async for result in observer.trigger(42)] + assert results == [((42,), {"b": 2})]