mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Dev 3.x flat package (#961)
* Move packages * Added changelog * Update examples/echo_bot.py Co-authored-by: Oleg A. <t0rr@mail.ru> * Rename `handler` -> `handlers` * Update __init__.py Co-authored-by: Oleg A. <t0rr@mail.ru>
This commit is contained in:
parent
5e7932ca20
commit
4315ecf1a2
111 changed files with 376 additions and 390 deletions
0
aiogram/fsm/__init__.py
Normal file
0
aiogram/fsm/__init__.py
Normal file
34
aiogram/fsm/context.py
Normal file
34
aiogram/fsm/context.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
from typing import Any, Dict, Optional
|
||||
|
||||
from aiogram import Bot
|
||||
from aiogram.fsm.storage.base import BaseStorage, StateType, StorageKey
|
||||
|
||||
|
||||
class FSMContext:
|
||||
def __init__(self, bot: Bot, storage: BaseStorage, key: StorageKey) -> None:
|
||||
self.bot = bot
|
||||
self.storage = storage
|
||||
self.key = key
|
||||
|
||||
async def set_state(self, state: StateType = None) -> None:
|
||||
await self.storage.set_state(bot=self.bot, key=self.key, state=state)
|
||||
|
||||
async def get_state(self) -> Optional[str]:
|
||||
return await self.storage.get_state(bot=self.bot, key=self.key)
|
||||
|
||||
async def set_data(self, data: Dict[str, Any]) -> None:
|
||||
await self.storage.set_data(bot=self.bot, key=self.key, data=data)
|
||||
|
||||
async def get_data(self) -> Dict[str, Any]:
|
||||
return await self.storage.get_data(bot=self.bot, key=self.key)
|
||||
|
||||
async def update_data(
|
||||
self, data: Optional[Dict[str, Any]] = None, **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
if data:
|
||||
kwargs.update(data)
|
||||
return await self.storage.update_data(bot=self.bot, key=self.key, data=kwargs)
|
||||
|
||||
async def clear(self) -> None:
|
||||
await self.set_state(state=None)
|
||||
await self.set_data({})
|
||||
86
aiogram/fsm/middleware.py
Normal file
86
aiogram/fsm/middleware.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
from typing import Any, Awaitable, Callable, Dict, Optional, cast
|
||||
|
||||
from aiogram import Bot
|
||||
from aiogram.dispatcher.middlewares.base import BaseMiddleware
|
||||
from aiogram.fsm.context import FSMContext
|
||||
from aiogram.fsm.storage.base import DEFAULT_DESTINY, BaseEventIsolation, BaseStorage, StorageKey
|
||||
from aiogram.fsm.strategy import FSMStrategy, apply_strategy
|
||||
from aiogram.types import TelegramObject
|
||||
|
||||
|
||||
class FSMContextMiddleware(BaseMiddleware):
|
||||
def __init__(
|
||||
self,
|
||||
storage: BaseStorage,
|
||||
events_isolation: BaseEventIsolation,
|
||||
strategy: FSMStrategy = FSMStrategy.USER_IN_CHAT,
|
||||
) -> None:
|
||||
self.storage = storage
|
||||
self.strategy = strategy
|
||||
self.events_isolation = events_isolation
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]],
|
||||
event: TelegramObject,
|
||||
data: Dict[str, Any],
|
||||
) -> Any:
|
||||
bot: Bot = cast(Bot, data["bot"])
|
||||
context = self.resolve_event_context(bot, data)
|
||||
data["fsm_storage"] = self.storage
|
||||
if context:
|
||||
data.update({"state": context, "raw_state": await context.get_state()})
|
||||
async with self.events_isolation.lock(bot=bot, key=context.key):
|
||||
return await handler(event, data)
|
||||
return await handler(event, data)
|
||||
|
||||
def resolve_event_context(
|
||||
self,
|
||||
bot: Bot,
|
||||
data: Dict[str, Any],
|
||||
destiny: str = DEFAULT_DESTINY,
|
||||
) -> Optional[FSMContext]:
|
||||
user = data.get("event_from_user")
|
||||
chat = data.get("event_chat")
|
||||
chat_id = chat.id if chat else None
|
||||
user_id = user.id if user else None
|
||||
return self.resolve_context(bot=bot, chat_id=chat_id, user_id=user_id, destiny=destiny)
|
||||
|
||||
def resolve_context(
|
||||
self,
|
||||
bot: Bot,
|
||||
chat_id: Optional[int],
|
||||
user_id: Optional[int],
|
||||
destiny: str = DEFAULT_DESTINY,
|
||||
) -> Optional[FSMContext]:
|
||||
if chat_id is None:
|
||||
chat_id = user_id
|
||||
|
||||
if chat_id is not None and user_id is not None:
|
||||
chat_id, user_id = apply_strategy(
|
||||
chat_id=chat_id, user_id=user_id, strategy=self.strategy
|
||||
)
|
||||
return self.get_context(bot=bot, chat_id=chat_id, user_id=user_id, destiny=destiny)
|
||||
return None
|
||||
|
||||
def get_context(
|
||||
self,
|
||||
bot: Bot,
|
||||
chat_id: int,
|
||||
user_id: int,
|
||||
destiny: str = DEFAULT_DESTINY,
|
||||
) -> FSMContext:
|
||||
return FSMContext(
|
||||
bot=bot,
|
||||
storage=self.storage,
|
||||
key=StorageKey(
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
bot_id=bot.id,
|
||||
destiny=destiny,
|
||||
),
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.storage.close()
|
||||
await self.events_isolation.close()
|
||||
150
aiogram/fsm/state.py
Normal file
150
aiogram/fsm/state.py
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
import inspect
|
||||
from typing import Any, Iterator, Optional, Tuple, Type, no_type_check
|
||||
|
||||
from aiogram.types import TelegramObject
|
||||
|
||||
|
||||
class State:
|
||||
"""
|
||||
State object
|
||||
"""
|
||||
|
||||
def __init__(self, state: Optional[str] = None, group_name: Optional[str] = None) -> None:
|
||||
self._state = state
|
||||
self._group_name = group_name
|
||||
self._group: Optional[Type[StatesGroup]] = None
|
||||
|
||||
@property
|
||||
def group(self) -> "Type[StatesGroup]":
|
||||
if not self._group:
|
||||
raise RuntimeError("This state is not in any group.")
|
||||
return self._group
|
||||
|
||||
@property
|
||||
def state(self) -> Optional[str]:
|
||||
if self._state is None or self._state == "*":
|
||||
return self._state
|
||||
|
||||
if self._group_name is None and self._group:
|
||||
group = self._group.__full_group_name__
|
||||
elif self._group_name:
|
||||
group = self._group_name
|
||||
else:
|
||||
group = "@"
|
||||
|
||||
return f"{group}:{self._state}"
|
||||
|
||||
def set_parent(self, group: "Type[StatesGroup]") -> None:
|
||||
if not issubclass(group, StatesGroup):
|
||||
raise ValueError("Group must be subclass of StatesGroup")
|
||||
self._group = group
|
||||
|
||||
def __set_name__(self, owner: "Type[StatesGroup]", name: str) -> None:
|
||||
if self._state is None:
|
||||
self._state = name
|
||||
self.set_parent(owner)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"<State '{self.state or ''}'>"
|
||||
|
||||
__repr__ = __str__
|
||||
|
||||
def __call__(self, event: TelegramObject, raw_state: Optional[str] = None) -> bool:
|
||||
if self.state == "*":
|
||||
return True
|
||||
return raw_state == self.state
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, self.__class__):
|
||||
return self.state == other.state
|
||||
if isinstance(other, str):
|
||||
return self.state == other
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.state)
|
||||
|
||||
|
||||
class StatesGroupMeta(type):
|
||||
__parent__: "Optional[Type[StatesGroup]]"
|
||||
__childs__: "Tuple[Type[StatesGroup], ...]"
|
||||
__states__: Tuple[State, ...]
|
||||
__state_names__: Tuple[str, ...]
|
||||
|
||||
@no_type_check
|
||||
def __new__(mcs, name, bases, namespace, **kwargs):
|
||||
cls = super(StatesGroupMeta, mcs).__new__(mcs, name, bases, namespace)
|
||||
|
||||
states = []
|
||||
childs = []
|
||||
|
||||
for name, arg in namespace.items():
|
||||
if isinstance(arg, State):
|
||||
states.append(arg)
|
||||
elif inspect.isclass(arg) and issubclass(arg, StatesGroup):
|
||||
childs.append(arg)
|
||||
arg.__parent__ = cls
|
||||
|
||||
cls.__parent__ = None
|
||||
cls.__childs__ = tuple(childs)
|
||||
cls.__states__ = tuple(states)
|
||||
cls.__state_names__ = tuple(state.state for state in states)
|
||||
|
||||
return cls
|
||||
|
||||
@property
|
||||
def __full_group_name__(cls) -> str:
|
||||
if cls.__parent__:
|
||||
return ".".join((cls.__parent__.__full_group_name__, cls.__name__))
|
||||
return cls.__name__
|
||||
|
||||
@property
|
||||
def __all_childs__(cls) -> Tuple[Type["StatesGroup"], ...]:
|
||||
result = cls.__childs__
|
||||
for child in cls.__childs__:
|
||||
result += child.__childs__
|
||||
return result
|
||||
|
||||
@property
|
||||
def __all_states__(cls) -> Tuple[State, ...]:
|
||||
result = cls.__states__
|
||||
for group in cls.__childs__:
|
||||
result += group.__all_states__
|
||||
return result
|
||||
|
||||
@property
|
||||
def __all_states_names__(cls) -> Tuple[str, ...]:
|
||||
return tuple(state.state for state in cls.__all_states__ if state.state)
|
||||
|
||||
def __contains__(cls, item: Any) -> bool:
|
||||
if isinstance(item, str):
|
||||
return item in cls.__all_states_names__
|
||||
if isinstance(item, State):
|
||||
return item in cls.__all_states__
|
||||
if isinstance(item, StatesGroupMeta):
|
||||
return item in cls.__all_childs__
|
||||
return False
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"<StatesGroup '{self.__full_group_name__}'>"
|
||||
|
||||
def __iter__(self) -> Iterator[State]:
|
||||
return iter(self.__all_states__)
|
||||
|
||||
|
||||
class StatesGroup(metaclass=StatesGroupMeta):
|
||||
@classmethod
|
||||
def get_root(cls) -> Type["StatesGroup"]:
|
||||
if cls.__parent__ is None:
|
||||
return cls
|
||||
return cls.__parent__.get_root()
|
||||
|
||||
def __call__(self, event: TelegramObject, raw_state: Optional[str] = None) -> bool:
|
||||
return raw_state in type(self).__all_states_names__
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"StatesGroup {type(self).__full_group_name__}"
|
||||
|
||||
|
||||
default_state = State()
|
||||
any_state = State(state="*")
|
||||
0
aiogram/fsm/storage/__init__.py
Normal file
0
aiogram/fsm/storage/__init__.py
Normal file
109
aiogram/fsm/storage/base.py
Normal file
109
aiogram/fsm/storage/base.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncGenerator, Dict, Optional, Union
|
||||
|
||||
from aiogram import Bot
|
||||
from aiogram.fsm.state import State
|
||||
|
||||
StateType = Optional[Union[str, State]]
|
||||
|
||||
DEFAULT_DESTINY = "default"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StorageKey:
|
||||
bot_id: int
|
||||
chat_id: int
|
||||
user_id: int
|
||||
destiny: str = DEFAULT_DESTINY
|
||||
|
||||
|
||||
class BaseStorage(ABC):
|
||||
"""
|
||||
Base class for all FSM storages
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def set_state(self, bot: Bot, key: StorageKey, state: StateType = None) -> None:
|
||||
"""
|
||||
Set state for specified key
|
||||
|
||||
:param bot: instance of the current bot
|
||||
:param key: storage key
|
||||
:param state: new state
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_state(self, bot: Bot, key: StorageKey) -> Optional[str]:
|
||||
"""
|
||||
Get key state
|
||||
|
||||
:param bot: instance of the current bot
|
||||
:param key: storage key
|
||||
:return: current state
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_data(self, bot: Bot, key: StorageKey, data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Write data (replace)
|
||||
|
||||
:param bot: instance of the current bot
|
||||
:param key: storage key
|
||||
:param data: new data
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_data(self, bot: Bot, key: StorageKey) -> Dict[str, Any]:
|
||||
"""
|
||||
Get current data for key
|
||||
|
||||
:param bot: instance of the current bot
|
||||
:param key: storage key
|
||||
:return: current data
|
||||
"""
|
||||
pass
|
||||
|
||||
async def update_data(self, bot: Bot, key: StorageKey, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Update date in the storage for key (like dict.update)
|
||||
|
||||
:param bot: instance of the current bot
|
||||
:param key: storage key
|
||||
:param data: partial data
|
||||
:return: new data
|
||||
"""
|
||||
current_data = await self.get_data(bot=bot, key=key)
|
||||
current_data.update(data)
|
||||
await self.set_data(bot=bot, key=key, data=current_data)
|
||||
return current_data.copy()
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None: # pragma: no cover
|
||||
"""
|
||||
Close storage (database connection, file or etc.)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class BaseEventIsolation(ABC):
|
||||
@abstractmethod
|
||||
@asynccontextmanager
|
||||
async def lock(self, bot: Bot, key: StorageKey) -> AsyncGenerator[None, None]:
|
||||
"""
|
||||
Isolate events with lock.
|
||||
Will be used as context manager
|
||||
|
||||
:param bot: instance of the current bot
|
||||
:param key: storage key
|
||||
:return: An async generator
|
||||
"""
|
||||
yield None
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
70
aiogram/fsm/storage/memory.py
Normal file
70
aiogram/fsm/storage/memory.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
from asyncio import Lock
|
||||
from collections import defaultdict
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, AsyncGenerator, DefaultDict, Dict, Hashable, Optional
|
||||
|
||||
from aiogram import Bot
|
||||
from aiogram.fsm.state import State
|
||||
from aiogram.fsm.storage.base import BaseEventIsolation, BaseStorage, StateType, StorageKey
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryStorageRecord:
|
||||
data: Dict[str, Any] = field(default_factory=dict)
|
||||
state: Optional[str] = None
|
||||
|
||||
|
||||
class MemoryStorage(BaseStorage):
|
||||
"""
|
||||
Default FSM storage, stores all data in :class:`dict` and loss everything on shutdown
|
||||
|
||||
.. warning::
|
||||
|
||||
Is not recommended using in production in due to you will lose all data
|
||||
when your bot restarts
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.storage: DefaultDict[StorageKey, MemoryStorageRecord] = defaultdict(
|
||||
MemoryStorageRecord
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
async def set_state(self, bot: Bot, key: StorageKey, state: StateType = None) -> None:
|
||||
self.storage[key].state = state.state if isinstance(state, State) else state
|
||||
|
||||
async def get_state(self, bot: Bot, key: StorageKey) -> Optional[str]:
|
||||
return self.storage[key].state
|
||||
|
||||
async def set_data(self, bot: Bot, key: StorageKey, data: Dict[str, Any]) -> None:
|
||||
self.storage[key].data = data.copy()
|
||||
|
||||
async def get_data(self, bot: Bot, key: StorageKey) -> Dict[str, Any]:
|
||||
return self.storage[key].data.copy()
|
||||
|
||||
|
||||
class DisabledEventIsolation(BaseEventIsolation):
|
||||
@asynccontextmanager
|
||||
async def lock(self, bot: Bot, key: StorageKey) -> AsyncGenerator[None, None]:
|
||||
yield
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class SimpleEventIsolation(BaseEventIsolation):
|
||||
def __init__(self) -> None:
|
||||
# TODO: Unused locks cleaner is needed
|
||||
self._locks: DefaultDict[Hashable, Lock] = defaultdict(Lock)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lock(self, bot: Bot, key: StorageKey) -> AsyncGenerator[None, None]:
|
||||
lock = self._locks[key]
|
||||
async with lock:
|
||||
yield
|
||||
|
||||
async def close(self) -> None:
|
||||
self._locks.clear()
|
||||
231
aiogram/fsm/storage/redis.py
Normal file
231
aiogram/fsm/storage/redis.py
Normal file
|
|
@ -0,0 +1,231 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, AsyncGenerator, Dict, Literal, Optional, cast
|
||||
|
||||
from redis.asyncio.client import Redis
|
||||
from redis.asyncio.connection import ConnectionPool
|
||||
from redis.asyncio.lock import Lock
|
||||
from redis.typing import ExpiryT
|
||||
|
||||
from aiogram import Bot
|
||||
from aiogram.fsm.state import State
|
||||
from aiogram.fsm.storage.base import (
|
||||
DEFAULT_DESTINY,
|
||||
BaseEventIsolation,
|
||||
BaseStorage,
|
||||
StateType,
|
||||
StorageKey,
|
||||
)
|
||||
|
||||
DEFAULT_REDIS_LOCK_KWARGS = {"timeout": 60}
|
||||
|
||||
|
||||
class KeyBuilder(ABC):
|
||||
"""
|
||||
Base class for Redis key builder
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def build(self, key: StorageKey, part: Literal["data", "state", "lock"]) -> str:
|
||||
"""
|
||||
This method should be implemented in subclasses
|
||||
|
||||
:param key: contextual key
|
||||
:param part: part of the record
|
||||
:return: key to be used in Redis queries
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DefaultKeyBuilder(KeyBuilder):
|
||||
"""
|
||||
Simple Redis key builder with default prefix.
|
||||
|
||||
Generates a colon-joined string with prefix, chat_id, user_id,
|
||||
optional bot_id and optional destiny.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
prefix: str = "fsm",
|
||||
separator: str = ":",
|
||||
with_bot_id: bool = False,
|
||||
with_destiny: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
:param prefix: prefix for all records
|
||||
:param separator: separator
|
||||
:param with_bot_id: include Bot id in the key
|
||||
:param with_destiny: include destiny key
|
||||
"""
|
||||
self.prefix = prefix
|
||||
self.separator = separator
|
||||
self.with_bot_id = with_bot_id
|
||||
self.with_destiny = with_destiny
|
||||
|
||||
def build(self, key: StorageKey, part: Literal["data", "state", "lock"]) -> str:
|
||||
parts = [self.prefix]
|
||||
if self.with_bot_id:
|
||||
parts.append(str(key.bot_id))
|
||||
parts.extend([str(key.chat_id), str(key.user_id)])
|
||||
if self.with_destiny:
|
||||
parts.append(key.destiny)
|
||||
elif key.destiny != DEFAULT_DESTINY:
|
||||
raise ValueError(
|
||||
"Redis key builder is not configured to use key destiny other the default.\n"
|
||||
"\n"
|
||||
"Probably, you should set `with_destiny=True` in for DefaultKeyBuilder.\n"
|
||||
"E.g: `RedisStorage(redis, key_builder=DefaultKeyBuilder(with_destiny=True))`"
|
||||
)
|
||||
parts.append(part)
|
||||
return self.separator.join(parts)
|
||||
|
||||
|
||||
class RedisStorage(BaseStorage):
|
||||
"""
|
||||
Redis storage required :code:`aioredis` package installed (:code:`pip install aioredis`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis: Redis,
|
||||
key_builder: Optional[KeyBuilder] = None,
|
||||
state_ttl: Optional[ExpiryT] = None,
|
||||
data_ttl: Optional[ExpiryT] = None,
|
||||
) -> None:
|
||||
"""
|
||||
:param redis: Instance of Redis connection
|
||||
:param key_builder: builder that helps to convert contextual key to string
|
||||
:param state_ttl: TTL for state records
|
||||
:param data_ttl: TTL for data records
|
||||
:param lock_kwargs: Custom arguments for Redis lock
|
||||
"""
|
||||
if key_builder is None:
|
||||
key_builder = DefaultKeyBuilder()
|
||||
self.redis = redis
|
||||
self.key_builder = key_builder
|
||||
self.state_ttl = state_ttl
|
||||
self.data_ttl = data_ttl
|
||||
|
||||
@classmethod
|
||||
def from_url(
|
||||
cls, url: str, connection_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any
|
||||
) -> "RedisStorage":
|
||||
"""
|
||||
Create an instance of :class:`RedisStorage` with specifying the connection string
|
||||
|
||||
:param url: for example :code:`redis://user:password@host:port/db`
|
||||
:param connection_kwargs: see :code:`aioredis` docs
|
||||
:param kwargs: arguments to be passed to :class:`RedisStorage`
|
||||
:return: an instance of :class:`RedisStorage`
|
||||
"""
|
||||
if connection_kwargs is None:
|
||||
connection_kwargs = {}
|
||||
pool = ConnectionPool.from_url(url, **connection_kwargs)
|
||||
redis = Redis(connection_pool=pool)
|
||||
return cls(redis=redis, **kwargs)
|
||||
|
||||
def create_isolation(self, **kwargs: Any) -> "RedisEventIsolation":
|
||||
return RedisEventIsolation(redis=self.redis, key_builder=self.key_builder, **kwargs)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.redis.close()
|
||||
|
||||
async def set_state(
|
||||
self,
|
||||
bot: Bot,
|
||||
key: StorageKey,
|
||||
state: StateType = None,
|
||||
) -> None:
|
||||
redis_key = self.key_builder.build(key, "state")
|
||||
if state is None:
|
||||
await self.redis.delete(redis_key)
|
||||
else:
|
||||
await self.redis.set(
|
||||
redis_key,
|
||||
cast(str, state.state if isinstance(state, State) else state),
|
||||
ex=self.state_ttl,
|
||||
)
|
||||
|
||||
async def get_state(
|
||||
self,
|
||||
bot: Bot,
|
||||
key: StorageKey,
|
||||
) -> Optional[str]:
|
||||
redis_key = self.key_builder.build(key, "state")
|
||||
value = await self.redis.get(redis_key)
|
||||
if isinstance(value, bytes):
|
||||
return value.decode("utf-8")
|
||||
return cast(Optional[str], value)
|
||||
|
||||
async def set_data(
|
||||
self,
|
||||
bot: Bot,
|
||||
key: StorageKey,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
redis_key = self.key_builder.build(key, "data")
|
||||
if not data:
|
||||
await self.redis.delete(redis_key)
|
||||
return
|
||||
await self.redis.set(
|
||||
redis_key,
|
||||
bot.session.json_dumps(data),
|
||||
ex=self.data_ttl,
|
||||
)
|
||||
|
||||
async def get_data(
|
||||
self,
|
||||
bot: Bot,
|
||||
key: StorageKey,
|
||||
) -> Dict[str, Any]:
|
||||
redis_key = self.key_builder.build(key, "data")
|
||||
value = await self.redis.get(redis_key)
|
||||
if value is None:
|
||||
return {}
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode("utf-8")
|
||||
return cast(Dict[str, Any], bot.session.json_loads(value))
|
||||
|
||||
|
||||
class RedisEventIsolation(BaseEventIsolation):
|
||||
def __init__(
|
||||
self,
|
||||
redis: Redis,
|
||||
key_builder: Optional[KeyBuilder] = None,
|
||||
lock_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
if key_builder is None:
|
||||
key_builder = DefaultKeyBuilder()
|
||||
if lock_kwargs is None:
|
||||
lock_kwargs = DEFAULT_REDIS_LOCK_KWARGS
|
||||
self.redis = redis
|
||||
self.key_builder = key_builder
|
||||
self.lock_kwargs = lock_kwargs
|
||||
|
||||
@classmethod
|
||||
def from_url(
|
||||
cls,
|
||||
url: str,
|
||||
connection_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> "RedisEventIsolation":
|
||||
if connection_kwargs is None:
|
||||
connection_kwargs = {}
|
||||
pool = ConnectionPool.from_url(url, **connection_kwargs)
|
||||
redis = Redis(connection_pool=pool)
|
||||
return cls(redis=redis, **kwargs)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lock(
|
||||
self,
|
||||
bot: Bot,
|
||||
key: StorageKey,
|
||||
) -> AsyncGenerator[None, None]:
|
||||
redis_key = self.key_builder.build(key, "lock")
|
||||
async with self.redis.lock(name=redis_key, **self.lock_kwargs, lock_class=Lock):
|
||||
yield None
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
16
aiogram/fsm/strategy.py
Normal file
16
aiogram/fsm/strategy.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
from enum import Enum, auto
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class FSMStrategy(Enum):
|
||||
USER_IN_CHAT = auto()
|
||||
CHAT = auto()
|
||||
GLOBAL_USER = auto()
|
||||
|
||||
|
||||
def apply_strategy(chat_id: int, user_id: int, strategy: FSMStrategy) -> Tuple[int, int]:
|
||||
if strategy == FSMStrategy.CHAT:
|
||||
return chat_id, chat_id
|
||||
if strategy == FSMStrategy.GLOBAL_USER:
|
||||
return user_id, user_id
|
||||
return chat_id, user_id
|
||||
Loading…
Add table
Add a link
Reference in a new issue