Small refactoring

This commit is contained in:
Alex Root Junior 2023-10-13 00:00:36 +03:00
parent bda7fcd13b
commit c999964f11
No known key found for this signature in database
GPG key ID: 074C1D455EBEA4AC
3 changed files with 67 additions and 42 deletions

View file

@ -1,11 +1,14 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, Final, Generator, List, Optional, Set from typing import Any, Dict, Final, Generator, List, Optional, Set, Type, TYPE_CHECKING
from ..types import TelegramObject
from .event.bases import REJECTED, UNHANDLED from .event.bases import REJECTED, UNHANDLED
from .event.event import EventObserver from .event.event import EventObserver
from .event.telegram import TelegramEventObserver from .event.telegram import TelegramEventObserver
from ..types import TelegramObject
if TYPE_CHECKING:
from ..fsm.scene import Scene
INTERNAL_UPDATE_TYPES: Final[frozenset[str]] = frozenset({"update", "error"}) INTERNAL_UPDATE_TYPES: Final[frozenset[str]] = frozenset({"update", "error"})
@ -218,6 +221,29 @@ class Router:
router.parent_router = self router.parent_router = self
return router return router
def include_scenes(self, *scenes: Type[Scene], name: Optional[str] = None) -> None:
"""
Include multiple scenes to this router as sub-routers
:param scenes: scene instances
:param name: optional name for router
:return:
"""
if not scenes:
raise ValueError("At least one scene must be provided")
for scene in scenes:
self.include_scene(scene, name=name)
def include_scene(self, scene: Type[Scene], name: Optional[str] = None) -> None:
"""
Include a scene to this router as sub-router
:param scene: scene instance
:param name: optional name for router
:return:
"""
self.include_router(scene.as_router(name=name))
async def emit_startup(self, *args: Any, **kwargs: Any) -> None: async def emit_startup(self, *args: Any, **kwargs: Any) -> None:
""" """
Recursively call startup callbacks Recursively call startup callbacks

View file

