From ce4e1a706df0e40c88381648446085b9c36f2f03 Mon Sep 17 00:00:00 2001 From: JRoot Junior Date: Mon, 20 Nov 2023 22:49:55 +0200 Subject: [PATCH] #1370 added possibility to check X | None on Python >= 3.10 --- aiogram/filters/callback_data.py | 9 ++++++++- tests/test_filters/test_callback_data.py | 11 +++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/aiogram/filters/callback_data.py b/aiogram/filters/callback_data.py index 7c0dadf8..17ccffbf 100644 --- a/aiogram/filters/callback_data.py +++ b/aiogram/filters/callback_data.py @@ -1,5 +1,7 @@ from __future__ import annotations +import sys +import types import typing from decimal import Decimal from enum import Enum @@ -29,6 +31,11 @@ T = TypeVar("T", bound="CallbackData") MAX_CALLBACK_LENGTH: int = 64 +_UNION_TYPES = {typing.Union} +if sys.version_info >= (3, 10): # pragma: no cover + _UNION_TYPES.add(types.UnionType) + + class CallbackDataException(Exception): pass @@ -195,6 +202,6 @@ def _check_field_is_nullable(field: FieldInfo) -> bool: if not field.is_required(): return True - return typing.get_origin(field.annotation) is typing.Union and type(None) in typing.get_args( + return typing.get_origin(field.annotation) in _UNION_TYPES and type(None) in typing.get_args( field.annotation ) diff --git a/tests/test_filters/test_callback_data.py b/tests/test_filters/test_callback_data.py index 4314aa34..635b8e9f 100644 --- a/tests/test_filters/test_callback_data.py +++ b/tests/test_filters/test_callback_data.py @@ -1,3 +1,4 @@ +import sys from decimal import Decimal from enum import Enum, auto from fractions import Fraction @@ -163,6 +164,16 @@ class TestCallbackData: assert TgData.unpack("tg:123:") == TgData(chat_id=123, thread_id=None) + @pytest.mark.skipif(sys.version_info < (3, 10), reason="UnionType is added in Python 3.10") + def test_unpack_optional_wo_default_union_type(self): + """Test CallbackData without default optional.""" + + class TgData(CallbackData, prefix="tg"): + chat_id: int + thread_id: int | None + + assert TgData.unpack("tg:123:") == TgData(chat_id=123, thread_id=None) + def test_build_filter(self): filter_object = MyCallback.filter(F.foo == "test") assert isinstance(filter_object.rule, MagicFilter)