refactor(handler): remove require func

refactor requirements feature
This commit is contained in:
mpa 2020-05-31 09:09:56 +04:00
parent de962d043d
commit e060678b40
No known key found for this signature in database
GPG key ID: BCCFBFCCC9B754A8
4 changed files with 69 additions and 82 deletions

View file

@ -35,7 +35,7 @@ class CallableMixin:
awaitable: bool = field(init=False) awaitable: bool = field(init=False)
spec: inspect.FullArgSpec = field(init=False) spec: inspect.FullArgSpec = field(init=False)
__reqs__: Dict[str, Requirement[Any]] = field(init=False) requirements: Dict[str, Requirement[Any]] = field(init=False)
def __post_init__(self) -> None: def __post_init__(self) -> None:
callback = inspect.unwrap(self.callback) callback = inspect.unwrap(self.callback)
@ -47,10 +47,9 @@ class CallableMixin:
if _is_class_handler(callback): if _is_class_handler(callback):
self.awaitable = True self.awaitable = True
self.__reqs__ = get_reqs_from_class(callback) self.requirements = get_reqs_from_class(callback)
else: else:
self.__reqs__ = get_reqs_from_callable(callable_=callback) self.requirements = get_reqs_from_callable(callable_=callback)
def _prepare_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: def _prepare_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
if self.spec.varkw: if self.spec.varkw:
@ -58,33 +57,33 @@ class CallableMixin:
return {k: v for k, v in kwargs.items() if k in self.spec.args} return {k: v for k, v in kwargs.items() if k in self.spec.args}
async def call(self, *args: Any, **kwargs: Any) -> Any: async def call(self, *args: Any, **data: Any) -> Any:
# we don't requirements_data and kwargs keys to intersect # we don't requirements_data and kwargs keys to intersect
requirements_data: Dict[str, Any] = {} requirements_data: Dict[str, Any] = {}
if self.__reqs__: if self.requirements:
stack = cast(AsyncExitStack, kwargs.get(ASYNC_STACK_KEY)) stack = cast(AsyncExitStack, data.get(ASYNC_STACK_KEY))
cache_dict: Dict[CacheKeyType, Any] = kwargs.get(REQUIREMENT_CACHE_KEY, {}) cache_dict: Dict[CacheKeyType, Any] = data.get(REQUIREMENT_CACHE_KEY, {})
requirements_data = kwargs.copy() requirements_data = data.copy()
for req_id, req in self.__reqs__.items(): for req_id, req in self.requirements.items():
requirements_data[req_id] = await req( requirements_data[req_id] = await req(
cache_dict=cache_dict, stack=stack, data=requirements_data cache_dict=cache_dict, stack=stack, data=requirements_data,
) )
for to_pop in kwargs: for to_pop in data:
requirements_data.pop(to_pop, None) requirements_data.pop(to_pop, None)
kwargs.pop(ASYNC_STACK_KEY, None) data.pop(ASYNC_STACK_KEY, None)
kwargs.pop(REQUIREMENT_CACHE_KEY, None) data.pop(REQUIREMENT_CACHE_KEY, None)
if _is_class_handler(self.callback): if _is_class_handler(self.callback):
wrapped = partial(self.callback, *args, requirements_data, kwargs) wrapped = partial(self.callback, *args, requirements_data, data)
else: else:
wrapped = partial( wrapped = partial(
self.callback, self.callback,
*args, *args,
**self._prepare_kwargs(kwargs), **self._prepare_kwargs(data),
**self._prepare_kwargs(requirements_data), **self._prepare_kwargs(requirements_data),
) )

View file

@ -1,4 +1,3 @@
import abc
import enum import enum
import inspect import inspect
from contextlib import AsyncExitStack, asynccontextmanager, contextmanager from contextlib import AsyncExitStack, asynccontextmanager, contextmanager
@ -8,6 +7,7 @@ from typing import (
Awaitable, Awaitable,
Callable, Callable,
Dict, Dict,
Generator,
Generic, Generic,
Optional, Optional,
Type, Type,
@ -18,7 +18,10 @@ from typing import (
T = TypeVar("T") T = TypeVar("T")
CacheKeyType = Union[str, int] CacheKeyType = Union[str, int]
_RequiredCallback = Callable[..., Union[T, AsyncGenerator[None, T], Awaitable[T]]] CacheType = Dict[CacheKeyType, Any]
_RequiredCallback = Callable[
..., Union[T, Generator[T, None, None], AsyncGenerator[None, T], Awaitable[T]]
]
class GeneratorKind(enum.IntEnum): class GeneratorKind(enum.IntEnum):
@ -42,28 +45,19 @@ async def move_to_async_gen(context_manager: Any) -> Any:
context_manager.__exit__(None, None, None) context_manager.__exit__(None, None, None)
class Requirement(abc.ABC, Generic[T]): class Requirement(Generic[T]):
"""
Interface for all requirements
"""
async def __call__(
self, cache_dict: Dict[CacheKeyType, Any], stack: AsyncExitStack, data: Dict[str, Any],
) -> T:
raise NotImplementedError()
class CallableRequirement(Requirement[T]):
__slots__ = "callable", "children", "cache_key", "use_cache", "generator_type" __slots__ = "callable", "children", "cache_key", "use_cache", "generator_type"
def __init__( def __init__(
self, self,
callable_: _RequiredCallback[T], callable_: _RequiredCallback[T],
*,
cache_key: Optional[CacheKeyType] = None, cache_key: Optional[CacheKeyType] = None,
use_cache: bool = True, use_cache: bool = True,
): ):
self.callable = callable_ self.callable = callable_
self.use_cache = use_cache
self.children: Dict[str, Requirement[Any]] = {}
self.generator_type = GeneratorKind.not_a_gen self.generator_type = GeneratorKind.not_a_gen
if inspect.isasyncgenfunction(callable_): if inspect.isasyncgenfunction(callable_):
@ -72,7 +66,6 @@ class CallableRequirement(Requirement[T]):
self.generator_type = GeneratorKind.plain_gen self.generator_type = GeneratorKind.plain_gen
self.cache_key = hash(callable_) if cache_key is None else cache_key self.cache_key = hash(callable_) if cache_key is None else cache_key
self.use_cache = use_cache
self.children = get_reqs_from_callable(callable_) self.children = get_reqs_from_callable(callable_)
assert not ( assert not (
@ -83,27 +76,23 @@ class CallableRequirement(Requirement[T]):
return {key: value for key, value in data.items() if key in self.children} return {key: value for key, value in data.items() if key in self.children}
async def initialize_children( async def initialize_children(
self, data: Dict[str, Any], cache_dict: Dict[CacheKeyType, Any], stack: AsyncExitStack self, data: Dict[str, Any], cache_dict: CacheType, stack: AsyncExitStack,
) -> None: ) -> None:
for req_id, req in self.children.items(): for req_id, req in self.children.items():
if isinstance(req, CachedRequirement):
data[req_id] = await req(data=data, cache_dict=cache_dict, stack=stack)
continue
if isinstance(req, CallableRequirement): await req.initialize_children(data, cache_dict, stack)
await req.initialize_children(data, cache_dict, stack)
if req.use_cache and req.cache_key in cache_dict: if req.use_cache and req.cache_key in cache_dict:
data[req_id] = cache_dict[req.cache_key] data[req_id] = cache_dict[req.cache_key]
else: else:
data[req_id] = await initialize_callable_requirement(req, data, stack) data[req_id] = await initialize_callable_requirement(req, data, stack)
if req.use_cache: if req.use_cache:
cache_dict[req.cache_key] = data[req_id] cache_dict[req.cache_key] = data[req_id]
async def __call__( async def __call__(
self, cache_dict: Dict[CacheKeyType, Any], stack: AsyncExitStack, data: Dict[str, Any], self, *, cache_dict: CacheType, stack: AsyncExitStack, data: Dict[str, Any],
) -> T: ) -> T:
await self.initialize_children(data, cache_dict, stack) await self.initialize_children(data, cache_dict, stack)
@ -113,24 +102,12 @@ class CallableRequirement(Requirement[T]):
result = await initialize_callable_requirement(self, data, stack) result = await initialize_callable_requirement(self, data, stack)
if self.use_cache: if self.use_cache:
cache_dict[self.cache_key] = result cache_dict[self.cache_key] = result
return result return result
class CachedRequirement(Requirement[T]):
__slots__ = "cache_key", "value_on_miss"
def __init__(self, cache_key: CacheKeyType, value_on_miss: T):
self.cache_key: CacheKeyType = cache_key
self.value_on_miss = value_on_miss
async def __call__(
self, cache_dict: Dict[CacheKeyType, Any], stack: AsyncExitStack, data: Dict[str, Any],
) -> T:
return cache_dict.get(self.cache_key, self.value_on_miss)
async def initialize_callable_requirement( async def initialize_callable_requirement(
required: CallableRequirement[T], data: Dict[str, Any], stack: AsyncExitStack required: Requirement[T], data: Dict[str, Any], stack: AsyncExitStack
) -> T: ) -> T:
actual_data = required.filter_kwargs(data) actual_data = required.filter_kwargs(data)
async_cm: Optional[Any] = None async_cm: Optional[Any] = None
@ -162,12 +139,3 @@ def get_reqs_from_class(cls: Type[Any]) -> Dict[str, Requirement[Any]]:
return { return {
req_attr: req for req_attr, req in cls.__dict__.items() if isinstance(req, Requirement) req_attr: req for req_attr, req in cls.__dict__.items() if isinstance(req, Requirement)
} }
def require(
what: _RequiredCallback[T],
*,
cache_key: Optional[CacheKeyType] = None,
use_cache: bool = True,
) -> T:
return CallableRequirement(what, cache_key=cache_key, use_cache=use_cache) # type: ignore

View file

@ -18,7 +18,7 @@ async def callback2(foo: int, bar: int, baz: int):
return locals() return locals()
async def callback3(foo: int, **kwargs): async def callback3(foo: int, **data):
return locals() return locals()
@ -62,8 +62,8 @@ class TestCallableMixin:
def test_init_decorated(self): def test_init_decorated(self):
def decorator(func): def decorator(func):
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **data):
return func(*args, **kwargs) return func(*args, **data)
return wrapper return wrapper
@ -85,7 +85,7 @@ class TestCallableMixin:
assert obj2.callback == callback2 assert obj2.callback == callback2
@pytest.mark.parametrize( @pytest.mark.parametrize(
"callback,kwargs,result", "callback,data,result",
[ [
pytest.param( pytest.param(
callback1, {"foo": 42, "spam": True, "baz": "fuz"}, {"foo": 42, "baz": "fuz"} callback1, {"foo": 42, "spam": True, "baz": "fuz"}, {"foo": 42, "baz": "fuz"}
@ -108,9 +108,9 @@ class TestCallableMixin:
), ),
], ],
) )
def test_prepare_kwargs(self, callback, kwargs, result): def test_prepare_data(self, callback, data, result):
obj = CallableMixin(callback) obj = CallableMixin(callback)
assert obj._prepare_kwargs(kwargs) == result assert obj._prepare_kwargs(data) == result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_sync_call(self): async def test_sync_call(self):
@ -127,8 +127,8 @@ class TestCallableMixin:
assert result == {"foo": 42, "bar": "test", "baz": "fuz"} assert result == {"foo": 42, "bar": "test", "baz": "fuz"}
async def simple_handler(*args, **kwargs): async def simple_handler(*args, **data):
return args, kwargs return args, data
class TestHandlerObject: class TestHandlerObject:

View file

@ -1,17 +1,37 @@
# todo # todo
from aiogram.dispatcher.requirement import require, CallableRequirement from aiogram.dispatcher.requirement import Requirement, get_reqs_from_class, get_reqs_from_callable
tick_data = {"ticks": 0} tick_data = {"ticks": 0}
req1 = Requirement(lambda: 1)
req2 = Requirement(lambda: 1)
async def callback(
o,
x=req1,
y=req2
):
...
def test_require(): def test_require():
x = require(lambda: "str", use_cache=True, cache_key=0) x = Requirement(lambda: "str", use_cache=True, cache_key=0)
assert isinstance(x, CallableRequirement) assert isinstance(x, Requirement)
assert callable(x) & callable(x.callable) assert callable(x) & callable(x.callable)
assert x.cache_key == 0 assert x.cache_key == 0
assert x.use_cache assert x.use_cache
class TestCallableRequirementCache: class TestReqUtils:
def test_cache(self): def test_get_reqs_from_callable(self):
... assert set(get_reqs_from_callable(callback).values()) == {req1, req2}
assert set(get_reqs_from_callable(callback).keys()) == {"x", "y"}
def test_get_reqs_from_class(self):
class Class:
x = req1
y = req2
assert set(get_reqs_from_class(Class)) == {"x", "y"}