diff --git a/aiogram/contrib/fsm_storage/rethinkdb.py b/aiogram/contrib/fsm_storage/rethinkdb.py index 5c755af6..fd8f1e13 100644 --- a/aiogram/contrib/fsm_storage/rethinkdb.py +++ b/aiogram/contrib/fsm_storage/rethinkdb.py @@ -1,4 +1,5 @@ import asyncio +import contextlib import typing import weakref @@ -11,9 +12,6 @@ __all__ = ['RethinkDBStorage', 'ConnectionNotClosed'] r.set_loop_type('asyncio') -# TODO: rewrite connections pool - - class ConnectionNotClosed(Exception): """ Indicates that DB connection wasn't closed. @@ -86,6 +84,12 @@ class RethinkDBStorage(BaseStorage): self._queue.put_nowait(conn) self._outstanding_connections.remove(conn) + @contextlib.asynccontextmanager + async def connection(self): + conn = await self.get_connection() + yield conn + await self.put_connection(conn) + async def close(self): """ Close all connections. @@ -113,49 +117,44 @@ class RethinkDBStorage(BaseStorage): async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, default: typing.Optional[str] = None) -> typing.Optional[str]: chat, user = map(str, self.check_address(chat=chat, user=user)) - conn = await self.get_connection() - result = await r.table(self._table).get(chat)[user]['state'].default(default or None).run(conn) - await self.put_connection(conn) + async with self.connection() as conn: + result = await r.table(self._table).get(chat)[user]['state'].default(default or None).run(conn) return result async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, default: typing.Optional[str] = None) -> typing.Dict: chat, user = map(str, self.check_address(chat=chat, user=user)) - conn = await self.get_connection() - result = await r.table(self._table).get(chat)[user]['data'].default(default or {}).run(conn) - await self.put_connection(conn) + async with self.connection() as conn: + result = await r.table(self._table).get(chat)[user]['data'].default(default or {}).run(conn) return result async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, state: typing.Optional[typing.AnyStr] = None): chat, user = map(str, self.check_address(chat=chat, user=user)) - conn = await self.get_connection() - if await r.table(self._table).get(chat).run(conn): - await r.table(self._table).get(chat).update({user: {'state': state}}).run(conn) - else: - await r.table(self._table).insert({'id': chat, user: {'state': state}}).run(conn) - await self.put_connection(conn) + async with self.connection() as conn: + if await r.table(self._table).get(chat).run(conn): + await r.table(self._table).get(chat).update({user: {'state': state}}).run(conn) + else: + await r.table(self._table).insert({'id': chat, user: {'state': state}}).run(conn) async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, data: typing.Dict = None): chat, user = map(str, self.check_address(chat=chat, user=user)) - conn = await self.get_connection() - if await r.table(self._table).get(chat).run(conn): - await r.table(self._table).get(chat).update({user: {'data': r.literal(data)}}).run(conn) - else: - await r.table(self._table).insert({'id': chat, user: {'data': data}}).run(conn) - await self.put_connection(conn) + async with self.connection() as conn: + if await r.table(self._table).get(chat).run(conn): + await r.table(self._table).get(chat).update({user: {'data': r.literal(data)}}).run(conn) + else: + await r.table(self._table).insert({'id': chat, user: {'data': data}}).run(conn) async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, data: typing.Dict = None, **kwargs): chat, user = map(str, self.check_address(chat=chat, user=user)) - conn = await self.get_connection() - if await r.table(self._table).get(chat).run(conn): - await r.table(self._table).get(chat).update({user: {'data': data}}).run(conn) - else: - await r.table(self._table).insert({'id': chat, user: {'data': data}}).run(conn) - await self.put_connection(conn) + async with self.connection() as conn: + if await r.table(self._table).get(chat).run(conn): + await r.table(self._table).get(chat).update({user: {'data': data}}).run(conn) + else: + await r.table(self._table).insert({'id': chat, user: {'data': data}}).run(conn) def has_bucket(self): return True @@ -163,31 +162,28 @@ class RethinkDBStorage(BaseStorage): async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, default: typing.Optional[dict] = None) -> typing.Dict: chat, user = map(str, self.check_address(chat=chat, user=user)) - conn = await self.get_connection() - result = await r.table(self._table).get(chat)[user]['bucket'].default(default or {}).run(conn) - await self.put_connection(conn) + async with self.connection() as conn: + result = await r.table(self._table).get(chat)[user]['bucket'].default(default or {}).run(conn) return result async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, bucket: typing.Dict = None): chat, user = map(str, self.check_address(chat=chat, user=user)) - conn = await self.get_connection() - if await r.table(self._table).get(chat).run(conn): - await r.table(self._table).get(chat).update({user: {'bucket': r.literal(bucket)}}).run(conn) - else: - await r.table(self._table).insert({'id': chat, user: {'bucket': bucket}}).run(conn) - await self.put_connection(conn) + async with self.connection() as conn: + if await r.table(self._table).get(chat).run(conn): + await r.table(self._table).get(chat).update({user: {'bucket': r.literal(bucket)}}).run(conn) + else: + await r.table(self._table).insert({'id': chat, user: {'bucket': bucket}}).run(conn) async def update_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, bucket: typing.Dict = None, **kwargs): chat, user = map(str, self.check_address(chat=chat, user=user)) - conn = await self.get_connection() - if await r.table(self._table).get(chat).run(conn): - await r.table(self._table).get(chat).update({user: {'bucket': bucket}}).run(conn) - else: - await r.table(self._table).insert({'id': chat, user: {'bucket': bucket}}).run(conn) - await self.put_connection(conn) + async with self.connection() as conn: + if await r.table(self._table).get(chat).run(conn): + await r.table(self._table).get(chat).update({user: {'bucket': bucket}}).run(conn) + else: + await r.table(self._table).insert({'id': chat, user: {'bucket': bucket}}).run(conn) async def get_states_list(self) -> typing.List[typing.Tuple[int, int]]: """ @@ -195,18 +191,16 @@ class RethinkDBStorage(BaseStorage): :return: list of tuples where first element is chat id and second is user id """ - conn = await self.get_connection() - result = [] + async with self.connection() as conn: + result = [] - items = (await r.table(self._table).run(conn)).items + items = (await r.table(self._table).run(conn)).items - for item in items: - chat = int(item.pop('id')) - for key in item.keys(): - user = int(key) - result.append((chat, user)) - - await self.put_connection(conn) + for item in items: + chat = int(item.pop('id')) + for key in item.keys(): + user = int(key) + result.append((chat, user)) return result @@ -214,6 +208,5 @@ class RethinkDBStorage(BaseStorage): """ Reset states in DB """ - conn = await self.get_connection() - await r.table(self._table).delete().run(conn) - await self.put_connection(conn) + async with self.connection() as conn: + await r.table(self._table).delete().run(conn) diff --git a/dev_requirements.txt b/dev_requirements.txt index ac4a62ec..72d6a989 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -5,6 +5,8 @@ python-rapidjson>=0.6.3 emoji>=0.5.0 pytest>=3.5.0 pytest-asyncio>=0.8.0 +tox>=3.0.0 +aresponses>=1.0.0 uvloop>=0.9.1 aioredis>=1.1.0 wheel>=0.31.0 @@ -13,6 +15,4 @@ sphinx>=1.7.3 sphinx-rtd-theme>=0.3.0 sphinxcontrib-programoutput>=0.11 aresponses>=1.0.0 -tox>=3.0.0 -aiosocksy>=0.1 -click>=6.7 +aiohttp-socks>=0.1.5