Add text filter and mechanism for registering builtin filters

This commit is contained in:
Alex Root Junior 2019-11-29 23:16:11 +02:00
parent e37395b161
commit 40b6a61e70
9 changed files with 398 additions and 8 deletions

View file

@ -0,0 +1,20 @@
from typing import Dict, Tuple, Union
from .base import BaseFilter
from .text import Text
__all__ = ("BUILTIN_FILTERS", "BaseFilter", "Text")
BUILTIN_FILTERS: Dict[str, Union[Tuple[BaseFilter], Tuple]] = {
"update": (),
"message": (Text,),
"edited_message": (Text,),
"channel_post": (Text,),
"edited_channel_post": (Text,),
"inline_query": (Text,),
"chosen_inline_result": (),
"callback_query": (Text,),
"shipping_query": (),
"pre_checkout_query": (),
"poll": (),
}

View file

@ -0,0 +1,80 @@
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from pydantic import root_validator
from aiogram.api.types import CallbackQuery, InlineQuery, Message, Poll
from aiogram.dispatcher.filters import BaseFilter
class Text(BaseFilter):
text: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None
text_contains: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None
text_startswith: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None
text_endswith: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None
text_ignore_case: bool = False
@root_validator
def validate_constraints(cls, values: Dict[str, Any]) -> Dict[str, Any]:
# Validate that only one text filter type is presented
used_args = set(
key for key, value in values.items() if key != "text_ignore_case" and value is not None
)
if len(used_args) < 1:
raise ValueError(
"Filter should contain one of arguments: {'text', 'text_contains', 'text_startswith', 'text_endswith'}"
)
if len(used_args) > 1:
raise ValueError(f"Arguments {used_args} cannot be used together")
# Convert single value to list
for arg in used_args:
if isinstance(values[arg], str):
values[arg] = [values[arg]]
return values
async def __call__(
self, obj: Union[Message, CallbackQuery, InlineQuery, Poll]
) -> Union[bool, Dict[str, Any]]:
if isinstance(obj, Message):
text = obj.text or obj.caption or ""
if not text and obj.poll:
text = obj.poll.question
elif isinstance(obj, CallbackQuery) and obj.data:
text = obj.data
elif isinstance(obj, InlineQuery):
text = obj.query
elif isinstance(obj, Poll):
text = obj.question
else:
return False
if not text:
return False
if self.text_ignore_case:
text = text.lower()
if self.text is not None:
equals = list(map(self.prepare_text, self.text))
return text in equals
if self.text_contains is not None:
contains = list(map(self.prepare_text, self.text_contains))
return all(map(text.__contains__, contains))
if self.text_startswith is not None:
startswith = list(map(self.prepare_text, self.text_startswith))
return any(map(text.startswith, startswith))
if self.text_endswith is not None:
endswith = list(map(self.prepare_text, self.text_endswith))
return any(map(text.endswith, endswith))
# Impossible because the validator prevents this situation
return False # pragma: no cover
def prepare_text(self, text: str):
if self.text_ignore_case:
return str(text).lower()
else:
return str(text)

View file

@ -1,9 +1,10 @@
from __future__ import annotations
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional
from ..api.types import Chat, Update, User
from .event.observer import EventObserver, SkipHandler, TelegramEventObserver
from .filters import BUILTIN_FILTERS
class Router:
@ -38,7 +39,8 @@ class Router:
self.startup = EventObserver()
self.shutdown = EventObserver()
self.observers = {
self.observers: Dict[str, TelegramEventObserver] = {
"update": self.update_handler,
"message": self.message_handler,
"edited_message": self.edited_message_handler,
"channel_post": self.channel_post_handler,
@ -52,6 +54,9 @@ class Router:
}
self.update_handler.register(self._listen_update)
for name, observer in self.observers.items():
for builtin_filter in BUILTIN_FILTERS.get(name, ()):
observer.bind_filter(builtin_filter)
@property
def parent_router(self) -> Optional[Router]: