This commit is contained in:
Виталий 2026-04-05 18:52:48 +03:00 committed by GitHub
commit 17539c777a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 533 additions and 4 deletions

View file

@ -0,0 +1,227 @@
import asyncio
import time
from datetime import datetime
from typing import Any, Awaitable, Callable
from unittest import mock
import pytest
from redis.asyncio.client import Redis
from aiogram.dispatcher.middlewares.media_group import (
BaseMediaGroupAggregator,
MediaGroupAggregatorMiddleware,
MemoryMediaGroupAggregator,
RedisMediaGroupAggregator,
)
from aiogram.types import Chat, Message
def _get_message(message_id: int, **kwargs):
chat = Chat(id=1, type="private", title="Test")
return Message(message_id=message_id, date=datetime.now(), chat=chat, **kwargs)
async def wait_until_func_call_sleep(func: Callable[..., Awaitable[Any]], *args, **kwargs) -> Any:
start_sleep = asyncio.Event()
real_sleep = asyncio.sleep
async def mock_sleep(*args, **kwargs):
start_sleep.set()
await real_sleep(0)
with mock.patch("asyncio.sleep", mock_sleep):
task1 = func(*args, **kwargs)
await start_sleep.wait()
return task1
class TestMediaGroupAggregatorMiddleware:
def get_middleware(self):
return MediaGroupAggregatorMiddleware(delay=0.1)
async def test_skip_non_media_group(self):
is_called = False
async def next_handler(*args, **kwargs):
nonlocal is_called
is_called = True
await self.get_middleware()(next_handler, _get_message(1), {})
assert is_called
async def test_called_once_for_album(self):
middleware = self.get_middleware()
counter = 0
album = None
async def next_handler(_, data: dict[str, Any]):
nonlocal counter, album
counter += 1
album = data.get("album")
await asyncio.gather(
middleware(next_handler, _get_message(1, media_group_id="42"), {}),
middleware(next_handler, _get_message(2, media_group_id="42"), {}),
)
assert album is not None
assert len(album) == 2
assert counter == 1
async def test_bot_object_saved(self, bot):
middleware = self.get_middleware()
event = album = None
async def next_handler(message: Message, data: dict[str, Any]):
nonlocal event, album
event = message
album = data.get("album")
await middleware(next_handler, _get_message(1, media_group_id="42"), {"bot": bot})
assert event.bot is bot
assert all(msg.bot is bot for msg in album)
async def test_propagate_first_media_in_album(self):
middleware = self.get_middleware()
first_message = None
async def next_handler(message: Message, _):
nonlocal first_message
first_message = message
task1 = await wait_until_func_call_sleep(
asyncio.create_task, middleware(next_handler, _get_message(2, media_group_id="42"), {})
)
await middleware(next_handler, _get_message(1, media_group_id="42"), {})
await task1
assert isinstance(first_message, Message)
assert first_message.message_id == 1
@pytest.mark.parametrize("deleted_object", ["album", "last_message_time"])
async def test_skip_propagating_if_data_deleted(self, deleted_object):
middleware = self.get_middleware()
counter = 0
async def next_handler(*args, **kwargs):
nonlocal counter
counter += 1
task1 = await wait_until_func_call_sleep(
asyncio.create_task, middleware(next_handler, _get_message(1, media_group_id="42"), {})
)
if deleted_object == "album":
middleware.media_group_aggregator.groups.pop("42")
else:
middleware.media_group_aggregator.last_message_timers.pop("42")
await task1
assert counter == 0
async def test_different_albums_non_interfere(self):
middleware = self.get_middleware()
counter = 0
albums = []
async def next_handler(_, data: dict[str, Any]):
nonlocal counter, albums
counter += 1
albums.append(data.get("album"))
await asyncio.gather(
middleware(next_handler, _get_message(1, media_group_id="1"), {}),
middleware(next_handler, _get_message(2, media_group_id="2"), {}),
)
assert counter == 2
assert len(albums) == 2
async def test_retry_handling(self):
middleware = self.get_middleware()
album = None
async def failed_handler(*args, **kwargs):
raise RuntimeError("Failed")
async def working_handler(_, data: dict[str, Any]):
nonlocal album
album = data.get("album")
first_message = _get_message(1, media_group_id="42")
second_message = _get_message(2, media_group_id="42")
with pytest.raises(RuntimeError):
await asyncio.gather(
middleware(failed_handler, first_message, {}),
middleware(failed_handler, second_message, {}),
)
await middleware(working_handler, first_message, {})
assert len(album) == 2
def test_message_deduplication():
message_1, message_2 = _get_message(1), _get_message(2)
res = [message_1, message_2]
assert BaseMediaGroupAggregator.deduplicate_messages([message_1, message_2]) == res
assert BaseMediaGroupAggregator.deduplicate_messages([message_1, message_2, message_2]) == res
assert BaseMediaGroupAggregator.deduplicate_messages([message_1, message_2, message_1]) == res
@pytest.fixture(params=["memory", "redis"], scope="function")
async def aggregator(request):
if request.param == "memory":
yield MemoryMediaGroupAggregator()
else:
redis = Redis.from_url(request.getfixturevalue("redis_server"))
yield RedisMediaGroupAggregator(redis)
keys = await redis.keys("media_group:*")
if keys:
await redis.delete(*keys)
await redis.aclose()
class TestMediaGroupAggregator:
async def test_group_creating(self, aggregator: BaseMediaGroupAggregator):
msg1 = _get_message(1)
msg2 = _get_message(2)
assert await aggregator.add_into_group("42", msg1) == 1
assert await aggregator.add_into_group("42", msg2) == 2
assert {msg.message_id for msg in await aggregator.get_group("42")} == {
msg1.message_id,
msg2.message_id,
}
await aggregator.delete_group("42")
assert await aggregator.get_group("42") == []
async def test_acquire_lock(self, aggregator: BaseMediaGroupAggregator):
for i in ("key1", "key2"):
assert await aggregator.acquire_lock("42", i)
assert not await aggregator.acquire_lock("42", i)
await aggregator.release_lock("42", i)
async def test_lock_not_acquired_with_wrong_key(self, aggregator: BaseMediaGroupAggregator):
await aggregator.acquire_lock("42", "key1")
await aggregator.release_lock("42", "key2")
assert not await aggregator.acquire_lock("42", "key1")
async def test_expired_objects_removed(self):
aggregator = MemoryMediaGroupAggregator()
await aggregator.add_into_group("42", _get_message(1))
with mock.patch("time.monotonic", return_value=time.time() + aggregator.ttl_sec + 1):
new_msg = _get_message(2)
await aggregator.add_into_group("24", new_msg)
assert await aggregator.get_group("42") == []
assert await aggregator.get_group("24") == [new_msg]
async def test_get_current_time_memory_aggregator(self):
aggregator = MemoryMediaGroupAggregator()
with mock.patch("time.monotonic", return_value=1.1):
assert await aggregator.get_current_time() == 1.1
async def test_get_current_time_redis_aggregator(self):
aggregator = RedisMediaGroupAggregator(mock.Mock(spec=Redis))
aggregator.redis.time = mock.AsyncMock(return_value=(1, 123456))
assert await aggregator.get_current_time() == 1.123456
async def test_last_message_time(self, aggregator: BaseMediaGroupAggregator):
assert await aggregator.get_last_message_time("42") is None
await aggregator.add_into_group("42", _get_message(1))
time_before_second_message = await aggregator.get_current_time()
assert await aggregator.get_last_message_time("42") <= time_before_second_message
await aggregator.add_into_group("42", _get_message(2))
assert await aggregator.get_last_message_time("42") >= time_before_second_message