mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Implement new middlewares
This commit is contained in:
parent
c262cc0ce6
commit
7f26ec9935
29 changed files with 532 additions and 1252 deletions
|
|
@ -1,9 +1,10 @@
|
|||
from pkg_resources import get_distribution
|
||||
|
||||
from .api import methods, types
|
||||
from .api.client import session
|
||||
from .api.client.bot import Bot
|
||||
from .dispatcher import filters, handler
|
||||
from .dispatcher.dispatcher import Dispatcher
|
||||
from .dispatcher.middlewares.base import BaseMiddleware
|
||||
from .dispatcher.router import Router
|
||||
|
||||
try:
|
||||
|
|
@ -23,10 +24,9 @@ __all__ = (
|
|||
"session",
|
||||
"Dispatcher",
|
||||
"Router",
|
||||
"BaseMiddleware",
|
||||
"filters",
|
||||
"handler",
|
||||
)
|
||||
|
||||
__version__ = "3.0.0a4"
|
||||
__version__ = get_distribution(dist=__package__).version
|
||||
__api_version__ = "4.8"
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ from ..api.client.bot import Bot
|
|||
from ..api.methods import TelegramMethod
|
||||
from ..api.types import Update, User
|
||||
from ..utils.exceptions import TelegramAPIError
|
||||
from .event.bases import NOT_HANDLED
|
||||
from .middlewares.update_processing_context import UserContextMiddleware
|
||||
from .router import Router
|
||||
|
||||
|
||||
|
|
@ -23,6 +25,9 @@ class Dispatcher(Router):
|
|||
super(Dispatcher, self).__init__(**kwargs)
|
||||
self._running_lock = Lock()
|
||||
|
||||
# Default middleware is needed for contextual features
|
||||
self.update.outer_middleware(UserContextMiddleware())
|
||||
|
||||
@property
|
||||
def parent_router(self) -> None:
|
||||
"""
|
||||
|
|
@ -42,9 +47,7 @@ class Dispatcher(Router):
|
|||
"""
|
||||
raise RuntimeError("Dispatcher can not be attached to another Router.")
|
||||
|
||||
async def feed_update(
|
||||
self, bot: Bot, update: Update, **kwargs: Any
|
||||
) -> AsyncGenerator[Any, None]:
|
||||
async def feed_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any:
|
||||
"""
|
||||
Main entry point for incoming updates
|
||||
|
||||
|
|
@ -57,9 +60,9 @@ class Dispatcher(Router):
|
|||
|
||||
Bot.set_current(bot)
|
||||
try:
|
||||
async for result in self.update.trigger(update, bot=bot, **kwargs):
|
||||
handled = True
|
||||
yield result
|
||||
response = await self.update.trigger(update, bot=bot, **kwargs)
|
||||
handled = response is not NOT_HANDLED
|
||||
return response
|
||||
finally:
|
||||
finish_time = loop.time()
|
||||
duration = (finish_time - start_time) * 1000
|
||||
|
|
@ -71,9 +74,7 @@ class Dispatcher(Router):
|
|||
bot.id,
|
||||
)
|
||||
|
||||
async def feed_raw_update(
|
||||
self, bot: Bot, update: Dict[str, Any], **kwargs: Any
|
||||
) -> AsyncGenerator[Any, None]:
|
||||
async def feed_raw_update(self, bot: Bot, update: Dict[str, Any], **kwargs: Any) -> Any:
|
||||
"""
|
||||
Main entry point for incoming updates with automatic Dict->Update serializer
|
||||
|
||||
|
|
@ -82,8 +83,7 @@ class Dispatcher(Router):
|
|||
:param kwargs:
|
||||
"""
|
||||
parsed_update = Update(**update)
|
||||
async for result in self.feed_update(bot=bot, update=parsed_update, **kwargs):
|
||||
yield result
|
||||
return await self.feed_update(bot=bot, update=parsed_update, **kwargs)
|
||||
|
||||
@classmethod
|
||||
async def _listen_updates(cls, bot: Bot) -> AsyncGenerator[Update, None]:
|
||||
|
|
@ -114,7 +114,7 @@ class Dispatcher(Router):
|
|||
# For debugging here is added logging.
|
||||
loggers.dispatcher.error("Failed to make answer: %s: %s", e.__class__.__name__, e)
|
||||
|
||||
async def process_update(
|
||||
async def _process_update(
|
||||
self, bot: Bot, update: Update, call_answer: bool = True, **kwargs: Any
|
||||
) -> bool:
|
||||
"""
|
||||
|
|
@ -126,11 +126,13 @@ class Dispatcher(Router):
|
|||
:param kwargs: contextual data for middlewares, filters and handlers
|
||||
:return: status
|
||||
"""
|
||||
handled = False
|
||||
try:
|
||||
async for result in self.feed_update(bot, update, **kwargs):
|
||||
if call_answer and isinstance(result, TelegramMethod):
|
||||
await self._silent_call_request(bot=bot, result=result)
|
||||
return True
|
||||
response = await self.feed_update(bot, update, **kwargs)
|
||||
handled = handled is not NOT_HANDLED
|
||||
if call_answer and isinstance(response, TelegramMethod):
|
||||
await self._silent_call_request(bot=bot, result=response)
|
||||
return handled
|
||||
|
||||
except Exception as e:
|
||||
loggers.dispatcher.exception(
|
||||
|
|
@ -142,8 +144,6 @@ class Dispatcher(Router):
|
|||
)
|
||||
return True # because update was processed but unsuccessful
|
||||
|
||||
return False
|
||||
|
||||
async def _polling(self, bot: Bot, **kwargs: Any) -> None:
|
||||
"""
|
||||
Internal polling process
|
||||
|
|
@ -153,16 +153,14 @@ class Dispatcher(Router):
|
|||
:return:
|
||||
"""
|
||||
async for update in self._listen_updates(bot):
|
||||
await self.process_update(bot=bot, update=update, **kwargs)
|
||||
await self._process_update(bot=bot, update=update, **kwargs)
|
||||
|
||||
async def _feed_webhook_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any:
|
||||
"""
|
||||
The same with `Dispatcher.process_update()` but returns real response instead of bool
|
||||
"""
|
||||
try:
|
||||
async for result in self.feed_update(bot, update, **kwargs):
|
||||
return result
|
||||
|
||||
return await self.feed_update(bot, update, **kwargs)
|
||||
except Exception as e:
|
||||
loggers.dispatcher.exception(
|
||||
"Cause exception while process update id=%d by bot id=%d\n%s: %s",
|
||||
|
|
@ -196,10 +194,10 @@ class Dispatcher(Router):
|
|||
|
||||
def process_response(task: Future[Any]) -> None:
|
||||
warnings.warn(
|
||||
f"Detected slow response into webhook.\n"
|
||||
f"Telegram is waiting for response only first 60 seconds and then re-send update.\n"
|
||||
f"For preventing this situation response into webhook returned immediately "
|
||||
f"and handler is moved to background and still processing update.",
|
||||
"Detected slow response into webhook.\n"
|
||||
"Telegram is waiting for response only first 60 seconds and then re-send update.\n"
|
||||
"For preventing this situation response into webhook returned immediately "
|
||||
"and handler is moved to background and still processing update.",
|
||||
RuntimeWarning,
|
||||
)
|
||||
try:
|
||||
|
|
|
|||
29
aiogram/dispatcher/event/bases.py
Normal file
29
aiogram/dispatcher/event/bases.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Awaitable, Callable, Dict, NoReturn, Optional, Union
|
||||
from unittest.mock import sentinel
|
||||
|
||||
from ...api.types import TelegramObject
|
||||
from ..middlewares.base import BaseMiddleware
|
||||
|
||||
NextMiddlewareType = Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]]
|
||||
MiddlewareType = Union[
|
||||
BaseMiddleware, Callable[[NextMiddlewareType, TelegramObject, Dict[str, Any]], Awaitable[Any]]
|
||||
]
|
||||
|
||||
NOT_HANDLED = sentinel.NOT_HANDLED
|
||||
|
||||
|
||||
class SkipHandler(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CancelHandler(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def skip(message: Optional[str] = None) -> NoReturn:
|
||||
"""
|
||||
Raise an SkipHandler
|
||||
"""
|
||||
raise SkipHandler(message or "Event skipped")
|
||||
39
aiogram/dispatcher/event/event.py
Normal file
39
aiogram/dispatcher/event/event.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, List
|
||||
|
||||
from .handler import CallbackType, HandlerObject, HandlerType
|
||||
|
||||
|
||||
class EventObserver:
|
||||
"""
|
||||
Simple events observer
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.handlers: List[HandlerObject] = []
|
||||
|
||||
def register(self, callback: HandlerType) -> None:
|
||||
"""
|
||||
Register callback with filters
|
||||
"""
|
||||
self.handlers.append(HandlerObject(callback=callback))
|
||||
|
||||
async def trigger(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""
|
||||
Propagate event to handlers.
|
||||
Handler will be called when all its filters is pass.
|
||||
"""
|
||||
for handler in self.handlers:
|
||||
await handler.call(*args, **kwargs)
|
||||
|
||||
def __call__(self) -> Callable[[CallbackType], CallbackType]:
|
||||
"""
|
||||
Decorator for registering event handlers
|
||||
"""
|
||||
|
||||
def wrapper(callback: CallbackType) -> CallbackType:
|
||||
self.register(callback)
|
||||
return callback
|
||||
|
||||
return wrapper
|
||||
|
|
@ -1,3 +1,5 @@
|
|||
import asyncio
|
||||
import contextvars
|
||||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
|
|
@ -6,9 +8,9 @@ from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type,
|
|||
from aiogram.dispatcher.filters.base import BaseFilter
|
||||
from aiogram.dispatcher.handler.base import BaseHandler
|
||||
|
||||
CallbackType = Callable[[Any], Awaitable[Any]]
|
||||
SyncFilter = Callable[[Any], Any]
|
||||
AsyncFilter = Callable[[Any], Awaitable[Any]]
|
||||
CallbackType = Callable[..., Awaitable[Any]]
|
||||
SyncFilter = Callable[..., Any]
|
||||
AsyncFilter = Callable[..., Awaitable[Any]]
|
||||
FilterType = Union[SyncFilter, AsyncFilter, BaseFilter]
|
||||
HandlerType = Union[FilterType, Type[BaseHandler]]
|
||||
|
||||
|
|
@ -40,7 +42,11 @@ class CallableMixin:
|
|||
wrapped = partial(self.callback, *args, **self._prepare_kwargs(kwargs))
|
||||
if self.awaitable:
|
||||
return await wrapped()
|
||||
return wrapped()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
context = contextvars.copy_context()
|
||||
wrapped = partial(context.run, wrapped)
|
||||
return await loop.run_in_executor(None, wrapped)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -60,11 +66,11 @@ class HandlerObject(CallableMixin):
|
|||
|
||||
async def check(self, *args: Any, **kwargs: Any) -> Tuple[bool, Dict[str, Any]]:
|
||||
if not self.filters:
|
||||
return True, {}
|
||||
return True, kwargs
|
||||
for event_filter in self.filters:
|
||||
check = await event_filter.call(*args, **kwargs)
|
||||
if not check:
|
||||
return False, {}
|
||||
return False, kwargs
|
||||
if isinstance(check, dict):
|
||||
kwargs.update(check)
|
||||
return True, kwargs
|
||||
|
|
|
|||
|
|
@ -1,93 +1,33 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from itertools import chain
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Type,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Type, Union
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from ...api.types import TelegramObject
|
||||
from ..filters.base import BaseFilter
|
||||
from ..middlewares.types import MiddlewareStep, UpdateType
|
||||
from .bases import NOT_HANDLED, MiddlewareType, NextMiddlewareType, SkipHandler
|
||||
from .handler import CallbackType, FilterObject, FilterType, HandlerObject, HandlerType
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from aiogram.dispatcher.router import Router
|
||||
|
||||
|
||||
class SkipHandler(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CancelHandler(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def skip(message: Optional[str] = None) -> NoReturn:
|
||||
"""
|
||||
Raise an SkipHandler
|
||||
"""
|
||||
raise SkipHandler(message or "Event skipped")
|
||||
|
||||
|
||||
class EventObserver:
|
||||
"""
|
||||
Base events observer
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.handlers: List[HandlerObject] = []
|
||||
|
||||
def register(self, callback: HandlerType) -> HandlerType:
|
||||
"""
|
||||
Register callback with filters
|
||||
"""
|
||||
self.handlers.append(HandlerObject(callback=callback))
|
||||
return callback
|
||||
|
||||
async def trigger(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]:
|
||||
"""
|
||||
Propagate event to handlers.
|
||||
Handler will be called when all its filters is pass.
|
||||
"""
|
||||
for handler in self.handlers:
|
||||
try:
|
||||
yield await handler.call(*args, **kwargs)
|
||||
except SkipHandler:
|
||||
continue
|
||||
|
||||
def __call__(self) -> Callable[[CallbackType], CallbackType]:
|
||||
"""
|
||||
Decorator for registering event handlers
|
||||
"""
|
||||
|
||||
def wrapper(callback: CallbackType) -> CallbackType:
|
||||
self.register(callback)
|
||||
return callback
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class TelegramEventObserver(EventObserver):
|
||||
class TelegramEventObserver:
|
||||
"""
|
||||
Event observer for Telegram events
|
||||
"""
|
||||
|
||||
def __init__(self, router: Router, event_name: str) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.router: Router = router
|
||||
self.event_name: str = event_name
|
||||
|
||||
self.handlers: List[HandlerObject] = []
|
||||
self.filters: List[Type[BaseFilter]] = []
|
||||
self.outer_middlewares: List[MiddlewareType] = []
|
||||
self.middlewares: List[MiddlewareType] = []
|
||||
|
||||
def bind_filter(self, bound_filter: Type[BaseFilter]) -> None:
|
||||
"""
|
||||
|
|
@ -144,37 +84,6 @@ class TelegramEventObserver(EventObserver):
|
|||
|
||||
return filters
|
||||
|
||||
async def trigger_middleware(
|
||||
self, step: MiddlewareStep, event: UpdateType, data: Dict[str, Any], result: Any = None,
|
||||
) -> None:
|
||||
"""
|
||||
Trigger middlewares chain
|
||||
|
||||
:param step:
|
||||
:param event:
|
||||
:param data:
|
||||
:param result:
|
||||
:return:
|
||||
"""
|
||||
reverse = step == MiddlewareStep.POST_PROCESS
|
||||
recursive = self.event_name == "update" or step == MiddlewareStep.PROCESS
|
||||
|
||||
if self.event_name == "update":
|
||||
routers = self.router.chain
|
||||
else:
|
||||
routers = self.router.chain_head
|
||||
for router in routers:
|
||||
await router.middleware.trigger(
|
||||
step=step,
|
||||
event_name=self.event_name,
|
||||
event=event,
|
||||
data=data,
|
||||
result=result,
|
||||
reverse=reverse,
|
||||
)
|
||||
if not recursive:
|
||||
break
|
||||
|
||||
def register(
|
||||
self, callback: HandlerType, *filters: FilterType, **bound_filters: Any
|
||||
) -> HandlerType:
|
||||
|
|
@ -190,32 +99,39 @@ class TelegramEventObserver(EventObserver):
|
|||
)
|
||||
return callback
|
||||
|
||||
async def trigger(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]:
|
||||
@classmethod
|
||||
def _wrap_middleware(
|
||||
cls, middlewares: List[MiddlewareType], handler: HandlerType
|
||||
) -> NextMiddlewareType:
|
||||
@functools.wraps(handler)
|
||||
def mapper(event: TelegramObject, kwargs: Dict[str, Any]) -> Any:
|
||||
return handler(event, **kwargs)
|
||||
|
||||
middleware = mapper
|
||||
for m in reversed(middlewares):
|
||||
middleware = functools.partial(m, middleware)
|
||||
return middleware
|
||||
|
||||
async def trigger(self, event: TelegramObject, **kwargs: Any) -> Any:
|
||||
"""
|
||||
Propagate event to handlers and stops propagation on first match.
|
||||
Handler will be called when all its filters is pass.
|
||||
"""
|
||||
event = args[0]
|
||||
await self.trigger_middleware(step=MiddlewareStep.PRE_PROCESS, event=event, data=kwargs)
|
||||
wrapped_outer = self._wrap_middleware(self.outer_middlewares, self._trigger)
|
||||
return await wrapped_outer(event, kwargs)
|
||||
|
||||
async def _trigger(self, event: TelegramObject, **kwargs: Any) -> Any:
|
||||
for handler in self.handlers:
|
||||
result, data = await handler.check(*args, **kwargs)
|
||||
result, data = await handler.check(event, **kwargs)
|
||||
if result:
|
||||
kwargs.update(data)
|
||||
await self.trigger_middleware(
|
||||
step=MiddlewareStep.PROCESS, event=event, data=kwargs
|
||||
)
|
||||
try:
|
||||
response = await handler.call(*args, **kwargs)
|
||||
await self.trigger_middleware(
|
||||
step=MiddlewareStep.POST_PROCESS,
|
||||
event=event,
|
||||
data=kwargs,
|
||||
result=response,
|
||||
)
|
||||
yield response
|
||||
wrapped_inner = self._wrap_middleware(self.middlewares, handler.call)
|
||||
return await wrapped_inner(event, kwargs)
|
||||
except SkipHandler:
|
||||
continue
|
||||
break
|
||||
|
||||
return NOT_HANDLED
|
||||
|
||||
def __call__(
|
||||
self, *args: FilterType, **bound_filters: BaseFilter
|
||||
|
|
@ -229,3 +145,45 @@ class TelegramEventObserver(EventObserver):
|
|||
return callback
|
||||
|
||||
return wrapper
|
||||
|
||||
def middleware(
|
||||
self, middleware: Optional[MiddlewareType] = None,
|
||||
) -> Union[Callable[[MiddlewareType], MiddlewareType], MiddlewareType]:
|
||||
"""
|
||||
Decorator for registering inner middlewares
|
||||
|
||||
Usage:
|
||||
>>> @<event>.middleware() # via decorator (variant 1)
|
||||
>>> @<event>.middleware # via decorator (variant 2)
|
||||
>>> async def my_middleware(handler, event, data): ...
|
||||
>>> <event>.middleware(middleware) # via method
|
||||
"""
|
||||
|
||||
def wrapper(m: MiddlewareType) -> MiddlewareType:
|
||||
self.middlewares.append(m)
|
||||
return m
|
||||
|
||||
if middleware is None:
|
||||
return wrapper
|
||||
return wrapper(middleware)
|
||||
|
||||
def outer_middleware(
|
||||
self, middleware: Optional[MiddlewareType] = None,
|
||||
) -> Union[Callable[[MiddlewareType], MiddlewareType], MiddlewareType]:
|
||||
"""
|
||||
Decorator for registering outer middlewares
|
||||
|
||||
Usage:
|
||||
>>> @<event>.outer_middleware() # via decorator (variant 1)
|
||||
>>> @<event>.outer_middleware # via decorator (variant 2)
|
||||
>>> async def my_middleware(handler, event, data): ...
|
||||
>>> <event>.outer_middleware(my_middleware) # via method
|
||||
"""
|
||||
|
||||
def wrapper(m: MiddlewareType) -> MiddlewareType:
|
||||
self.outer_middlewares.append(m)
|
||||
return m
|
||||
|
||||
if middleware is None:
|
||||
return wrapper
|
||||
return wrapper(middleware)
|
||||
|
|
@ -1,61 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from aiogram.dispatcher.middlewares.types import MiddlewareStep, UpdateType
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from aiogram.dispatcher.middlewares.manager import MiddlewareManager
|
||||
|
||||
|
||||
class AbstractMiddleware(ABC):
|
||||
"""
|
||||
Abstract class for middleware.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._manager: Optional[MiddlewareManager] = None
|
||||
|
||||
@property
|
||||
def manager(self) -> MiddlewareManager:
|
||||
"""
|
||||
Instance of MiddlewareManager
|
||||
"""
|
||||
if self._manager is None:
|
||||
raise RuntimeError("Middleware is not configured!")
|
||||
return self._manager
|
||||
|
||||
def setup(self, manager: MiddlewareManager, _stack_level: int = 1) -> AbstractMiddleware:
|
||||
"""
|
||||
Mark middleware as configured
|
||||
|
||||
:param manager:
|
||||
:param _stack_level:
|
||||
:return:
|
||||
"""
|
||||
if self.configured:
|
||||
return manager.setup(self, _stack_level=_stack_level + 1)
|
||||
|
||||
self._manager = manager
|
||||
return self
|
||||
|
||||
@property
|
||||
def configured(self) -> bool:
|
||||
"""
|
||||
Check middleware is configured
|
||||
|
||||
:return:
|
||||
"""
|
||||
return bool(self._manager)
|
||||
|
||||
@abstractmethod
|
||||
async def trigger(
|
||||
self,
|
||||
step: MiddlewareStep,
|
||||
event_name: str,
|
||||
event: UpdateType,
|
||||
data: Dict[str, Any],
|
||||
result: Any = None,
|
||||
) -> Any: # pragma: no cover
|
||||
pass
|
||||
|
|
@ -1,317 +1,15 @@
|
|||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Awaitable, Callable, Dict, Generic, TypeVar
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
|
||||
from aiogram.dispatcher.middlewares.abstract import AbstractMiddleware
|
||||
from aiogram.dispatcher.middlewares.types import MiddlewareStep, UpdateType
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from aiogram.api.types import (
|
||||
CallbackQuery,
|
||||
ChosenInlineResult,
|
||||
InlineQuery,
|
||||
Message,
|
||||
Poll,
|
||||
PollAnswer,
|
||||
PreCheckoutQuery,
|
||||
ShippingQuery,
|
||||
Update,
|
||||
)
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class BaseMiddleware(AbstractMiddleware):
|
||||
"""
|
||||
Base class for middleware.
|
||||
|
||||
All methods on the middle always must be coroutines and name starts with "on_" like "on_process_message".
|
||||
"""
|
||||
|
||||
async def trigger(
|
||||
class BaseMiddleware(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
async def __call__(
|
||||
self,
|
||||
step: MiddlewareStep,
|
||||
event_name: str,
|
||||
event: UpdateType,
|
||||
handler: Callable[[T, Dict[str, Any]], Awaitable[Any]],
|
||||
event: T,
|
||||
data: Dict[str, Any],
|
||||
result: Any = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Trigger action.
|
||||
|
||||
:param step:
|
||||
:param event_name:
|
||||
:param event:
|
||||
:param data:
|
||||
:param result:
|
||||
:return:
|
||||
"""
|
||||
handler_name = f"on_{step.value}_{event_name}"
|
||||
handler = getattr(self, handler_name, None)
|
||||
if not handler:
|
||||
return None
|
||||
args = (event, result, data) if step == MiddlewareStep.POST_PROCESS else (event, data)
|
||||
return await handler(*args)
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
# =============================================================================================
|
||||
# Event that triggers before process <event>
|
||||
# =============================================================================================
|
||||
async def on_pre_process_update(self, update: Update, data: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Event that triggers before process update
|
||||
"""
|
||||
|
||||
async def on_pre_process_message(self, message: Message, data: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Event that triggers before process message
|
||||
"""
|
||||
|
||||
async def on_pre_process_edited_message(
|
||||
self, edited_message: Message, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Event that triggers before process edited_message
|
||||
"""
|
||||
|
||||
async def on_pre_process_channel_post(
|
||||
self, channel_post: Message, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Event that triggers before process channel_post
|
||||
"""
|
||||
|
||||
async def on_pre_process_edited_channel_post(
|
||||
self, edited_channel_post: Message, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Event that triggers before process edited_channel_post
|
||||
"""
|
||||
|
||||
async def on_pre_process_inline_query(
|
||||
self, inline_query: InlineQuery, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Event that triggers before process inline_query
|
||||
"""
|
||||
|
||||
async def on_pre_process_chosen_inline_result(
|
||||
self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Event that triggers before process chosen_inline_result
|
||||
"""
|
||||
|
||||
async def on_pre_process_callback_query(
|
||||
self, callback_query: CallbackQuery, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Event that triggers before process callback_query
|
||||
"""
|
||||
|
||||
async def on_pre_process_shipping_query(
|
||||
self, shipping_query: ShippingQuery, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Event that triggers before process shipping_query
|
||||
"""
|
||||
|
||||
async def on_pre_process_pre_checkout_query(
|
||||
self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Event that triggers before process pre_checkout_query
|
||||
"""
|
||||
|
||||
async def on_pre_process_poll(self, poll: Poll, data: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Event that triggers before process poll
|
||||
"""
|
||||
|
||||
async def on_pre_process_poll_answer(
|
||||
self, poll_answer: PollAnswer, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Event that triggers before process poll_answer
|
||||
"""
|
||||
|
||||
async def on_pre_process_error(self, exception: Exception, data: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Event that triggers before process error
|
||||
"""
|
||||
|
||||
# =============================================================================================
|
||||
# Event that triggers on process <event> after filters.
|
||||
# =============================================================================================
|
||||
async def on_process_update(self, update: Update, data: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Event that triggers on process update
|
||||
"""
|
||||
|
||||
async def on_process_message(self, message: Message, data: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Event that triggers on process message
|
||||
"""
|
||||
|
||||
async def on_process_edited_message(
|
||||
self, edited_message: Message, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Event that triggers on process edited_message
|
||||
"""
|
||||
|
||||
async def on_process_channel_post(
|
||||
self, channel_post: Message, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Event that triggers on process channel_post
|
||||
"""
|
||||
|
||||
async def on_process_edited_channel_post(
|
||||
self, edited_channel_post: Message, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Event that triggers on process edited_channel_post
|
||||
"""
|
||||
|
||||
async def on_process_inline_query(
|
||||
self, inline_query: InlineQuery, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Event that triggers on process inline_query
|
||||
"""
|
||||
|
||||
async def on_process_chosen_inline_result(
|
||||
self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Event that triggers on process chosen_inline_result
|
||||
"""
|
||||
|
||||
async def on_process_callback_query(
|
||||
self, callback_query: CallbackQuery, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Event that triggers on process callback_query
|
||||
"""
|
||||
|
||||
async def on_process_shipping_query(
|
||||
self, shipping_query: ShippingQuery, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Event that triggers on process shipping_query
|
||||
"""
|
||||
|
||||
async def on_process_pre_checkout_query(
|
||||
self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Event that triggers on process pre_checkout_query
|
||||
"""
|
||||
|
||||
async def on_process_poll(self, poll: Poll, data: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Event that triggers on process poll
|
||||
"""
|
||||
|
||||
async def on_process_poll_answer(
|
||||
self, poll_answer: PollAnswer, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Event that triggers on process poll_answer
|
||||
"""
|
||||
|
||||
async def on_process_error(self, exception: Exception, data: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Event that triggers on process error
|
||||
"""
|
||||
|
||||
# =============================================================================================
|
||||
# Event that triggers after process <event>.
|
||||
# =============================================================================================
|
||||
async def on_post_process_update(
|
||||
self, update: Update, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
"""
|
||||
Event that triggers after processing update
|
||||
"""
|
||||
|
||||
async def on_post_process_message(
|
||||
self, message: Message, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
"""
|
||||
Event that triggers after processing message
|
||||
"""
|
||||
|
||||
async def on_post_process_edited_message(
|
||||
self, edited_message: Message, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
"""
|
||||
Event that triggers after processing edited_message
|
||||
"""
|
||||
|
||||
async def on_post_process_channel_post(
|
||||
self, channel_post: Message, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
"""
|
||||
Event that triggers after processing channel_post
|
||||
"""
|
||||
|
||||
async def on_post_process_edited_channel_post(
|
||||
self, edited_channel_post: Message, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
"""
|
||||
Event that triggers after processing edited_channel_post
|
||||
"""
|
||||
|
||||
async def on_post_process_inline_query(
|
||||
self, inline_query: InlineQuery, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
"""
|
||||
Event that triggers after processing inline_query
|
||||
"""
|
||||
|
||||
async def on_post_process_chosen_inline_result(
|
||||
self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
"""
|
||||
Event that triggers after processing chosen_inline_result
|
||||
"""
|
||||
|
||||
async def on_post_process_callback_query(
|
||||
self, callback_query: CallbackQuery, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
"""
|
||||
Event that triggers after processing callback_query
|
||||
"""
|
||||
|
||||
async def on_post_process_shipping_query(
|
||||
self, shipping_query: ShippingQuery, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
"""
|
||||
Event that triggers after processing shipping_query
|
||||
"""
|
||||
|
||||
async def on_post_process_pre_checkout_query(
|
||||
self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
"""
|
||||
Event that triggers after processing pre_checkout_query
|
||||
"""
|
||||
|
||||
async def on_post_process_poll(self, poll: Poll, data: Dict[str, Any], result: Any) -> Any:
|
||||
"""
|
||||
Event that triggers after processing poll
|
||||
"""
|
||||
|
||||
async def on_post_process_poll_answer(
|
||||
self, poll_answer: PollAnswer, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
"""
|
||||
Event that triggers after processing poll_answer
|
||||
"""
|
||||
|
||||
async def on_post_process_error(
|
||||
self, exception: Exception, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
"""
|
||||
Event that triggers after processing error
|
||||
"""
|
||||
) -> Any: # pragma: no cover
|
||||
pass
|
||||
|
|
|
|||
31
aiogram/dispatcher/middlewares/error.py
Normal file
31
aiogram/dispatcher/middlewares/error.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict
|
||||
|
||||
from ...api.types import Update
|
||||
from ..event.bases import NOT_HANDLED, CancelHandler, SkipHandler
|
||||
from .base import BaseMiddleware
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from ..router import Router
|
||||
|
||||
|
||||
class ErrorsMiddleware(BaseMiddleware[Update]):
|
||||
def __init__(self, router: Router):
|
||||
self.router = router
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
handler: Callable[[Any, Dict[str, Any]], Awaitable[Any]],
|
||||
event: Any,
|
||||
data: Dict[str, Any],
|
||||
) -> Any:
|
||||
try:
|
||||
return await handler(event, data)
|
||||
except (SkipHandler, CancelHandler): # pragma: no cover
|
||||
raise
|
||||
except Exception as e:
|
||||
response = await self.router.errors.trigger(event, exception=e, **data)
|
||||
if response is NOT_HANDLED:
|
||||
raise
|
||||
return response
|
||||
|
|
@ -1,71 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
from warnings import warn
|
||||
|
||||
from .abstract import AbstractMiddleware
|
||||
from .types import MiddlewareStep, UpdateType
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from aiogram.dispatcher.router import Router
|
||||
|
||||
|
||||
class MiddlewareManager:
|
||||
"""
|
||||
Middleware manager.
|
||||
"""
|
||||
|
||||
def __init__(self, router: Router) -> None:
|
||||
self.router = router
|
||||
self.middlewares: List[AbstractMiddleware] = []
|
||||
|
||||
def setup(self, middleware: AbstractMiddleware, _stack_level: int = 1) -> AbstractMiddleware:
|
||||
"""
|
||||
Setup middleware
|
||||
|
||||
:param middleware:
|
||||
:param _stack_level:
|
||||
:return:
|
||||
"""
|
||||
if not isinstance(middleware, AbstractMiddleware):
|
||||
raise TypeError(
|
||||
f"`middleware` should be instance of BaseMiddleware, not {type(middleware)}"
|
||||
)
|
||||
if middleware.configured:
|
||||
if middleware.manager is self:
|
||||
warn(
|
||||
f"Middleware {middleware} is already configured for this Router "
|
||||
"That's mean re-installing of this middleware has no effect.",
|
||||
category=RuntimeWarning,
|
||||
stacklevel=_stack_level + 1,
|
||||
)
|
||||
return middleware
|
||||
raise ValueError(
|
||||
f"Middleware is already configured for another manager {middleware.manager} "
|
||||
f"in router {middleware.manager.router}!"
|
||||
)
|
||||
|
||||
self.middlewares.append(middleware)
|
||||
middleware.setup(self)
|
||||
return middleware
|
||||
|
||||
async def trigger(
|
||||
self,
|
||||
step: MiddlewareStep,
|
||||
event_name: str,
|
||||
event: UpdateType,
|
||||
data: Dict[str, Any],
|
||||
result: Any = None,
|
||||
reverse: bool = False,
|
||||
) -> Any:
|
||||
"""
|
||||
Call action to middlewares with args lilt.
|
||||
"""
|
||||
middlewares = reversed(self.middlewares) if reverse else self.middlewares
|
||||
for middleware in middlewares:
|
||||
await middleware.trigger(
|
||||
step=step, event_name=event_name, event=event, data=data, result=result
|
||||
)
|
||||
|
||||
def __contains__(self, item: AbstractMiddleware) -> bool:
|
||||
return item in self.middlewares
|
||||
|
|
@ -1,35 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Union
|
||||
|
||||
from aiogram.api.types import (
|
||||
CallbackQuery,
|
||||
ChosenInlineResult,
|
||||
InlineQuery,
|
||||
Message,
|
||||
Poll,
|
||||
PollAnswer,
|
||||
PreCheckoutQuery,
|
||||
ShippingQuery,
|
||||
Update,
|
||||
)
|
||||
|
||||
UpdateType = Union[
|
||||
CallbackQuery,
|
||||
ChosenInlineResult,
|
||||
InlineQuery,
|
||||
Message,
|
||||
Poll,
|
||||
PollAnswer,
|
||||
PreCheckoutQuery,
|
||||
ShippingQuery,
|
||||
Update,
|
||||
BaseException,
|
||||
]
|
||||
|
||||
|
||||
class MiddlewareStep(Enum):
|
||||
PRE_PROCESS = "pre_process"
|
||||
PROCESS = "process"
|
||||
POST_PROCESS = "post_process"
|
||||
62
aiogram/dispatcher/middlewares/update_processing_context.py
Normal file
62
aiogram/dispatcher/middlewares/update_processing_context.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
from contextlib import contextmanager
|
||||
from typing import Any, Awaitable, Callable, Dict, Iterator, Optional, Tuple
|
||||
|
||||
from aiogram.api.types import Chat, Update, User
|
||||
from aiogram.dispatcher.middlewares.base import BaseMiddleware
|
||||
|
||||
|
||||
class UserContextMiddleware(BaseMiddleware[Update]):
|
||||
async def __call__(
|
||||
self,
|
||||
handler: Callable[[Update, Dict[str, Any]], Awaitable[Any]],
|
||||
event: Update,
|
||||
data: Dict[str, Any],
|
||||
) -> Any:
|
||||
chat, user = self.resolve_event_context(event=event)
|
||||
with self.context(chat=chat, user=user):
|
||||
return await handler(event, data)
|
||||
|
||||
@contextmanager
|
||||
def context(self, chat: Optional[Chat] = None, user: Optional[User] = None) -> Iterator[None]:
|
||||
chat_token = None
|
||||
user_token = None
|
||||
if chat:
|
||||
chat_token = chat.set_current(chat)
|
||||
if user:
|
||||
user_token = user.set_current(user)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if chat and chat_token:
|
||||
chat.reset_current(chat_token)
|
||||
if user and user_token:
|
||||
user.reset_current(user_token)
|
||||
|
||||
@classmethod
|
||||
def resolve_event_context(cls, event: Update) -> Tuple[Optional[Chat], Optional[User]]:
|
||||
"""
|
||||
Resolve chat and user instance from Update object
|
||||
"""
|
||||
if event.message:
|
||||
return event.message.chat, event.message.from_user
|
||||
if event.edited_message:
|
||||
return event.edited_message.chat, event.edited_message.from_user
|
||||
if event.channel_post:
|
||||
return event.channel_post.chat, None
|
||||
if event.edited_channel_post:
|
||||
return event.edited_channel_post.chat, None
|
||||
if event.inline_query:
|
||||
return None, event.inline_query.from_user
|
||||
if event.chosen_inline_result:
|
||||
return None, event.chosen_inline_result.from_user
|
||||
if event.callback_query:
|
||||
if event.callback_query.message:
|
||||
return event.callback_query.message.chat, event.callback_query.from_user
|
||||
return None, event.callback_query.from_user
|
||||
if event.shipping_query:
|
||||
return None, event.shipping_query.from_user
|
||||
if event.pre_checkout_query:
|
||||
return None, event.pre_checkout_query.from_user
|
||||
if event.poll_answer:
|
||||
return None, event.poll_answer.user
|
||||
return None, None
|
||||
|
|
@ -3,13 +3,14 @@ from __future__ import annotations
|
|||
import warnings
|
||||
from typing import Any, Dict, Generator, List, Optional, Union
|
||||
|
||||
from ..api.types import Chat, TelegramObject, Update, User
|
||||
from ..api.types import TelegramObject, Update
|
||||
from ..utils.imports import import_module
|
||||
from ..utils.warnings import CodeHasNoEffect
|
||||
from .event.observer import EventObserver, SkipHandler, TelegramEventObserver
|
||||
from .event.bases import NOT_HANDLED, SkipHandler
|
||||
from .event.event import EventObserver
|
||||
from .event.telegram import TelegramEventObserver
|
||||
from .filters import BUILTIN_FILTERS
|
||||
from .middlewares.abstract import AbstractMiddleware
|
||||
from .middlewares.manager import MiddlewareManager
|
||||
from .middlewares.error import ErrorsMiddleware
|
||||
|
||||
|
||||
class Router:
|
||||
|
|
@ -44,8 +45,6 @@ class Router:
|
|||
self.poll_answer = TelegramEventObserver(router=self, event_name="poll_answer")
|
||||
self.errors = TelegramEventObserver(router=self, event_name="error")
|
||||
|
||||
self.middleware = MiddlewareManager(router=self)
|
||||
|
||||
self.startup = EventObserver()
|
||||
self.shutdown = EventObserver()
|
||||
|
||||
|
|
@ -68,6 +67,8 @@ class Router:
|
|||
# Root handler
|
||||
self.update.register(self._listen_update)
|
||||
|
||||
self.update.outer_middleware(ErrorsMiddleware(self))
|
||||
|
||||
# Builtin filters
|
||||
if use_builtin_filters:
|
||||
for name, observer in self.observers.items():
|
||||
|
|
@ -94,16 +95,6 @@ class Router:
|
|||
next(tail) # Skip self
|
||||
yield from tail
|
||||
|
||||
def use(self, middleware: AbstractMiddleware, _stack_level: int = 1) -> AbstractMiddleware:
|
||||
"""
|
||||
Use middleware
|
||||
|
||||
:param middleware:
|
||||
:param _stack_level:
|
||||
:return:
|
||||
"""
|
||||
return self.middleware.setup(middleware, _stack_level=_stack_level + 1)
|
||||
|
||||
@property
|
||||
def parent_router(self) -> Optional[Router]:
|
||||
return self._parent_router
|
||||
|
|
@ -176,53 +167,40 @@ class Router:
|
|||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
chat: Optional[Chat] = None
|
||||
from_user: Optional[User] = None
|
||||
|
||||
event: TelegramObject
|
||||
if update.message:
|
||||
update_type = "message"
|
||||
from_user = update.message.from_user
|
||||
chat = update.message.chat
|
||||
event = update.message
|
||||
elif update.edited_message:
|
||||
update_type = "edited_message"
|
||||
from_user = update.edited_message.from_user
|
||||
chat = update.edited_message.chat
|
||||
event = update.edited_message
|
||||
elif update.channel_post:
|
||||
update_type = "channel_post"
|
||||
chat = update.channel_post.chat
|
||||
event = update.channel_post
|
||||
elif update.edited_channel_post:
|
||||
update_type = "edited_channel_post"
|
||||
chat = update.edited_channel_post.chat
|
||||
event = update.edited_channel_post
|
||||
elif update.inline_query:
|
||||
update_type = "inline_query"
|
||||
from_user = update.inline_query.from_user
|
||||
event = update.inline_query
|
||||
elif update.chosen_inline_result:
|
||||
update_type = "chosen_inline_result"
|
||||
from_user = update.chosen_inline_result.from_user
|
||||
event = update.chosen_inline_result
|
||||
elif update.callback_query:
|
||||
update_type = "callback_query"
|
||||
if update.callback_query.message:
|
||||
chat = update.callback_query.message.chat
|
||||
from_user = update.callback_query.from_user
|
||||
event = update.callback_query
|
||||
elif update.shipping_query:
|
||||
update_type = "shipping_query"
|
||||
from_user = update.shipping_query.from_user
|
||||
event = update.shipping_query
|
||||
elif update.pre_checkout_query:
|
||||
update_type = "pre_checkout_query"
|
||||
from_user = update.pre_checkout_query.from_user
|
||||
event = update.pre_checkout_query
|
||||
elif update.poll:
|
||||
update_type = "poll"
|
||||
event = update.poll
|
||||
elif update.poll_answer:
|
||||
update_type = "poll_answer"
|
||||
event = update.poll_answer
|
||||
else:
|
||||
warnings.warn(
|
||||
"Detected unknown update type.\n"
|
||||
|
|
@ -232,76 +210,17 @@ class Router:
|
|||
)
|
||||
raise SkipHandler
|
||||
|
||||
return await self.listen_update(
|
||||
update_type=update_type,
|
||||
update=update,
|
||||
event=event,
|
||||
from_user=from_user,
|
||||
chat=chat,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def listen_update(
|
||||
self,
|
||||
update_type: str,
|
||||
update: Update,
|
||||
event: TelegramObject,
|
||||
from_user: Optional[User] = None,
|
||||
chat: Optional[Chat] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
Listen update by current and child routers
|
||||
|
||||
:param update_type:
|
||||
:param update:
|
||||
:param event:
|
||||
:param from_user:
|
||||
:param chat:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
user_token = None
|
||||
if from_user:
|
||||
user_token = User.set_current(from_user)
|
||||
chat_token = None
|
||||
if chat:
|
||||
chat_token = Chat.set_current(chat)
|
||||
|
||||
kwargs.update(event_update=update, event_router=self)
|
||||
observer = self.observers[update_type]
|
||||
try:
|
||||
async for result in observer.trigger(event, update=update, **kwargs):
|
||||
return result
|
||||
response = await observer.trigger(event, update=update, **kwargs)
|
||||
|
||||
if response is NOT_HANDLED: # Resolve nested routers
|
||||
for router in self.sub_routers:
|
||||
try:
|
||||
return await router.listen_update(
|
||||
update_type=update_type,
|
||||
update=update,
|
||||
event=event,
|
||||
from_user=from_user,
|
||||
chat=chat,
|
||||
**kwargs,
|
||||
)
|
||||
except SkipHandler:
|
||||
response = await router.update.trigger(event=update, **kwargs)
|
||||
if response is NOT_HANDLED:
|
||||
continue
|
||||
|
||||
raise SkipHandler
|
||||
|
||||
except SkipHandler:
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
async for result in self.errors.trigger(e, **kwargs):
|
||||
return result
|
||||
raise
|
||||
|
||||
finally:
|
||||
if user_token:
|
||||
User.reset_current(user_token)
|
||||
if chat_token:
|
||||
Chat.reset_current(chat_token)
|
||||
return response
|
||||
|
||||
async def emit_startup(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""
|
||||
|
|
@ -312,8 +231,7 @@ class Router:
|
|||
:return:
|
||||
"""
|
||||
kwargs.update(router=self)
|
||||
async for _ in self.startup.trigger(*args, **kwargs): # pragma: no cover
|
||||
pass
|
||||
await self.startup.trigger(*args, **kwargs)
|
||||
for router in self.sub_routers:
|
||||
await router.emit_startup(*args, **kwargs)
|
||||
|
||||
|
|
@ -326,8 +244,7 @@ class Router:
|
|||
:return:
|
||||
"""
|
||||
kwargs.update(router=self)
|
||||
async for _ in self.shutdown.trigger(*args, **kwargs): # pragma: no cover
|
||||
pass
|
||||
await self.shutdown.trigger(*args, **kwargs)
|
||||
for router in self.sub_routers:
|
||||
await router.emit_shutdown(*args, **kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import contextvars
|
||||
from typing import Any, ClassVar, Dict, Generic, Optional, TypeVar, cast, overload
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Generic, Optional, TypeVar, cast, overload
|
||||
|
||||
from typing_extensions import Literal
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from typing_extensions import Literal
|
||||
|
||||
__all__ = ("ContextInstanceMixin", "DataMixin")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue