Change user, chat in redis

This commit is contained in:
Anthony Byuraev 2020-07-04 18:10:47 +03:00
parent 0d274af8f1
commit c7cdfc5ab4

View file

@ -1,5 +1,6 @@
""" """
This module has redis storage for finite-state machine based on `aioredis <https://github.com/aio-libs/aioredis>`_ driver This module has redis storage for finite-state machine
based on `aioredis <https://github.com/aio-libs/aioredis>`_ driver
""" """
import asyncio import asyncio
@ -35,7 +36,8 @@ class RedisStorage(BaseStorage):
await dp.storage.wait_closed() await dp.storage.wait_closed()
""" """
def __init__(self, host='localhost', port=6379, db=None, password=None, ssl=None, loop=None, **kwargs): def __init__(self, host='localhost', port=6379, db=None,
password=None, ssl=None, loop=None, **kwargs):
self._host = host self._host = host
self._port = port self._port = port
self._db = db self._db = db
@ -72,17 +74,17 @@ class RedisStorage(BaseStorage):
return self._redis return self._redis
async def get_record(self, *, async def get_record(self, *,
chat: typing.Union[str, int, None] = None, chat_id: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None) -> typing.Dict: user_id: typing.Union[str, int, None] = None) -> typing.Dict:
""" """
Get record from storage Get record from storage
:param chat: :param chat_id:
:param user: :param user_id:
:return: :return:
""" """
chat, user = self.check_address(chat=chat, user=user) chat_id, user_id = self.check_address(chat_id=chat_id, user_id=user_id)
addr = f"fsm:{chat}:{user}" addr = f"fsm:{chat_id}:{user_id}"
conn = await self.redis() conn = await self.redis()
data = await conn.execute('GET', addr) data = await conn.execute('GET', addr)
@ -90,14 +92,16 @@ class RedisStorage(BaseStorage):
return {'state': None, 'data': {}} return {'state': None, 'data': {}}
return json.loads(data) return json.loads(data)
async def set_record(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, async def set_record(self, *,
chat_id: typing.Union[str, int, None] = None,
user_id: typing.Union[str, int, None] = None,
state=None, data=None, bucket=None): state=None, data=None, bucket=None):
""" """
Write record to storage Write record to storage
:param bucket: :param bucket:
:param chat: :param chat_id:
:param user: :param user_id:
:param state: :param state:
:param data: :param data:
:return: :return:
@ -107,42 +111,52 @@ class RedisStorage(BaseStorage):
if bucket is None: if bucket is None:
bucket = {} bucket = {}
chat, user = self.check_address(chat=chat, user=user) chat_id, user_id = self.check_address(chat_id=chat_id, user_id=user_id)
addr = f"fsm:{chat}:{user}" addr = f"fsm:{chat_id}:{user_id}"
record = {'state': state, 'data': data, 'bucket': bucket} record = {'state': state, 'data': data, 'bucket': bucket}
conn = await self.redis() conn = await self.redis()
await conn.execute('SET', addr, json.dumps(record)) await conn.execute('SET', addr, json.dumps(record))
async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, async def get_state(self, *,
chat_id: typing.Union[str, int, None] = None,
user_id: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Optional[str]: default: typing.Optional[str] = None) -> typing.Optional[str]:
record = await self.get_record(chat=chat, user=user) record = await self.get_record(chat_id=chat_id, user_id=user_id)
return record['state'] return record['state']
async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, async def get_data(self, *,
chat_id: typing.Union[str, int, None] = None,
user_id: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Dict: default: typing.Optional[str] = None) -> typing.Dict:
record = await self.get_record(chat=chat, user=user) record = await self.get_record(chat_id=chat_id, user_id=user_id)
return record['data'] return record['data']
async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, async def set_state(self, *,
chat_id: typing.Union[str, int, None] = None,
user_id: typing.Union[str, int, None] = None,
state: typing.Optional[typing.AnyStr] = None): state: typing.Optional[typing.AnyStr] = None):
record = await self.get_record(chat=chat, user=user) record = await self.get_record(chat_id=chat_id, user_id=user_id)
await self.set_record(chat=chat, user=user, state=state, data=record['data']) await self.set_record(chat_id=chat_id, user_id=user_id, state=state, data=record['data'])
async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, async def set_data(self, *,
chat_id: typing.Union[str, int, None] = None,
user_id: typing.Union[str, int, None] = None,
data: typing.Dict = None): data: typing.Dict = None):
record = await self.get_record(chat=chat, user=user) record = await self.get_record(chat_id=chat_id, user_id=user_id)
await self.set_record(chat=chat, user=user, state=record['state'], data=data) await self.set_record(chat_id=chat_id, user_id=user_id, state=record['state'], data=data)
async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, async def update_data(self, *,
chat_id: typing.Union[str, int, None] = None,
user_id: typing.Union[str, int, None] = None,
data: typing.Dict = None, **kwargs): data: typing.Dict = None, **kwargs):
if data is None: if data is None:
data = {} data = {}
record = await self.get_record(chat=chat, user=user) record = await self.get_record(chat_id=chat_id, user_id=user_id)
record_data = record.get('data', {}) record_data = record.get('data', {})
record_data.update(data, **kwargs) record_data.update(data, **kwargs)
await self.set_record(chat=chat, user=user, state=record['state'], data=record_data) await self.set_record(chat_id=chat_id, user_id=user_id, state=record['state'], data=record_data)
async def get_states_list(self) -> typing.List[typing.Tuple[str, str]]: async def get_states_list(self) -> typing.List[typing.Tuple[str, str]]:
""" """
@ -155,8 +169,8 @@ class RedisStorage(BaseStorage):
keys = await conn.execute('KEYS', 'fsm:*') keys = await conn.execute('KEYS', 'fsm:*')
for item in keys: for item in keys:
*_, chat, user = item.decode('utf-8').split(':') *_, chat_id, user_id = item.decode('utf-8').split(':')
result.append((chat, user)) result.append((chat_id, user_id))
return result return result
@ -178,25 +192,30 @@ class RedisStorage(BaseStorage):
def has_bucket(self): def has_bucket(self):
return True return True
async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, async def get_bucket(self, *,
chat_id: typing.Union[str, int, None] = None,
user_id: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Dict: default: typing.Optional[str] = None) -> typing.Dict:
record = await self.get_record(chat=chat, user=user) record = await self.get_record(chat_id=chat_id, user_id=user_id)
return record.get('bucket', {}) return record.get('bucket', {})
async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, async def set_bucket(self, *,
chat_id: typing.Union[str, int, None] = None,
user_id: typing.Union[str, int, None] = None,
bucket: typing.Dict = None): bucket: typing.Dict = None):
record = await self.get_record(chat=chat, user=user) record = await self.get_record(chat_id=chat_id, user_id=user_id)
await self.set_record(chat=chat, user=user, state=record['state'], data=record['data'], bucket=bucket) await self.set_record(chat_id=chat_id, user_id=user_id, state=record['state'], data=record['data'], bucket=bucket)
async def update_bucket(self, *, chat: typing.Union[str, int, None] = None, async def update_bucket(self, *,
user: typing.Union[str, int, None] = None, chat_id: typing.Union[str, int, None] = None,
user_id: typing.Union[str, int, None] = None,
bucket: typing.Dict = None, **kwargs): bucket: typing.Dict = None, **kwargs):
record = await self.get_record(chat=chat, user=user) record = await self.get_record(chat_id=chat_id, user_id=user_id)
record_bucket = record.get('bucket', {}) record_bucket = record.get('bucket', {})
if bucket is None: if bucket is None:
bucket = {} bucket = {}
record_bucket.update(bucket, **kwargs) record_bucket.update(bucket, **kwargs)
await self.set_record(chat=chat, user=user, state=record['state'], data=record_bucket, bucket=bucket) await self.set_record(chat_id=chat_id, user_id=user_id, state=record['state'], data=record_bucket, bucket=bucket)
class RedisStorage2(BaseStorage): class RedisStorage2(BaseStorage):
@ -269,76 +288,91 @@ class RedisStorage2(BaseStorage):
return await self._redis.wait_closed() return await self._redis.wait_closed()
return True return True
async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, async def get_state(self, *,
chat_id: typing.Union[str, int, None] = None,
user_id: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Optional[str]: default: typing.Optional[str] = None) -> typing.Optional[str]:
chat, user = self.check_address(chat=chat, user=user) chat_id, user_id = self.check_address(chat_id=chat_id, user_id=user_id)
key = self.generate_key(chat, user, STATE_KEY) key = self.generate_key(chat_id, user_id, STATE_KEY)
redis = await self.redis() redis = await self.redis()
return await redis.get(key, encoding='utf8') or None return await redis.get(key, encoding='utf8') or None
async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, async def get_data(self, *,
chat_id: typing.Union[str, int, None] = None,
user_id: typing.Union[str, int, None] = None,
default: typing.Optional[dict] = None) -> typing.Dict: default: typing.Optional[dict] = None) -> typing.Dict:
chat, user = self.check_address(chat=chat, user=user) chat_id, user_id = self.check_address(chat_id=chat_id, user_id=user_id)
key = self.generate_key(chat, user, STATE_DATA_KEY) key = self.generate_key(chat_id, user_id, STATE_DATA_KEY)
redis = await self.redis() redis = await self.redis()
raw_result = await redis.get(key, encoding='utf8') raw_result = await redis.get(key, encoding='utf8')
if raw_result: if raw_result:
return json.loads(raw_result) return json.loads(raw_result)
return default or {} return default or {}
async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, async def set_state(self, *,
chat_id: typing.Union[str, int, None] = None,
user_id: typing.Union[str, int, None] = None,
state: typing.Optional[typing.AnyStr] = None): state: typing.Optional[typing.AnyStr] = None):
chat, user = self.check_address(chat=chat, user=user) chat_id, user_id = self.check_address(chat_id=chat_id, user_id=user_id)
key = self.generate_key(chat, user, STATE_KEY) key = self.generate_key(chat_id, user_id, STATE_KEY)
redis = await self.redis() redis = await self.redis()
if state is None: if state is None:
await redis.delete(key) await redis.delete(key)
else: else:
await redis.set(key, state, expire=self._state_ttl) await redis.set(key, state, expire=self._state_ttl)
async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, async def set_data(self, *,
chat_id: typing.Union[str, int, None] = None,
user_id: typing.Union[str, int, None] = None,
data: typing.Dict = None): data: typing.Dict = None):
chat, user = self.check_address(chat=chat, user=user) chat_id, user_id = self.check_address(chat_id=chat_id, user_id=user_id)
key = self.generate_key(chat, user, STATE_DATA_KEY) key = self.generate_key(chat_id, user_id, STATE_DATA_KEY)
redis = await self.redis() redis = await self.redis()
await redis.set(key, json.dumps(data), expire=self._data_ttl) await redis.set(key, json.dumps(data), expire=self._data_ttl)
async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, async def update_data(self, *,
chat_id: typing.Union[str, int, None] = None,
user_id: typing.Union[str, int, None] = None,
data: typing.Dict = None, **kwargs): data: typing.Dict = None, **kwargs):
if data is None: if data is None:
data = {} data = {}
temp_data = await self.get_data(chat=chat, user=user, default={}) temp_data = await self.get_data(chat_id=chat_id, user_id=user_id, default={})
temp_data.update(data, **kwargs) temp_data.update(data, **kwargs)
await self.set_data(chat=chat, user=user, data=temp_data) await self.set_data(chat_id=chat_id, user_id=user_id, data=temp_data)
def has_bucket(self): def has_bucket(self):
return True return True
async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, async def get_bucket(self, *,
chat_id: typing.Union[str, int, None] = None,
user_id: typing.Union[str, int, None] = None,
default: typing.Optional[dict] = None) -> typing.Dict: default: typing.Optional[dict] = None) -> typing.Dict:
chat, user = self.check_address(chat=chat, user=user) chat_id, user_id = self.check_address(chat_id=chat_id, user_id=user_id)
key = self.generate_key(chat, user, STATE_BUCKET_KEY) key = self.generate_key(chat_id, user_id, STATE_BUCKET_KEY)
redis = await self.redis() redis = await self.redis()
raw_result = await redis.get(key, encoding='utf8') raw_result = await redis.get(key, encoding='utf8')
if raw_result: if raw_result:
return json.loads(raw_result) return json.loads(raw_result)
return default or {} return default or {}
async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, async def set_bucket(self, *,
chat_id: typing.Union[str, int, None] = None,
user_id: typing.Union[str, int, None] = None,
bucket: typing.Dict = None): bucket: typing.Dict = None):
chat, user = self.check_address(chat=chat, user=user) chat_id, user_id = self.check_address(chat_id=chat_id, user_id=user_id)
key = self.generate_key(chat, user, STATE_BUCKET_KEY) key = self.generate_key(chat_id, user_id, STATE_BUCKET_KEY)
redis = await self.redis() redis = await self.redis()
await redis.set(key, json.dumps(bucket), expire=self._bucket_ttl) await redis.set(key, json.dumps(bucket), expire=self._bucket_ttl)
async def update_bucket(self, *, chat: typing.Union[str, int, None] = None, async def update_bucket(self, *,
user: typing.Union[str, int, None] = None, chat_id: typing.Union[str, int, None] = None,
user_id: typing.Union[str, int, None] = None,
bucket: typing.Dict = None, **kwargs): bucket: typing.Dict = None, **kwargs):
if bucket is None: if bucket is None:
bucket = {} bucket = {}
temp_bucket = await self.get_bucket(chat=chat, user=user) temp_bucket = await self.get_bucket(chat_id=chat_id, user_id=user_id)
temp_bucket.update(bucket, **kwargs) temp_bucket.update(bucket, **kwargs)
await self.set_bucket(chat=chat, user=user, bucket=temp_bucket) await self.set_bucket(chat_id=chat_id, user_id=user_id, bucket=temp_bucket)
async def reset_all(self, full=True): async def reset_all(self, full=True):
""" """
@ -366,8 +400,8 @@ class RedisStorage2(BaseStorage):
keys = await conn.keys(self.generate_key('*', '*', STATE_KEY), encoding='utf8') keys = await conn.keys(self.generate_key('*', '*', STATE_KEY), encoding='utf8')
for item in keys: for item in keys:
*_, chat, user, _ = item.split(':') *_, chat_id, user_id, _ = item.split(':')
result.append((chat, user)) result.append((chat_id, user_id))
return result return result
@ -390,14 +424,14 @@ async def migrate_redis1_to_redis2(storage1: RedisStorage, storage2: RedisStorag
log = logging.getLogger('aiogram.RedisStorage') log = logging.getLogger('aiogram.RedisStorage')
for chat, user in await storage1.get_states_list(): for chat_id, user_id in await storage1.get_states_list():
state = await storage1.get_state(chat=chat, user=user) state = await storage1.get_state(chat_id=chat_id, user_id=user_id)
await storage2.set_state(chat=chat, user=user, state=state) await storage2.set_state(chat_id=chat_id, user_id=user_id, state=state)
data = await storage1.get_data(chat=chat, user=user) data = await storage1.get_data(chat_id=chat_id, user_id=user_id)
await storage2.set_data(chat=chat, user=user, data=data) await storage2.set_data(chat_id=chat_id, user_id=user_id, data=data)
bucket = await storage1.get_bucket(chat=chat, user=user) bucket = await storage1.get_bucket(chat_id=chat_id, user_id=user_id)
await storage2.set_bucket(chat=chat, user=user, bucket=bucket) await storage2.set_bucket(chat_id=chat_id, user_id=user_id, bucket=bucket)
log.info(f"Migrated user {user} in chat {chat}") log.info(f"Migrated user {user_id} in chat {chat_id}")