Forum topic in FSM (#1161)

* Base implementation

* Added tests, fixed arguments priority

* Use `Optional[X]` instead of `X | None`

* Added changelog

* Added tests
This commit is contained in:
Alex Root Junior 2023-04-22 19:35:41 +03:00 committed by GitHub
parent 1538bc2e2d
commit 942ba0d520
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 164 additions and 60 deletions

17
CHANGES/1161.feature.rst Normal file
View file

@ -0,0 +1,17 @@
Added support for FSM in Forum topics.
The strategy can be changed in dispatcher:
.. code-block:: python
from aiogram.fsm.strategy import FSMStrategy
...
dispatcher = Dispatcher(
fsm_strategy=FSMStrategy.USER_IN_THREAD,
storage=..., # Any persistent storage
)
.. note::
If you have implemented you own storages you should extend record key generation
with new one attribute - `thread_id`

View file

@ -4,6 +4,10 @@ from typing import Any, Awaitable, Callable, Dict, Iterator, Optional, Tuple
from aiogram.dispatcher.middlewares.base import BaseMiddleware from aiogram.dispatcher.middlewares.base import BaseMiddleware
from aiogram.types import Chat, TelegramObject, Update, User from aiogram.types import Chat, TelegramObject, Update, User
EVENT_FROM_USER_KEY = "event_from_user"
EVENT_CHAT_KEY = "event_chat"
EVENT_THREAD_ID_KEY = "event_thread_id"
class UserContextMiddleware(BaseMiddleware): class UserContextMiddleware(BaseMiddleware):
async def __call__( async def __call__(
@ -14,61 +18,64 @@ class UserContextMiddleware(BaseMiddleware):
) -> Any: ) -> Any:
if not isinstance(event, Update): if not isinstance(event, Update):
raise RuntimeError("UserContextMiddleware got an unexpected event type!") raise RuntimeError("UserContextMiddleware got an unexpected event type!")
chat, user = self.resolve_event_context(event=event) chat, user, thread_id = self.resolve_event_context(event=event)
with self.context(chat=chat, user=user): if user is not None:
if user is not None: data[EVENT_FROM_USER_KEY] = user
data["event_from_user"] = user if chat is not None:
if chat is not None: data[EVENT_CHAT_KEY] = chat
data["event_chat"] = chat if thread_id is not None:
return await handler(event, data) data[EVENT_THREAD_ID_KEY] = thread_id
return await handler(event, data)
@contextmanager
def context(self, chat: Optional[Chat] = None, user: Optional[User] = None) -> Iterator[None]:
chat_token = None
user_token = None
if chat:
chat_token = chat.set_current(chat)
if user:
user_token = user.set_current(user)
try:
yield
finally:
if chat and chat_token:
chat.reset_current(chat_token)
if user and user_token:
user.reset_current(user_token)
@classmethod @classmethod
def resolve_event_context(cls, event: Update) -> Tuple[Optional[Chat], Optional[User]]: def resolve_event_context(
cls, event: Update
) -> Tuple[Optional[Chat], Optional[User], Optional[int]]:
""" """
Resolve chat and user instance from Update object Resolve chat and user instance from Update object
""" """
if event.message: if event.message:
return event.message.chat, event.message.from_user return (
event.message.chat,
event.message.from_user,
event.message.message_thread_id if event.message.is_topic_message else None,
)
if event.edited_message: if event.edited_message:
return event.edited_message.chat, event.edited_message.from_user return (
event.edited_message.chat,
event.edited_message.from_user,
event.edited_message.message_thread_id
if event.edited_message.is_topic_message
else None,
)
if event.channel_post: if event.channel_post:
return event.channel_post.chat, None return event.channel_post.chat, None, None
if event.edited_channel_post: if event.edited_channel_post:
return event.edited_channel_post.chat, None return event.edited_channel_post.chat, None, None
if event.inline_query: if event.inline_query:
return None, event.inline_query.from_user return None, event.inline_query.from_user, None
if event.chosen_inline_result: if event.chosen_inline_result:
return None, event.chosen_inline_result.from_user return None, event.chosen_inline_result.from_user, None
if event.callback_query: if event.callback_query:
if event.callback_query.message: if event.callback_query.message:
return event.callback_query.message.chat, event.callback_query.from_user return (
return None, event.callback_query.from_user event.callback_query.message.chat,
event.callback_query.from_user,
event.callback_query.message.message_thread_id
if event.callback_query.message.is_topic_message
else None,
)
return None, event.callback_query.from_user, None
if event.shipping_query: if event.shipping_query:
return None, event.shipping_query.from_user return None, event.shipping_query.from_user, None
if event.pre_checkout_query: if event.pre_checkout_query:
return None, event.pre_checkout_query.from_user return None, event.pre_checkout_query.from_user, None
if event.poll_answer: if event.poll_answer:
return None, event.poll_answer.user return None, event.poll_answer.user, None
if event.my_chat_member: if event.my_chat_member:
return event.my_chat_member.chat, event.my_chat_member.from_user return event.my_chat_member.chat, event.my_chat_member.from_user, None
if event.chat_member: if event.chat_member:
return event.chat_member.chat, event.chat_member.from_user return event.chat_member.chat, event.chat_member.from_user, None
if event.chat_join_request: if event.chat_join_request:
return event.chat_join_request.chat, event.chat_join_request.from_user return event.chat_join_request.chat, event.chat_join_request.from_user, None
return None, None return None, None, None

View file

@ -47,25 +47,42 @@ class FSMContextMiddleware(BaseMiddleware):
) -> Optional[FSMContext]: ) -> Optional[FSMContext]:
user = data.get("event_from_user") user = data.get("event_from_user")
chat = data.get("event_chat") chat = data.get("event_chat")
thread_id = data.get("event_thread_id")
chat_id = chat.id if chat else None chat_id = chat.id if chat else None
user_id = user.id if user else None user_id = user.id if user else None
return self.resolve_context(bot=bot, chat_id=chat_id, user_id=user_id, destiny=destiny) return self.resolve_context(
bot=bot,
chat_id=chat_id,
user_id=user_id,
thread_id=thread_id,
destiny=destiny,
)
def resolve_context( def resolve_context(
self, self,
bot: Bot, bot: Bot,
chat_id: Optional[int], chat_id: Optional[int],
user_id: Optional[int], user_id: Optional[int],
thread_id: Optional[int] = None,
destiny: str = DEFAULT_DESTINY, destiny: str = DEFAULT_DESTINY,
) -> Optional[FSMContext]: ) -> Optional[FSMContext]:
if chat_id is None: if chat_id is None:
chat_id = user_id chat_id = user_id
if chat_id is not None and user_id is not None: if chat_id is not None and user_id is not None:
chat_id, user_id = apply_strategy( chat_id, user_id, thread_id = apply_strategy(
chat_id=chat_id, user_id=user_id, strategy=self.strategy chat_id=chat_id,
user_id=user_id,
thread_id=thread_id,
strategy=self.strategy,
)
return self.get_context(
bot=bot,
chat_id=chat_id,
user_id=user_id,
thread_id=thread_id,
destiny=destiny,
) )
return self.get_context(bot=bot, chat_id=chat_id, user_id=user_id, destiny=destiny)
return None return None
def get_context( def get_context(
@ -73,6 +90,7 @@ class FSMContextMiddleware(BaseMiddleware):
bot: Bot, bot: Bot,
chat_id: int, chat_id: int,
user_id: int, user_id: int,
thread_id: Optional[int] = None,
destiny: str = DEFAULT_DESTINY, destiny: str = DEFAULT_DESTINY,
) -> FSMContext: ) -> FSMContext:
return FSMContext( return FSMContext(
@ -81,6 +99,7 @@ class FSMContextMiddleware(BaseMiddleware):
user_id=user_id, user_id=user_id,
chat_id=chat_id, chat_id=chat_id,
bot_id=bot.id, bot_id=bot.id,
thread_id=thread_id,
destiny=destiny, destiny=destiny,
), ),
) )

