Refactor UserContextMiddleware to use EventContext class

This update significantly refactors UserContextMiddleware to leverage a new class, EventContext. Instead of resolving event context as a tuple, it now produces an instance of EventContext. Additional adjustments include supporting a business connection ID for event context identification and facilitating backwards compatibility. Tests and other files were also updated accordingly for these changes.
This commit is contained in:
JRoot Junior 2024-04-09 01:26:31 +03:00
parent 7dd18bfa11
commit 924fb755cb
No known key found for this signature in database
GPG key ID: 738964250D5FF6E2
5 changed files with 128 additions and 63 deletions

View file

@ -1,13 +1,32 @@
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Dict, Optional
from aiogram.dispatcher.middlewares.base import BaseMiddleware from aiogram.dispatcher.middlewares.base import BaseMiddleware
from aiogram.types import Chat, InaccessibleMessage, TelegramObject, Update, User from aiogram.types import Chat, InaccessibleMessage, TelegramObject, Update, User
EVENT_CONTEXT_KEY = "event_context"
EVENT_FROM_USER_KEY = "event_from_user" EVENT_FROM_USER_KEY = "event_from_user"
EVENT_CHAT_KEY = "event_chat" EVENT_CHAT_KEY = "event_chat"
EVENT_THREAD_ID_KEY = "event_thread_id" EVENT_THREAD_ID_KEY = "event_thread_id"
@dataclass(frozen=True)
class EventContext:
chat: Optional[Chat] = None
user: Optional[User] = None
thread_id: Optional[int] = None
business_connection_id: Optional[str] = None
@property
def user_id(self) -> Optional[int]:
return self.user.id if self.user else None
@property
def chat_id(self) -> Optional[int]:
return self.chat.id if self.chat else None
class UserContextMiddleware(BaseMiddleware): class UserContextMiddleware(BaseMiddleware):
async def __call__( async def __call__(
self, self,
@ -17,93 +36,114 @@ class UserContextMiddleware(BaseMiddleware):
) -> Any: ) -> Any:
if not isinstance(event, Update): if not isinstance(event, Update):
raise RuntimeError("UserContextMiddleware got an unexpected event type!") raise RuntimeError("UserContextMiddleware got an unexpected event type!")
chat, user, thread_id = self.resolve_event_context(event=event) event_context = data[EVENT_CONTEXT_KEY] = self.resolve_event_context(event=event)
if user is not None:
data[EVENT_FROM_USER_KEY] = user # Backward compatibility
if chat is not None: if event_context.user is not None:
data[EVENT_CHAT_KEY] = chat data[EVENT_FROM_USER_KEY] = event_context.user
if thread_id is not None: if event_context.chat is not None:
data[EVENT_THREAD_ID_KEY] = thread_id data[EVENT_CHAT_KEY] = event_context.chat
if event_context.thread_id is not None:
data[EVENT_THREAD_ID_KEY] = event_context.thread_id
return await handler(event, data) return await handler(event, data)
@classmethod @classmethod
def resolve_event_context( def resolve_event_context(cls, event: Update) -> EventContext:
cls, event: Update
) -> Tuple[Optional[Chat], Optional[User], Optional[int]]:
""" """
Resolve chat and user instance from Update object Resolve chat and user instance from Update object
""" """
if event.message: if event.message:
return ( return EventContext(
event.message.chat, chat=event.message.chat,
event.message.from_user, user=event.message.from_user,
event.message.message_thread_id if event.message.is_topic_message else None, thread_id=event.message.message_thread_id
if event.message.is_topic_message
else None,
) )
if event.edited_message: if event.edited_message:
return ( return EventContext(
event.edited_message.chat, chat=event.edited_message.chat,
event.edited_message.from_user, user=event.edited_message.from_user,
event.edited_message.message_thread_id thread_id=event.edited_message.message_thread_id
if event.edited_message.is_topic_message if event.edited_message.is_topic_message
else None, else None,
) )
if event.channel_post: if event.channel_post:
return event.channel_post.chat, None, None return EventContext(chat=event.channel_post.chat)
if event.edited_channel_post: if event.edited_channel_post:
return event.edited_channel_post.chat, None, None return EventContext(chat=event.edited_channel_post.chat)
if event.inline_query: if event.inline_query:
return None, event.inline_query.from_user, None return EventContext(user=event.inline_query.from_user)
if event.chosen_inline_result: if event.chosen_inline_result:
return None, event.chosen_inline_result.from_user, None return EventContext(user=event.chosen_inline_result.from_user)
if event.callback_query: if event.callback_query:
if event.callback_query.message: if event.callback_query.message:
return ( return EventContext(
event.callback_query.message.chat, chat=event.callback_query.message.chat,
event.callback_query.from_user, user=event.callback_query.from_user,
event.callback_query.message.message_thread_id thread_id=event.callback_query.message.message_thread_id
if not isinstance(event.callback_query.message, InaccessibleMessage) if not isinstance(event.callback_query.message, InaccessibleMessage)
and event.callback_query.message.is_topic_message and event.callback_query.message.is_topic_message
else None, else None,
) )
return None, event.callback_query.from_user, None return EventContext(user=event.callback_query.from_user)
if event.shipping_query: if event.shipping_query:
return None, event.shipping_query.from_user, None return EventContext(user=event.shipping_query.from_user)
if event.pre_checkout_query: if event.pre_checkout_query:
return None, event.pre_checkout_query.from_user, None return EventContext(user=event.pre_checkout_query.from_user)
if event.poll_answer: if event.poll_answer:
return event.poll_answer.voter_chat, event.poll_answer.user, None return EventContext(
chat=event.poll_answer.voter_chat,
user=event.poll_answer.user,
)
if event.my_chat_member: if event.my_chat_member:
return event.my_chat_member.chat, event.my_chat_member.from_user, None return EventContext(
chat=event.my_chat_member.chat, user=event.my_chat_member.from_user
)
if event.chat_member: if event.chat_member:
return event.chat_member.chat, event.chat_member.from_user, None return EventContext(chat=event.chat_member.chat, user=event.chat_member.from_user)
if event.chat_join_request: if event.chat_join_request:
return event.chat_join_request.chat, event.chat_join_request.from_user, None return EventContext(
chat=event.chat_join_request.chat, user=event.chat_join_request.from_user
)
if event.message_reaction: if event.message_reaction:
return event.message_reaction.chat, event.message_reaction.user, None return EventContext(
chat=event.message_reaction.chat,
user=event.message_reaction.user,
)
if event.message_reaction_count: if event.message_reaction_count:
return event.message_reaction_count.chat, None, None return EventContext(chat=event.message_reaction_count.chat)
if event.chat_boost: if event.chat_boost:
return event.chat_boost.chat, None, None return EventContext(chat=event.chat_boost.chat)
if event.removed_chat_boost: if event.removed_chat_boost:
return event.removed_chat_boost.chat, None, None return EventContext(chat=event.removed_chat_boost.chat)
if event.deleted_business_messages: if event.deleted_business_messages:
return event.deleted_business_messages.chat, None, None return EventContext(
chat=event.deleted_business_messages.chat,
business_connection_id=event.deleted_business_messages.business_connection_id,
)
if event.business_connection: if event.business_connection:
return None, event.business_connection.user, None return EventContext(
user=event.business_connection.user,
business_connection_id=event.business_connection.id,
)
if event.business_message: if event.business_message:
return ( return EventContext(
event.business_message.chat, chat=event.business_message.chat,
event.business_message.from_user, user=event.business_message.from_user,
event.business_message.message_thread_id thread_id=event.business_message.message_thread_id
if event.business_message.is_topic_message if event.business_message.is_topic_message
else None, else None,
business_connection_id=event.business_message.business_connection_id,
) )
if event.edited_business_message: if event.edited_business_message:
return ( return EventContext(
event.edited_business_message.chat, chat=event.edited_business_message.chat,
event.edited_business_message.from_user, user=event.edited_business_message.from_user,
event.edited_business_message.message_thread_id thread_id=event.edited_business_message.message_thread_id
if event.edited_business_message.is_topic_message if event.edited_business_message.is_topic_message
else None, else None,
business_connection_id=event.edited_business_message.business_connection_id,
) )
return None, None, None return EventContext()

