mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
refactor(handler): remove require func
refactor requirements feature
This commit is contained in:
parent
de962d043d
commit
e060678b40
4 changed files with 69 additions and 82 deletions
|
|
@ -35,7 +35,7 @@ class CallableMixin:
|
||||||
awaitable: bool = field(init=False)
|
awaitable: bool = field(init=False)
|
||||||
spec: inspect.FullArgSpec = field(init=False)
|
spec: inspect.FullArgSpec = field(init=False)
|
||||||
|
|
||||||
__reqs__: Dict[str, Requirement[Any]] = field(init=False)
|
requirements: Dict[str, Requirement[Any]] = field(init=False)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
callback = inspect.unwrap(self.callback)
|
callback = inspect.unwrap(self.callback)
|
||||||
|
|
@ -47,10 +47,9 @@ class CallableMixin:
|
||||||
|
|
||||||
if _is_class_handler(callback):
|
if _is_class_handler(callback):
|
||||||
self.awaitable = True
|
self.awaitable = True
|
||||||
self.__reqs__ = get_reqs_from_class(callback)
|
self.requirements = get_reqs_from_class(callback)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.__reqs__ = get_reqs_from_callable(callable_=callback)
|
self.requirements = get_reqs_from_callable(callable_=callback)
|
||||||
|
|
||||||
def _prepare_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
def _prepare_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
if self.spec.varkw:
|
if self.spec.varkw:
|
||||||
|
|
@ -58,33 +57,33 @@ class CallableMixin:
|
||||||
|
|
||||||
return {k: v for k, v in kwargs.items() if k in self.spec.args}
|
return {k: v for k, v in kwargs.items() if k in self.spec.args}
|
||||||
|
|
||||||
async def call(self, *args: Any, **kwargs: Any) -> Any:
|
async def call(self, *args: Any, **data: Any) -> Any:
|
||||||
# we don't requirements_data and kwargs keys to intersect
|
# we don't requirements_data and kwargs keys to intersect
|
||||||
requirements_data: Dict[str, Any] = {}
|
requirements_data: Dict[str, Any] = {}
|
||||||
|
|
||||||
if self.__reqs__:
|
if self.requirements:
|
||||||
stack = cast(AsyncExitStack, kwargs.get(ASYNC_STACK_KEY))
|
stack = cast(AsyncExitStack, data.get(ASYNC_STACK_KEY))
|
||||||
cache_dict: Dict[CacheKeyType, Any] = kwargs.get(REQUIREMENT_CACHE_KEY, {})
|
cache_dict: Dict[CacheKeyType, Any] = data.get(REQUIREMENT_CACHE_KEY, {})
|
||||||
requirements_data = kwargs.copy()
|
requirements_data = data.copy()
|
||||||
|
|
||||||
for req_id, req in self.__reqs__.items():
|
for req_id, req in self.requirements.items():
|
||||||
requirements_data[req_id] = await req(
|
requirements_data[req_id] = await req(
|
||||||
cache_dict=cache_dict, stack=stack, data=requirements_data
|
cache_dict=cache_dict, stack=stack, data=requirements_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
for to_pop in kwargs:
|
for to_pop in data:
|
||||||
requirements_data.pop(to_pop, None)
|
requirements_data.pop(to_pop, None)
|
||||||
|
|
||||||
kwargs.pop(ASYNC_STACK_KEY, None)
|
data.pop(ASYNC_STACK_KEY, None)
|
||||||
kwargs.pop(REQUIREMENT_CACHE_KEY, None)
|
data.pop(REQUIREMENT_CACHE_KEY, None)
|
||||||
|
|
||||||
if _is_class_handler(self.callback):
|
if _is_class_handler(self.callback):
|
||||||
wrapped = partial(self.callback, *args, requirements_data, kwargs)
|
wrapped = partial(self.callback, *args, requirements_data, data)
|
||||||
else:
|
else:
|
||||||
wrapped = partial(
|
wrapped = partial(
|
||||||
self.callback,
|
self.callback,
|
||||||
*args,
|
*args,
|
||||||
**self._prepare_kwargs(kwargs),
|
**self._prepare_kwargs(data),
|
||||||
**self._prepare_kwargs(requirements_data),
|
**self._prepare_kwargs(requirements_data),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import abc
|
|
||||||
import enum
|
import enum
|
||||||
import inspect
|
import inspect
|
||||||
from contextlib import AsyncExitStack, asynccontextmanager, contextmanager
|
from contextlib import AsyncExitStack, asynccontextmanager, contextmanager
|
||||||
|
|
@ -8,6 +7,7 @@ from typing import (
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
|
Generator,
|
||||||
Generic,
|
Generic,
|
||||||
Optional,
|
Optional,
|
||||||
Type,
|
Type,
|
||||||
|
|
@ -18,7 +18,10 @@ from typing import (
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
CacheKeyType = Union[str, int]
|
CacheKeyType = Union[str, int]
|
||||||
_RequiredCallback = Callable[..., Union[T, AsyncGenerator[None, T], Awaitable[T]]]
|
CacheType = Dict[CacheKeyType, Any]
|
||||||
|
_RequiredCallback = Callable[
|
||||||
|
..., Union[T, Generator[T, None, None], AsyncGenerator[None, T], Awaitable[T]]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class GeneratorKind(enum.IntEnum):
|
class GeneratorKind(enum.IntEnum):
|
||||||
|
|
@ -42,28 +45,19 @@ async def move_to_async_gen(context_manager: Any) -> Any:
|
||||||
context_manager.__exit__(None, None, None)
|
context_manager.__exit__(None, None, None)
|
||||||
|
|
||||||
|
|
||||||
class Requirement(abc.ABC, Generic[T]):
|
class Requirement(Generic[T]):
|
||||||
"""
|
|
||||||
Interface for all requirements
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def __call__(
|
|
||||||
self, cache_dict: Dict[CacheKeyType, Any], stack: AsyncExitStack, data: Dict[str, Any],
|
|
||||||
) -> T:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
class CallableRequirement(Requirement[T]):
|
|
||||||
__slots__ = "callable", "children", "cache_key", "use_cache", "generator_type"
|
__slots__ = "callable", "children", "cache_key", "use_cache", "generator_type"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
callable_: _RequiredCallback[T],
|
callable_: _RequiredCallback[T],
|
||||||
*,
|
|
||||||
cache_key: Optional[CacheKeyType] = None,
|
cache_key: Optional[CacheKeyType] = None,
|
||||||
use_cache: bool = True,
|
use_cache: bool = True,
|
||||||
):
|
):
|
||||||
self.callable = callable_
|
self.callable = callable_
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.children: Dict[str, Requirement[Any]] = {}
|
||||||
|
|
||||||
self.generator_type = GeneratorKind.not_a_gen
|
self.generator_type = GeneratorKind.not_a_gen
|
||||||
|
|
||||||
if inspect.isasyncgenfunction(callable_):
|
if inspect.isasyncgenfunction(callable_):
|
||||||
|
|
@ -72,7 +66,6 @@ class CallableRequirement(Requirement[T]):
|
||||||
self.generator_type = GeneratorKind.plain_gen
|
self.generator_type = GeneratorKind.plain_gen
|
||||||
|
|
||||||
self.cache_key = hash(callable_) if cache_key is None else cache_key
|
self.cache_key = hash(callable_) if cache_key is None else cache_key
|
||||||
self.use_cache = use_cache
|
|
||||||
self.children = get_reqs_from_callable(callable_)
|
self.children = get_reqs_from_callable(callable_)
|
||||||
|
|
||||||
assert not (
|
assert not (
|
||||||
|
|
@ -83,27 +76,23 @@ class CallableRequirement(Requirement[T]):
|
||||||
return {key: value for key, value in data.items() if key in self.children}
|
return {key: value for key, value in data.items() if key in self.children}
|
||||||
|
|
||||||
async def initialize_children(
|
async def initialize_children(
|
||||||
self, data: Dict[str, Any], cache_dict: Dict[CacheKeyType, Any], stack: AsyncExitStack
|
self, data: Dict[str, Any], cache_dict: CacheType, stack: AsyncExitStack,
|
||||||
) -> None:
|
) -> None:
|
||||||
for req_id, req in self.children.items():
|
for req_id, req in self.children.items():
|
||||||
if isinstance(req, CachedRequirement):
|
|
||||||
data[req_id] = await req(data=data, cache_dict=cache_dict, stack=stack)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if isinstance(req, CallableRequirement):
|
await req.initialize_children(data, cache_dict, stack)
|
||||||
await req.initialize_children(data, cache_dict, stack)
|
|
||||||
|
|
||||||
if req.use_cache and req.cache_key in cache_dict:
|
if req.use_cache and req.cache_key in cache_dict:
|
||||||
data[req_id] = cache_dict[req.cache_key]
|
data[req_id] = cache_dict[req.cache_key]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
data[req_id] = await initialize_callable_requirement(req, data, stack)
|
data[req_id] = await initialize_callable_requirement(req, data, stack)
|
||||||
|
|
||||||
if req.use_cache:
|
if req.use_cache:
|
||||||
cache_dict[req.cache_key] = data[req_id]
|
cache_dict[req.cache_key] = data[req_id]
|
||||||
|
|
||||||
async def __call__(
|
async def __call__(
|
||||||
self, cache_dict: Dict[CacheKeyType, Any], stack: AsyncExitStack, data: Dict[str, Any],
|
self, *, cache_dict: CacheType, stack: AsyncExitStack, data: Dict[str, Any],
|
||||||
) -> T:
|
) -> T:
|
||||||
await self.initialize_children(data, cache_dict, stack)
|
await self.initialize_children(data, cache_dict, stack)
|
||||||
|
|
||||||
|
|
@ -113,24 +102,12 @@ class CallableRequirement(Requirement[T]):
|
||||||
result = await initialize_callable_requirement(self, data, stack)
|
result = await initialize_callable_requirement(self, data, stack)
|
||||||
if self.use_cache:
|
if self.use_cache:
|
||||||
cache_dict[self.cache_key] = result
|
cache_dict[self.cache_key] = result
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class CachedRequirement(Requirement[T]):
|
|
||||||
__slots__ = "cache_key", "value_on_miss"
|
|
||||||
|
|
||||||
def __init__(self, cache_key: CacheKeyType, value_on_miss: T):
|
|
||||||
self.cache_key: CacheKeyType = cache_key
|
|
||||||
self.value_on_miss = value_on_miss
|
|
||||||
|
|
||||||
async def __call__(
|
|
||||||
self, cache_dict: Dict[CacheKeyType, Any], stack: AsyncExitStack, data: Dict[str, Any],
|
|
||||||
) -> T:
|
|
||||||
return cache_dict.get(self.cache_key, self.value_on_miss)
|
|
||||||
|
|
||||||
|
|
||||||
async def initialize_callable_requirement(
|
async def initialize_callable_requirement(
|
||||||
required: CallableRequirement[T], data: Dict[str, Any], stack: AsyncExitStack
|
required: Requirement[T], data: Dict[str, Any], stack: AsyncExitStack
|
||||||
) -> T:
|
) -> T:
|
||||||
actual_data = required.filter_kwargs(data)
|
actual_data = required.filter_kwargs(data)
|
||||||
async_cm: Optional[Any] = None
|
async_cm: Optional[Any] = None
|
||||||
|
|
@ -162,12 +139,3 @@ def get_reqs_from_class(cls: Type[Any]) -> Dict[str, Requirement[Any]]:
|
||||||
return {
|
return {
|
||||||
req_attr: req for req_attr, req in cls.__dict__.items() if isinstance(req, Requirement)
|
req_attr: req for req_attr, req in cls.__dict__.items() if isinstance(req, Requirement)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def require(
|
|
||||||
what: _RequiredCallback[T],
|
|
||||||
*,
|
|
||||||
cache_key: Optional[CacheKeyType] = None,
|
|
||||||
use_cache: bool = True,
|
|
||||||
) -> T:
|
|
||||||
return CallableRequirement(what, cache_key=cache_key, use_cache=use_cache) # type: ignore
|
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ async def callback2(foo: int, bar: int, baz: int):
|
||||||
return locals()
|
return locals()
|
||||||
|
|
||||||
|
|
||||||
async def callback3(foo: int, **kwargs):
|
async def callback3(foo: int, **data):
|
||||||
return locals()
|
return locals()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -62,8 +62,8 @@ class TestCallableMixin:
|
||||||
def test_init_decorated(self):
|
def test_init_decorated(self):
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **data):
|
||||||
return func(*args, **kwargs)
|
return func(*args, **data)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
@ -85,7 +85,7 @@ class TestCallableMixin:
|
||||||
assert obj2.callback == callback2
|
assert obj2.callback == callback2
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"callback,kwargs,result",
|
"callback,data,result",
|
||||||
[
|
[
|
||||||
pytest.param(
|
pytest.param(
|
||||||
callback1, {"foo": 42, "spam": True, "baz": "fuz"}, {"foo": 42, "baz": "fuz"}
|
callback1, {"foo": 42, "spam": True, "baz": "fuz"}, {"foo": 42, "baz": "fuz"}
|
||||||
|
|
@ -108,9 +108,9 @@ class TestCallableMixin:
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_prepare_kwargs(self, callback, kwargs, result):
|
def test_prepare_data(self, callback, data, result):
|
||||||
obj = CallableMixin(callback)
|
obj = CallableMixin(callback)
|
||||||
assert obj._prepare_kwargs(kwargs) == result
|
assert obj._prepare_kwargs(data) == result
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_sync_call(self):
|
async def test_sync_call(self):
|
||||||
|
|
@ -127,8 +127,8 @@ class TestCallableMixin:
|
||||||
assert result == {"foo": 42, "bar": "test", "baz": "fuz"}
|
assert result == {"foo": 42, "bar": "test", "baz": "fuz"}
|
||||||
|
|
||||||
|
|
||||||
async def simple_handler(*args, **kwargs):
|
async def simple_handler(*args, **data):
|
||||||
return args, kwargs
|
return args, data
|
||||||
|
|
||||||
|
|
||||||
class TestHandlerObject:
|
class TestHandlerObject:
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,37 @@
|
||||||
# todo
|
# todo
|
||||||
from aiogram.dispatcher.requirement import require, CallableRequirement
|
from aiogram.dispatcher.requirement import Requirement, get_reqs_from_class, get_reqs_from_callable
|
||||||
|
|
||||||
tick_data = {"ticks": 0}
|
tick_data = {"ticks": 0}
|
||||||
|
|
||||||
|
|
||||||
|
req1 = Requirement(lambda: 1)
|
||||||
|
req2 = Requirement(lambda: 1)
|
||||||
|
|
||||||
|
|
||||||
|
async def callback(
|
||||||
|
o,
|
||||||
|
x=req1,
|
||||||
|
y=req2
|
||||||
|
):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
def test_require():
|
def test_require():
|
||||||
x = require(lambda: "str", use_cache=True, cache_key=0)
|
x = Requirement(lambda: "str", use_cache=True, cache_key=0)
|
||||||
assert isinstance(x, CallableRequirement)
|
assert isinstance(x, Requirement)
|
||||||
assert callable(x) & callable(x.callable)
|
assert callable(x) & callable(x.callable)
|
||||||
assert x.cache_key == 0
|
assert x.cache_key == 0
|
||||||
assert x.use_cache
|
assert x.use_cache
|
||||||
|
|
||||||
|
|
||||||
class TestCallableRequirementCache:
|
class TestReqUtils:
|
||||||
def test_cache(self):
|
def test_get_reqs_from_callable(self):
|
||||||
...
|
assert set(get_reqs_from_callable(callback).values()) == {req1, req2}
|
||||||
|
assert set(get_reqs_from_callable(callback).keys()) == {"x", "y"}
|
||||||
|
|
||||||
|
def test_get_reqs_from_class(self):
|
||||||
|
class Class:
|
||||||
|
x = req1
|
||||||
|
y = req2
|
||||||
|
|
||||||
|
assert set(get_reqs_from_class(Class)) == {"x", "y"}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue