From 3aa40224a23f4cd0d8b824c07ebefceae913be07 Mon Sep 17 00:00:00 2001 From: Oleg A Date: Thu, 5 Aug 2021 22:34:15 +0300 Subject: [PATCH] aioredis v2 support (#649) * feat: aioredis v1-v2 adapters #648 * chore: aioredis version without importlib * chore: refactor _get_redis for adapter * chore: proxy Redis methods * chore: adapter.get_redis become public * fix: add missed redis methods * chore: separate get_adapter method * chore: remove method proxy * chore: add docstrings * chore: add redis deprecations * docs: correct redis storage version * chore: encoding one style * refactor: remove redundant import * fix: int version --- aiogram/contrib/fsm_storage/redis.py | 230 ++++++++++++++++++++++----- docs/source/dispatcher/fsm.rst | 2 +- 2 files changed, 192 insertions(+), 40 deletions(-) diff --git a/aiogram/contrib/fsm_storage/redis.py b/aiogram/contrib/fsm_storage/redis.py index 5d0b762c..c8b95517 100644 --- a/aiogram/contrib/fsm_storage/redis.py +++ b/aiogram/contrib/fsm_storage/redis.py @@ -5,11 +5,13 @@ This module has redis storage for finite-state machine based on `aioredis aioredis.Redis: + """Get Redis connection.""" + pass + + def close(self): + """Grace shutdown.""" + pass + + async def wait_closed(self): + """Wait for grace shutdown finishes.""" + pass + + async def set(self, name, value, ex=None, **kwargs): + """Set the value at key ``name`` to ``value``.""" + return await self._redis.set(name, value, ex=ex, **kwargs) + + async def get(self, name, **kwargs): + """Return the value at key ``name`` or None.""" + return await self._redis.get(name, **kwargs) + + async def delete(self, *names): + """Delete one or more keys specified by ``names``""" + return await self._redis.delete(*names) + + async def keys(self, pattern, **kwargs): + """Returns a list of keys matching ``pattern``.""" + return await self._redis.keys(pattern, **kwargs) + + async def flushdb(self): + """Delete all keys in the current database.""" + return await self._redis.flushdb() + + +class AioRedisAdapterV1(AioRedisAdapterBase): + """Redis adapter for aioredis v1.""" + + async def get_redis(self) -> aioredis.Redis: + """Get Redis connection.""" + async with self._connection_lock: # to prevent race + if self._redis is None or self._redis.closed: + self._redis = await aioredis.create_redis_pool( + (self._host, self._port), + db=self._db, + password=self._password, + ssl=self._ssl, + minsize=1, + maxsize=self._pool_size, + loop=self._loop, + **self._kwargs, + ) + return self._redis + + def close(self): + async with self._connection_lock: + if self._redis and not self._redis.closed: + self._redis.close() + + async def wait_closed(self): + async with self._connection_lock: + if self._redis: + return await self._redis.wait_closed() + return True + + async def get(self, name, **kwargs): + return await self._redis.get(name, encoding="utf8", **kwargs) + + async def set(self, name, value, ex=None, **kwargs): + return await self._redis.set(name, value, expire=ex, **kwargs) + + async def keys(self, pattern, **kwargs): + """Returns a list of keys matching ``pattern``.""" + return await self._redis.keys(pattern, encoding="utf8", **kwargs) + + +class AioRedisAdapterV2(AioRedisAdapterBase): + """Redis adapter for aioredis v2.""" + + async def get_redis(self) -> aioredis.Redis: + """Get Redis connection.""" + async with self._connection_lock: # to prevent race + if self._redis is None: + self._redis = aioredis.Redis( + host=self._host, + port=self._port, + db=self._db, + password=self._password, + ssl=self._ssl, + max_connections=self._pool_size, + **self._kwargs, + ) + return self._redis + + class RedisStorage2(BaseStorage): """ Busted Redis-base storage for FSM. @@ -224,12 +356,22 @@ class RedisStorage2(BaseStorage): await dp.storage.wait_closed() """ - def __init__(self, host: str = 'localhost', port=6379, db=None, password=None, - ssl=None, pool_size=10, loop=None, prefix='fsm', - state_ttl: int = 0, - data_ttl: int = 0, - bucket_ttl: int = 0, - **kwargs): + + def __init__( + self, + host: str = "localhost", + port: int = 6379, + db: typing.Optional[int] = None, + password: typing.Optional[str] = None, + ssl: typing.Optional[bool] = None, + pool_size: int = 10, + loop: typing.Optional[asyncio.AbstractEventLoop] = None, + prefix: str = "fsm", + state_ttl: int = 0, + data_ttl: int = 0, + bucket_ttl: int = 0, + **kwargs, + ): self._host = host self._port = port self._db = db @@ -244,49 +386,59 @@ class RedisStorage2(BaseStorage): self._data_ttl = data_ttl self._bucket_ttl = bucket_ttl - self._redis: typing.Optional[aioredis.RedisConnection] = None + self._redis: typing.Optional[AioRedisAdapterBase] = None self._connection_lock = asyncio.Lock(loop=self._loop) + @deprecated("This method will be removed in aiogram v3.0. " + "You should use your own instance of Redis.", stacklevel=3) async def redis(self) -> aioredis.Redis: - """ - Get Redis connection - """ - # Use thread-safe asyncio Lock because this method without that is not safe - async with self._connection_lock: - if self._redis is None or self._redis.closed: - self._redis = await aioredis.create_redis_pool((self._host, self._port), - db=self._db, password=self._password, ssl=self._ssl, - minsize=1, maxsize=self._pool_size, - loop=self._loop, **self._kwargs) + adapter = await self._get_adapter() + return await adapter.get_redis() + + async def _get_adapter(self) -> AioRedisAdapterBase: + """Get adapter based on aioredis version.""" + if self._redis is None: + redis_version = int(aioredis.__version__.split(".")[0]) + connection_data = dict( + host=self._host, + port=self._port, + db=self._db, + password=self._password, + ssl=self._ssl, + pool_size=self._pool_size, + loop=self._loop, + **self._kwargs, + ) + if redis_version == 1: + self._redis = AioRedisAdapterV1(**connection_data) + elif redis_version == 2: + self._redis = AioRedisAdapterV2(**connection_data) return self._redis def generate_key(self, *parts): return ':'.join(self._prefix + tuple(map(str, parts))) async def close(self): - async with self._connection_lock: - if self._redis and not self._redis.closed: - self._redis.close() + if self._redis: + return self._redis.close() async def wait_closed(self): - async with self._connection_lock: - if self._redis: - return await self._redis.wait_closed() - return True + if self._redis: + return await self._redis.wait_closed() async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, default: typing.Optional[str] = None) -> typing.Optional[str]: chat, user = self.check_address(chat=chat, user=user) key = self.generate_key(chat, user, STATE_KEY) - redis = await self.redis() - return await redis.get(key, encoding='utf8') or self.resolve_state(default) + redis = await self._get_adapter() + return await redis.get(key) or self.resolve_state(default) async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, default: typing.Optional[dict] = None) -> typing.Dict: chat, user = self.check_address(chat=chat, user=user) key = self.generate_key(chat, user, STATE_DATA_KEY) - redis = await self.redis() - raw_result = await redis.get(key, encoding='utf8') + redis = await self._get_adapter() + raw_result = await redis.get(key) if raw_result: return json.loads(raw_result) return default or {} @@ -295,7 +447,7 @@ class RedisStorage2(BaseStorage): state: typing.Optional[typing.AnyStr] = None): chat, user = self.check_address(chat=chat, user=user) key = self.generate_key(chat, user, STATE_KEY) - redis = await self.redis() + redis = await self._get_adapter() if state is None: await redis.delete(key) else: @@ -305,7 +457,7 @@ class RedisStorage2(BaseStorage): data: typing.Dict = None): chat, user = self.check_address(chat=chat, user=user) key = self.generate_key(chat, user, STATE_DATA_KEY) - redis = await self.redis() + redis = await self._get_adapter() if data: await redis.set(key, json.dumps(data), expire=self._data_ttl) else: @@ -326,8 +478,8 @@ class RedisStorage2(BaseStorage): default: typing.Optional[dict] = None) -> typing.Dict: chat, user = self.check_address(chat=chat, user=user) key = self.generate_key(chat, user, STATE_BUCKET_KEY) - redis = await self.redis() - raw_result = await redis.get(key, encoding='utf8') + redis = await self._get_adapter() + raw_result = await redis.get(key) if raw_result: return json.loads(raw_result) return default or {} @@ -336,7 +488,7 @@ class RedisStorage2(BaseStorage): bucket: typing.Dict = None): chat, user = self.check_address(chat=chat, user=user) key = self.generate_key(chat, user, STATE_BUCKET_KEY) - redis = await self.redis() + redis = await self._get_adapter() if bucket: await redis.set(key, json.dumps(bucket), expire=self._bucket_ttl) else: @@ -358,13 +510,13 @@ class RedisStorage2(BaseStorage): :param full: clean DB or clean only states :return: """ - conn = await self.redis() + redis = await self._get_adapter() if full: - await conn.flushdb() + await redis.flushdb() else: - keys = await conn.keys(self.generate_key('*')) - await conn.delete(*keys) + keys = await redis.keys(self.generate_key('*')) + await redis.delete(*keys) async def get_states_list(self) -> typing.List[typing.Tuple[str, str]]: """ @@ -372,10 +524,10 @@ class RedisStorage2(BaseStorage): :return: list of tuples where first element is chat id and second is user id """ - conn = await self.redis() + redis = await self._get_adapter() result = [] - keys = await conn.keys(self.generate_key('*', '*', STATE_KEY), encoding='utf8') + keys = await redis.keys(self.generate_key('*', '*', STATE_KEY)) for item in keys: *_, chat, user, _ = item.split(':') result.append((chat, user)) diff --git a/docs/source/dispatcher/fsm.rst b/docs/source/dispatcher/fsm.rst index 1b00e81e..dc3a868e 100644 --- a/docs/source/dispatcher/fsm.rst +++ b/docs/source/dispatcher/fsm.rst @@ -19,7 +19,7 @@ Memory storage Redis storage ~~~~~~~~~~~~~ -.. autoclass:: aiogram.contrib.fsm_storage.redis.RedisStorage +.. autoclass:: aiogram.contrib.fsm_storage.redis.RedisStorage2 :show-inheritance: Mongo storage