diff --git a/aiogram/contrib/fsm_storage/memory.py b/aiogram/contrib/fsm_storage/memory.py index e1d6bdc0..8950aa8e 100644 --- a/aiogram/contrib/fsm_storage/memory.py +++ b/aiogram/contrib/fsm_storage/memory.py @@ -66,6 +66,7 @@ class MemoryStorage(BaseStorage): data: typing.Dict = None): chat, user = self.resolve_address(chat=chat, user=user) self.data[chat][user]['data'] = copy.deepcopy(data) + self._cleanup(chat, user) async def reset_state(self, *, chat: typing.Union[str, int, None] = None, @@ -74,6 +75,7 @@ class MemoryStorage(BaseStorage): await self.set_state(chat=chat, user=user, state=None) if with_data: await self.set_data(chat=chat, user=user, data={}) + self._cleanup(chat, user) def has_bucket(self): return True @@ -91,6 +93,7 @@ class MemoryStorage(BaseStorage): bucket: typing.Dict = None): chat, user = self.resolve_address(chat=chat, user=user) self.data[chat][user]['bucket'] = copy.deepcopy(bucket) + self._cleanup(chat, user) async def update_bucket(self, *, chat: typing.Union[str, int, None] = None, @@ -100,3 +103,9 @@ class MemoryStorage(BaseStorage): bucket = {} chat, user = self.resolve_address(chat=chat, user=user) self.data[chat][user]['bucket'].update(bucket, **kwargs) + + def _cleanup(self, chat, user): + if self.data[chat][user] == {'state': None, 'data': {}, 'bucket': {}}: + del self.data[chat][user] + if not self.data[chat]: + del self.data[chat] diff --git a/aiogram/contrib/fsm_storage/redis.py b/aiogram/contrib/fsm_storage/redis.py index 01a0fe5c..5d0b762c 100644 --- a/aiogram/contrib/fsm_storage/redis.py +++ b/aiogram/contrib/fsm_storage/redis.py @@ -110,10 +110,12 @@ class RedisStorage(BaseStorage): chat, user = self.check_address(chat=chat, user=user) addr = f"fsm:{chat}:{user}" - record = {'state': state, 'data': data, 'bucket': bucket} - conn = await self.redis() - await conn.execute('SET', addr, json.dumps(record)) + if state is None and data == bucket == {}: + await conn.execute('DEL', addr) + else: + record = {'state': state, 'data': data, 'bucket': bucket} + await conn.execute('SET', addr, json.dumps(record)) async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, default: typing.Optional[str] = None) -> typing.Optional[str]: @@ -222,7 +224,7 @@ class RedisStorage2(BaseStorage): await dp.storage.wait_closed() """ - def __init__(self, host: str = 'localhost', port=6379, db=None, password=None, + def __init__(self, host: str = 'localhost', port=6379, db=None, password=None, ssl=None, pool_size=10, loop=None, prefix='fsm', state_ttl: int = 0, data_ttl: int = 0, @@ -304,7 +306,10 @@ class RedisStorage2(BaseStorage): chat, user = self.check_address(chat=chat, user=user) key = self.generate_key(chat, user, STATE_DATA_KEY) redis = await self.redis() - await redis.set(key, json.dumps(data), expire=self._data_ttl) + if data: + await redis.set(key, json.dumps(data), expire=self._data_ttl) + else: + await redis.delete(key) async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, data: typing.Dict = None, **kwargs): @@ -332,7 +337,10 @@ class RedisStorage2(BaseStorage): chat, user = self.check_address(chat=chat, user=user) key = self.generate_key(chat, user, STATE_BUCKET_KEY) redis = await self.redis() - await redis.set(key, json.dumps(bucket), expire=self._bucket_ttl) + if bucket: + await redis.set(key, json.dumps(bucket), expire=self._bucket_ttl) + else: + await redis.delete(key) async def update_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, diff --git a/dev_requirements.txt b/dev_requirements.txt index ef5272af..26e410aa 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -16,3 +16,4 @@ aiohttp-socks>=0.3.4 rethinkdb>=2.4.1 coverage==4.5.3 motor>=2.2.0 +pytest-lazy-fixture==0.6.* diff --git a/tests/contrib/fsm_storage/test_redis.py b/tests/contrib/fsm_storage/test_redis.py deleted file mode 100644 index 527c905e..00000000 --- a/tests/contrib/fsm_storage/test_redis.py +++ /dev/null @@ -1,33 +0,0 @@ -import pytest - -from aiogram.contrib.fsm_storage.redis import RedisStorage2 - - -@pytest.fixture() -async def store(redis_options): - s = RedisStorage2(**redis_options) - try: - yield s - finally: - conn = await s.redis() - await conn.flushdb() - await s.close() - await s.wait_closed() - - -@pytest.mark.redis -class TestRedisStorage2: - @pytest.mark.asyncio - async def test_set_get(self, store): - assert await store.get_data(chat='1234') == {} - await store.set_data(chat='1234', data={'foo': 'bar'}) - assert await store.get_data(chat='1234') == {'foo': 'bar'} - - @pytest.mark.asyncio - async def test_close_and_open_connection(self, store): - await store.set_data(chat='1234', data={'foo': 'bar'}) - assert await store.get_data(chat='1234') == {'foo': 'bar'} - pool_id = id(store._redis) - await store.close() - assert await store.get_data(chat='1234') == {'foo': 'bar'} # new pool was opened at this point - assert id(store._redis) != pool_id diff --git a/tests/contrib/fsm_storage/test_storage.py b/tests/contrib/fsm_storage/test_storage.py new file mode 100644 index 00000000..0cde2de2 --- /dev/null +++ b/tests/contrib/fsm_storage/test_storage.py @@ -0,0 +1,79 @@ +import pytest + +from aiogram.contrib.fsm_storage.memory import MemoryStorage +from aiogram.contrib.fsm_storage.redis import RedisStorage2, RedisStorage + + +@pytest.fixture() +@pytest.mark.redis +async def redis_store(redis_options): + s = RedisStorage(**redis_options) + try: + yield s + finally: + conn = await s.redis() + await conn.execute('FLUSHDB') + await s.close() + await s.wait_closed() + + +@pytest.fixture() +@pytest.mark.redis +async def redis_store2(redis_options): + s = RedisStorage2(**redis_options) + try: + yield s + finally: + conn = await s.redis() + await conn.flushdb() + await s.close() + await s.wait_closed() + + +@pytest.fixture() +async def memory_store(): + yield MemoryStorage() + + +@pytest.mark.parametrize( + "store", [ + pytest.lazy_fixture('redis_store'), + pytest.lazy_fixture('redis_store2'), + pytest.lazy_fixture('memory_store'), + ] +) +class TestStorage: + @pytest.mark.asyncio + async def test_set_get(self, store): + assert await store.get_data(chat='1234') == {} + await store.set_data(chat='1234', data={'foo': 'bar'}) + assert await store.get_data(chat='1234') == {'foo': 'bar'} + + @pytest.mark.asyncio + async def test_reset(self, store): + await store.set_data(chat='1234', data={'foo': 'bar'}) + await store.reset_data(chat='1234') + assert await store.get_data(chat='1234') == {} + + @pytest.mark.asyncio + async def test_reset_empty(self, store): + await store.reset_data(chat='1234') + assert await store.get_data(chat='1234') == {} + + +@pytest.mark.parametrize( + "store", [ + pytest.lazy_fixture('redis_store'), + pytest.lazy_fixture('redis_store2'), + ] +) +class TestRedisStorage2: + @pytest.mark.asyncio + async def test_close_and_open_connection(self, store): + await store.set_data(chat='1234', data={'foo': 'bar'}) + assert await store.get_data(chat='1234') == {'foo': 'bar'} + pool_id = id(store._redis) + await store.close() + assert await store.get_data(chat='1234') == { + 'foo': 'bar'} # new pool was opened at this point + assert id(store._redis) != pool_id