Small optimizations

This commit is contained in:
Alex Root Junior 2023-08-26 17:32:36 +03:00
parent 278697297e
commit bcdf8ea9da
No known key found for this signature in database
GPG key ID: 074C1D455EBEA4AC
3 changed files with 80 additions and 67 deletions

View file

@ -1,22 +1,17 @@
from dataclasses import dataclass, replace from dataclasses import replace
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from aiogram import loggers from aiogram import loggers
from aiogram.fsm.context import FSMContext from aiogram.fsm.context import FSMContext
from aiogram.fsm.storage.memory import MemoryStorageRecord
@dataclass
class StateContainer:
state: Optional[str]
data: Dict[str, Any]
class HistoryManager: class HistoryManager:
def __init__(self, context: FSMContext, destiny: str = "history", size: int = 10): def __init__(self, state: FSMContext, destiny: str = "scenes_history", size: int = 10):
self._size = size self._size = size
self._state = context self._state = state
self._history_state = FSMContext( self._history_state = FSMContext(
storage=context.storage, key=replace(context.key, destiny=destiny) storage=state.storage, key=replace(state.key, destiny=destiny)
) )
async def push(self, state: Optional[str], data: Dict[str, Any]) -> None: async def push(self, state: Optional[str], data: Dict[str, Any]) -> None:
@ -25,10 +20,14 @@ class HistoryManager:
history.append({"state": state, "data": data}) history.append({"state": state, "data": data})
if len(history) > self._size: if len(history) > self._size:
history = history[-self._size :] history = history[-self._size :]
loggers.scene.debug("Push state=%s data=%s", state, data) loggers.scene.debug("Push state=%s data=%s to history", state, data)
await self._history_state.update_data(history=history)
async def pop(self) -> Optional[StateContainer]: if not history:
await self._history_state.set_data({})
else:
await self._history_state.update_data(history=history)
async def pop(self) -> Optional[MemoryStorageRecord]:
history_data = await self._history_state.get_data() history_data = await self._history_state.get_data()
history = history_data.setdefault("history", []) history = history_data.setdefault("history", [])
if not history: if not history:
@ -36,41 +35,48 @@ class HistoryManager:
record = history.pop() record = history.pop()
state = record["state"] state = record["state"]
data = record["data"] data = record["data"]
await self._history_state.update_data(history=history) if not history:
loggers.scene.debug("Pop state=%s data=%s", state, data) await self._history_state.set_data({})
return StateContainer(state=state, data=data) else:
await self._history_state.update_data(history=history)
loggers.scene.debug("Pop state=%s data=%s from history", state, data)
return MemoryStorageRecord(state=state, data=data)
async def get(self) -> Optional[StateContainer]: async def get(self) -> Optional[MemoryStorageRecord]:
history_data = await self._history_state.get_data() history_data = await self._history_state.get_data()
history = history_data.setdefault("history", []) history = history_data.setdefault("history", [])
if not history: if not history:
return None return None
return StateContainer(**history[-1]) return MemoryStorageRecord(**history[-1])
async def all(self) -> List[StateContainer]: async def all(self) -> List[MemoryStorageRecord]:
history_data = await self._history_state.get_data() history_data = await self._history_state.get_data()
history = history_data.setdefault("history", []) history = history_data.setdefault("history", [])
return [StateContainer(**item) for item in history] return [MemoryStorageRecord(**item) for item in history]
async def clear(self) -> None: async def clear(self) -> None:
loggers.scene.debug("Clear history") loggers.scene.debug("Clear history")
await self._history_state.clear() await self._history_state.set_data({})
async def snapshot(self) -> None: async def snapshot(self) -> None:
state = await self._state.get_state() state = await self._state.get_state()
data = await self._state.get_data() data = await self._state.get_data()
await self.push(state, data) await self.push(state, data)
async def _set_state(self, state: Optional[str], data: Dict[str, Any]) -> None:
await self._state.set_state(state)
await self._state.set_data(data)
async def rollback(self) -> Optional[str]: async def rollback(self) -> Optional[str]:
state_container = await self.pop() previous_state = await self.pop()
if not state_container: if not previous_state:
await self._set_state(None, {})
return None return None
loggers.scene.debug( loggers.scene.debug(
"Rollback to state=%s data=%s", "Rollback to state=%s data=%s",
state_container.state, previous_state.state,
state_container.data, previous_state.data,
) )
await self._state.set_state(state_container.state) await self._set_state(previous_state.state, previous_state.data)
await self._state.set_data(state_container.data) return previous_state.state
return state_container.state

View file

@ -4,22 +4,11 @@ import inspect
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from typing import ( from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, Union
TYPE_CHECKING,
Any,
ClassVar,
Dict,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
from typing_extensions import Self from typing_extensions import Self
from aiogram import Router, loggers from aiogram import Dispatcher, Router, loggers
from aiogram.dispatcher.event.bases import NextMiddlewareType from aiogram.dispatcher.event.bases import NextMiddlewareType
from aiogram.dispatcher.event.handler import CallableObject, CallbackType from aiogram.dispatcher.event.handler import CallableObject, CallbackType
from aiogram.filters import StateFilter from aiogram.filters import StateFilter
@ -325,24 +314,23 @@ class SceneWizard:
await self.state.set_state(self.scene_config.state) await self.state.set_state(self.scene_config.state)
await self._on_action(SceneAction.enter, **kwargs) await self._on_action(SceneAction.enter, **kwargs)
async def leave(self, **kwargs: Any) -> None: async def leave(self, _with_history: bool = True, **kwargs: Any) -> None:
loggers.scene.debug("Leaving scene %r", self.scene_config.state) loggers.scene.debug("Leaving scene %r", self.scene_config.state)
await self.manager.history.snapshot() if _with_history:
await self.manager.history.snapshot()
await self._on_action(SceneAction.leave, **kwargs) await self._on_action(SceneAction.leave, **kwargs)
async def exit(self, **kwargs: Any) -> None: async def exit(self, **kwargs: Any) -> None:
loggers.scene.debug("Exiting scene %r", self.scene_config.state) loggers.scene.debug("Exiting scene %r", self.scene_config.state)
await self.state.set_state(None)
await self.manager.history.clear() await self.manager.history.clear()
await self._on_action(SceneAction.exit, **kwargs) await self._on_action(SceneAction.exit, **kwargs)
await self.manager.enter(None, _check_active=False, _with_history=False, **kwargs) await self.manager.enter(None, _check_active=False, **kwargs)
async def back(self, **kwargs: Any) -> None: async def back(self, **kwargs: Any) -> None:
loggers.scene.debug("Back to previous scene from scene %s", self.scene_config.state) loggers.scene.debug("Back to previous scene from scene %s", self.scene_config.state)
await self.leave() await self.leave(_with_history=False, **kwargs)
await self.manager.history.rollback()
new_scene = await self.manager.history.rollback() new_scene = await self.manager.history.rollback()
await self.manager.enter(new_scene, _check_active=False, _with_history=False, **kwargs) await self.manager.enter(new_scene, _check_active=False, **kwargs)
async def replay(self, event: TelegramObject) -> None: async def replay(self, event: TelegramObject) -> None:
await self._on_action(SceneAction.enter, event=event) await self._on_action(SceneAction.enter, event=event)
@ -430,16 +418,20 @@ class ScenesManager:
self, self,
scene_type: Optional[Union[Type[Scene], str]], scene_type: Optional[Union[Type[Scene], str]],
_check_active: bool = True, _check_active: bool = True,
_with_history: bool = True,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
scene = await self._get_scene(scene_type) scene = await self._get_scene(scene_type)
if _check_active: if _check_active:
active_scene = await self._get_active_scene() active_scene = await self._get_active_scene()
if active_scene is not None: if active_scene is not None:
await active_scene.wizard.exit(**kwargs) await active_scene.wizard.exit(**kwargs)
await scene.wizard.enter(_with_history=_with_history, **kwargs) if not scene:
loggers.scene.debug("Reset state")
await self.state.set_state(None)
else:
await scene.wizard.enter(**kwargs)
async def close(self, **kwargs: Any) -> None: async def close(self, **kwargs: Any) -> None:
try: try:
@ -454,13 +446,38 @@ class ScenesManager:
class SceneRegistry: class SceneRegistry:
def __init__(self, router: Router) -> None: def __init__(self, router: Router) -> None:
self.router = router self.router = router
self._scenes: Dict[Optional[str], Type[Scene]] = {}
self._setup_middleware(router)
def _setup_middleware(self, router: Router) -> None:
if isinstance(router, Dispatcher):
# Small optimization for Dispatcher
# - we don't need to set up middleware for all observers
router.update.outer_middleware(self._update_middleware)
return
for observer in router.observers.values(): for observer in router.observers.values():
if observer.event_name in {"update", "error"}: if observer.event_name in {"update", "error"}:
continue continue
observer.outer_middleware(self._middleware) observer.outer_middleware(self._middleware)
self._scenes: Dict[Optional[str], Type[Scene]] = {} async def _update_middleware(
self,
handler: NextMiddlewareType[TelegramObject],
event: TelegramObject,
data: Dict[str, Any],
) -> Any:
assert isinstance(event, Update), "Event must be an Update instance"
data["scenes"] = ScenesManager(
registry=self,
update_type=event.event_type,
event=event.event,
state=data["state"],
data=data,
)
return await handler(event, data)
async def _middleware( async def _middleware(
self, self,
@ -497,12 +514,10 @@ class SceneRegistry:
raise TypeError("Scene must be a subclass of Scene or a string") raise TypeError("Scene must be a subclass of Scene or a string")
try: try:
result = self._scenes[scene] return self._scenes[scene]
except KeyError: except KeyError:
raise ValueError(f"Scene {scene!r} is not registered") raise ValueError(f"Scene {scene!r} is not registered")
return result
@dataclass @dataclass
class After: class After:

View file

@ -37,11 +37,7 @@ class CancellableScene(Scene):
@on.message(F.text.casefold() == BUTTON_BACK.text.casefold(), after=After.back()) @on.message(F.text.casefold() == BUTTON_BACK.text.casefold(), after=After.back())
async def handle_back(self, message: Message): async def handle_back(self, message: Message):
await message.answer("Back.", reply_markup=ReplyKeyboardRemove()) await message.answer("Back.")
@on.message.exit()
async def on_exit(self, message: Message):
await self.wizard.clear_data()
class LanguageScene(CancellableScene, state="language"): class LanguageScene(CancellableScene, state="language"):
@ -99,10 +95,7 @@ class LikeBotsScene(CancellableScene, state="like_bots"):
@on.message(F.text.casefold() == "yes", after=After.goto(LanguageScene)) @on.message(F.text.casefold() == "yes", after=After.goto(LanguageScene))
async def process_like_write_bots(self, message: Message): async def process_like_write_bots(self, message: Message):
await message.reply( await message.reply("Cool! I'm too!")
"Cool! I'm too!",
reply_markup=ReplyKeyboardRemove(),
)
@on.message(F.text.casefold() == "no", after=After.exit()) @on.message(F.text.casefold() == "no", after=After.exit())
async def process_dont_like_write_bots(self, message: Message): async def process_dont_like_write_bots(self, message: Message):
@ -135,10 +128,9 @@ class NameScene(CancellableScene, state="name"):
@on.message.leave() # Marker for handler that should be called when a user leaves the scene. @on.message.leave() # Marker for handler that should be called when a user leaves the scene.
async def on_leave(self, message: Message): async def on_leave(self, message: Message):
await message.answer( data: FSMData = await self.wizard.get_data()
f"Nice to meet you, {html.quote(message.text)}!", name = data.get("name", "Anonymous")
reply_markup=ReplyKeyboardRemove(), await message.answer(f"Nice to meet you, {html.quote(name)}!")
)
@on.message(after=After.goto(LikeBotsScene)) @on.message(after=After.goto(LikeBotsScene))
async def input_name(self, message: Message): async def input_name(self, message: Message):