diff --git a/aiogram/api/types/message.py b/aiogram/api/types/message.py index 56b35579..dbfdd9e9 100644 --- a/aiogram/api/types/message.py +++ b/aiogram/api/types/message.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, List, Optional from pydantic import Field +from ...utils import helper from .base import TelegramObject if TYPE_CHECKING: # pragma: no cover @@ -152,3 +153,97 @@ class Message(TelegramObject): reply_markup: Optional[InlineKeyboardMarkup] = None """Inline keyboard attached to the message. login_url buttons are represented as ordinary url buttons.""" + + @property + def content_type(self): + if self.text: + return ContentType.TEXT + if self.audio: + return ContentType.AUDIO + if self.animation: + return ContentType.ANIMATION + if self.document: + return ContentType.DOCUMENT + if self.game: + return ContentType.GAME + if self.photo: + return ContentType.PHOTO + if self.sticker: + return ContentType.STICKER + if self.video: + return ContentType.VIDEO + if self.video_note: + return ContentType.VIDEO_NOTE + if self.voice: + return ContentType.VOICE + if self.contact: + return ContentType.CONTACT + if self.venue: + return ContentType.VENUE + if self.location: + return ContentType.LOCATION + if self.new_chat_members: + return ContentType.NEW_CHAT_MEMBERS + if self.left_chat_member: + return ContentType.LEFT_CHAT_MEMBER + if self.invoice: + return ContentType.INVOICE + if self.successful_payment: + return ContentType.SUCCESSFUL_PAYMENT + if self.connected_website: + return ContentType.CONNECTED_WEBSITE + if self.migrate_from_chat_id: + return ContentType.MIGRATE_FROM_CHAT_ID + if self.migrate_to_chat_id: + return ContentType.MIGRATE_TO_CHAT_ID + if self.pinned_message: + return ContentType.PINNED_MESSAGE + if self.new_chat_title: + return ContentType.NEW_CHAT_TITLE + if self.new_chat_photo: + return ContentType.NEW_CHAT_PHOTO + if self.delete_chat_photo: + return ContentType.DELETE_CHAT_PHOTO + if self.group_chat_created: + return ContentType.GROUP_CHAT_CREATED + if self.passport_data: + return ContentType.PASSPORT_DATA + if self.poll: + return ContentType.POLL + + return ContentType.UNKNOWN + + +class ContentType(helper.Helper): + mode = helper.HelperMode.snake_case + + TEXT = helper.Item() # text + AUDIO = helper.Item() # audio + DOCUMENT = helper.Item() # document + ANIMATION = helper.Item() # animation + GAME = helper.Item() # game + PHOTO = helper.Item() # photo + STICKER = helper.Item() # sticker + VIDEO = helper.Item() # video + VIDEO_NOTE = helper.Item() # video_note + VOICE = helper.Item() # voice + CONTACT = helper.Item() # contact + LOCATION = helper.Item() # location + VENUE = helper.Item() # venue + NEW_CHAT_MEMBERS = helper.Item() # new_chat_member + LEFT_CHAT_MEMBER = helper.Item() # left_chat_member + INVOICE = helper.Item() # invoice + SUCCESSFUL_PAYMENT = helper.Item() # successful_payment + CONNECTED_WEBSITE = helper.Item() # connected_website + MIGRATE_TO_CHAT_ID = helper.Item() # migrate_to_chat_id + MIGRATE_FROM_CHAT_ID = helper.Item() # migrate_from_chat_id + PINNED_MESSAGE = helper.Item() # pinned_message + NEW_CHAT_TITLE = helper.Item() # new_chat_title + NEW_CHAT_PHOTO = helper.Item() # new_chat_photo + DELETE_CHAT_PHOTO = helper.Item() # delete_chat_photo + GROUP_CHAT_CREATED = helper.Item() # group_chat_created + PASSPORT_DATA = helper.Item() # passport_data + POLL = helper.Item() + + UNKNOWN = helper.Item() # unknown + ANY = helper.Item() # any diff --git a/aiogram/dispatcher/filters/__init__.py b/aiogram/dispatcher/filters/__init__.py index 4cc2c6e9..bafbe17d 100644 --- a/aiogram/dispatcher/filters/__init__.py +++ b/aiogram/dispatcher/filters/__init__.py @@ -3,15 +3,23 @@ from typing import Dict, Tuple, Union from .base import BaseFilter from .command import Command, CommandObject from .text import Text +from .content_type import ContentTypesFilter -__all__ = ("BUILTIN_FILTERS", "BaseFilter", "Text", "Command", "CommandObject") +__all__ = ( + "BUILTIN_FILTERS", + "BaseFilter", + "Text", + "Command", + "CommandObject", + "ContentTypesFilter", +) BUILTIN_FILTERS: Dict[str, Union[Tuple[BaseFilter], Tuple]] = { "update": (), - "message": (Text, Command), - "edited_message": (Text, Command), - "channel_post": (Text,), - "edited_channel_post": (Text,), + "message": (Text, Command, ContentTypesFilter), + "edited_message": (Text, Command, ContentTypesFilter), + "channel_post": (Text, ContentTypesFilter), + "edited_channel_post": (Text, ContentTypesFilter), "inline_query": (Text,), "chosen_inline_result": (), "callback_query": (Text,), diff --git a/aiogram/dispatcher/filters/content_type.py b/aiogram/dispatcher/filters/content_type.py new file mode 100644 index 00000000..ec30a376 --- /dev/null +++ b/aiogram/dispatcher/filters/content_type.py @@ -0,0 +1,26 @@ +from typing import Any, Dict, List, Optional, Union + +from pydantic import root_validator + +from ...api.types import Message +from ...api.types.message import ContentType +from .base import BaseFilter + + +class ContentTypesFilter(BaseFilter): + content_types: Optional[List[str]] = None + + @root_validator + def validate_constraints(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if "content_types" not in values or not values["content_types"]: + values["content_types"] = [ContentType.TEXT] + allowed_content_types = set(ContentType.all()) + bad_content_types = set(values["content_types"]) - allowed_content_types + if bad_content_types: + raise ValueError(f"Invalid content types {bad_content_types} is not allowed here") + return values + + async def __call__(self, message: Message) -> Union[bool, Dict[str, Any]]: + if not self.content_types: + return False + return ContentType.ANY in self.content_types or message.content_type in self.content_types diff --git a/aiogram/dispatcher/handler/__init__.py b/aiogram/dispatcher/handler/__init__.py index 1762b5bc..15dddab6 100644 --- a/aiogram/dispatcher/handler/__init__.py +++ b/aiogram/dispatcher/handler/__init__.py @@ -1,4 +1,4 @@ from .base import BaseHandler, BaseHandlerMixin -from .message import MessageHandler +from .message import MessageHandler, MessageHandlerCommandMixin -__all__ = ("BaseHandler", "BaseHandlerMixin", "MessageHandler") +__all__ = ("BaseHandler", "BaseHandlerMixin", "MessageHandler", "MessageHandlerCommandMixin")