Add support for nullable fields in callback data

This update extends the callback data handling by adding support for nullable fields. The code now uses the Python typing structures `Optional` and `Union` to parse such fields correctly. A helper function `_check_field_is_nullable` has been added to assist in efficiently checking if a given field is nullable.
This commit is contained in:
JRoot Junior 2023-11-20 22:17:48 +02:00
parent 7c295f6b3d
commit 42599fa82a
No known key found for this signature in database
GPG key ID: 738964250D5FF6E2
2 changed files with 31 additions and 5 deletions

View file

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import typing
from decimal import Decimal from decimal import Decimal
from enum import Enum from enum import Enum
from fractions import Fraction from fractions import Fraction
@ -18,6 +19,7 @@ from uuid import UUID
from magic_filter import MagicFilter from magic_filter import MagicFilter
from pydantic import BaseModel from pydantic import BaseModel
from pydantic.fields import FieldInfo
from aiogram.filters.base import Filter from aiogram.filters.base import Filter
from aiogram.types import CallbackQuery from aiogram.types import CallbackQuery
@ -120,8 +122,9 @@ class CallbackData(BaseModel):
raise ValueError(f"Bad prefix ({prefix!r} != {cls.__prefix__!r})") raise ValueError(f"Bad prefix ({prefix!r} != {cls.__prefix__!r})")
payload = {} payload = {}
for k, v in zip(names, parts): # type: str, Optional[str] for k, v in zip(names, parts): # type: str, Optional[str]
if v == "": if field := cls.model_fields.get(k):
v = None if v == "" and _check_field_is_nullable(field):
v = None
payload[k] = v payload[k] = v
return cls(**payload) return cls(**payload)
@ -179,3 +182,19 @@ class CallbackQueryFilter(Filter):
if self.rule is None or self.rule.resolve(callback_data): if self.rule is None or self.rule.resolve(callback_data):
return {"callback_data": callback_data} return {"callback_data": callback_data}
return False return False
def _check_field_is_nullable(field: FieldInfo) -> bool:
"""
Check if the given field is nullable.
:param field: The FieldInfo object representing the field to check.
:return: True if the field is nullable, False otherwise.
"""
if not field.is_required():
return True
return typing.get_origin(field.annotation) is typing.Union and type(None) in typing.get_args(
field.annotation
)

View file

@ -1,7 +1,7 @@
from decimal import Decimal from decimal import Decimal
from enum import Enum, auto from enum import Enum, auto
from fractions import Fraction from fractions import Fraction
from typing import Optional from typing import Optional, Union
from uuid import UUID from uuid import UUID
import pytest import pytest
@ -147,12 +147,19 @@ class TestCallbackData:
assert MyCallback3.unpack("test3:experiment:42") == MyCallback3(bar=42) assert MyCallback3.unpack("test3:experiment:42") == MyCallback3(bar=42)
assert MyCallback3.unpack("test3:spam:42") == MyCallback3(foo="spam", bar=42) assert MyCallback3.unpack("test3:spam:42") == MyCallback3(foo="spam", bar=42)
def test_unpack_optional_wo_default(self): @pytest.mark.parametrize(
"hint",
[
Union[int, None],
Optional[int],
],
)
def test_unpack_optional_wo_default(self, hint):
"""Test CallbackData without default optional.""" """Test CallbackData without default optional."""
class TgData(CallbackData, prefix="tg"): class TgData(CallbackData, prefix="tg"):
chat_id: int chat_id: int
thread_id: Optional[int] thread_id: hint
assert TgData.unpack("tg:123:") == TgData(chat_id=123, thread_id=None) assert TgData.unpack("tg:123:") == TgData(chat_id=123, thread_id=None)