From 4fb77a3a2a30f0d39be0f968e63c3f948b553cf3 Mon Sep 17 00:00:00 2001 From: Alex Root Junior Date: Sun, 24 Apr 2022 04:19:19 +0300 Subject: [PATCH] Added possibility to combine filters or invert result (#895) * Added possibility to combine filters or invert result --- CHANGES/894.feature.rst | 7 ++ aiogram/dispatcher/event/event.py | 4 +- aiogram/dispatcher/event/handler.py | 15 ++-- aiogram/dispatcher/event/telegram.py | 30 +++++-- aiogram/dispatcher/filters/__init__.py | 4 + aiogram/dispatcher/filters/base.py | 4 +- aiogram/dispatcher/filters/logic.py | 87 +++++++++++++++++++ aiogram/dispatcher/middlewares/manager.py | 4 +- aiogram/dispatcher/router.py | 32 +++---- .../test_filters/test_logic.py | 37 ++++++++ 10 files changed, 184 insertions(+), 40 deletions(-) create mode 100644 CHANGES/894.feature.rst create mode 100644 aiogram/dispatcher/filters/logic.py create mode 100644 tests/test_dispatcher/test_filters/test_logic.py diff --git a/CHANGES/894.feature.rst b/CHANGES/894.feature.rst new file mode 100644 index 00000000..f89f4e07 --- /dev/null +++ b/CHANGES/894.feature.rst @@ -0,0 +1,7 @@ +Added possibility to combine filters or invert result + +Example: +.. code-block:: python + Text(text="demo") | Command(commands=["demo"]) + MyFilter() & AnotherFilter() + ~StateFilter(state='my-state') diff --git a/aiogram/dispatcher/event/event.py b/aiogram/dispatcher/event/event.py index ef87d329..a2e1165c 100644 --- a/aiogram/dispatcher/event/event.py +++ b/aiogram/dispatcher/event/event.py @@ -2,7 +2,7 @@ from __future__ import annotations from typing import Any, Callable, List -from .handler import CallbackType, HandlerObject, HandlerType +from .handler import CallbackType, HandlerObject class EventObserver: @@ -26,7 +26,7 @@ class EventObserver: def __init__(self) -> None: self.handlers: List[HandlerObject] = [] - def register(self, callback: HandlerType) -> None: + def register(self, callback: CallbackType) -> None: """ Register callback with filters """ diff --git a/aiogram/dispatcher/event/handler.py b/aiogram/dispatcher/event/handler.py index 813ddf51..af4cff37 100644 --- a/aiogram/dispatcher/event/handler.py +++ b/aiogram/dispatcher/event/handler.py @@ -3,24 +3,19 @@ import contextvars import inspect from dataclasses import dataclass, field from functools import partial -from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Tuple from magic_filter import MagicFilter -from aiogram.dispatcher.filters.base import BaseFilter from aiogram.dispatcher.flags.getter import extract_flags_from_object from aiogram.dispatcher.handler.base import BaseHandler -CallbackType = Callable[..., Awaitable[Any]] -SyncFilter = Callable[..., Any] -AsyncFilter = Callable[..., Awaitable[Any]] -FilterType = Union[SyncFilter, AsyncFilter, BaseFilter, MagicFilter] -HandlerType = Union[FilterType, Type[BaseHandler]] +CallbackType = Callable[..., Any] @dataclass class CallableMixin: - callback: HandlerType + callback: CallbackType awaitable: bool = field(init=False) spec: inspect.FullArgSpec = field(init=False) @@ -50,7 +45,7 @@ class CallableMixin: @dataclass class FilterObject(CallableMixin): - callback: FilterType + callback: CallbackType def __post_init__(self) -> None: # TODO: Make possibility to extract and explain magic from filter object. @@ -63,7 +58,7 @@ class FilterObject(CallableMixin): @dataclass class HandlerObject(CallableMixin): - callback: HandlerType + callback: CallbackType filters: Optional[List[FilterObject]] = None flags: Dict[str, Any] = field(default_factory=dict) diff --git a/aiogram/dispatcher/event/telegram.py b/aiogram/dispatcher/event/telegram.py index fcf3d7d2..ab815de7 100644 --- a/aiogram/dispatcher/event/telegram.py +++ b/aiogram/dispatcher/event/telegram.py @@ -12,7 +12,7 @@ from ...exceptions import FiltersResolveError from ...types import TelegramObject from ..filters.base import BaseFilter from .bases import REJECTED, UNHANDLED, MiddlewareType, SkipHandler -from .handler import CallbackType, FilterObject, FilterType, HandlerObject, HandlerType +from .handler import CallbackType, FilterObject, HandlerObject if TYPE_CHECKING: from aiogram.dispatcher.router import Router @@ -40,7 +40,7 @@ class TelegramEventObserver: # with dummy callback which never will be used self._handler = HandlerObject(callback=lambda: True, filters=[]) - def filter(self, *filters: FilterType, **bound_filters: Any) -> None: + def filter(self, *filters: CallbackType, **bound_filters: Any) -> None: """ Register filter for all handlers of this event observer @@ -51,7 +51,13 @@ class TelegramEventObserver: if self._handler.filters is None: self._handler.filters = [] self._handler.filters.extend( - [FilterObject(filter_) for filter_ in chain(resolved_filters, filters)] + [ + FilterObject(filter_) # type: ignore + for filter_ in chain( + resolved_filters, + filters, + ) + ] ) def bind_filter(self, bound_filter: Type[BaseFilter]) -> None: @@ -96,7 +102,7 @@ class TelegramEventObserver: def resolve_filters( self, - filters: Tuple[FilterType, ...], + filters: Tuple[CallbackType, ...], full_config: Dict[str, Any], ignore_default: bool = True, ) -> List[BaseFilter]: @@ -158,11 +164,11 @@ class TelegramEventObserver: def register( self, - callback: HandlerType, - *filters: FilterType, + callback: CallbackType, + *filters: CallbackType, flags: Optional[Dict[str, Any]] = None, **bound_filters: Any, - ) -> HandlerType: + ) -> CallbackType: """ Register event handler """ @@ -174,7 +180,13 @@ class TelegramEventObserver: self.handlers.append( HandlerObject( callback=callback, - filters=[FilterObject(filter_) for filter_ in chain(resolved_filters, filters)], + filters=[ + FilterObject(filter_) # type: ignore + for filter_ in chain( + resolved_filters, + filters, + ) + ], flags=flags, ) ) @@ -216,7 +228,7 @@ class TelegramEventObserver: return UNHANDLED def __call__( - self, *args: FilterType, flags: Optional[Dict[str, Any]] = None, **bound_filters: Any + self, *args: CallbackType, flags: Optional[Dict[str, Any]] = None, **bound_filters: Any ) -> Callable[[CallbackType], CallbackType]: """ Decorator for registering event handlers diff --git a/aiogram/dispatcher/filters/__init__.py b/aiogram/dispatcher/filters/__init__.py index a38b57af..4caa137c 100644 --- a/aiogram/dispatcher/filters/__init__.py +++ b/aiogram/dispatcher/filters/__init__.py @@ -19,6 +19,7 @@ from .chat_member_updated import ( from .command import Command, CommandObject from .content_types import ContentTypesFilter from .exception import ExceptionMessageFilter, ExceptionTypeFilter +from .logic import and_f, invert_f, or_f from .magic_data import MagicData from .state import StateFilter from .text import Text @@ -47,6 +48,9 @@ __all__ = ( "IS_NOT_MEMBER", "JOIN_TRANSITION", "LEAVE_TRANSITION", + "and_f", + "or_f", + "invert_f", ) _ALL_EVENTS_FILTERS: Tuple[Type[BaseFilter], ...] = (MagicData,) diff --git a/aiogram/dispatcher/filters/base.py b/aiogram/dispatcher/filters/base.py index d2bb99cf..877f98f6 100644 --- a/aiogram/dispatcher/filters/base.py +++ b/aiogram/dispatcher/filters/base.py @@ -3,8 +3,10 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Union from pydantic import BaseModel +from aiogram.dispatcher.filters.logic import _LogicFilter -class BaseFilter(ABC, BaseModel): + +class BaseFilter(BaseModel, ABC, _LogicFilter): """ If you want to register own filters like builtin filters you will need to write subclass of this class with overriding the :code:`__call__` diff --git a/aiogram/dispatcher/filters/logic.py b/aiogram/dispatcher/filters/logic.py new file mode 100644 index 00000000..9a43956b --- /dev/null +++ b/aiogram/dispatcher/filters/logic.py @@ -0,0 +1,87 @@ +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Union + +if TYPE_CHECKING: + from aiogram.dispatcher.event.handler import CallbackType, FilterObject + + +class _LogicFilter: + __call__: Callable[..., Awaitable[Union[bool, Dict[str, Any]]]] + + def __and__(self, other: "CallbackType") -> "_AndFilter": + return and_f(self, other) + + def __or__(self, other: "CallbackType") -> "_OrFilter": + return or_f(self, other) + + def __invert__(self) -> "_InvertFilter": + return invert_f(self) + + def __await__(self): # type: ignore # pragma: no cover + # Is needed only for inspection and this method is never be called + return self.__call__ + + +class _InvertFilter(_LogicFilter): + __slots__ = ("target",) + + def __init__(self, target: "FilterObject") -> None: + self.target = target + + async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]: + return not bool(await self.target.call(*args, **kwargs)) + + +class _AndFilter(_LogicFilter): + __slots__ = ("targets",) + + def __init__(self, *targets: "FilterObject") -> None: + self.targets = targets + + async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]: + final_result = {} + + for target in self.targets: + result = await target.call(*args, **kwargs) + if not result: + return False + if isinstance(result, dict): + final_result.update(result) + + if final_result: + return final_result + return True + + +class _OrFilter(_LogicFilter): + __slots__ = ("targets",) + + def __init__(self, *targets: "FilterObject") -> None: + self.targets = targets + + async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]: + for target in self.targets: + result = await target.call(*args, **kwargs) + if not result: + continue + if isinstance(result, dict): + return result + return bool(result) + return False + + +def and_f(target1: "CallbackType", target2: "CallbackType") -> _AndFilter: + from aiogram.dispatcher.event.handler import FilterObject + + return _AndFilter(FilterObject(target1), FilterObject(target2)) + + +def or_f(target1: "CallbackType", target2: "CallbackType") -> _OrFilter: + from aiogram.dispatcher.event.handler import FilterObject + + return _OrFilter(FilterObject(target1), FilterObject(target2)) + + +def invert_f(target: "CallbackType") -> _InvertFilter: + from aiogram.dispatcher.event.handler import FilterObject + + return _InvertFilter(FilterObject(target)) diff --git a/aiogram/dispatcher/middlewares/manager.py b/aiogram/dispatcher/middlewares/manager.py index 89892e77..9f132f58 100644 --- a/aiogram/dispatcher/middlewares/manager.py +++ b/aiogram/dispatcher/middlewares/manager.py @@ -2,7 +2,7 @@ import functools from typing import Any, Callable, Dict, List, Optional, Sequence, Union, overload from aiogram.dispatcher.event.bases import MiddlewareEventType, MiddlewareType, NextMiddlewareType -from aiogram.dispatcher.event.handler import HandlerType +from aiogram.dispatcher.event.handler import CallbackType from aiogram.types import TelegramObject @@ -49,7 +49,7 @@ class MiddlewareManager(Sequence[MiddlewareType[TelegramObject]]): @staticmethod def wrap_middlewares( - middlewares: Sequence[MiddlewareType[MiddlewareEventType]], handler: HandlerType + middlewares: Sequence[MiddlewareType[MiddlewareEventType]], handler: CallbackType ) -> NextMiddlewareType[MiddlewareEventType]: @functools.wraps(handler) def handler_wrapper(event: TelegramObject, kwargs: Dict[str, Any]) -> Any: diff --git a/aiogram/dispatcher/router.py b/aiogram/dispatcher/router.py index 32e82195..a1a20fba 100644 --- a/aiogram/dispatcher/router.py +++ b/aiogram/dispatcher/router.py @@ -8,7 +8,7 @@ from ..utils.imports import import_module from ..utils.warnings import CodeHasNoEffect from .event.bases import REJECTED, UNHANDLED from .event.event import EventObserver -from .event.handler import HandlerType +from .event.handler import CallbackType from .event.telegram import TelegramEventObserver from .filters import BUILTIN_FILTERS @@ -396,7 +396,7 @@ class Router: ) return self.errors - def register_message(self, *args: Any, **kwargs: Any) -> HandlerType: + def register_message(self, *args: Any, **kwargs: Any) -> CallbackType: warnings.warn( "`Router.register_message(...)` is deprecated and will be removed in version 3.2 " "use `Router.message.register(...)`", @@ -405,7 +405,7 @@ class Router: ) return self.message.register(*args, **kwargs) - def register_edited_message(self, *args: Any, **kwargs: Any) -> HandlerType: + def register_edited_message(self, *args: Any, **kwargs: Any) -> CallbackType: warnings.warn( "`Router.register_edited_message(...)` is deprecated and will be removed in version 3.2 " "use `Router.edited_message.register(...)`", @@ -414,7 +414,7 @@ class Router: ) return self.edited_message.register(*args, **kwargs) - def register_channel_post(self, *args: Any, **kwargs: Any) -> HandlerType: + def register_channel_post(self, *args: Any, **kwargs: Any) -> CallbackType: warnings.warn( "`Router.register_channel_post(...)` is deprecated and will be removed in version 3.2 " "use `Router.channel_post.register(...)`", @@ -423,7 +423,7 @@ class Router: ) return self.channel_post.register(*args, **kwargs) - def register_edited_channel_post(self, *args: Any, **kwargs: Any) -> HandlerType: + def register_edited_channel_post(self, *args: Any, **kwargs: Any) -> CallbackType: warnings.warn( "`Router.register_edited_channel_post(...)` is deprecated and will be removed in version 3.2 " "use `Router.edited_channel_post.register(...)`", @@ -432,7 +432,7 @@ class Router: ) return self.edited_channel_post.register(*args, **kwargs) - def register_inline_query(self, *args: Any, **kwargs: Any) -> HandlerType: + def register_inline_query(self, *args: Any, **kwargs: Any) -> CallbackType: warnings.warn( "`Router.register_inline_query(...)` is deprecated and will be removed in version 3.2 " "use `Router.inline_query.register(...)`", @@ -441,7 +441,7 @@ class Router: ) return self.inline_query.register(*args, **kwargs) - def register_chosen_inline_result(self, *args: Any, **kwargs: Any) -> HandlerType: + def register_chosen_inline_result(self, *args: Any, **kwargs: Any) -> CallbackType: warnings.warn( "`Router.register_chosen_inline_result(...)` is deprecated and will be removed in version 3.2 " "use `Router.chosen_inline_result.register(...)`", @@ -450,7 +450,7 @@ class Router: ) return self.chosen_inline_result.register(*args, **kwargs) - def register_callback_query(self, *args: Any, **kwargs: Any) -> HandlerType: + def register_callback_query(self, *args: Any, **kwargs: Any) -> CallbackType: warnings.warn( "`Router.register_callback_query(...)` is deprecated and will be removed in version 3.2 " "use `Router.callback_query.register(...)`", @@ -459,7 +459,7 @@ class Router: ) return self.callback_query.register(*args, **kwargs) - def register_shipping_query(self, *args: Any, **kwargs: Any) -> HandlerType: + def register_shipping_query(self, *args: Any, **kwargs: Any) -> CallbackType: warnings.warn( "`Router.register_shipping_query(...)` is deprecated and will be removed in version 3.2 " "use `Router.shipping_query.register(...)`", @@ -468,7 +468,7 @@ class Router: ) return self.shipping_query.register(*args, **kwargs) - def register_pre_checkout_query(self, *args: Any, **kwargs: Any) -> HandlerType: + def register_pre_checkout_query(self, *args: Any, **kwargs: Any) -> CallbackType: warnings.warn( "`Router.register_pre_checkout_query(...)` is deprecated and will be removed in version 3.2 " "use `Router.pre_checkout_query.register(...)`", @@ -477,7 +477,7 @@ class Router: ) return self.pre_checkout_query.register(*args, **kwargs) - def register_poll(self, *args: Any, **kwargs: Any) -> HandlerType: + def register_poll(self, *args: Any, **kwargs: Any) -> CallbackType: warnings.warn( "`Router.register_poll(...)` is deprecated and will be removed in version 3.2 " "use `Router.poll.register(...)`", @@ -486,7 +486,7 @@ class Router: ) return self.poll.register(*args, **kwargs) - def register_poll_answer(self, *args: Any, **kwargs: Any) -> HandlerType: + def register_poll_answer(self, *args: Any, **kwargs: Any) -> CallbackType: warnings.warn( "`Router.register_poll_answer(...)` is deprecated and will be removed in version 3.2 " "use `Router.poll_answer.register(...)`", @@ -495,7 +495,7 @@ class Router: ) return self.poll_answer.register(*args, **kwargs) - def register_my_chat_member(self, *args: Any, **kwargs: Any) -> HandlerType: + def register_my_chat_member(self, *args: Any, **kwargs: Any) -> CallbackType: warnings.warn( "`Router.register_my_chat_member(...)` is deprecated and will be removed in version 3.2 " "use `Router.my_chat_member.register(...)`", @@ -504,7 +504,7 @@ class Router: ) return self.my_chat_member.register(*args, **kwargs) - def register_chat_member(self, *args: Any, **kwargs: Any) -> HandlerType: + def register_chat_member(self, *args: Any, **kwargs: Any) -> CallbackType: warnings.warn( "`Router.register_chat_member(...)` is deprecated and will be removed in version 3.2 " "use `Router.chat_member.register(...)`", @@ -513,7 +513,7 @@ class Router: ) return self.chat_member.register(*args, **kwargs) - def register_chat_join_request(self, *args: Any, **kwargs: Any) -> HandlerType: + def register_chat_join_request(self, *args: Any, **kwargs: Any) -> CallbackType: warnings.warn( "`Router.register_chat_join_request(...)` is deprecated and will be removed in version 3.2 " "use `Router.chat_join_request.register(...)`", @@ -522,7 +522,7 @@ class Router: ) return self.chat_join_request.register(*args, **kwargs) - def register_errors(self, *args: Any, **kwargs: Any) -> HandlerType: + def register_errors(self, *args: Any, **kwargs: Any) -> CallbackType: warnings.warn( "`Router.register_errors(...)` is deprecated and will be removed in version 3.2 " "use `Router.errors.register(...)`", diff --git a/tests/test_dispatcher/test_filters/test_logic.py b/tests/test_dispatcher/test_filters/test_logic.py new file mode 100644 index 00000000..ccbf1cb5 --- /dev/null +++ b/tests/test_dispatcher/test_filters/test_logic.py @@ -0,0 +1,37 @@ +import pytest + +from aiogram.dispatcher.filters import Text, and_f, invert_f, or_f +from aiogram.dispatcher.filters.logic import _AndFilter, _InvertFilter, _OrFilter + + +class TestLogic: + @pytest.mark.parametrize( + "obj,case,result", + [ + [True, and_f(lambda t: t is True, lambda t: t is True), True], + [True, and_f(lambda t: t is True, lambda t: t is False), False], + [True, and_f(lambda t: t is False, lambda t: t is False), False], + [True, and_f(lambda t: {"t": t}, lambda t: t is False), False], + [True, and_f(lambda t: {"t": t}, lambda t: t is True), {"t": True}], + [True, or_f(lambda t: t is True, lambda t: t is True), True], + [True, or_f(lambda t: t is True, lambda t: t is False), True], + [True, or_f(lambda t: t is False, lambda t: t is False), False], + [True, or_f(lambda t: t is False, lambda t: t is True), True], + [True, or_f(lambda t: t is False, lambda t: {"t": t}), {"t": True}], + [True, or_f(lambda t: {"t": t}, lambda t: {"a": 42}), {"t": True}], + [True, invert_f(lambda t: t is False), True], + ], + ) + async def test_logic(self, obj, case, result): + assert await case(obj) == result + + @pytest.mark.parametrize( + "case,type_", + [ + [Text(text="test") | Text(text="test"), _OrFilter], + [Text(text="test") & Text(text="test"), _AndFilter], + [~Text(text="test"), _InvertFilter], + ], + ) + def test_dunder_methods(self, case, type_): + assert isinstance(case, type_)