View file

@ -15,6 +15,7 @@ class StorageKey:
bot_id: int bot_id: int
chat_id: int chat_id: int
user_id: int user_id: int
thread_id: Optional[int] = None
destiny: str = DEFAULT_DESTINY destiny: str = DEFAULT_DESTINY

View file

@ -70,7 +70,10 @@ class DefaultKeyBuilder(KeyBuilder):
parts = [self.prefix] parts = [self.prefix]
if self.with_bot_id: if self.with_bot_id:
parts.append(str(key.bot_id)) parts.append(str(key.bot_id))
parts.extend([str(key.chat_id), str(key.user_id)]) parts.append(str(key.chat_id))
if key.thread_id:
parts.append(str(key.thread_id))
parts.append(str(key.user_id))
if self.with_destiny: if self.with_destiny:
parts.append(key.destiny) parts.append(key.destiny)
elif key.destiny != DEFAULT_DESTINY: elif key.destiny != DEFAULT_DESTINY:

View file

@ -1,16 +1,24 @@
from enum import Enum, auto from enum import Enum, auto
from typing import Tuple from typing import Optional, Tuple
class FSMStrategy(Enum): class FSMStrategy(Enum):
USER_IN_CHAT = auto() USER_IN_CHAT = auto()
CHAT = auto() CHAT = auto()
GLOBAL_USER = auto() GLOBAL_USER = auto()
USER_IN_THREAD = auto()
def apply_strategy(chat_id: int, user_id: int, strategy: FSMStrategy) -> Tuple[int, int]: def apply_strategy(
strategy: FSMStrategy,
chat_id: int,
user_id: int,
thread_id: Optional[int] = None,
) -> Tuple[int, int, Optional[int]]:
if strategy == FSMStrategy.CHAT: if strategy == FSMStrategy.CHAT:
return chat_id, chat_id return chat_id, chat_id, None
if strategy == FSMStrategy.GLOBAL_USER: if strategy == FSMStrategy.GLOBAL_USER:
return user_id, user_id return user_id, user_id, None
return chat_id, user_id if strategy == FSMStrategy.USER_IN_THREAD:
return chat_id, user_id, thread_id
return chat_id, user_id, None

View file

@ -14,7 +14,7 @@ from aiogram import Bot
from aiogram.dispatcher.dispatcher import Dispatcher from aiogram.dispatcher.dispatcher import Dispatcher
from aiogram.dispatcher.event.bases import UNHANDLED, SkipHandler from aiogram.dispatcher.event.bases import UNHANDLED, SkipHandler
from aiogram.dispatcher.router import Router from aiogram.dispatcher.router import Router
from aiogram.methods import GetMe, GetUpdates, Request, SendMessage, TelegramMethod from aiogram.methods import GetMe, GetUpdates, SendMessage, TelegramMethod
from aiogram.types import ( from aiogram.types import (
CallbackQuery, CallbackQuery,
Chat, Chat,
@ -462,9 +462,9 @@ class TestDispatcher:
async def my_handler(event: Any, **kwargs: Any): async def my_handler(event: Any, **kwargs: Any):
assert event == getattr(update, event_type) assert event == getattr(update, event_type)
if has_chat: if has_chat:
assert Chat.get_current(False) assert kwargs["event_chat"]
if has_user: if has_user:
assert User.get_current(False) assert kwargs["event_from_user"]
return kwargs return kwargs
result = await router.feed_update(bot, update, test="PASS") result = await router.feed_update(bot, update, test="PASS")

View file

@ -1,6 +1,9 @@
from unittest.mock import patch
import pytest import pytest
from aiogram.dispatcher.middlewares.user_context import UserContextMiddleware from aiogram.dispatcher.middlewares.user_context import UserContextMiddleware
from aiogram.types import Update
async def next_handler(*args, **kwargs): async def next_handler(*args, **kwargs):
@ -11,3 +14,13 @@ class TestUserContextMiddleware:
async def test_unexpected_event_type(self): async def test_unexpected_event_type(self):
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
await UserContextMiddleware()(next_handler, object(), {}) await UserContextMiddleware()(next_handler, object(), {})
async def test_call(self):
middleware = UserContextMiddleware()
data = {}
with patch.object(UserContextMiddleware, "resolve_event_context", return_value=[1, 2, 3]):
await middleware(next_handler, Update(update_id=42), data)
assert data["event_chat"] == 1
assert data["event_from_user"] == 2
assert data["event_thread_id"] == 3

View file

@ -11,6 +11,7 @@ PREFIX = "test"
BOT_ID = 42 BOT_ID = 42
CHAT_ID = -1 CHAT_ID = -1
USER_ID = 2 USER_ID = 2
THREAD_ID = 3
FIELD = "data" FIELD = "data"
@ -46,6 +47,19 @@ class TestRedisDefaultKeyBuilder:
with pytest.raises(ValueError): with pytest.raises(ValueError):
key_builder.build(key, FIELD) key_builder.build(key, FIELD)
def test_thread_id(self):
key_builder = DefaultKeyBuilder(
prefix=PREFIX,
)
key = StorageKey(
chat_id=CHAT_ID,
user_id=USER_ID,
bot_id=BOT_ID,
thread_id=THREAD_ID,
destiny=DEFAULT_DESTINY,
)
assert key_builder.build(key, FIELD) == f"{PREFIX}:{CHAT_ID}:{THREAD_ID}:{USER_ID}:{FIELD}"
def test_create_isolation(self): def test_create_isolation(self):
fake_redis = object() fake_redis = object()
storage = RedisStorage(redis=fake_redis) storage = RedisStorage(redis=fake_redis)

View file

@ -2,19 +2,41 @@ import pytest
from aiogram.fsm.strategy import FSMStrategy, apply_strategy from aiogram.fsm.strategy import FSMStrategy, apply_strategy
CHAT_ID = -42
USER_ID = 42
THREAD_ID = 1
PRIVATE = (USER_ID, USER_ID, None)
CHAT = (CHAT_ID, USER_ID, None)
THREAD = (CHAT_ID, USER_ID, THREAD_ID)
class TestStrategy: class TestStrategy:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"strategy,case,expected", "strategy,case,expected",
[ [
[FSMStrategy.USER_IN_CHAT, (-42, 42), (-42, 42)], [FSMStrategy.USER_IN_CHAT, CHAT, CHAT],
[FSMStrategy.CHAT, (-42, 42), (-42, -42)], [FSMStrategy.USER_IN_CHAT, PRIVATE, PRIVATE],
[FSMStrategy.GLOBAL_USER, (-42, 42), (42, 42)], [FSMStrategy.USER_IN_CHAT, THREAD, CHAT],
[FSMStrategy.USER_IN_CHAT, (42, 42), (42, 42)], [FSMStrategy.CHAT, CHAT, (CHAT_ID, CHAT_ID, None)],
[FSMStrategy.CHAT, (42, 42), (42, 42)], [FSMStrategy.CHAT, PRIVATE, (USER_ID, USER_ID, None)],
[FSMStrategy.GLOBAL_USER, (42, 42), (42, 42)], [FSMStrategy.CHAT, THREAD, (CHAT_ID, CHAT_ID, None)],
[FSMStrategy.GLOBAL_USER, CHAT, PRIVATE],
[FSMStrategy.GLOBAL_USER, PRIVATE, PRIVATE],
[FSMStrategy.GLOBAL_USER, THREAD, PRIVATE],
[FSMStrategy.USER_IN_THREAD, CHAT, CHAT],
[FSMStrategy.USER_IN_THREAD, PRIVATE, PRIVATE],
[FSMStrategy.USER_IN_THREAD, THREAD, THREAD],
], ],
) )
def test_strategy(self, strategy, case, expected): def test_strategy(self, strategy, case, expected):
chat_id, user_id = case chat_id, user_id, thread_id = case
assert apply_strategy(chat_id=chat_id, user_id=user_id, strategy=strategy) == expected assert (
apply_strategy(
chat_id=chat_id,
user_id=user_id,
thread_id=thread_id,
strategy=strategy,
)
== expected
)