mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
feat(reqs): determine function types
This commit is contained in:
parent
ca44f9c01a
commit
2d4419ecdc
1 changed files with 18 additions and 5 deletions
|
|
@ -1,5 +1,8 @@
|
||||||
import enum
|
import enum
|
||||||
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
|
import contextvars
|
||||||
|
from functools import partial
|
||||||
from contextlib import AsyncExitStack, asynccontextmanager, contextmanager
|
from contextlib import AsyncExitStack, asynccontextmanager, contextmanager
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
|
@ -46,7 +49,7 @@ async def move_to_async_gen(context_manager: Any) -> Any:
|
||||||
|
|
||||||
|
|
||||||
class Requirement(Generic[T]):
|
class Requirement(Generic[T]):
|
||||||
__slots__ = "callable", "children", "cache_key", "use_cache", "generator_type"
|
__slots__ = "callable", "children", "cache_key", "use_cache", "generator_type", "is_async"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -60,10 +63,14 @@ class Requirement(Generic[T]):
|
||||||
|
|
||||||
self.generator_type = GeneratorKind.not_a_gen
|
self.generator_type = GeneratorKind.not_a_gen
|
||||||
|
|
||||||
|
self.is_async: Optional[bool] = None # unset value
|
||||||
|
|
||||||
if inspect.isasyncgenfunction(callable_):
|
if inspect.isasyncgenfunction(callable_):
|
||||||
self.generator_type = GeneratorKind.async_gen
|
self.generator_type = GeneratorKind.async_gen
|
||||||
elif inspect.isgeneratorfunction(callable_):
|
elif inspect.isgeneratorfunction(callable_):
|
||||||
self.generator_type = GeneratorKind.plain_gen
|
self.generator_type = GeneratorKind.plain_gen
|
||||||
|
else:
|
||||||
|
self.is_async = inspect.iscoroutinefunction(callable_)
|
||||||
|
|
||||||
self.cache_key = hash(callable_) if cache_key is None else cache_key
|
self.cache_key = hash(callable_) if cache_key is None else cache_key
|
||||||
self.children = get_reqs_from_callable(callable_)
|
self.children = get_reqs_from_callable(callable_)
|
||||||
|
|
@ -119,11 +126,17 @@ async def initialize_callable_requirement(
|
||||||
|
|
||||||
if async_cm is not None:
|
if async_cm is not None:
|
||||||
return await stack.enter_async_context(async_cm)
|
return await stack.enter_async_context(async_cm)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
result = required.callable(**actual_data)
|
if not required.is_async:
|
||||||
if isinstance(result, Awaitable):
|
context = contextvars.copy_context()
|
||||||
return cast(T, await result)
|
wrapped = partial(context.run, partial(required.callable, **actual_data))
|
||||||
return cast(T, result)
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return cast(T, await loop.run_in_executor(None, wrapped))
|
||||||
|
|
||||||
|
else:
|
||||||
|
return cast(T, await required.callable(**actual_data))
|
||||||
|
|
||||||
|
|
||||||
def get_reqs_from_callable(callable_: _RequiredCallback[T]) -> Dict[str, Requirement[Any]]:
|
def get_reqs_from_callable(callable_: _RequiredCallback[T]) -> Dict[str, Requirement[Any]]:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue