diff --git a/aiogram/dispatcher/middlewares/media_group.py b/aiogram/dispatcher/middlewares/media_group.py index 7ca2ac6b..5efcd0f8 100644 --- a/aiogram/dispatcher/middlewares/media_group.py +++ b/aiogram/dispatcher/middlewares/media_group.py @@ -5,6 +5,7 @@ from collections import defaultdict from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Any, cast +from aiogram import Bot from aiogram.dispatcher.middlewares.base import BaseMiddleware from aiogram.types import Message, TelegramObject @@ -56,8 +57,9 @@ class BaseMediaGroupAggregator(ABC): class RedisMediaGroupAggregator(BaseMediaGroupAggregator): redis: "Redis" - def __init__(self, redis: "Redis") -> None: + def __init__(self, redis: "Redis", ttl_sec: int = TTL_SEC) -> None: self.redis = redis + self.ttl_sec = ttl_sec @staticmethod def get_group_key(media_group_id: str) -> str: @@ -73,9 +75,9 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator): async def add_into_group(self, media_group_id: str, media: Message) -> int: async with self.redis.pipeline(transaction=True) as pipe: - pipe.set(self.get_last_message_time_key(media_group_id), time.time(), ex=TTL_SEC) + pipe.set(self.get_last_message_time_key(media_group_id), time.time(), ex=self.ttl_sec) pipe.rpush(self.get_group_key(media_group_id), media.model_dump_json()) - pipe.expire(self.get_group_key(media_group_id), TTL_SEC) + pipe.expire(self.get_group_key(media_group_id), self.ttl_sec) res = await pipe.execute() return cast(int, res[1]) @@ -110,16 +112,17 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator): class MemoryMediaGroupAggregator(BaseMediaGroupAggregator): - def __init__(self) -> None: + def __init__(self, ttl_sec: int = TTL_SEC) -> None: self.groups: dict[str, list[Message]] = defaultdict(list) self.last_message_timers: dict[str, float] = {} self.locks: dict[str, bool] = {} + self.ttl_sec = ttl_sec def remove_expired_objects(self) -> None: expired_group_ids = [] current_time = time.time() for group_id, last_message_time in self.last_message_timers.items(): - if current_time - last_message_time > TTL_SEC: + if current_time - last_message_time > self.ttl_sec: expired_group_ids.append(group_id) else: break # the list is sorted in ascending order @@ -188,7 +191,10 @@ class MediaGroupAggregatorMiddleware(BaseMiddleware): album = await self.media_group_aggregator.get_group(event.media_group_id) if not album: return None - album.sort(key=lambda msg: msg.message_id) + album = sorted( + (msg.as_(cast(Bot, data.get("bot"))) for msg in album), + key=lambda msg: msg.message_id, + ) data.update(album=album) result = await handler(album[0], data) await self.media_group_aggregator.delete_group(event.media_group_id) diff --git a/aiogram/filters/media_group.py b/aiogram/filters/media_group.py index 3c8422f4..a685bb23 100644 --- a/aiogram/filters/media_group.py +++ b/aiogram/filters/media_group.py @@ -54,7 +54,7 @@ class MediaGroupFilter(Filter): ) async def __call__( - self, message: Message, album: list[Message] = None + self, message: Message, album: list[Message] | None = None ) -> Literal[False] | dict[str, Any]: media_count = len(album or []) if not (self.min_media_count <= media_count <= self.max_media_count): diff --git a/tests/test_dispatcher/test_middlewares/test_media_group.py b/tests/test_dispatcher/test_middlewares/test_media_group.py index f9fd7373..91b4bdba 100644 --- a/tests/test_dispatcher/test_middlewares/test_media_group.py +++ b/tests/test_dispatcher/test_middlewares/test_media_group.py @@ -1,17 +1,41 @@ import asyncio - -from aiogram.dispatcher.middlewares.media_group import MediaGroupAggregatorMiddleware -from aiogram.types import Message, Chat +import time from datetime import datetime -from typing import Any +from typing import Any, Awaitable, Callable +from unittest import mock + import pytest +from redis.asyncio.client import Redis + +from aiogram.dispatcher.middlewares.media_group import ( + BaseMediaGroupAggregator, + MediaGroupAggregatorMiddleware, + MemoryMediaGroupAggregator, + RedisMediaGroupAggregator, +) +from aiogram.types import Chat, Message + + +def _get_message(message_id: int, **kwargs): + chat = Chat(id=1, type="private", title="Test") + return Message(message_id=message_id, date=datetime.now(), chat=chat, **kwargs) + + +async def wait_until_func_call_sleep(func: Callable[..., Awaitable[Any]], *args, **kwargs) -> Any: + start_sleep = asyncio.Event() + real_sleep = asyncio.sleep + + async def mock_sleep(*args, **kwargs): + start_sleep.set() + await real_sleep(0) + + with mock.patch("asyncio.sleep", mock_sleep): + task1 = func(*args, **kwargs) + await start_sleep.wait() + return task1 class TestMediaGroupAggregatorMiddleware: - def _get_message(self, message_id: int, **kwargs): - chat = Chat(id=1, type="private", title="Test") - return Message(message_id=message_id, date=datetime.now(), chat=chat, **kwargs) - def get_middleware(self): return MediaGroupAggregatorMiddleware(delay=0.1) @@ -22,7 +46,7 @@ class TestMediaGroupAggregatorMiddleware: nonlocal is_called is_called = True - await self.get_middleware()(next_handler, self._get_message(1), {}) + await self.get_middleware()(next_handler, _get_message(1), {}) assert is_called async def test_called_once_for_album(self): @@ -36,13 +60,26 @@ class TestMediaGroupAggregatorMiddleware: album = data.get("album") await asyncio.gather( - middleware(next_handler, self._get_message(1, media_group_id="42"), {}), - middleware(next_handler, self._get_message(2, media_group_id="42"), {}), + middleware(next_handler, _get_message(1, media_group_id="42"), {}), + middleware(next_handler, _get_message(2, media_group_id="42"), {}), ) assert album is not None assert len(album) == 2 assert counter == 1 + async def test_bot_object_saved(self, bot): + middleware = self.get_middleware() + event = album = None + + async def next_handler(message: Message, data: dict[str, Any]): + nonlocal event, album + event = message + album = data.get("album") + + await middleware(next_handler, _get_message(1, media_group_id="42"), {"bot": bot}) + assert event.bot is bot + assert all(msg.bot is bot for msg in album) + async def test_propagate_first_media_in_album(self): middleware = self.get_middleware() first_message = None @@ -51,13 +88,29 @@ class TestMediaGroupAggregatorMiddleware: nonlocal first_message first_message = message - await asyncio.gather( - middleware(next_handler, self._get_message(2, media_group_id="42"), {}), - middleware(next_handler, self._get_message(1, media_group_id="42"), {}), + task1 = await wait_until_func_call_sleep( + asyncio.create_task, middleware(next_handler, _get_message(2, media_group_id="42"), {}) ) + await middleware(next_handler, _get_message(1, media_group_id="42"), {}) + await task1 assert isinstance(first_message, Message) assert first_message.message_id == 1 + async def test_skip_propagating_if_data_deleted(self): + middleware = self.get_middleware() + counter = 0 + + async def next_handler(*args, **kwargs): + nonlocal counter + counter += 1 + + task1 = await wait_until_func_call_sleep( + asyncio.create_task, middleware(next_handler, _get_message(1, media_group_id="42"), {}) + ) + await middleware.media_group_aggregator.delete_group("42") + await task1 + assert counter == 0 + async def test_different_albums_non_interfere(self): middleware = self.get_middleware() counter = 0 @@ -69,8 +122,8 @@ class TestMediaGroupAggregatorMiddleware: albums.append(data.get("album")) await asyncio.gather( - middleware(next_handler, self._get_message(1, media_group_id="1"), {}), - middleware(next_handler, self._get_message(2, media_group_id="2"), {}), + middleware(next_handler, _get_message(1, media_group_id="1"), {}), + middleware(next_handler, _get_message(2, media_group_id="2"), {}), ) assert counter == 2 assert len(albums) == 2 @@ -80,18 +133,76 @@ class TestMediaGroupAggregatorMiddleware: album = None async def failed_handler(*args, **kwargs): - raise Exception("Failed") + raise RuntimeError("Failed") async def working_handler(_, data: dict[str, Any]): nonlocal album album = data.get("album") - first_message = self._get_message(1, media_group_id="42") - second_message = self._get_message(2, media_group_id="42") - with pytest.raises(Exception): + first_message = _get_message(1, media_group_id="42") + second_message = _get_message(2, media_group_id="42") + with pytest.raises(RuntimeError): await asyncio.gather( middleware(failed_handler, first_message, {}), middleware(failed_handler, second_message, {}), ) await middleware(working_handler, first_message, {}) assert len(album) == 2 + + +def test_message_deduplication(): + message_1, message_2 = _get_message(1), _get_message(2) + res = [message_1, message_2] + assert BaseMediaGroupAggregator.deduplicate_messages([message_1, message_2]) == res + assert BaseMediaGroupAggregator.deduplicate_messages([message_1, message_2, message_2]) == res + assert BaseMediaGroupAggregator.deduplicate_messages([message_1, message_2, message_1]) == res + + +@pytest.fixture(params=["memory", "redis"], scope="function") +async def aggregator(request): + if request.param == "memory": + yield MemoryMediaGroupAggregator() + else: + redis = Redis.from_url(request.getfixturevalue("redis_server")) + yield RedisMediaGroupAggregator(redis) + keys = await redis.keys("media_group:*") + if keys: + await redis.delete(*keys) + await redis.aclose() + + +class TestMediaGroupAggregator: + async def test_group_creating(self, aggregator: BaseMediaGroupAggregator): + msg1 = _get_message(1) + msg2 = _get_message(2) + assert await aggregator.add_into_group("42", msg1) == 1 + assert await aggregator.add_into_group("42", msg2) == 2 + assert {msg.message_id for msg in await aggregator.get_group("42")} == { + msg1.message_id, + msg2.message_id, + } + await aggregator.delete_group("42") + assert await aggregator.get_group("42") == [] + + async def test_acquire_lock(self, aggregator: BaseMediaGroupAggregator): + for _ in range(2): + assert await aggregator.acquire_lock("42") + assert not await aggregator.acquire_lock("42") + await aggregator.release_lock("42") + + async def test_expired_objects_removed(self): + aggregator = MemoryMediaGroupAggregator() + await aggregator.add_into_group("42", _get_message(1)) + with mock.patch("time.time", return_value=time.time() + aggregator.ttl_sec + 1): + new_msg = _get_message(2) + await aggregator.add_into_group("24", new_msg) + assert await aggregator.get_group("42") == [] + assert await aggregator.get_group("24") == [new_msg] + + async def test_last_message_time(self, aggregator: BaseMediaGroupAggregator): + assert await aggregator.get_last_message_time("42") is None + await aggregator.add_into_group("42", _get_message(1)) + time_before_second_message = time.time() + assert await aggregator.get_last_message_time("42") <= time_before_second_message + await aggregator.add_into_group("42", _get_message(2)) + assert await aggregator.get_last_message_time("42") >= time_before_second_message