View file

@ -2,6 +2,7 @@ from typing import Any, Awaitable, Callable, Dict, Optional, cast
from aiogram import Bot from aiogram import Bot
from aiogram.dispatcher.middlewares.base import BaseMiddleware from aiogram.dispatcher.middlewares.base import BaseMiddleware
from aiogram.dispatcher.middlewares.user_context import EVENT_CONTEXT_KEY, EventContext
from aiogram.fsm.context import FSMContext from aiogram.fsm.context import FSMContext
from aiogram.fsm.storage.base import ( from aiogram.fsm.storage.base import (
DEFAULT_DESTINY, DEFAULT_DESTINY,
@ -47,16 +48,13 @@ class FSMContextMiddleware(BaseMiddleware):
data: Dict[str, Any], data: Dict[str, Any],
destiny: str = DEFAULT_DESTINY, destiny: str = DEFAULT_DESTINY,
) -> Optional[FSMContext]: ) -> Optional[FSMContext]:
user = data.get("event_from_user") event_context: EventContext = cast(EventContext, data.get(EVENT_CONTEXT_KEY))
chat = data.get("event_chat")
thread_id = data.get("event_thread_id")
chat_id = chat.id if chat else None
user_id = user.id if user else None
return self.resolve_context( return self.resolve_context(
bot=bot, bot=bot,
chat_id=chat_id, chat_id=event_context.chat_id,
user_id=user_id, user_id=event_context.user_id,
thread_id=thread_id, thread_id=event_context.thread_id,
business_connection_id=event_context.business_connection_id,
destiny=destiny, destiny=destiny,
) )
@ -66,6 +64,7 @@ class FSMContextMiddleware(BaseMiddleware):
chat_id: Optional[int], chat_id: Optional[int],
user_id: Optional[int], user_id: Optional[int],
thread_id: Optional[int] = None, thread_id: Optional[int] = None,
business_connection_id: Optional[str] = None,
destiny: str = DEFAULT_DESTINY, destiny: str = DEFAULT_DESTINY,
) -> Optional[FSMContext]: ) -> Optional[FSMContext]:
if chat_id is None: if chat_id is None:
@ -83,6 +82,7 @@ class FSMContextMiddleware(BaseMiddleware):
chat_id=chat_id, chat_id=chat_id,
user_id=user_id, user_id=user_id,
thread_id=thread_id, thread_id=thread_id,
business_connection_id=business_connection_id,
destiny=destiny, destiny=destiny,
) )
return None return None
@ -93,6 +93,7 @@ class FSMContextMiddleware(BaseMiddleware):
chat_id: int, chat_id: int,
user_id: int, user_id: int,
thread_id: Optional[int] = None, thread_id: Optional[int] = None,
business_connection_id: Optional[str] = None,
destiny: str = DEFAULT_DESTINY, destiny: str = DEFAULT_DESTINY,
) -> FSMContext: ) -> FSMContext:
return FSMContext( return FSMContext(
@ -102,6 +103,7 @@ class FSMContextMiddleware(BaseMiddleware):
chat_id=chat_id, chat_id=chat_id,
bot_id=bot.id, bot_id=bot.id,
thread_id=thread_id, thread_id=thread_id,
business_connection_id=business_connection_id,
destiny=destiny, destiny=destiny,
), ),
) )

