mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
fix(bot,dispatcher): do not use _MainThread event loop on ::Bot, ::Dispatcher initializations
This commit is contained in:
parent
d1452b1620
commit
54da4d2f2b
3 changed files with 70 additions and 20 deletions
|
|
@ -56,6 +56,8 @@ class BaseBot:
|
||||||
:type timeout: :obj:`typing.Optional[typing.Union[base.Integer, base.Float, aiohttp.ClientTimeout]]`
|
:type timeout: :obj:`typing.Optional[typing.Union[base.Integer, base.Float, aiohttp.ClientTimeout]]`
|
||||||
:raise: when token is invalid throw an :obj:`aiogram.utils.exceptions.ValidationError`
|
:raise: when token is invalid throw an :obj:`aiogram.utils.exceptions.ValidationError`
|
||||||
"""
|
"""
|
||||||
|
self._main_loop = loop
|
||||||
|
|
||||||
# Authentication
|
# Authentication
|
||||||
if validate_token:
|
if validate_token:
|
||||||
api.check_token(token)
|
api.check_token(token)
|
||||||
|
|
@ -66,19 +68,12 @@ class BaseBot:
|
||||||
self.proxy = proxy
|
self.proxy = proxy
|
||||||
self.proxy_auth = proxy_auth
|
self.proxy_auth = proxy_auth
|
||||||
|
|
||||||
# Asyncio loop instance
|
|
||||||
if loop is None:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
self.loop = loop
|
|
||||||
|
|
||||||
# aiohttp main session
|
# aiohttp main session
|
||||||
ssl_context = ssl.create_default_context(cafile=certifi.where())
|
ssl_context = ssl.create_default_context(cafile=certifi.where())
|
||||||
|
|
||||||
self._session: Optional[aiohttp.ClientSession] = None
|
self._session: Optional[aiohttp.ClientSession] = None
|
||||||
self._connector_class: Type[aiohttp.TCPConnector] = aiohttp.TCPConnector
|
self._connector_class: Type[aiohttp.TCPConnector] = aiohttp.TCPConnector
|
||||||
self._connector_init = dict(
|
self._connector_init = dict(limit=connections_limit, ssl=ssl_context)
|
||||||
limit=connections_limit, ssl=ssl_context, loop=self.loop
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(proxy, str) and (proxy.startswith('socks5://') or proxy.startswith('socks4://')):
|
if isinstance(proxy, str) and (proxy.startswith('socks5://') or proxy.startswith('socks4://')):
|
||||||
from aiohttp_socks import SocksConnector
|
from aiohttp_socks import SocksConnector
|
||||||
|
|
@ -106,11 +101,15 @@ class BaseBot:
|
||||||
|
|
||||||
def get_new_session(self) -> aiohttp.ClientSession:
|
def get_new_session(self) -> aiohttp.ClientSession:
|
||||||
return aiohttp.ClientSession(
|
return aiohttp.ClientSession(
|
||||||
connector=self._connector_class(**self._connector_init),
|
connector=self._connector_class(**self._connector_init, loop=self._main_loop),
|
||||||
loop=self.loop,
|
loop=self._main_loop,
|
||||||
json_serialize=json.dumps
|
json_serialize=json.dumps
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loop(self) -> Optional[asyncio.AbstractEventLoop]:
|
||||||
|
return self._main_loop
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def session(self) -> Optional[aiohttp.ClientSession]:
|
def session(self) -> Optional[aiohttp.ClientSession]:
|
||||||
if self._session is None or self._session.closed:
|
if self._session is None or self._session.closed:
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,32 @@ log = logging.getLogger(__name__)
|
||||||
DEFAULT_RATE_LIMIT = .1
|
DEFAULT_RATE_LIMIT = .1
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_loop(x):
|
||||||
|
assert isinstance(
|
||||||
|
x, asyncio.AbstractEventLoop
|
||||||
|
), f"Loop must the implementation of {asyncio.AbstractEventLoop!r}, " \
|
||||||
|
f"not {type(x)!r}"
|
||||||
|
|
||||||
|
|
||||||
|
if callable(getattr(asyncio, "create_task")):
|
||||||
|
_asyncio_create_task = asyncio.create_task
|
||||||
|
else:
|
||||||
|
from asyncio import events as _asyncio_events
|
||||||
|
|
||||||
|
def _asyncio_create_task(coro, *, name=None):
|
||||||
|
# ported and modified from asyncio-py38
|
||||||
|
loop = _asyncio_events.get_running_loop()
|
||||||
|
task = loop.create_task(coro)
|
||||||
|
if name is not None:
|
||||||
|
try:
|
||||||
|
set_name = task.set_name
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
set_name(name)
|
||||||
|
return task
|
||||||
|
|
||||||
|
|
||||||
class Dispatcher(DataMixin, ContextInstanceMixin):
|
class Dispatcher(DataMixin, ContextInstanceMixin):
|
||||||
"""
|
"""
|
||||||
Simple Updates dispatcher
|
Simple Updates dispatcher
|
||||||
|
|
@ -43,15 +69,15 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
|
||||||
if not isinstance(bot, Bot):
|
if not isinstance(bot, Bot):
|
||||||
raise TypeError(f"Argument 'bot' must be an instance of Bot, not '{type(bot).__name__}'")
|
raise TypeError(f"Argument 'bot' must be an instance of Bot, not '{type(bot).__name__}'")
|
||||||
|
|
||||||
if loop is None:
|
|
||||||
loop = bot.loop
|
|
||||||
if storage is None:
|
if storage is None:
|
||||||
storage = DisabledStorage()
|
storage = DisabledStorage()
|
||||||
if filters_factory is None:
|
if filters_factory is None:
|
||||||
filters_factory = FiltersFactory(self)
|
filters_factory = FiltersFactory(self)
|
||||||
|
|
||||||
self.bot: Bot = bot
|
self.bot: Bot = bot
|
||||||
self.loop = loop
|
if loop is not None:
|
||||||
|
_ensure_loop(loop)
|
||||||
|
self._main_loop = loop
|
||||||
self.storage = storage
|
self.storage = storage
|
||||||
self.run_tasks_by_default = run_tasks_by_default
|
self.run_tasks_by_default = run_tasks_by_default
|
||||||
|
|
||||||
|
|
@ -79,10 +105,25 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
|
||||||
|
|
||||||
self._polling = False
|
self._polling = False
|
||||||
self._closed = True
|
self._closed = True
|
||||||
self._close_waiter = loop.create_future()
|
self._dispatcher_close_waiter = None
|
||||||
|
|
||||||
self._setup_filters()
|
self._setup_filters()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loop(self) -> asyncio.AbstractEventLoop:
|
||||||
|
# for the sake of backward compatibility
|
||||||
|
# lib internally must delegate tasks with respect to _main_loop attribute
|
||||||
|
return self._main_loop
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _close_waiter(self) -> "asyncio.Future":
|
||||||
|
if self._dispatcher_close_waiter is None:
|
||||||
|
if self._main_loop is not None:
|
||||||
|
self._dispatcher_close_waiter = self._main_loop.create_future()
|
||||||
|
else:
|
||||||
|
self._dispatcher_close_waiter = asyncio.get_running_loop().create_future()
|
||||||
|
return self._dispatcher_close_waiter
|
||||||
|
|
||||||
def _setup_filters(self):
|
def _setup_filters(self):
|
||||||
filters_factory = self.filters_factory
|
filters_factory = self.filters_factory
|
||||||
|
|
||||||
|
|
@ -282,6 +323,13 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
|
||||||
|
|
||||||
return await self.bot.delete_webhook()
|
return await self.bot.delete_webhook()
|
||||||
|
|
||||||
|
def _loop_create_task(self, coro):
|
||||||
|
if self._main_loop is None:
|
||||||
|
return _asyncio_create_task(coro=coro)
|
||||||
|
else:
|
||||||
|
_ensure_loop(self._main_loop)
|
||||||
|
return self._main_loop.create_task(coro)
|
||||||
|
|
||||||
async def start_polling(self,
|
async def start_polling(self,
|
||||||
timeout=20,
|
timeout=20,
|
||||||
relax=0.1,
|
relax=0.1,
|
||||||
|
|
@ -337,7 +385,7 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
|
||||||
log.debug(f"Received {len(updates)} updates.")
|
log.debug(f"Received {len(updates)} updates.")
|
||||||
offset = updates[-1].update_id + 1
|
offset = updates[-1].update_id + 1
|
||||||
|
|
||||||
self.loop.create_task(self._process_polling_updates(updates, fast))
|
self._loop_create_task(self._process_polling_updates(updates, fast))
|
||||||
|
|
||||||
if relax:
|
if relax:
|
||||||
await asyncio.sleep(relax)
|
await asyncio.sleep(relax)
|
||||||
|
|
@ -381,7 +429,7 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
|
||||||
|
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
await asyncio.shield(self._close_waiter, loop=self.loop)
|
await asyncio.shield(self._close_waiter)
|
||||||
|
|
||||||
def is_polling(self):
|
def is_polling(self):
|
||||||
"""
|
"""
|
||||||
|
|
@ -1158,15 +1206,15 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
|
||||||
try:
|
try:
|
||||||
response = task.result()
|
response = task.result()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.loop.create_task(
|
self._loop_create_task(
|
||||||
self.errors_handlers.notify(types.Update.get_current(), e))
|
self.errors_handlers.notify(types.Update.get_current(), e))
|
||||||
else:
|
else:
|
||||||
if isinstance(response, BaseResponse):
|
if isinstance(response, BaseResponse):
|
||||||
self.loop.create_task(response.execute_response(self.bot))
|
self._loop_create_task(response.execute_response(self.bot))
|
||||||
|
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
async def wrapper(*args, **kwargs):
|
async def wrapper(*args, **kwargs):
|
||||||
task = self.loop.create_task(func(*args, **kwargs))
|
task = self._loop_create_task(func(*args, **kwargs))
|
||||||
task.add_done_callback(process_response)
|
task.add_done_callback(process_response)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
|
||||||
|
|
@ -16,11 +16,14 @@ class MiddlewareManager:
|
||||||
:param dispatcher: instance of Dispatcher
|
:param dispatcher: instance of Dispatcher
|
||||||
"""
|
"""
|
||||||
self.dispatcher = dispatcher
|
self.dispatcher = dispatcher
|
||||||
self.loop = dispatcher.loop
|
|
||||||
self.bot = dispatcher.bot
|
self.bot = dispatcher.bot
|
||||||
self.storage = dispatcher.storage
|
self.storage = dispatcher.storage
|
||||||
self.applications = []
|
self.applications = []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loop(self):
|
||||||
|
return self.dispatcher.loop
|
||||||
|
|
||||||
def setup(self, middleware):
|
def setup(self, middleware):
|
||||||
"""
|
"""
|
||||||
Setup middleware
|
Setup middleware
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue