mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Refactor UserContextMiddleware to use EventContext class
This update significantly refactors UserContextMiddleware to leverage a new class, EventContext. Instead of resolving event context as a tuple, it now produces an instance of EventContext. Additional adjustments include supporting a business connection ID for event context identification and facilitating backwards compatibility. Tests and other files were also updated accordingly for these changes.
This commit is contained in:
parent
7dd18bfa11
commit
924fb755cb
5 changed files with 128 additions and 63 deletions
|
|
@ -1,13 +1,32 @@
|
||||||
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Awaitable, Callable, Dict, Optional
|
||||||
|
|
||||||
from aiogram.dispatcher.middlewares.base import BaseMiddleware
|
from aiogram.dispatcher.middlewares.base import BaseMiddleware
|
||||||
from aiogram.types import Chat, InaccessibleMessage, TelegramObject, Update, User
|
from aiogram.types import Chat, InaccessibleMessage, TelegramObject, Update, User
|
||||||
|
|
||||||
|
EVENT_CONTEXT_KEY = "event_context"
|
||||||
|
|
||||||
EVENT_FROM_USER_KEY = "event_from_user"
|
EVENT_FROM_USER_KEY = "event_from_user"
|
||||||
EVENT_CHAT_KEY = "event_chat"
|
EVENT_CHAT_KEY = "event_chat"
|
||||||
EVENT_THREAD_ID_KEY = "event_thread_id"
|
EVENT_THREAD_ID_KEY = "event_thread_id"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class EventContext:
|
||||||
|
chat: Optional[Chat] = None
|
||||||
|
user: Optional[User] = None
|
||||||
|
thread_id: Optional[int] = None
|
||||||
|
business_connection_id: Optional[str] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def user_id(self) -> Optional[int]:
|
||||||
|
return self.user.id if self.user else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chat_id(self) -> Optional[int]:
|
||||||
|
return self.chat.id if self.chat else None
|
||||||
|
|
||||||
|
|
||||||
class UserContextMiddleware(BaseMiddleware):
|
class UserContextMiddleware(BaseMiddleware):
|
||||||
async def __call__(
|
async def __call__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -17,93 +36,114 @@ class UserContextMiddleware(BaseMiddleware):
|
||||||
) -> Any:
|
) -> Any:
|
||||||
if not isinstance(event, Update):
|
if not isinstance(event, Update):
|
||||||
raise RuntimeError("UserContextMiddleware got an unexpected event type!")
|
raise RuntimeError("UserContextMiddleware got an unexpected event type!")
|
||||||
chat, user, thread_id = self.resolve_event_context(event=event)
|
event_context = data[EVENT_CONTEXT_KEY] = self.resolve_event_context(event=event)
|
||||||
if user is not None:
|
|
||||||
data[EVENT_FROM_USER_KEY] = user
|
# Backward compatibility
|
||||||
if chat is not None:
|
if event_context.user is not None:
|
||||||
data[EVENT_CHAT_KEY] = chat
|
data[EVENT_FROM_USER_KEY] = event_context.user
|
||||||
if thread_id is not None:
|
if event_context.chat is not None:
|
||||||
data[EVENT_THREAD_ID_KEY] = thread_id
|
data[EVENT_CHAT_KEY] = event_context.chat
|
||||||
|
if event_context.thread_id is not None:
|
||||||
|
data[EVENT_THREAD_ID_KEY] = event_context.thread_id
|
||||||
|
|
||||||
return await handler(event, data)
|
return await handler(event, data)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def resolve_event_context(
|
def resolve_event_context(cls, event: Update) -> EventContext:
|
||||||
cls, event: Update
|
|
||||||
) -> Tuple[Optional[Chat], Optional[User], Optional[int]]:
|
|
||||||
"""
|
"""
|
||||||
Resolve chat and user instance from Update object
|
Resolve chat and user instance from Update object
|
||||||
"""
|
"""
|
||||||
if event.message:
|
if event.message:
|
||||||
return (
|
return EventContext(
|
||||||
event.message.chat,
|
chat=event.message.chat,
|
||||||
event.message.from_user,
|
user=event.message.from_user,
|
||||||
event.message.message_thread_id if event.message.is_topic_message else None,
|
thread_id=event.message.message_thread_id
|
||||||
|
if event.message.is_topic_message
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
if event.edited_message:
|
if event.edited_message:
|
||||||
return (
|
return EventContext(
|
||||||
event.edited_message.chat,
|
chat=event.edited_message.chat,
|
||||||
event.edited_message.from_user,
|
user=event.edited_message.from_user,
|
||||||
event.edited_message.message_thread_id
|
thread_id=event.edited_message.message_thread_id
|
||||||
if event.edited_message.is_topic_message
|
if event.edited_message.is_topic_message
|
||||||
else None,
|
else None,
|
||||||
)
|
)
|
||||||
if event.channel_post:
|
if event.channel_post:
|
||||||
return event.channel_post.chat, None, None
|
return EventContext(chat=event.channel_post.chat)
|
||||||
if event.edited_channel_post:
|
if event.edited_channel_post:
|
||||||
return event.edited_channel_post.chat, None, None
|
return EventContext(chat=event.edited_channel_post.chat)
|
||||||
if event.inline_query:
|
if event.inline_query:
|
||||||
return None, event.inline_query.from_user, None
|
return EventContext(user=event.inline_query.from_user)
|
||||||
if event.chosen_inline_result:
|
if event.chosen_inline_result:
|
||||||
return None, event.chosen_inline_result.from_user, None
|
return EventContext(user=event.chosen_inline_result.from_user)
|
||||||
if event.callback_query:
|
if event.callback_query:
|
||||||
if event.callback_query.message:
|
if event.callback_query.message:
|
||||||
return (
|
return EventContext(
|
||||||
event.callback_query.message.chat,
|
chat=event.callback_query.message.chat,
|
||||||
event.callback_query.from_user,
|
user=event.callback_query.from_user,
|
||||||
event.callback_query.message.message_thread_id
|
thread_id=event.callback_query.message.message_thread_id
|
||||||
if not isinstance(event.callback_query.message, InaccessibleMessage)
|
if not isinstance(event.callback_query.message, InaccessibleMessage)
|
||||||
and event.callback_query.message.is_topic_message
|
and event.callback_query.message.is_topic_message
|
||||||
else None,
|
else None,
|
||||||
)
|
)
|
||||||
return None, event.callback_query.from_user, None
|
return EventContext(user=event.callback_query.from_user)
|
||||||
if event.shipping_query:
|
if event.shipping_query:
|
||||||
return None, event.shipping_query.from_user, None
|
return EventContext(user=event.shipping_query.from_user)
|
||||||
if event.pre_checkout_query:
|
if event.pre_checkout_query:
|
||||||
return None, event.pre_checkout_query.from_user, None
|
return EventContext(user=event.pre_checkout_query.from_user)
|
||||||
if event.poll_answer:
|
if event.poll_answer:
|
||||||
return event.poll_answer.voter_chat, event.poll_answer.user, None
|
return EventContext(
|
||||||
|
chat=event.poll_answer.voter_chat,
|
||||||
|
user=event.poll_answer.user,
|
||||||
|
)
|
||||||
if event.my_chat_member:
|
if event.my_chat_member:
|
||||||
return event.my_chat_member.chat, event.my_chat_member.from_user, None
|
return EventContext(
|
||||||
|
chat=event.my_chat_member.chat, user=event.my_chat_member.from_user
|
||||||
|
)
|
||||||
if event.chat_member:
|
if event.chat_member:
|
||||||
return event.chat_member.chat, event.chat_member.from_user, None
|
return EventContext(chat=event.chat_member.chat, user=event.chat_member.from_user)
|
||||||
if event.chat_join_request:
|
if event.chat_join_request:
|
||||||
return event.chat_join_request.chat, event.chat_join_request.from_user, None
|
return EventContext(
|
||||||
|
chat=event.chat_join_request.chat, user=event.chat_join_request.from_user
|
||||||
|
)
|
||||||
if event.message_reaction:
|
if event.message_reaction:
|
||||||
return event.message_reaction.chat, event.message_reaction.user, None
|
return EventContext(
|
||||||
|
chat=event.message_reaction.chat,
|
||||||
|
user=event.message_reaction.user,
|
||||||
|
)
|
||||||
if event.message_reaction_count:
|
if event.message_reaction_count:
|
||||||
return event.message_reaction_count.chat, None, None
|
return EventContext(chat=event.message_reaction_count.chat)
|
||||||
if event.chat_boost:
|
if event.chat_boost:
|
||||||
return event.chat_boost.chat, None, None
|
return EventContext(chat=event.chat_boost.chat)
|
||||||
if event.removed_chat_boost:
|
if event.removed_chat_boost:
|
||||||
return event.removed_chat_boost.chat, None, None
|
return EventContext(chat=event.removed_chat_boost.chat)
|
||||||
if event.deleted_business_messages:
|
if event.deleted_business_messages:
|
||||||
return event.deleted_business_messages.chat, None, None
|
return EventContext(
|
||||||
|
chat=event.deleted_business_messages.chat,
|
||||||
|
business_connection_id=event.deleted_business_messages.business_connection_id,
|
||||||
|
)
|
||||||
if event.business_connection:
|
if event.business_connection:
|
||||||
return None, event.business_connection.user, None
|
return EventContext(
|
||||||
|
user=event.business_connection.user,
|
||||||
|
business_connection_id=event.business_connection.id,
|
||||||
|
)
|
||||||
if event.business_message:
|
if event.business_message:
|
||||||
return (
|
return EventContext(
|
||||||
event.business_message.chat,
|
chat=event.business_message.chat,
|
||||||
event.business_message.from_user,
|
user=event.business_message.from_user,
|
||||||
event.business_message.message_thread_id
|
thread_id=event.business_message.message_thread_id
|
||||||
if event.business_message.is_topic_message
|
if event.business_message.is_topic_message
|
||||||
else None,
|
else None,
|
||||||
|
business_connection_id=event.business_message.business_connection_id,
|
||||||
)
|
)
|
||||||
if event.edited_business_message:
|
if event.edited_business_message:
|
||||||
return (
|
return EventContext(
|
||||||
event.edited_business_message.chat,
|
chat=event.edited_business_message.chat,
|
||||||
event.edited_business_message.from_user,
|
user=event.edited_business_message.from_user,
|
||||||
event.edited_business_message.message_thread_id
|
thread_id=event.edited_business_message.message_thread_id
|
||||||
if event.edited_business_message.is_topic_message
|
if event.edited_business_message.is_topic_message
|
||||||
else None,
|
else None,
|
||||||
|
business_connection_id=event.edited_business_message.business_connection_id,
|
||||||
)
|
)
|
||||||
return None, None, None
|
return EventContext()
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ from typing import Any, Awaitable, Callable, Dict, Optional, cast
|
||||||
|
|
||||||
from aiogram import Bot
|
from aiogram import Bot
|
||||||
from aiogram.dispatcher.middlewares.base import BaseMiddleware
|
from aiogram.dispatcher.middlewares.base import BaseMiddleware
|
||||||
|
from aiogram.dispatcher.middlewares.user_context import EVENT_CONTEXT_KEY, EventContext
|
||||||
from aiogram.fsm.context import FSMContext
|
from aiogram.fsm.context import FSMContext
|
||||||
from aiogram.fsm.storage.base import (
|
from aiogram.fsm.storage.base import (
|
||||||
DEFAULT_DESTINY,
|
DEFAULT_DESTINY,
|
||||||
|
|
@ -47,16 +48,13 @@ class FSMContextMiddleware(BaseMiddleware):
|
||||||
data: Dict[str, Any],
|
data: Dict[str, Any],
|
||||||
destiny: str = DEFAULT_DESTINY,
|
destiny: str = DEFAULT_DESTINY,
|
||||||
) -> Optional[FSMContext]:
|
) -> Optional[FSMContext]:
|
||||||
user = data.get("event_from_user")
|
event_context: EventContext = cast(EventContext, data.get(EVENT_CONTEXT_KEY))
|
||||||
chat = data.get("event_chat")
|
|
||||||
thread_id = data.get("event_thread_id")
|
|
||||||
chat_id = chat.id if chat else None
|
|
||||||
user_id = user.id if user else None
|
|
||||||
return self.resolve_context(
|
return self.resolve_context(
|
||||||
bot=bot,
|
bot=bot,
|
||||||
chat_id=chat_id,
|
chat_id=event_context.chat_id,
|
||||||
user_id=user_id,
|
user_id=event_context.user_id,
|
||||||
thread_id=thread_id,
|
thread_id=event_context.thread_id,
|
||||||
|
business_connection_id=event_context.business_connection_id,
|
||||||
destiny=destiny,
|
destiny=destiny,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -66,6 +64,7 @@ class FSMContextMiddleware(BaseMiddleware):
|
||||||
chat_id: Optional[int],
|
chat_id: Optional[int],
|
||||||
user_id: Optional[int],
|
user_id: Optional[int],
|
||||||
thread_id: Optional[int] = None,
|
thread_id: Optional[int] = None,
|
||||||
|
business_connection_id: Optional[str] = None,
|
||||||
destiny: str = DEFAULT_DESTINY,
|
destiny: str = DEFAULT_DESTINY,
|
||||||
) -> Optional[FSMContext]:
|
) -> Optional[FSMContext]:
|
||||||
if chat_id is None:
|
if chat_id is None:
|
||||||
|
|
@ -83,6 +82,7 @@ class FSMContextMiddleware(BaseMiddleware):
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
|
business_connection_id=business_connection_id,
|
||||||
destiny=destiny,
|
destiny=destiny,
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
@ -93,6 +93,7 @@ class FSMContextMiddleware(BaseMiddleware):
|
||||||
chat_id: int,
|
chat_id: int,
|
||||||
user_id: int,
|
user_id: int,
|
||||||
thread_id: Optional[int] = None,
|
thread_id: Optional[int] = None,
|
||||||
|
business_connection_id: Optional[str] = None,
|
||||||
destiny: str = DEFAULT_DESTINY,
|
destiny: str = DEFAULT_DESTINY,
|
||||||
) -> FSMContext:
|
) -> FSMContext:
|
||||||
return FSMContext(
|
return FSMContext(
|
||||||
|
|
@ -102,6 +103,7 @@ class FSMContextMiddleware(BaseMiddleware):
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
bot_id=bot.id,
|
bot_id=bot.id,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
|
business_connection_id=business_connection_id,
|
||||||
destiny=destiny,
|
destiny=destiny,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ class StorageKey:
|
||||||
chat_id: int
|
chat_id: int
|
||||||
user_id: int
|
user_id: int
|
||||||
thread_id: Optional[int] = None
|
thread_id: Optional[int] = None
|
||||||
|
business_connection_id: Optional[str] = None
|
||||||
destiny: str = DEFAULT_DESTINY
|
destiny: str = DEFAULT_DESTINY
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -53,17 +53,20 @@ class DefaultKeyBuilder(KeyBuilder):
|
||||||
prefix: str = "fsm",
|
prefix: str = "fsm",
|
||||||
separator: str = ":",
|
separator: str = ":",
|
||||||
with_bot_id: bool = False,
|
with_bot_id: bool = False,
|
||||||
|
with_business_connection_id: bool = False,
|
||||||
with_destiny: bool = False,
|
with_destiny: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
:param prefix: prefix for all records
|
:param prefix: prefix for all records
|
||||||
:param separator: separator
|
:param separator: separator
|
||||||
:param with_bot_id: include Bot id in the key
|
:param with_bot_id: include Bot id in the key
|
||||||
|
:param with_business_connection_id: include business connection id
|
||||||
:param with_destiny: include destiny key
|
:param with_destiny: include destiny key
|
||||||
"""
|
"""
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
self.separator = separator
|
self.separator = separator
|
||||||
self.with_bot_id = with_bot_id
|
self.with_bot_id = with_bot_id
|
||||||
|
self.with_business_connection_id = with_business_connection_id
|
||||||
self.with_destiny = with_destiny
|
self.with_destiny = with_destiny
|
||||||
|
|
||||||
def build(self, key: StorageKey, part: Literal["data", "state", "lock"]) -> str:
|
def build(self, key: StorageKey, part: Literal["data", "state", "lock"]) -> str:
|
||||||
|
|
@ -74,6 +77,8 @@ class DefaultKeyBuilder(KeyBuilder):
|
||||||
if key.thread_id:
|
if key.thread_id:
|
||||||
parts.append(str(key.thread_id))
|
parts.append(str(key.thread_id))
|
||||||
parts.append(str(key.user_id))
|
parts.append(str(key.user_id))
|
||||||
|
if self.with_business_connection_id and key.business_connection_id:
|
||||||
|
parts.append(str(key.business_connection_id))
|
||||||
if self.with_destiny:
|
if self.with_destiny:
|
||||||
parts.append(key.destiny)
|
parts.append(key.destiny)
|
||||||
elif key.destiny != DEFAULT_DESTINY:
|
elif key.destiny != DEFAULT_DESTINY:
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,11 @@ from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from aiogram.dispatcher.middlewares.user_context import UserContextMiddleware
|
from aiogram.dispatcher.middlewares.user_context import (
|
||||||
from aiogram.types import Update
|
EventContext,
|
||||||
|
UserContextMiddleware,
|
||||||
|
)
|
||||||
|
from aiogram.types import Chat, Update, User
|
||||||
|
|
||||||
|
|
||||||
async def next_handler(*args, **kwargs):
|
async def next_handler(*args, **kwargs):
|
||||||
|
|
@ -18,9 +21,23 @@ class TestUserContextMiddleware:
|
||||||
async def test_call(self):
|
async def test_call(self):
|
||||||
middleware = UserContextMiddleware()
|
middleware = UserContextMiddleware()
|
||||||
data = {}
|
data = {}
|
||||||
with patch.object(UserContextMiddleware, "resolve_event_context", return_value=[1, 2, 3]):
|
|
||||||
|
chat = Chat(id=1, type="private", title="Test")
|
||||||
|
user = User(id=2, first_name="Test", is_bot=False)
|
||||||
|
thread_id = 3
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
UserContextMiddleware,
|
||||||
|
"resolve_event_context",
|
||||||
|
return_value=EventContext(user=user, chat=chat, thread_id=3),
|
||||||
|
):
|
||||||
await middleware(next_handler, Update(update_id=42), data)
|
await middleware(next_handler, Update(update_id=42), data)
|
||||||
|
|
||||||
assert data["event_chat"] == 1
|
event_context = data["event_context"]
|
||||||
assert data["event_from_user"] == 2
|
assert isinstance(event_context, EventContext)
|
||||||
assert data["event_thread_id"] == 3
|
assert event_context.chat is chat
|
||||||
|
assert event_context.user is user
|
||||||
|
assert event_context.thread_id == thread_id
|
||||||
|
assert data["event_chat"] is chat
|
||||||
|
assert data["event_from_user"] is user
|
||||||
|
assert data["event_thread_id"] == thread_id
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue