mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Add tests for content types
This commit is contained in:
parent
8df6c345c3
commit
6ee05fb901
6 changed files with 391 additions and 10 deletions
|
|
@ -58,7 +58,7 @@ from .labeled_price import LabeledPrice
|
|||
from .location import Location
|
||||
from .login_url import LoginUrl
|
||||
from .mask_position import MaskPosition
|
||||
from .message import Message
|
||||
from .message import ContentType, Message
|
||||
from .message_entity import MessageEntity
|
||||
from .order_info import OrderInfo
|
||||
from .passport_data import PassportData
|
||||
|
|
@ -104,6 +104,7 @@ __all__ = (
|
|||
"User",
|
||||
"Chat",
|
||||
"Message",
|
||||
"ContentType",
|
||||
"MessageEntity",
|
||||
"PhotoSize",
|
||||
"Audio",
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@ from typing import Dict, Tuple, Union
|
|||
|
||||
from .base import BaseFilter
|
||||
from .command import Command, CommandObject
|
||||
from .content_types import ContentTypesFilter
|
||||
from .text import Text
|
||||
from .content_type import ContentTypesFilter
|
||||
|
||||
__all__ = (
|
||||
"BUILTIN_FILTERS",
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import root_validator
|
||||
from pydantic import validator
|
||||
|
||||
from ...api.types import Message
|
||||
from ...api.types.message import ContentType
|
||||
|
|
@ -10,17 +10,18 @@ from .base import BaseFilter
|
|||
class ContentTypesFilter(BaseFilter):
|
||||
content_types: Optional[List[str]] = None
|
||||
|
||||
@root_validator
|
||||
def validate_constraints(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if "content_types" not in values or not values["content_types"]:
|
||||
values["content_types"] = [ContentType.TEXT]
|
||||
@validator("content_types", always=True)
|
||||
def _validate_content_types(cls, value: Optional[List[str]]) -> Optional[List[str]]:
|
||||
if not value:
|
||||
value = [ContentType.TEXT]
|
||||
allowed_content_types = set(ContentType.all())
|
||||
bad_content_types = set(values["content_types"]) - allowed_content_types
|
||||
bad_content_types = set(value) - allowed_content_types
|
||||
if bad_content_types:
|
||||
raise ValueError(f"Invalid content types {bad_content_types} is not allowed here")
|
||||
return values
|
||||
return value
|
||||
|
||||
async def __call__(self, message: Message) -> Union[bool, Dict[str, Any]]:
|
||||
if not self.content_types:
|
||||
if not self.content_types: # pragma: no cover
|
||||
# Is impossible but needed for valid typechecking
|
||||
return False
|
||||
return ContentType.ANY in self.content_types or message.content_type in self.content_types
|
||||
Loading…
Add table
Add a link
Reference in a new issue