mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
update tests for media group aggregator
This commit is contained in:
parent
16b718ca07
commit
b7f2da391b
3 changed files with 144 additions and 27 deletions
|
|
@ -5,6 +5,7 @@ from collections import defaultdict
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import TYPE_CHECKING, Any, cast
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
|
from aiogram import Bot
|
||||||
from aiogram.dispatcher.middlewares.base import BaseMiddleware
|
from aiogram.dispatcher.middlewares.base import BaseMiddleware
|
||||||
from aiogram.types import Message, TelegramObject
|
from aiogram.types import Message, TelegramObject
|
||||||
|
|
||||||
|
|
@ -56,8 +57,9 @@ class BaseMediaGroupAggregator(ABC):
|
||||||
class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
|
class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
|
||||||
redis: "Redis"
|
redis: "Redis"
|
||||||
|
|
||||||
def __init__(self, redis: "Redis") -> None:
|
def __init__(self, redis: "Redis", ttl_sec: int = TTL_SEC) -> None:
|
||||||
self.redis = redis
|
self.redis = redis
|
||||||
|
self.ttl_sec = ttl_sec
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_group_key(media_group_id: str) -> str:
|
def get_group_key(media_group_id: str) -> str:
|
||||||
|
|
@ -73,9 +75,9 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
|
||||||
|
|
||||||
async def add_into_group(self, media_group_id: str, media: Message) -> int:
|
async def add_into_group(self, media_group_id: str, media: Message) -> int:
|
||||||
async with self.redis.pipeline(transaction=True) as pipe:
|
async with self.redis.pipeline(transaction=True) as pipe:
|
||||||
pipe.set(self.get_last_message_time_key(media_group_id), time.time(), ex=TTL_SEC)
|
pipe.set(self.get_last_message_time_key(media_group_id), time.time(), ex=self.ttl_sec)
|
||||||
pipe.rpush(self.get_group_key(media_group_id), media.model_dump_json())
|
pipe.rpush(self.get_group_key(media_group_id), media.model_dump_json())
|
||||||
pipe.expire(self.get_group_key(media_group_id), TTL_SEC)
|
pipe.expire(self.get_group_key(media_group_id), self.ttl_sec)
|
||||||
res = await pipe.execute()
|
res = await pipe.execute()
|
||||||
return cast(int, res[1])
|
return cast(int, res[1])
|
||||||
|
|
||||||
|
|
@ -110,16 +112,17 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
|
||||||
|
|
||||||
|
|
||||||
class MemoryMediaGroupAggregator(BaseMediaGroupAggregator):
|
class MemoryMediaGroupAggregator(BaseMediaGroupAggregator):
|
||||||
def __init__(self) -> None:
|
def __init__(self, ttl_sec: int = TTL_SEC) -> None:
|
||||||
self.groups: dict[str, list[Message]] = defaultdict(list)
|
self.groups: dict[str, list[Message]] = defaultdict(list)
|
||||||
self.last_message_timers: dict[str, float] = {}
|
self.last_message_timers: dict[str, float] = {}
|
||||||
self.locks: dict[str, bool] = {}
|
self.locks: dict[str, bool] = {}
|
||||||
|
self.ttl_sec = ttl_sec
|
||||||
|
|
||||||
def remove_expired_objects(self) -> None:
|
def remove_expired_objects(self) -> None:
|
||||||
expired_group_ids = []
|
expired_group_ids = []
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
for group_id, last_message_time in self.last_message_timers.items():
|
for group_id, last_message_time in self.last_message_timers.items():
|
||||||
if current_time - last_message_time > TTL_SEC:
|
if current_time - last_message_time > self.ttl_sec:
|
||||||
expired_group_ids.append(group_id)
|
expired_group_ids.append(group_id)
|
||||||
else:
|
else:
|
||||||
break # the list is sorted in ascending order
|
break # the list is sorted in ascending order
|
||||||
|
|
@ -188,7 +191,10 @@ class MediaGroupAggregatorMiddleware(BaseMiddleware):
|
||||||
album = await self.media_group_aggregator.get_group(event.media_group_id)
|
album = await self.media_group_aggregator.get_group(event.media_group_id)
|
||||||
if not album:
|
if not album:
|
||||||
return None
|
return None
|
||||||
album.sort(key=lambda msg: msg.message_id)
|
album = sorted(
|
||||||
|
(msg.as_(cast(Bot, data.get("bot"))) for msg in album),
|
||||||
|
key=lambda msg: msg.message_id,
|
||||||
|
)
|
||||||
data.update(album=album)
|
data.update(album=album)
|
||||||
result = await handler(album[0], data)
|
result = await handler(album[0], data)
|
||||||
await self.media_group_aggregator.delete_group(event.media_group_id)
|
await self.media_group_aggregator.delete_group(event.media_group_id)
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,7 @@ class MediaGroupFilter(Filter):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def __call__(
|
async def __call__(
|
||||||
self, message: Message, album: list[Message] = None
|
self, message: Message, album: list[Message] | None = None
|
||||||
) -> Literal[False] | dict[str, Any]:
|
) -> Literal[False] | dict[str, Any]:
|
||||||
media_count = len(album or [])
|
media_count = len(album or [])
|
||||||
if not (self.min_media_count <= media_count <= self.max_media_count):
|
if not (self.min_media_count <= media_count <= self.max_media_count):
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,41 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import time
|
||||||
from aiogram.dispatcher.middlewares.media_group import MediaGroupAggregatorMiddleware
|
|
||||||
from aiogram.types import Message, Chat
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any, Awaitable, Callable
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from redis.asyncio.client import Redis
|
||||||
|
|
||||||
|
from aiogram.dispatcher.middlewares.media_group import (
|
||||||
|
BaseMediaGroupAggregator,
|
||||||
|
MediaGroupAggregatorMiddleware,
|
||||||
|
MemoryMediaGroupAggregator,
|
||||||
|
RedisMediaGroupAggregator,
|
||||||
|
)
|
||||||
|
from aiogram.types import Chat, Message
|
||||||
|
|
||||||
|
|
||||||
|
def _get_message(message_id: int, **kwargs):
|
||||||
|
chat = Chat(id=1, type="private", title="Test")
|
||||||
|
return Message(message_id=message_id, date=datetime.now(), chat=chat, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
async def wait_until_func_call_sleep(func: Callable[..., Awaitable[Any]], *args, **kwargs) -> Any:
|
||||||
|
start_sleep = asyncio.Event()
|
||||||
|
real_sleep = asyncio.sleep
|
||||||
|
|
||||||
|
async def mock_sleep(*args, **kwargs):
|
||||||
|
start_sleep.set()
|
||||||
|
await real_sleep(0)
|
||||||
|
|
||||||
|
with mock.patch("asyncio.sleep", mock_sleep):
|
||||||
|
task1 = func(*args, **kwargs)
|
||||||
|
await start_sleep.wait()
|
||||||
|
return task1
|
||||||
|
|
||||||
|
|
||||||
class TestMediaGroupAggregatorMiddleware:
|
class TestMediaGroupAggregatorMiddleware:
|
||||||
def _get_message(self, message_id: int, **kwargs):
|
|
||||||
chat = Chat(id=1, type="private", title="Test")
|
|
||||||
return Message(message_id=message_id, date=datetime.now(), chat=chat, **kwargs)
|
|
||||||
|
|
||||||
def get_middleware(self):
|
def get_middleware(self):
|
||||||
return MediaGroupAggregatorMiddleware(delay=0.1)
|
return MediaGroupAggregatorMiddleware(delay=0.1)
|
||||||
|
|
||||||
|
|
@ -22,7 +46,7 @@ class TestMediaGroupAggregatorMiddleware:
|
||||||
nonlocal is_called
|
nonlocal is_called
|
||||||
is_called = True
|
is_called = True
|
||||||
|
|
||||||
await self.get_middleware()(next_handler, self._get_message(1), {})
|
await self.get_middleware()(next_handler, _get_message(1), {})
|
||||||
assert is_called
|
assert is_called
|
||||||
|
|
||||||
async def test_called_once_for_album(self):
|
async def test_called_once_for_album(self):
|
||||||
|
|
@ -36,13 +60,26 @@ class TestMediaGroupAggregatorMiddleware:
|
||||||
album = data.get("album")
|
album = data.get("album")
|
||||||
|
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
middleware(next_handler, self._get_message(1, media_group_id="42"), {}),
|
middleware(next_handler, _get_message(1, media_group_id="42"), {}),
|
||||||
middleware(next_handler, self._get_message(2, media_group_id="42"), {}),
|
middleware(next_handler, _get_message(2, media_group_id="42"), {}),
|
||||||
)
|
)
|
||||||
assert album is not None
|
assert album is not None
|
||||||
assert len(album) == 2
|
assert len(album) == 2
|
||||||
assert counter == 1
|
assert counter == 1
|
||||||
|
|
||||||
|
async def test_bot_object_saved(self, bot):
|
||||||
|
middleware = self.get_middleware()
|
||||||
|
event = album = None
|
||||||
|
|
||||||
|
async def next_handler(message: Message, data: dict[str, Any]):
|
||||||
|
nonlocal event, album
|
||||||
|
event = message
|
||||||
|
album = data.get("album")
|
||||||
|
|
||||||
|
await middleware(next_handler, _get_message(1, media_group_id="42"), {"bot": bot})
|
||||||
|
assert event.bot is bot
|
||||||
|
assert all(msg.bot is bot for msg in album)
|
||||||
|
|
||||||
async def test_propagate_first_media_in_album(self):
|
async def test_propagate_first_media_in_album(self):
|
||||||
middleware = self.get_middleware()
|
middleware = self.get_middleware()
|
||||||
first_message = None
|
first_message = None
|
||||||
|
|
@ -51,13 +88,29 @@ class TestMediaGroupAggregatorMiddleware:
|
||||||
nonlocal first_message
|
nonlocal first_message
|
||||||
first_message = message
|
first_message = message
|
||||||
|
|
||||||
await asyncio.gather(
|
task1 = await wait_until_func_call_sleep(
|
||||||
middleware(next_handler, self._get_message(2, media_group_id="42"), {}),
|
asyncio.create_task, middleware(next_handler, _get_message(2, media_group_id="42"), {})
|
||||||
middleware(next_handler, self._get_message(1, media_group_id="42"), {}),
|
|
||||||
)
|
)
|
||||||
|
await middleware(next_handler, _get_message(1, media_group_id="42"), {})
|
||||||
|
await task1
|
||||||
assert isinstance(first_message, Message)
|
assert isinstance(first_message, Message)
|
||||||
assert first_message.message_id == 1
|
assert first_message.message_id == 1
|
||||||
|
|
||||||
|
async def test_skip_propagating_if_data_deleted(self):
|
||||||
|
middleware = self.get_middleware()
|
||||||
|
counter = 0
|
||||||
|
|
||||||
|
async def next_handler(*args, **kwargs):
|
||||||
|
nonlocal counter
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
task1 = await wait_until_func_call_sleep(
|
||||||
|
asyncio.create_task, middleware(next_handler, _get_message(1, media_group_id="42"), {})
|
||||||
|
)
|
||||||
|
await middleware.media_group_aggregator.delete_group("42")
|
||||||
|
await task1
|
||||||
|
assert counter == 0
|
||||||
|
|
||||||
async def test_different_albums_non_interfere(self):
|
async def test_different_albums_non_interfere(self):
|
||||||
middleware = self.get_middleware()
|
middleware = self.get_middleware()
|
||||||
counter = 0
|
counter = 0
|
||||||
|
|
@ -69,8 +122,8 @@ class TestMediaGroupAggregatorMiddleware:
|
||||||
albums.append(data.get("album"))
|
albums.append(data.get("album"))
|
||||||
|
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
middleware(next_handler, self._get_message(1, media_group_id="1"), {}),
|
middleware(next_handler, _get_message(1, media_group_id="1"), {}),
|
||||||
middleware(next_handler, self._get_message(2, media_group_id="2"), {}),
|
middleware(next_handler, _get_message(2, media_group_id="2"), {}),
|
||||||
)
|
)
|
||||||
assert counter == 2
|
assert counter == 2
|
||||||
assert len(albums) == 2
|
assert len(albums) == 2
|
||||||
|
|
@ -80,18 +133,76 @@ class TestMediaGroupAggregatorMiddleware:
|
||||||
album = None
|
album = None
|
||||||
|
|
||||||
async def failed_handler(*args, **kwargs):
|
async def failed_handler(*args, **kwargs):
|
||||||
raise Exception("Failed")
|
raise RuntimeError("Failed")
|
||||||
|
|
||||||
async def working_handler(_, data: dict[str, Any]):
|
async def working_handler(_, data: dict[str, Any]):
|
||||||
nonlocal album
|
nonlocal album
|
||||||
album = data.get("album")
|
album = data.get("album")
|
||||||
|
|
||||||
first_message = self._get_message(1, media_group_id="42")
|
first_message = _get_message(1, media_group_id="42")
|
||||||
second_message = self._get_message(2, media_group_id="42")
|
second_message = _get_message(2, media_group_id="42")
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(RuntimeError):
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
middleware(failed_handler, first_message, {}),
|
middleware(failed_handler, first_message, {}),
|
||||||
middleware(failed_handler, second_message, {}),
|
middleware(failed_handler, second_message, {}),
|
||||||
)
|
)
|
||||||
await middleware(working_handler, first_message, {})
|
await middleware(working_handler, first_message, {})
|
||||||
assert len(album) == 2
|
assert len(album) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_deduplication():
|
||||||
|
message_1, message_2 = _get_message(1), _get_message(2)
|
||||||
|
res = [message_1, message_2]
|
||||||
|
assert BaseMediaGroupAggregator.deduplicate_messages([message_1, message_2]) == res
|
||||||
|
assert BaseMediaGroupAggregator.deduplicate_messages([message_1, message_2, message_2]) == res
|
||||||
|
assert BaseMediaGroupAggregator.deduplicate_messages([message_1, message_2, message_1]) == res
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=["memory", "redis"], scope="function")
|
||||||
|
async def aggregator(request):
|
||||||
|
if request.param == "memory":
|
||||||
|
yield MemoryMediaGroupAggregator()
|
||||||
|
else:
|
||||||
|
redis = Redis.from_url(request.getfixturevalue("redis_server"))
|
||||||
|
yield RedisMediaGroupAggregator(redis)
|
||||||
|
keys = await redis.keys("media_group:*")
|
||||||
|
if keys:
|
||||||
|
await redis.delete(*keys)
|
||||||
|
await redis.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
class TestMediaGroupAggregator:
|
||||||
|
async def test_group_creating(self, aggregator: BaseMediaGroupAggregator):
|
||||||
|
msg1 = _get_message(1)
|
||||||
|
msg2 = _get_message(2)
|
||||||
|
assert await aggregator.add_into_group("42", msg1) == 1
|
||||||
|
assert await aggregator.add_into_group("42", msg2) == 2
|
||||||
|
assert {msg.message_id for msg in await aggregator.get_group("42")} == {
|
||||||
|
msg1.message_id,
|
||||||
|
msg2.message_id,
|
||||||
|
}
|
||||||
|
await aggregator.delete_group("42")
|
||||||
|
assert await aggregator.get_group("42") == []
|
||||||
|
|
||||||
|
async def test_acquire_lock(self, aggregator: BaseMediaGroupAggregator):
|
||||||
|
for _ in range(2):
|
||||||
|
assert await aggregator.acquire_lock("42")
|
||||||
|
assert not await aggregator.acquire_lock("42")
|
||||||
|
await aggregator.release_lock("42")
|
||||||
|
|
||||||
|
async def test_expired_objects_removed(self):
|
||||||
|
aggregator = MemoryMediaGroupAggregator()
|
||||||
|
await aggregator.add_into_group("42", _get_message(1))
|
||||||
|
with mock.patch("time.time", return_value=time.time() + aggregator.ttl_sec + 1):
|
||||||
|
new_msg = _get_message(2)
|
||||||
|
await aggregator.add_into_group("24", new_msg)
|
||||||
|
assert await aggregator.get_group("42") == []
|
||||||
|
assert await aggregator.get_group("24") == [new_msg]
|
||||||
|
|
||||||
|
async def test_last_message_time(self, aggregator: BaseMediaGroupAggregator):
|
||||||
|
assert await aggregator.get_last_message_time("42") is None
|
||||||
|
await aggregator.add_into_group("42", _get_message(1))
|
||||||
|
time_before_second_message = time.time()
|
||||||
|
assert await aggregator.get_last_message_time("42") <= time_before_second_message
|
||||||
|
await aggregator.add_into_group("42", _get_message(2))
|
||||||
|
assert await aggregator.get_last_message_time("42") >= time_before_second_message
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue