diff --git a/aiogram/dispatcher/dispatcher.py b/aiogram/dispatcher/dispatcher.py index 39eb7810..0da5f621 100644 --- a/aiogram/dispatcher/dispatcher.py +++ b/aiogram/dispatcher/dispatcher.py @@ -9,7 +9,7 @@ import aiohttp from aiohttp.helpers import sentinel from .filters import Command, ContentTypeFilter, ExceptionsFilter, FiltersFactory, HashTag, Regexp, \ - RegexpCommandsFilter, StateFilter, Text + RegexpCommandsFilter, StateFilter, Text, IdFilter from .handler import Handler from .middlewares import MiddlewareManager from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, \ @@ -114,6 +114,11 @@ class Dispatcher(DataMixin, ContextInstanceMixin): filters_factory.bind(ExceptionsFilter, event_handlers=[ self.errors_handlers ]) + filters_factory.bind(IdFilter, event_handlers=[ + self.message_handlers, self.edited_message_handlers, + self.channel_post_handlers, self.edited_channel_post_handlers, + self.callback_query_handlers, self.inline_query_handlers + ]) def __del__(self): self.stop_polling() diff --git a/aiogram/dispatcher/filters/__init__.py b/aiogram/dispatcher/filters/__init__.py index 2ae959cf..eb4a5a52 100644 --- a/aiogram/dispatcher/filters/__init__.py +++ b/aiogram/dispatcher/filters/__init__.py @@ -1,5 +1,5 @@ from .builtin import Command, CommandHelp, CommandPrivacy, CommandSettings, CommandStart, ContentTypeFilter, \ - ExceptionsFilter, HashTag, Regexp, RegexpCommandsFilter, StateFilter, Text + ExceptionsFilter, HashTag, Regexp, RegexpCommandsFilter, StateFilter, Text, IdFilter from .factory import FiltersFactory from .filters import AbstractFilter, BoundFilter, Filter, FilterNotPassed, FilterRecord, execute_filter, \ check_filters, get_filter_spec, get_filters_spec @@ -23,6 +23,7 @@ __all__ = [ 'Regexp', 'StateFilter', 'Text', + 'IdFilter', 'get_filter_spec', 'get_filters_spec', 'execute_filter', diff --git a/aiogram/dispatcher/filters/builtin.py b/aiogram/dispatcher/filters/builtin.py index c68bae72..f3bbdba7 100644 --- a/aiogram/dispatcher/filters/builtin.py +++ b/aiogram/dispatcher/filters/builtin.py @@ -491,3 +491,69 @@ class ExceptionsFilter(BoundFilter): return True except: return False + + +class IdFilter(Filter): + + def __init__(self, + user_id: Optional[Union[str, int]] = None, + chat_id: Optional[Union[str, int]] = None, + ): + """ + :param user_id: + :param chat_id: + """ + if user_id is None and chat_id is None: + raise ValueError("Both user_id and chat_id can't be None") + + self.user_id = user_id + self.chat_id = chat_id + + # both params should be convertible to int if they aren't None + # here we checks it + # also, by default in Telegram chat_id and user_id are Integer, + # so for convenience we cast them to int + if self.user_id: + self.user_id = int(self.user_id) + if self.chat_id: + self.chat_id = int(self.chat_id) + + @classmethod + def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]: + result = {} + if 'user' in full_config: + result['user_id'] = full_config.pop('user') + elif 'user_id' in full_config: + result['user_id'] = full_config.pop('user_id') + + if 'chat' in full_config: + result['chat_id'] = full_config.pop('chat') + elif 'chat_id' in full_config: + result['chat_id'] = full_config.pop('chat_id') + + return result + + async def check(self, obj: Union[Message, CallbackQuery, InlineQuery]): + if isinstance(obj, Message): + user_id = obj.from_user.id + chat_id = obj.chat.id + elif isinstance(obj, CallbackQuery): + user_id = obj.from_user.id + chat_id = None + if obj.message is not None: + # if the button was sent with message + chat_id = obj.message.chat.id + elif isinstance(obj, InlineQuery): + user_id = obj.from_user.id + chat_id = None + else: + return False + + if self.user_id and self.chat_id: + return self.user_id == user_id and self.chat_id == chat_id + elif self.user_id: + return self.user_id == user_id + elif self.chat_id: + return self.chat_id == chat_id + + return False