mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Small optimizations
This commit is contained in:
parent
278697297e
commit
bcdf8ea9da
3 changed files with 80 additions and 67 deletions
|
|
@ -1,22 +1,17 @@
|
||||||
from dataclasses import dataclass, replace
|
from dataclasses import replace
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from aiogram import loggers
|
from aiogram import loggers
|
||||||
from aiogram.fsm.context import FSMContext
|
from aiogram.fsm.context import FSMContext
|
||||||
|
from aiogram.fsm.storage.memory import MemoryStorageRecord
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class StateContainer:
|
|
||||||
state: Optional[str]
|
|
||||||
data: Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
class HistoryManager:
|
class HistoryManager:
|
||||||
def __init__(self, context: FSMContext, destiny: str = "history", size: int = 10):
|
def __init__(self, state: FSMContext, destiny: str = "scenes_history", size: int = 10):
|
||||||
self._size = size
|
self._size = size
|
||||||
self._state = context
|
self._state = state
|
||||||
self._history_state = FSMContext(
|
self._history_state = FSMContext(
|
||||||
storage=context.storage, key=replace(context.key, destiny=destiny)
|
storage=state.storage, key=replace(state.key, destiny=destiny)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def push(self, state: Optional[str], data: Dict[str, Any]) -> None:
|
async def push(self, state: Optional[str], data: Dict[str, Any]) -> None:
|
||||||
|
|
@ -25,10 +20,14 @@ class HistoryManager:
|
||||||
history.append({"state": state, "data": data})
|
history.append({"state": state, "data": data})
|
||||||
if len(history) > self._size:
|
if len(history) > self._size:
|
||||||
history = history[-self._size :]
|
history = history[-self._size :]
|
||||||
loggers.scene.debug("Push state=%s data=%s", state, data)
|
loggers.scene.debug("Push state=%s data=%s to history", state, data)
|
||||||
await self._history_state.update_data(history=history)
|
|
||||||
|
|
||||||
async def pop(self) -> Optional[StateContainer]:
|
if not history:
|
||||||
|
await self._history_state.set_data({})
|
||||||
|
else:
|
||||||
|
await self._history_state.update_data(history=history)
|
||||||
|
|
||||||
|
async def pop(self) -> Optional[MemoryStorageRecord]:
|
||||||
history_data = await self._history_state.get_data()
|
history_data = await self._history_state.get_data()
|
||||||
history = history_data.setdefault("history", [])
|
history = history_data.setdefault("history", [])
|
||||||
if not history:
|
if not history:
|
||||||
|
|
@ -36,41 +35,48 @@ class HistoryManager:
|
||||||
record = history.pop()
|
record = history.pop()
|
||||||
state = record["state"]
|
state = record["state"]
|
||||||
data = record["data"]
|
data = record["data"]
|
||||||
await self._history_state.update_data(history=history)
|
if not history:
|
||||||
loggers.scene.debug("Pop state=%s data=%s", state, data)
|
await self._history_state.set_data({})
|
||||||
return StateContainer(state=state, data=data)
|
else:
|
||||||
|
await self._history_state.update_data(history=history)
|
||||||
|
loggers.scene.debug("Pop state=%s data=%s from history", state, data)
|
||||||
|
return MemoryStorageRecord(state=state, data=data)
|
||||||
|
|
||||||
async def get(self) -> Optional[StateContainer]:
|
async def get(self) -> Optional[MemoryStorageRecord]:
|
||||||
history_data = await self._history_state.get_data()
|
history_data = await self._history_state.get_data()
|
||||||
history = history_data.setdefault("history", [])
|
history = history_data.setdefault("history", [])
|
||||||
if not history:
|
if not history:
|
||||||
return None
|
return None
|
||||||
return StateContainer(**history[-1])
|
return MemoryStorageRecord(**history[-1])
|
||||||
|
|
||||||
async def all(self) -> List[StateContainer]:
|
async def all(self) -> List[MemoryStorageRecord]:
|
||||||
history_data = await self._history_state.get_data()
|
history_data = await self._history_state.get_data()
|
||||||
history = history_data.setdefault("history", [])
|
history = history_data.setdefault("history", [])
|
||||||
return [StateContainer(**item) for item in history]
|
return [MemoryStorageRecord(**item) for item in history]
|
||||||
|
|
||||||
async def clear(self) -> None:
|
async def clear(self) -> None:
|
||||||
loggers.scene.debug("Clear history")
|
loggers.scene.debug("Clear history")
|
||||||
await self._history_state.clear()
|
await self._history_state.set_data({})
|
||||||
|
|
||||||
async def snapshot(self) -> None:
|
async def snapshot(self) -> None:
|
||||||
state = await self._state.get_state()
|
state = await self._state.get_state()
|
||||||
data = await self._state.get_data()
|
data = await self._state.get_data()
|
||||||
await self.push(state, data)
|
await self.push(state, data)
|
||||||
|
|
||||||
|
async def _set_state(self, state: Optional[str], data: Dict[str, Any]) -> None:
|
||||||
|
await self._state.set_state(state)
|
||||||
|
await self._state.set_data(data)
|
||||||
|
|
||||||
async def rollback(self) -> Optional[str]:
|
async def rollback(self) -> Optional[str]:
|
||||||
state_container = await self.pop()
|
previous_state = await self.pop()
|
||||||
if not state_container:
|
if not previous_state:
|
||||||
|
await self._set_state(None, {})
|
||||||
return None
|
return None
|
||||||
|
|
||||||
loggers.scene.debug(
|
loggers.scene.debug(
|
||||||
"Rollback to state=%s data=%s",
|
"Rollback to state=%s data=%s",
|
||||||
state_container.state,
|
previous_state.state,
|
||||||
state_container.data,
|
previous_state.data,
|
||||||
)
|
)
|
||||||
await self._state.set_state(state_container.state)
|
await self._set_state(previous_state.state, previous_state.data)
|
||||||
await self._state.set_data(state_container.data)
|
return previous_state.state
|
||||||
return state_container.state
|
|
||||||
|
|
|
||||||
|
|
@ -4,22 +4,11 @@ import inspect
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import (
|
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, Union
|
||||||
TYPE_CHECKING,
|
|
||||||
Any,
|
|
||||||
ClassVar,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
Union,
|
|
||||||
cast,
|
|
||||||
)
|
|
||||||
|
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from aiogram import Router, loggers
|
from aiogram import Dispatcher, Router, loggers
|
||||||
from aiogram.dispatcher.event.bases import NextMiddlewareType
|
from aiogram.dispatcher.event.bases import NextMiddlewareType
|
||||||
from aiogram.dispatcher.event.handler import CallableObject, CallbackType
|
from aiogram.dispatcher.event.handler import CallableObject, CallbackType
|
||||||
from aiogram.filters import StateFilter
|
from aiogram.filters import StateFilter
|
||||||
|
|
@ -325,24 +314,23 @@ class SceneWizard:
|
||||||
await self.state.set_state(self.scene_config.state)
|
await self.state.set_state(self.scene_config.state)
|
||||||
await self._on_action(SceneAction.enter, **kwargs)
|
await self._on_action(SceneAction.enter, **kwargs)
|
||||||
|
|
||||||
async def leave(self, **kwargs: Any) -> None:
|
async def leave(self, _with_history: bool = True, **kwargs: Any) -> None:
|
||||||
loggers.scene.debug("Leaving scene %r", self.scene_config.state)
|
loggers.scene.debug("Leaving scene %r", self.scene_config.state)
|
||||||
await self.manager.history.snapshot()
|
if _with_history:
|
||||||
|
await self.manager.history.snapshot()
|
||||||
await self._on_action(SceneAction.leave, **kwargs)
|
await self._on_action(SceneAction.leave, **kwargs)
|
||||||
|
|
||||||
async def exit(self, **kwargs: Any) -> None:
|
async def exit(self, **kwargs: Any) -> None:
|
||||||
loggers.scene.debug("Exiting scene %r", self.scene_config.state)
|
loggers.scene.debug("Exiting scene %r", self.scene_config.state)
|
||||||
await self.state.set_state(None)
|
|
||||||
await self.manager.history.clear()
|
await self.manager.history.clear()
|
||||||
await self._on_action(SceneAction.exit, **kwargs)
|
await self._on_action(SceneAction.exit, **kwargs)
|
||||||
await self.manager.enter(None, _check_active=False, _with_history=False, **kwargs)
|
await self.manager.enter(None, _check_active=False, **kwargs)
|
||||||
|
|
||||||
async def back(self, **kwargs: Any) -> None:
|
async def back(self, **kwargs: Any) -> None:
|
||||||
loggers.scene.debug("Back to previous scene from scene %s", self.scene_config.state)
|
loggers.scene.debug("Back to previous scene from scene %s", self.scene_config.state)
|
||||||
await self.leave()
|
await self.leave(_with_history=False, **kwargs)
|
||||||
await self.manager.history.rollback()
|
|
||||||
new_scene = await self.manager.history.rollback()
|
new_scene = await self.manager.history.rollback()
|
||||||
await self.manager.enter(new_scene, _check_active=False, _with_history=False, **kwargs)
|
await self.manager.enter(new_scene, _check_active=False, **kwargs)
|
||||||
|
|
||||||
async def replay(self, event: TelegramObject) -> None:
|
async def replay(self, event: TelegramObject) -> None:
|
||||||
await self._on_action(SceneAction.enter, event=event)
|
await self._on_action(SceneAction.enter, event=event)
|
||||||
|
|
@ -430,16 +418,20 @@ class ScenesManager:
|
||||||
self,
|
self,
|
||||||
scene_type: Optional[Union[Type[Scene], str]],
|
scene_type: Optional[Union[Type[Scene], str]],
|
||||||
_check_active: bool = True,
|
_check_active: bool = True,
|
||||||
_with_history: bool = True,
|
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
scene = await self._get_scene(scene_type)
|
scene = await self._get_scene(scene_type)
|
||||||
|
|
||||||
if _check_active:
|
if _check_active:
|
||||||
active_scene = await self._get_active_scene()
|
active_scene = await self._get_active_scene()
|
||||||
if active_scene is not None:
|
if active_scene is not None:
|
||||||
await active_scene.wizard.exit(**kwargs)
|
await active_scene.wizard.exit(**kwargs)
|
||||||
|
|
||||||
await scene.wizard.enter(_with_history=_with_history, **kwargs)
|
if not scene:
|
||||||
|
loggers.scene.debug("Reset state")
|
||||||
|
await self.state.set_state(None)
|
||||||
|
else:
|
||||||
|
await scene.wizard.enter(**kwargs)
|
||||||
|
|
||||||
async def close(self, **kwargs: Any) -> None:
|
async def close(self, **kwargs: Any) -> None:
|
||||||
try:
|
try:
|
||||||
|
|
@ -454,13 +446,38 @@ class ScenesManager:
|
||||||
class SceneRegistry:
|
class SceneRegistry:
|
||||||
def __init__(self, router: Router) -> None:
|
def __init__(self, router: Router) -> None:
|
||||||
self.router = router
|
self.router = router
|
||||||
|
self._scenes: Dict[Optional[str], Type[Scene]] = {}
|
||||||
|
|
||||||
|
self._setup_middleware(router)
|
||||||
|
|
||||||
|
def _setup_middleware(self, router: Router) -> None:
|
||||||
|
if isinstance(router, Dispatcher):
|
||||||
|
# Small optimization for Dispatcher
|
||||||
|
# - we don't need to set up middleware for all observers
|
||||||
|
router.update.outer_middleware(self._update_middleware)
|
||||||
|
return
|
||||||
|
|
||||||
for observer in router.observers.values():
|
for observer in router.observers.values():
|
||||||
if observer.event_name in {"update", "error"}:
|
if observer.event_name in {"update", "error"}:
|
||||||
continue
|
continue
|
||||||
observer.outer_middleware(self._middleware)
|
observer.outer_middleware(self._middleware)
|
||||||
|
|
||||||
self._scenes: Dict[Optional[str], Type[Scene]] = {}
|
async def _update_middleware(
|
||||||
|
self,
|
||||||
|
handler: NextMiddlewareType[TelegramObject],
|
||||||
|
event: TelegramObject,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
) -> Any:
|
||||||
|
assert isinstance(event, Update), "Event must be an Update instance"
|
||||||
|
|
||||||
|
data["scenes"] = ScenesManager(
|
||||||
|
registry=self,
|
||||||
|
update_type=event.event_type,
|
||||||
|
event=event.event,
|
||||||
|
state=data["state"],
|
||||||
|
data=data,
|
||||||
|
)
|
||||||
|
return await handler(event, data)
|
||||||
|
|
||||||
async def _middleware(
|
async def _middleware(
|
||||||
self,
|
self,
|
||||||
|
|
@ -497,12 +514,10 @@ class SceneRegistry:
|
||||||
raise TypeError("Scene must be a subclass of Scene or a string")
|
raise TypeError("Scene must be a subclass of Scene or a string")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = self._scenes[scene]
|
return self._scenes[scene]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise ValueError(f"Scene {scene!r} is not registered")
|
raise ValueError(f"Scene {scene!r} is not registered")
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class After:
|
class After:
|
||||||
|
|
|
||||||
|
|
@ -37,11 +37,7 @@ class CancellableScene(Scene):
|
||||||
|
|
||||||
@on.message(F.text.casefold() == BUTTON_BACK.text.casefold(), after=After.back())
|
@on.message(F.text.casefold() == BUTTON_BACK.text.casefold(), after=After.back())
|
||||||
async def handle_back(self, message: Message):
|
async def handle_back(self, message: Message):
|
||||||
await message.answer("Back.", reply_markup=ReplyKeyboardRemove())
|
await message.answer("Back.")
|
||||||
|
|
||||||
@on.message.exit()
|
|
||||||
async def on_exit(self, message: Message):
|
|
||||||
await self.wizard.clear_data()
|
|
||||||
|
|
||||||
|
|
||||||
class LanguageScene(CancellableScene, state="language"):
|
class LanguageScene(CancellableScene, state="language"):
|
||||||
|
|
@ -99,10 +95,7 @@ class LikeBotsScene(CancellableScene, state="like_bots"):
|
||||||
|
|
||||||
@on.message(F.text.casefold() == "yes", after=After.goto(LanguageScene))
|
@on.message(F.text.casefold() == "yes", after=After.goto(LanguageScene))
|
||||||
async def process_like_write_bots(self, message: Message):
|
async def process_like_write_bots(self, message: Message):
|
||||||
await message.reply(
|
await message.reply("Cool! I'm too!")
|
||||||
"Cool! I'm too!",
|
|
||||||
reply_markup=ReplyKeyboardRemove(),
|
|
||||||
)
|
|
||||||
|
|
||||||
@on.message(F.text.casefold() == "no", after=After.exit())
|
@on.message(F.text.casefold() == "no", after=After.exit())
|
||||||
async def process_dont_like_write_bots(self, message: Message):
|
async def process_dont_like_write_bots(self, message: Message):
|
||||||
|
|
@ -135,10 +128,9 @@ class NameScene(CancellableScene, state="name"):
|
||||||
|
|
||||||
@on.message.leave() # Marker for handler that should be called when a user leaves the scene.
|
@on.message.leave() # Marker for handler that should be called when a user leaves the scene.
|
||||||
async def on_leave(self, message: Message):
|
async def on_leave(self, message: Message):
|
||||||
await message.answer(
|
data: FSMData = await self.wizard.get_data()
|
||||||
f"Nice to meet you, {html.quote(message.text)}!",
|
name = data.get("name", "Anonymous")
|
||||||
reply_markup=ReplyKeyboardRemove(),
|
await message.answer(f"Nice to meet you, {html.quote(name)}!")
|
||||||
)
|
|
||||||
|
|
||||||
@on.message(after=After.goto(LikeBotsScene))
|
@on.message(after=After.goto(LikeBotsScene))
|
||||||
async def input_name(self, message: Message):
|
async def input_name(self, message: Message):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue