From 1b8bcbd1d92eb9ddfd07192bbf1d8cee3e7dd0c4 Mon Sep 17 00:00:00 2001 From: Oleg A Date: Tue, 2 Jul 2019 02:05:20 +0300 Subject: [PATCH 1/2] added mongo storage --- aiogram/contrib/fsm_storage/mongo.py | 201 +++++++++++++++++++++++++++ 1 file changed, 201 insertions(+) create mode 100644 aiogram/contrib/fsm_storage/mongo.py diff --git a/aiogram/contrib/fsm_storage/mongo.py b/aiogram/contrib/fsm_storage/mongo.py new file mode 100644 index 00000000..e0a6b3cc --- /dev/null +++ b/aiogram/contrib/fsm_storage/mongo.py @@ -0,0 +1,201 @@ +""" +This module has mongo storage for finite-state machine + based on `aiomongo AioMongoClient: + if isinstance(self._mongo, AioMongoClient): + return self._mongo + + uri = 'mongodb://' + + # set username + password + if self._username and self._password: + uri += f'{self._username}:{self._password}@' + + # set host and port (optional) + uri += f'{self._host}' if self._host else 'localhost' + uri += f':{self._port}' if self._port else '/' + + # define and return client + self._mongo = await aiomongo.create_client(uri) + return self._mongo + + async def get_db(self) -> Database: + """ + Get Mongo db + + This property is awaitable. + """ + if isinstance(self._db, Database): + return self._db + + mongo = await self.get_client() + self._db = mongo.get_database(self._db_name) + + if self._index: + await self.apply_index(self._db) + return self._db + + @staticmethod + async def apply_index(db): + for collection in COLLECTIONS: + await db[collection].create_index(keys=[('chat', 1), ('user', 1)], + name="chat_user_idx", unique=True, background=True) + + async def close(self): + if self._mongo: + self._mongo.close() + + async def wait_closed(self): + if self._mongo: + return await self._mongo.wait_closed() + return True + + async def set_state(self, *, chat: Union[str, int, None] = None, user: Union[str, int, None] = None, + state: Optional[AnyStr] = None): + chat, user = self.check_address(chat=chat, user=user) + db = await self.get_db() + + if state is None: + await db[STATE].delete_one(filter={'chat': chat, 'user': user}) + else: + await db[STATE].update_one(filter={'chat': chat, 'user': user}, + update={'state': state}, upsert=True) + + async def get_state(self, *, chat: Union[str, int, None] = None, user: Union[str, int, None] = None, + default: Optional[str] = None) -> Optional[str]: + chat, user = self.check_address(chat=chat, user=user) + db = await self.get_db() + result = await db[STATE].find_one(filter={'chat': chat, 'user': user}) + + return result.get('state') if result else default + + async def set_data(self, *, chat: Union[str, int, None] = None, user: Union[str, int, None] = None, + data: Dict = None): + chat, user = self.check_address(chat=chat, user=user) + db = await self.get_db() + + await db[DATA].update_one(filter={'chat': chat, 'user': user}, + update={'data': data}, upsert=True) + + async def get_data(self, *, chat: Union[str, int, None] = None, user: Union[str, int, None] = None, + default: Optional[dict] = None) -> Dict: + chat, user = self.check_address(chat=chat, user=user) + db = await self.get_db() + result = await db[DATA].find_one(filter={'chat': chat, 'user': user}) + + return result.get('data') if result else default or {} + + async def update_data(self, *, chat: Union[str, int, None] = None, user: Union[str, int, None] = None, + data: Dict = None, **kwargs): + if data is None: + data = {} + temp_data = await self.get_data(chat=chat, user=user, default={}) + temp_data.update(data, **kwargs) + await self.set_data(chat=chat, user=user, data=temp_data) + + def has_bucket(self): + return True + + async def get_bucket(self, *, chat: Union[str, int, None] = None, user: Union[str, int, None] = None, + default: Optional[dict] = None) -> Dict: + chat, user = self.check_address(chat=chat, user=user) + db = await self.get_db() + result = await db[BUCKET].find_one(filter={'chat': chat, 'user': user}) + return result.get('bucket') if result else default or {} + + async def set_bucket(self, *, chat: Union[str, int, None] = None, user: Union[str, int, None] = None, + bucket: Dict = None): + chat, user = self.check_address(chat=chat, user=user) + db = await self.get_db() + + await db[BUCKET].update_one(filter={'chat': chat, 'user': user}, + update={'bucket': bucket}, upsert=True) + + async def update_bucket(self, *, chat: Union[str, int, None] = None, + user: Union[str, int, None] = None, + bucket: Dict = None, **kwargs): + if bucket is None: + bucket = {} + temp_bucket = await self.get_bucket(chat=chat, user=user) + temp_bucket.update(bucket, **kwargs) + await self.set_bucket(chat=chat, user=user, bucket=temp_bucket) + + async def reset_all(self, full=True): + """ + Reset states in DB + + :param full: clean DB or clean only states + :return: + """ + db = await self.get_db() + + await db[STATE].drop() + + if full: + await db[DATA].drop() + await db[BUCKET].drop() + + async def get_states_list(self) -> List[Tuple[int, int]]: + """ + Get list of all stored chat's and user's + + :return: list of tuples where first element is chat id and second is user id + """ + db = await self.get_db() + result = [] + + items = await db[STATE].find().to_list() + for item in items: + result.append( + (int(item['chat']), int(item['user'])) + ) + + return result From 4a2b569e5ff0ca278e6d60755153710021b4cbca Mon Sep 17 00:00:00 2001 From: Oleg A Date: Tue, 9 Jul 2019 00:05:46 +0300 Subject: [PATCH 2/2] fixed uri; fixed $set --- aiogram/contrib/fsm_storage/mongo.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/aiogram/contrib/fsm_storage/mongo.py b/aiogram/contrib/fsm_storage/mongo.py index e0a6b3cc..9ec18090 100644 --- a/aiogram/contrib/fsm_storage/mongo.py +++ b/aiogram/contrib/fsm_storage/mongo.py @@ -60,8 +60,7 @@ class MongoStorage(BaseStorage): uri += f'{self._username}:{self._password}@' # set host and port (optional) - uri += f'{self._host}' if self._host else 'localhost' - uri += f':{self._port}' if self._port else '/' + uri += f'{self._host}:{self._port}' if self._host else f'localhost:{self._port}' # define and return client self._mongo = await aiomongo.create_client(uri) @@ -107,7 +106,7 @@ class MongoStorage(BaseStorage): await db[STATE].delete_one(filter={'chat': chat, 'user': user}) else: await db[STATE].update_one(filter={'chat': chat, 'user': user}, - update={'state': state}, upsert=True) + update={'$set': {'state': state}}, upsert=True) async def get_state(self, *, chat: Union[str, int, None] = None, user: Union[str, int, None] = None, default: Optional[str] = None) -> Optional[str]: @@ -123,7 +122,7 @@ class MongoStorage(BaseStorage): db = await self.get_db() await db[DATA].update_one(filter={'chat': chat, 'user': user}, - update={'data': data}, upsert=True) + update={'$set': {'data': data}}, upsert=True) async def get_data(self, *, chat: Union[str, int, None] = None, user: Union[str, int, None] = None, default: Optional[dict] = None) -> Dict: @@ -157,7 +156,7 @@ class MongoStorage(BaseStorage): db = await self.get_db() await db[BUCKET].update_one(filter={'chat': chat, 'user': user}, - update={'bucket': bucket}, upsert=True) + update={'$set': {'bucket': bucket}}, upsert=True) async def update_bucket(self, *, chat: Union[str, int, None] = None, user: Union[str, int, None] = None,