diff --git a/aiogram/dispatcher/router.py b/aiogram/dispatcher/router.py index 93fd29a6..70481595 100644 --- a/aiogram/dispatcher/router.py +++ b/aiogram/dispatcher/router.py @@ -137,18 +137,19 @@ class Router: return result for router in self.sub_routers: - kwargs.update(event_router=router) async for result in router.update_handler.trigger(update, **kwargs): return result raise SkipHandler - def emit_startup(self, *args, **kwargs): - self.startup.trigger(*args, **kwargs) + async def emit_startup(self, *args, **kwargs): + async for _ in self.startup.trigger(*args, **kwargs): # pragma: no cover + pass for router in self.sub_routers: - router.emit_startup(*args, **kwargs) + await router.emit_startup(*args, **kwargs) - def emit_shutdown(self, *args, **kwargs): - self.startup.trigger(*args, **kwargs) + async def emit_shutdown(self, *args, **kwargs): + async for _ in self.shutdown.trigger(*args, **kwargs): # pragma: no cover + pass for router in self.sub_routers: - router.emit_startup(*args, **kwargs) + await router.emit_shutdown(*args, **kwargs) diff --git a/tests/test_dispatcher/test_router.py b/tests/test_dispatcher/test_router.py index 1c06f4df..b61fc28e 100644 --- a/tests/test_dispatcher/test_router.py +++ b/tests/test_dispatcher/test_router.py @@ -9,11 +9,15 @@ from aiogram.api.types import ( ChosenInlineResult, InlineQuery, Message, + Poll, + PollOption, + PreCheckoutQuery, ShippingAddress, ShippingQuery, Update, User, ) +from aiogram.dispatcher.event.observer import SkipHandler from aiogram.dispatcher.router import Router @@ -206,8 +210,38 @@ class TestRouter: False, True, ), - # pytest.param("pre_checkout_query", Update(update_id=42, pre_checkout_query=...), False, False), - # pytest.param("poll", Update(update_id=42, poll=...), False, False), + pytest.param( + "pre_checkout_query", + Update( + update_id=42, + pre_checkout_query=PreCheckoutQuery( + id="query id", + from_user=User(id=42, is_bot=False, first_name="Test"), + currency="BTC", + total_amount=1, + invoice_payload="payload", + ), + ), + False, + True, + ), + pytest.param( + "poll", + Update( + update_id=42, + poll=Poll( + id="poll id", + question="Q?", + options=[ + PollOption(text="A1", voter_count=2), + PollOption(text="A2", voter_count=3), + ], + is_closed=False, + ), + ), + False, + False, + ), ], ) async def test_listen_update( @@ -230,3 +264,108 @@ class TestRouter: assert result["event_update"] == update assert result["event_router"] == router assert result["test"] == "PASS" + + @pytest.mark.asyncio + async def test_listen_unknown_update(self): + router = Router() + + with pytest.raises(SkipHandler): + await router._listen_update(Update(update_id=42)) + + @pytest.mark.asyncio + async def test_listen_unhandled_update(self): + router = Router() + observer = router.observers["message"] + + @observer(lambda event: False) + async def handler(event: Any): + pass + + with pytest.raises(SkipHandler): + await router._listen_update( + Update( + update_id=42, + poll=Poll( + id="poll id", + question="Q?", + options=[ + PollOption(text="A1", voter_count=2), + PollOption(text="A2", voter_count=3), + ], + is_closed=False, + ), + ) + ) + + @pytest.mark.asyncio + async def test_nested_router_listen_update(self): + router1 = Router() + router2 = Router() + router1.include_router(router2) + observer = router2.message_handler + + @observer() + async def my_handler(event: Message, **kwargs: Any): + assert Chat.get_current(False) + assert User.get_current(False) + return kwargs + + update = Update( + update_id=42, + message=Message( + message_id=42, + date=datetime.datetime.now(), + text="test", + chat=Chat(id=42, type="private"), + from_user=User(id=42, is_bot=False, first_name="Test"), + ), + ) + result = await router1._listen_update(update, test="PASS") + assert isinstance(result, dict) + assert result["event_update"] == update + assert result["event_router"] == router2 + assert result["test"] == "PASS" + + @pytest.mark.asyncio + async def test_emit_startup(self): + router1 = Router() + router2 = Router() + router1.include_router(router2) + + results = [] + + @router1.startup() + async def startup1(): + results.append(1) + + @router2.startup() + async def startup2(): + results.append(2) + + await router2.emit_startup() + assert results == [2] + + await router1.emit_startup() + assert results == [2, 1, 2] + + @pytest.mark.asyncio + async def test_emit_shutdown(self): + router1 = Router() + router2 = Router() + router1.include_router(router2) + + results = [] + + @router1.shutdown() + async def shutdown1(): + results.append(1) + + @router2.shutdown() + async def shutdown2(): + results.append(2) + + await router2.emit_shutdown() + assert results == [2] + + await router1.emit_shutdown() + assert results == [2, 1, 2]