View file

@ -16,6 +16,7 @@ class StorageKey:
chat_id: int chat_id: int
user_id: int user_id: int
thread_id: Optional[int] = None thread_id: Optional[int] = None
business_connection_id: Optional[str] = None
destiny: str = DEFAULT_DESTINY destiny: str = DEFAULT_DESTINY

View file

@ -53,17 +53,20 @@ class DefaultKeyBuilder(KeyBuilder):
prefix: str = "fsm", prefix: str = "fsm",
separator: str = ":", separator: str = ":",
with_bot_id: bool = False, with_bot_id: bool = False,
with_business_connection_id: bool = False,
with_destiny: bool = False, with_destiny: bool = False,
) -> None: ) -> None:
""" """
:param prefix: prefix for all records :param prefix: prefix for all records
:param separator: separator :param separator: separator
:param with_bot_id: include Bot id in the key :param with_bot_id: include Bot id in the key
:param with_business_connection_id: include business connection id
:param with_destiny: include destiny key :param with_destiny: include destiny key
""" """
self.prefix = prefix self.prefix = prefix
self.separator = separator self.separator = separator
self.with_bot_id = with_bot_id self.with_bot_id = with_bot_id
self.with_business_connection_id = with_business_connection_id
self.with_destiny = with_destiny self.with_destiny = with_destiny
def build(self, key: StorageKey, part: Literal["data", "state", "lock"]) -> str: def build(self, key: StorageKey, part: Literal["data", "state", "lock"]) -> str:
@ -74,6 +77,8 @@ class DefaultKeyBuilder(KeyBuilder):
if key.thread_id: if key.thread_id:
parts.append(str(key.thread_id)) parts.append(str(key.thread_id))
parts.append(str(key.user_id)) parts.append(str(key.user_id))
if self.with_business_connection_id and key.business_connection_id:
parts.append(str(key.business_connection_id))
if self.with_destiny: if self.with_destiny:
parts.append(key.destiny) parts.append(key.destiny)
elif key.destiny != DEFAULT_DESTINY: elif key.destiny != DEFAULT_DESTINY:

View file

@ -2,8 +2,11 @@ from unittest.mock import patch
import pytest import pytest
from aiogram.dispatcher.middlewares.user_context import UserContextMiddleware from aiogram.dispatcher.middlewares.user_context import (
from aiogram.types import Update EventContext,
UserContextMiddleware,
)
from aiogram.types import Chat, Update, User
async def next_handler(*args, **kwargs): async def next_handler(*args, **kwargs):
@ -18,9 +21,23 @@ class TestUserContextMiddleware:
async def test_call(self): async def test_call(self):
middleware = UserContextMiddleware() middleware = UserContextMiddleware()
data = {} data = {}
with patch.object(UserContextMiddleware, "resolve_event_context", return_value=[1, 2, 3]):
chat = Chat(id=1, type="private", title="Test")
user = User(id=2, first_name="Test", is_bot=False)
thread_id = 3
with patch.object(
UserContextMiddleware,
"resolve_event_context",
return_value=EventContext(user=user, chat=chat, thread_id=3),
):
await middleware(next_handler, Update(update_id=42), data) await middleware(next_handler, Update(update_id=42), data)
assert data["event_chat"] == 1 event_context = data["event_context"]
assert data["event_from_user"] == 2 assert isinstance(event_context, EventContext)
assert data["event_thread_id"] == 3 assert event_context.chat is chat
assert event_context.user is user
assert event_context.thread_id == thread_id
assert data["event_chat"] is chat
assert data["event_from_user"] is user
assert data["event_thread_id"] == thread_id