@ -8,9 +8,12 @@ from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, Union
from typing_extensions import Self from typing_extensions import Self
from aiogram import Dispatcher, Router, loggers from aiogram import loggers
from aiogram.dispatcher.dispatcher import Dispatcher
from aiogram.dispatcher.event.bases import NextMiddlewareType from aiogram.dispatcher.event.bases import NextMiddlewareType
from aiogram.dispatcher.event.handler import CallableObject, CallbackType from aiogram.dispatcher.event.handler import CallableObject, CallbackType
from aiogram.dispatcher.flags import extract_flags_from_object
from aiogram.dispatcher.router import Router
from aiogram.exceptions import SceneException from aiogram.exceptions import SceneException
from aiogram.filters import StateFilter from aiogram.filters import StateFilter
from aiogram.fsm.context import FSMContext from aiogram.fsm.context import FSMContext
@ -107,18 +110,6 @@ class ObserverDecorator:
self.action = action self.action = action
self.after = after self.after = after
def _wrap_class(self, target: Type[Scene]) -> None:
if not issubclass(target, Scene):
raise TypeError("Only subclass of Scene is allowed")
if self.action is not None:
raise TypeError("This action is not allowed for class")
filters = getattr(target, "__aiogram_filters__", None)
if filters is None:
filters = defaultdict(list)
setattr(target, "__aiogram_filters__", filters)
filters[self.name].extend(self.filters)
def _wrap_filter(self, target: Type[Scene] | CallbackType) -> None: def _wrap_filter(self, target: Type[Scene] | CallbackType) -> None:
handlers = getattr(target, "__aiogram_handler__", None) handlers = getattr(target, "__aiogram_handler__", None)
if not handlers: if not handlers:
@ -134,21 +125,21 @@ class ObserverDecorator:
) )
) )
def _wrap_action(self, target: Type[Scene] | CallbackType) -> None: def _wrap_action(self, target: CallbackType) -> None:
action = getattr(target, "__aiogram_action__", None) action = getattr(target, "__aiogram_action__", None)
if action is None: if action is None:
action = defaultdict(dict) action = defaultdict(dict)
setattr(target, "__aiogram_action__", action) setattr(target, "__aiogram_action__", action)
action[self.action][self.name] = CallableObject(target) action[self.action][self.name] = CallableObject(target)
def __call__(self, target: Type[Scene] | CallbackType) -> Type[Scene] | CallbackType: def __call__(self, target: CallbackType) -> CallbackType:
if inspect.isclass(target): if inspect.isfunction(target):
self._wrap_class(target)
elif inspect.isfunction(target):
if self.action is None: if self.action is None:
self._wrap_filter(target) self._wrap_filter(target)
else: else:
self._wrap_action(target) self._wrap_action(target)
else:
raise TypeError("Only function or method is allowed")
return target return target
def leave(self) -> ActionContainer: def leave(self) -> ActionContainer:
@ -213,8 +204,6 @@ class HandlerContainer:
class SceneConfig: class SceneConfig:
state: Optional[str] state: Optional[str]
"""Scene state""" """Scene state"""
filters: Dict[str, List[CallbackType]]
"""Global scene filters"""
handlers: List[HandlerContainer] handlers: List[HandlerContainer]
"""Scene handlers""" """Scene handlers"""
actions: Dict[SceneAction, Dict[str, CallableObject]] actions: Dict[SceneAction, Dict[str, CallableObject]]
@ -316,7 +305,6 @@ class Scene:
super().__init_subclass__(**kwargs) super().__init_subclass__(**kwargs)
filters: defaultdict[str, List[CallbackType]] = defaultdict(list)
handlers: list[HandlerContainer] = [] handlers: list[HandlerContainer] = []
actions: defaultdict[SceneAction, Dict[str, CallableObject]] = defaultdict(dict) actions: defaultdict[SceneAction, Dict[str, CallableObject]] = defaultdict(dict)
@ -328,7 +316,6 @@ class Scene:
if not parent_scene_config: if not parent_scene_config:
continue continue
filters.update(parent_scene_config.filters)
handlers.extend(parent_scene_config.handlers) handlers.extend(parent_scene_config.handlers)
for action, action_handlers in parent_scene_config.actions.items(): for action, action_handlers in parent_scene_config.actions.items():
actions[action].update(action_handlers) actions[action].update(action_handlers)
@ -360,7 +347,6 @@ class Scene:
cls.__scene_config__ = SceneConfig( cls.__scene_config__ = SceneConfig(
state=state_name, state=state_name,
filters=dict(filters),
handlers=handlers, handlers=handlers,
actions=dict(actions), actions=dict(actions),
reset_data_on_enter=reset_data_on_enter, reset_data_on_enter=reset_data_on_enter,
@ -379,10 +365,6 @@ class Scene:
scene_config = cls.__scene_config__ scene_config = cls.__scene_config__
used_observers = set() used_observers = set()
for observer_name, filters in scene_config.filters.items():
router.observers[observer_name].filter(*filters)
used_observers.add(observer_name)
for handler in scene_config.handlers: for handler in scene_config.handlers:
router.observers[handler.name].register( router.observers[handler.name].register(
SceneHandlerWrapper( SceneHandlerWrapper(
@ -391,6 +373,7 @@ class Scene:
after=handler.after, after=handler.after,
), ),
*handler.filters, *handler.filters,
flags=extract_flags_from_object(handler.handler),
) )
used_observers.add(handler.name) used_observers.add(handler.name)
@ -706,15 +689,17 @@ class SceneRegistry:
A class that represents a registry for scenes in a Telegram bot. A class that represents a registry for scenes in a Telegram bot.
""" """
def __init__(self, router: Router) -> None: def __init__(self, router: Router, register_on_add: bool = False) -> None:
""" """
Initialize a new instance of the SceneRegistry class. Initialize a new instance of the SceneRegistry class.
:param router: The router instance used for scene registration. :param router: The router instance used for scene registration.
:param register_on_add: Whether to register the scenes to the router when they are added.
""" """
self.router = router self.router = router
self._scenes: Dict[Optional[str], Type[Scene]] = {} self.register_on_add = register_on_add
self._scenes: Dict[Optional[str], Type[Scene]] = {}
self._setup_middleware(router) self._setup_middleware(router)
def _setup_middleware(self, router: Router) -> None: def _setup_middleware(self, router: Router) -> None:
@ -764,23 +749,24 @@ class SceneRegistry:
def add(self, *scenes: Type[Scene], router: Optional[Router] = None) -> None: def add(self, *scenes: Type[Scene], router: Optional[Router] = None) -> None:
""" """
This method adds the specified scenes to the router. This method adds the specified scenes to the registry
If a router is not provided, it uses the default router stored and optionally registers it to the router.
in the SceneRegistry instance.
The scenes are included in the router by calling the `as_router()`
method on each scene and passing the router as a parameter to this method.
If a scene with the same state already exists in the registry, a SceneException is raised. If a scene with the same state already exists in the registry, a SceneException is raised.
.. warning::
If the router is not specified, the scenes will not be registered to the router.
You will need to include the scenes manually to the router or use the register method.
:param scenes: A variable length parameter that accepts one or more types of scenes. :param scenes: A variable length parameter that accepts one or more types of scenes.
These scenes are instances of the Scene class. These scenes are instances of the Scene class.
:param router: An optional parameter that specifies the router :param router: An optional parameter that specifies the router
to which the scenes should be added. If not provided, the scenes will be to which the scenes should be added.
added to the default router stored in the SceneRegistry instance.
:return: None :return: None
""" """
if router is None: if not scenes:
router = self.router raise ValueError("At least one scene must be specified")
for scene in scenes: for scene in scenes:
if scene.__scene_config__.state in self._scenes: if scene.__scene_config__.state in self._scenes:
@ -790,7 +776,19 @@ class SceneRegistry:
self._scenes[scene.__scene_config__.state] = scene self._scenes[scene.__scene_config__.state] = scene
router.include_router(scene.as_router()) if router:
router.include_router(scene.as_router())
elif self.register_on_add:
self.router.include_router(scene.as_router())
def register(self, *scenes: Type[Scene]) -> None:
"""
Registers one or more scenes to the SceneRegistry.
:param scenes: One or more scene classes to register.
:return: None
"""
self.add(*scenes, router=self.router)
def get(self, scene: Optional[Union[Type[Scene], str]]) -> Type[Scene]: def get(self, scene: Optional[Union[Type[Scene], str]]) -> Type[Scene]:
""" """

View file

@ -252,6 +252,7 @@ class QuizScene(Scene, state="quiz"):
quiz_router = Router(name=__name__) quiz_router = Router(name=__name__)
# Add handler that initializes the scene # Add handler that initializes the scene
quiz_router.message.register(QuizScene.as_handler(), Command("quiz")) quiz_router.message.register(QuizScene.as_handler(), Command("quiz"))
quiz_router.include_scene(QuizScene)
@quiz_router.message(Command("start")) @quiz_router.message(Command("start"))
@ -275,7 +276,7 @@ def create_dispatcher():
# ... and then register a scene in the registry # ... and then register a scene in the registry
# by default, Scene will be mounted to the router that passed to the SceneRegistry, # by default, Scene will be mounted to the router that passed to the SceneRegistry,
# but you can specify the router explicitly using the `router` argument # but you can specify the router explicitly using the `router` argument
scene_registry.add(QuizScene, router=quiz_router) scene_registry.add(QuizScene)
return dispatcher return dispatcher