feat(reqs): determine function types

This commit is contained in:
mpa 2020-07-21 05:43:56 +04:00
parent ca44f9c01a
commit 2d4419ecdc
No known key found for this signature in database
GPG key ID: BCCFBFCCC9B754A8

View file

@ -1,5 +1,8 @@
import enum import enum
import asyncio
import inspect import inspect
import contextvars
from functools import partial
from contextlib import AsyncExitStack, asynccontextmanager, contextmanager from contextlib import AsyncExitStack, asynccontextmanager, contextmanager
from typing import ( from typing import (
Any, Any,
@ -46,7 +49,7 @@ async def move_to_async_gen(context_manager: Any) -> Any:
class Requirement(Generic[T]): class Requirement(Generic[T]):
__slots__ = "callable", "children", "cache_key", "use_cache", "generator_type" __slots__ = "callable", "children", "cache_key", "use_cache", "generator_type", "is_async"
def __init__( def __init__(
self, self,
@ -60,10 +63,14 @@ class Requirement(Generic[T]):
self.generator_type = GeneratorKind.not_a_gen self.generator_type = GeneratorKind.not_a_gen
self.is_async: Optional[bool] = None # unset value
if inspect.isasyncgenfunction(callable_): if inspect.isasyncgenfunction(callable_):
self.generator_type = GeneratorKind.async_gen self.generator_type = GeneratorKind.async_gen
elif inspect.isgeneratorfunction(callable_): elif inspect.isgeneratorfunction(callable_):
self.generator_type = GeneratorKind.plain_gen self.generator_type = GeneratorKind.plain_gen
else:
self.is_async = inspect.iscoroutinefunction(callable_)
self.cache_key = hash(callable_) if cache_key is None else cache_key self.cache_key = hash(callable_) if cache_key is None else cache_key
self.children = get_reqs_from_callable(callable_) self.children = get_reqs_from_callable(callable_)
@ -119,11 +126,17 @@ async def initialize_callable_requirement(
if async_cm is not None: if async_cm is not None:
return await stack.enter_async_context(async_cm) return await stack.enter_async_context(async_cm)
else: else:
result = required.callable(**actual_data) if not required.is_async:
if isinstance(result, Awaitable): context = contextvars.copy_context()
return cast(T, await result) wrapped = partial(context.run, partial(required.callable, **actual_data))
return cast(T, result)
loop = asyncio.get_event_loop()
return cast(T, await loop.run_in_executor(None, wrapped))
else:
return cast(T, await required.callable(**actual_data))
def get_reqs_from_callable(callable_: _RequiredCallback[T]) -> Dict[str, Requirement[Any]]: def get_reqs_from_callable(callable_: _RequiredCallback[T]) -> Dict[str, Requirement[Any]]: