diff --git a/README.md b/README.md index c61a0cd..531c37a 100644 --- a/README.md +++ b/README.md @@ -14,8 +14,6 @@ Supported types: - `Mapping` (with typed keys and values), `Set`, `Sequence` Example: -class Annotated: -pass ```python import dataclasses @@ -77,3 +75,38 @@ loaded = mr.load(CompanyUpdateData, {"annual_turnover": None}) assert loaded.name is mr.MISSING assert loaded.annual_turnover is None ``` + +Also generics are supported. All works automatically except one case. Dump operation of generic dataclass with `frozen=True` or `slots=True` requires explicitly specified subscripted generic type as `cls` argument of `dump` and `dump_many` methods. + +```python +import dataclasses +from typing import Generic, TypeVar +import marshmallow_recipe as mr + +T = TypeVar("T") + +@dataclasses.dataclass() +class Regular(Generic[T]): + value: T + +mr.dump(Regular[int](value=123)) # it works without explicit cls arg + +@dataclasses.dataclass(frozen=True) +class Frozen(Generic[T]): + value: T + +mr.dump(Frozen[int](value=123), cls=Frozen[int]) # cls required for frozen generic + +@dataclasses.dataclass(slots=True) +class Slots(Generic[T]): + value: T + +mr.dump(Slots[int](value=123), cls=Slots[int]) # cls required for generic with slots + +@dataclasses.dataclass(slots=True) +class SlotsNonGeneric(Slots[int]): + pass + +mr.dump(SlotsNonGeneric(value=123)) # cls not required + +``` diff --git a/marshmallow_recipe/bake.py b/marshmallow_recipe/bake.py index 4984c4e..5af9bc0 100644 --- a/marshmallow_recipe/bake.py +++ b/marshmallow_recipe/bake.py @@ -6,7 +6,7 @@ import inspect import types import uuid -from typing import Annotated, Any, Protocol, TypeVar, Union, get_args, get_origin +from typing import Annotated, Any, NamedTuple, Protocol, TypeVar, Union, cast, get_args, get_origin import marshmallow as m @@ -29,6 +29,7 @@ tuple_field, uuid_field, ) +from .generics import TypeLike, get_fields_type_map from .hooks import get_pre_loads from .metadata import EMPTY_METADATA, Metadata, is_metadata from .naming_case import NamingCase @@ -48,16 +49,23 @@ class _SchemaTypeKey: _schema_types: dict[_SchemaTypeKey, type[m.Schema]] = {} +class _FieldDescription(NamedTuple): + field: dataclasses.Field + value_type: TypeLike + metadata: Metadata + + def bake_schema( cls: type, *, naming_case: NamingCase | None = None, none_value_handling: NoneValueHandling | None = None, ) -> type[m.Schema]: - if not dataclasses.is_dataclass(cls): - raise ValueError(f"{cls} is not a dataclass") + origin: type = get_origin(cls) or cls + if not dataclasses.is_dataclass(origin): + raise ValueError(f"{origin} is not a dataclass") - if options := try_get_options_for(cls): + if options := try_get_options_for(origin): cls_none_value_handling = none_value_handling or options.none_value_handling cls_naming_case = naming_case or options.naming_case else: @@ -72,40 +80,42 @@ def bake_schema( if result := _schema_types.get(key): return result - fields_with_metadata = [ - ( + fields_type_map = get_fields_type_map(cls) + + fields = [ + _FieldDescription( field, + fields_type_map[field.name], _get_metadata( name=field.name if cls_naming_case is None else cls_naming_case(field.name), default=_get_field_default(field), metadata=field.metadata, ), ) - for field in dataclasses.fields(cls) + for field in dataclasses.fields(origin) if field.init ] - for field, _ in fields_with_metadata: - for other_field, metadata in fields_with_metadata: - if field is other_field: + for first in fields: + for second in fields: + if first is second: continue + second_name = second.metadata["name"] + if first.field.name == second_name: + raise ValueError(f"Invalid name={second_name} in metadata for field={second.field.name}") - other_field_name = metadata["name"] - if field.name == other_field_name: - raise ValueError(f"Invalid name={other_field_name} in metadata for field={other_field.name}") - - schema_type: type[m.Schema] = type( + schema_type = type( cls.__name__, (_get_base_schema(cls, cls_none_value_handling or NoneValueHandling.IGNORE),), {"__module__": f"{__package__}.auto_generated"} | { field.name: get_field_for( - field.type, # type: ignore + value_type, metadata, naming_case=naming_case, none_value_handling=none_value_handling, ) - for field, metadata in fields_with_metadata + for field, value_type, metadata in fields }, ) _schema_types[key] = schema_type @@ -113,20 +123,18 @@ def bake_schema( def get_field_for( - type: type, + t: TypeLike, metadata: Metadata, naming_case: NamingCase | None, none_value_handling: NoneValueHandling | None, ) -> m.fields.Field: - if type is Any: + if t is Any: return raw_field(**metadata) - type = _substitute_any_to_open_generic(type) - - if underlying_type_from_optional := _try_get_underlying_type_from_optional(type): + if underlying_type_from_optional := _try_get_underlying_type_from_optional(t): required = False allow_none = True - type = underlying_type_from_optional + t = underlying_type_from_optional elif metadata.get("default", dataclasses.MISSING) is not dataclasses.MISSING: required = False allow_none = False @@ -134,19 +142,19 @@ def get_field_for( required = True allow_none = False - if inspect.isclass(type) and issubclass(type, enum.Enum): - return enum_field(enum_type=type, required=required, allow_none=allow_none, **metadata) + if inspect.isclass(t) and issubclass(t, enum.Enum): + return enum_field(enum_type=t, required=required, allow_none=allow_none, **metadata) - if dataclasses.is_dataclass(type): + if dataclasses.is_dataclass(get_origin(t) or t): return nested_field( - bake_schema(type, naming_case=naming_case, none_value_handling=none_value_handling), + bake_schema(cast(type, t), naming_case=naming_case, none_value_handling=none_value_handling), required=required, allow_none=allow_none, **metadata, ) - if (origin := get_origin(type)) is not None: - arguments = get_args(type) + if (origin := get_origin(t)) is not None: + arguments = get_args(t) if origin is list or origin is collections.abc.Sequence: collection_field_metadata = dict(metadata) @@ -268,11 +276,11 @@ def get_field_for( none_value_handling=none_value_handling, ) - field_factory = _SIMPLE_TYPE_FIELD_FACTORIES.get(type) - if field_factory: + if t in _SIMPLE_TYPE_FIELD_FACTORIES: + field_factory = _SIMPLE_TYPE_FIELD_FACTORIES[t] return field_factory(required=required, allow_none=allow_none, **metadata) - raise ValueError(f"Unsupported {type=}") + raise ValueError(f"Unsupported {t=}") if _MARSHMALLOW_VERSION_MAJOR >= 3: @@ -373,26 +381,12 @@ def _get_metadata(*, name: str, default: Any, metadata: collections.abc.Mapping[ return Metadata(values) -def _substitute_any_to_open_generic(type: type) -> type: - if type is list: - return list[Any] - if type is set: - return set[Any] - if type is frozenset: - return frozenset[Any] - if type is dict: - return dict[Any, Any] - if type is tuple: - return tuple[Any, ...] - return type - - -def _try_get_underlying_type_from_optional(type: type) -> type | None: +def _try_get_underlying_type_from_optional(t: TypeLike) -> TypeLike | None: # to support Union[int, None] and int | None - if get_origin(type) is Union or isinstance(type, types.UnionType): # type: ignore - type_args = get_args(type) + if get_origin(t) is Union or isinstance(t, types.UnionType): # type: ignore + type_args = get_args(t) if types.NoneType not in type_args or len(type_args) != 2: - raise ValueError(f"Unsupported {type=}") + raise ValueError(f"Unsupported {t=}") return next(type_arg for type_arg in type_args if type_arg is not types.NoneType) # noqa return None diff --git a/marshmallow_recipe/generics.py b/marshmallow_recipe/generics.py new file mode 100644 index 0000000..8d4ae06 --- /dev/null +++ b/marshmallow_recipe/generics.py @@ -0,0 +1,150 @@ +import dataclasses +import types +import typing +from typing import TYPE_CHECKING, Annotated, Any, Generic, TypeAlias, TypeVar, Union, get_args, get_origin + +_GenericAlias: TypeAlias = typing._GenericAlias # type: ignore + +if TYPE_CHECKING: + from _typeshed import DataclassInstance +else: + DataclassInstance: TypeAlias = type + +TypeLike: TypeAlias = type | TypeVar | types.UnionType | types.GenericAlias | _GenericAlias +FieldsTypeMap: TypeAlias = dict[str, TypeLike] +TypeVarMap: TypeAlias = dict[TypeVar, TypeLike] +FieldsClassMap: TypeAlias = dict[str, TypeLike] +ClassTypeVarMap: TypeAlias = dict[TypeLike, TypeVarMap] +FieldsTypeVarMap: TypeAlias = dict[str, TypeVarMap] + + +def extract_type(data: Any, cls: type | None) -> type: + data_type = _get_orig_class(data) or type(data) + + if not _is_unsubscripted_type(data_type): + if cls and data_type != cls: + raise ValueError(f"{cls=} is invalid but can be removed, actual type is {data_type}") + return data_type + + if not cls: + raise ValueError(f"Explicit cls required for unsubscripted type {data_type}") + + if _is_unsubscripted_type(cls) or get_origin(cls) != data_type: + raise ValueError(f"{cls=} is not subscripted version of {data_type}") + + return cls + + +def get_fields_type_map(cls: type) -> FieldsTypeMap: + origin: type = get_origin(cls) or cls + if not dataclasses.is_dataclass(origin): + raise ValueError(f"{origin} is not a dataclass") + + class_type_var_map = get_class_type_var_map(cls) + fields_class_map = get_fields_class_map(origin) + return { + f.name: build_subscripted_type(f.type, class_type_var_map.get(fields_class_map[f.name], {})) + for f in dataclasses.fields(origin) + } + + +def get_fields_class_map(cls: type[DataclassInstance]) -> FieldsClassMap: + names: dict[str, dataclasses.Field] = {} + result: FieldsClassMap = {} + + mro = cls.__mro__ + for cls in (*mro[-1:0:-1], cls): + if not dataclasses.is_dataclass(cls): + continue + for field in dataclasses.fields(cls): + if names.get(field.name) != field: + names[field.name] = field + result[field.name] = cls + + return result + + +def build_subscripted_type(t: TypeLike, type_var_map: TypeVarMap) -> TypeLike: + if isinstance(t, TypeVar): + return build_subscripted_type(type_var_map[t], type_var_map) + + origin = get_origin(t) + if origin is Union or origin is types.UnionType: + return Union[*(build_subscripted_type(x, type_var_map) for x in get_args(t))] + + if origin is Annotated: + t, *annotations = get_args(t) + return Annotated[build_subscripted_type(t, type_var_map), *annotations] + + if origin and isinstance(t, types.GenericAlias): + return types.GenericAlias(origin, tuple(build_subscripted_type(x, type_var_map) for x in get_args(t))) + + if origin and isinstance(t, _GenericAlias): + return _GenericAlias(origin, tuple(build_subscripted_type(x, type_var_map) for x in get_args(t))) + + return _subscript_with_any(t) + + +def get_class_type_var_map(t: TypeLike) -> ClassTypeVarMap: + class_type_var_map: ClassTypeVarMap = {} + _build_class_type_var_map(t, class_type_var_map) + return class_type_var_map + + +def _build_class_type_var_map(t: TypeLike, class_type_var_map: ClassTypeVarMap) -> None: + if _get_params(t): + raise ValueError(f"Expected subscripted generic, but got unsubscripted {t}") + + type_var_map: TypeVarMap = {} + origin = get_origin(t) or t + params = _get_params(origin) + args = get_args(t) + if params or args: + if not params or not args or len(params) != len(args): + raise ValueError(f"Unexpected generic {t}") + for i, parameter in enumerate(params): + assert isinstance(parameter, TypeVar) + type_var_map[parameter] = args[i] + if origin not in class_type_var_map: + class_type_var_map[origin] = type_var_map + elif class_type_var_map[origin] != type_var_map: + raise ValueError( + f"Incompatible Base class {origin} with generic args {class_type_var_map[origin]} and {type_var_map}" + ) + + if orig_bases := _get_orig_bases(origin): + for orig_base in orig_bases: + if get_origin(orig_base) is Generic: + continue + subscripted_base = build_subscripted_type(orig_base, type_var_map) + _build_class_type_var_map(subscripted_base, class_type_var_map) + + +def _is_unsubscripted_type(t: TypeLike) -> bool: + return bool(_get_params(t)) or any(_is_unsubscripted_type(arg) for arg in get_args(t) or []) + + +def _get_orig_class(t: Any) -> type | None: + return getattr(t, "__orig_class__", None) + + +def _get_params(t: Any) -> tuple[TypeLike, ...] | None: + return getattr(t, "__parameters__", None) + + +def _get_orig_bases(t: Any) -> tuple[TypeLike, ...] | None: + return getattr(t, "__orig_bases__", None) + + +def _subscript_with_any(t: TypeLike) -> TypeLike: + if t is list: + return list[Any] + if t is set: + return set[Any] + if t is frozenset: + return frozenset[Any] + if t is dict: + return dict[Any, Any] + if t is tuple: + return tuple[Any, ...] + return t diff --git a/marshmallow_recipe/serialization.py b/marshmallow_recipe/serialization.py index 46003f7..8d971ca 100644 --- a/marshmallow_recipe/serialization.py +++ b/marshmallow_recipe/serialization.py @@ -5,6 +5,7 @@ import marshmallow as m from .bake import bake_schema +from .generics import extract_type from .naming_case import NamingCase _T = TypeVar("_T") @@ -36,11 +37,15 @@ def __call__( class DumpFunction(Protocol): - def __call__(self, data: Any, *, naming_case: NamingCase | None = None) -> dict[str, Any]: ... + def __call__( + self, data: Any, *, naming_case: NamingCase | None = None, cls: type | None = None + ) -> dict[str, Any]: ... class DumpManyFunction(Protocol): - def __call__(self, data: list[Any], *, naming_case: NamingCase | None = None) -> list[dict[str, Any]]: ... + def __call__( + self, data: list[Any], *, naming_case: NamingCase | None = None, cls: type | None = None + ) -> list[dict[str, Any]]: ... schema: SchemaFunction @@ -73,12 +78,8 @@ def load_many_v3(cls: type[_T], data: list[dict[str, Any]], *, naming_case: Nami load_many = load_many_v3 - def dump_v3( - data: Any, - *, - naming_case: NamingCase | None = None, - ) -> dict[str, Any]: - data_schema = schema_v3(type(data), naming_case=naming_case) + def dump_v3(data: Any, *, naming_case: NamingCase | None = None, cls: type | None = None) -> dict[str, Any]: + data_schema = schema_v3(extract_type(data, cls), naming_case=naming_case) dumped: dict[str, Any] = data_schema.dump(data) # type: ignore if errors := data_schema.validate(dumped): raise m.ValidationError(errors) @@ -86,10 +87,12 @@ def dump_v3( dump = dump_v3 - def dump_many_v3(data: list[Any], *, naming_case: NamingCase | None = None) -> list[dict[str, Any]]: + def dump_many_v3( + data: list[Any], *, naming_case: NamingCase | None = None, cls: type | None = None + ) -> list[dict[str, Any]]: if not data: return [] - data_schema = schema_v3(type(data[0]), many=True, naming_case=naming_case) + data_schema = schema_v3(extract_type(data[0], cls), many=True, naming_case=naming_case) dumped: list[dict[str, Any]] = data_schema.dump(data) # type: ignore if errors := data_schema.validate(dumped): raise m.ValidationError(errors) @@ -122,12 +125,8 @@ def load_many_v2(cls: type[_T], data: list[dict[str, Any]], *, naming_case: Nami load_many = load_many_v2 - def dump_v2( - data: Any, - *, - naming_case: NamingCase | None = None, - ) -> dict[str, Any]: - data_schema = schema_v2(type(data), naming_case=naming_case) + def dump_v2(data: Any, *, naming_case: NamingCase | None = None, cls: type | None = None) -> dict[str, Any]: + data_schema = schema_v2(extract_type(data, cls), naming_case=naming_case) dumped, errors = data_schema.dump(data) if errors: raise m.ValidationError(errors) @@ -137,10 +136,12 @@ def dump_v2( dump = dump_v2 - def dump_many_v2(data: list[Any], *, naming_case: NamingCase | None = None) -> list[dict[str, Any]]: + def dump_many_v2( + data: list[Any], *, naming_case: NamingCase | None = None, cls: type | None = None + ) -> list[dict[str, Any]]: if not data: return [] - data_schema = schema_v2(type(data[0]), many=True, naming_case=naming_case) + data_schema = schema_v2(extract_type(data[0], cls), many=True, naming_case=naming_case) dumped, errors = data_schema.dump(data) if errors: raise m.ValidationError(errors) diff --git a/tests/test_generics.py b/tests/test_generics.py new file mode 100644 index 0000000..56af7ef --- /dev/null +++ b/tests/test_generics.py @@ -0,0 +1,395 @@ +import dataclasses +import types +from contextlib import nullcontext as does_not_raise +from typing import Annotated, Any, ContextManager, Generic, Iterable, List, TypeVar, Union +from unittest.mock import ANY + +import pytest + +from marshmallow_recipe.generics import ( + build_subscripted_type, + extract_type, + get_class_type_var_map, + get_fields_class_map, + get_fields_type_map, +) + +T = TypeVar("T") + + +@dataclasses.dataclass() +class OtherType: + pass + + +@dataclasses.dataclass() +class NonGeneric: + pass + + +@dataclasses.dataclass() +class RegularGeneric(Generic[T]): + pass + + +@dataclasses.dataclass(frozen=True) +class FrozenGeneric(Generic[T]): + pass + + +def e(match: str) -> ContextManager: + return pytest.raises(ValueError, match=match) + + +@pytest.mark.parametrize( + "data, cls, expected, context", + [ + (1, None, int, does_not_raise()), + (1, int, int, does_not_raise()), + (1, OtherType, ANY, e("OtherType'> is invalid but can be removed, actual type is ")), + (NonGeneric(), None, NonGeneric, does_not_raise()), + (NonGeneric(), NonGeneric, NonGeneric, does_not_raise()), + (NonGeneric(), OtherType, ANY, e("OtherType'> is invalid but can be removed, actual type is is not subscripted version of is invalid but can be removed, actual type")), + (RegularGeneric[RegularGeneric[int]](), RegularGeneric[RegularGeneric[str]], ANY, e("str]] is invalid but")), + (RegularGeneric[int](), OtherType, ANY, e("OtherType'> is invalid but can be removed, actual type is tests")), + (FrozenGeneric[int](), None, ANY, e("Explicit cls required for unsubscripted type is not subscripted version of is not subscripted version of None: + with context: + actual = extract_type(data, cls) + assert actual == expected + + +def test_get_fields_type_map_with_field_override() -> None: + @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) + class Value1: + v1: str + + @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) + class Value2(Value1): + v2: str + + _TValue = TypeVar("_TValue", bound=Value1) + _TItem = TypeVar("_TItem") + + @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) + class T1(Generic[_TItem]): + value: Value1 + iterable: Iterable[_TItem] + + @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) + class T2(Generic[_TValue, _TItem], T1[_TItem]): + value: _TValue + iterable: set[_TItem] + + actual = get_fields_type_map(T2[Value2, int]) + assert actual == { + "value": Value2, + "iterable": set[int], + } + + +def test_get_fields_type_map_non_generic() -> None: + @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) + class Value1: + v1: int + + @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) + class Value2(Value1): + v2: bool + + actual = get_fields_type_map(Value2) + assert actual == { + "v1": int, + "v2": bool, + } + + +def test_get_fields_type_map_generic_inheritance() -> None: + _T = TypeVar("_T") + + @dataclasses.dataclass() + class NonGeneric: + v: bool + + @dataclasses.dataclass() + class Value1(Generic[_T]): + v1: _T + + @dataclasses.dataclass() + class Value2(Value1[int], NonGeneric): + v2: float + + actual = get_fields_type_map(Value2) + assert actual == { + "v": bool, + "v1": int, + "v2": float, + } + + +def test_get_fields_type_map_non_dataclass() -> None: + with pytest.raises(ValueError) as e: + get_fields_type_map(list[int]) + assert e.value.args[0] == " is not a dataclass" + + +def test_get_fields_type_map_not_subscripted() -> None: + _T = TypeVar("_T") + + @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) + class Xxx(Generic[_T]): + xxx: _T + + with pytest.raises(Exception) as e: + get_fields_type_map(Xxx) + + assert e.value.args[0] == ( + "Expected subscripted generic, but got unsubscripted " + ".Xxx'>" + ) + + +def test_get_fields_type_map_for_subscripted() -> None: + _T = TypeVar("_T") + + @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) + class Xxx(Generic[_T]): + xxx: _T + + actual = get_fields_type_map(Xxx[str]) + assert actual == {"xxx": str} + + +def test_get_fields_class_map() -> None: + _T = TypeVar("_T") + + @dataclasses.dataclass() + class Base1(Generic[_T]): + a: str + b: str + c: str + + @dataclasses.dataclass() + class Base2(Base1[int]): + a: str + b: str + d: str + e: str + + @dataclasses.dataclass() + class BaseG: + f: str + g: str + + @dataclasses.dataclass() + class Base3(Base2, BaseG): + a: str + d: str + f: str + h: str + + actual = get_fields_class_map(Base3) + assert actual == { + "a": Base3, + "b": Base2, + "c": Base1, + "d": Base3, + "e": Base2, + "f": Base3, + "g": BaseG, + "h": Base3, + } + + +def test_get_class_type_var_map_with_inheritance() -> None: + _T1 = TypeVar("_T1") + _T2 = TypeVar("_T2") + _T3 = TypeVar("_T3") + + @dataclasses.dataclass() + class NonGeneric: + pass + + @dataclasses.dataclass() + class Aaa(Generic[_T1, _T2]): + pass + + @dataclasses.dataclass() + class Bbb(Generic[_T1], Aaa[int, _T1]): + pass + + @dataclasses.dataclass() + class Ccc(Generic[_T3]): + pass + + @dataclasses.dataclass() + class Ddd(Generic[_T1, _T2, _T3], Bbb[_T2], Ccc[_T1], NonGeneric): + pass + + actual = get_class_type_var_map(Ddd[bool, str, float]) + assert actual == { + Aaa: { + _T1: int, + _T2: str, + }, + Bbb: { + _T1: str, + }, + Ccc: { + _T3: bool, + }, + Ddd: { + _T1: bool, + _T2: str, + _T3: float, + }, + } + + +def test_get_class_type_var_map_with_incompatible_inheritance() -> None: + _T = TypeVar("_T") + + @dataclasses.dataclass() + class Aaa(Generic[_T]): + pass + + @dataclasses.dataclass() + class Bbb(Aaa[int]): + pass + + @dataclasses.dataclass() + class Ccc(Bbb, Aaa[str]): # type: ignore + pass + + with pytest.raises(ValueError, match="Incompatible Base class") as e: + get_class_type_var_map(Ccc) + assert ".Aaa'> with generic args {~_T: } and {~_T: }" in e.value.args[0] + + +def test_get_class_type_var_map_with_duplicated_generic_inheritance() -> None: + _T = TypeVar("_T") + + @dataclasses.dataclass() + class NonGeneric: + pass + + @dataclasses.dataclass() + class Aaa(Generic[_T]): + pass + + @dataclasses.dataclass() + class Bbb(Aaa[int], NonGeneric): + pass + + @dataclasses.dataclass() + class Ccc(Bbb, Aaa[int], NonGeneric): + pass + + actual = get_class_type_var_map(Ccc) + assert actual == { + Aaa: { + _T: int, + }, + } + + +def test_get_class_type_var_map_with_nesting() -> None: + _T1 = TypeVar("_T1") + _T2 = TypeVar("_T2") + _T3 = TypeVar("_T3") + + @dataclasses.dataclass() + class Aaa(Generic[_T1, _T2]): + pass + + @dataclasses.dataclass() + class Bbb(Generic[_T1]): + pass + + @dataclasses.dataclass() + class Ccc(Generic[_T1, _T2, _T3], Aaa[Bbb[list[Annotated[Bbb[_T2], "xxx"]]], _T1 | None]): + pass + + actual = get_class_type_var_map(Ccc[bool, str, float]) + assert actual == { + Aaa: { + _T1: Bbb[list[Annotated[Bbb[str], "xxx"]]], + _T2: bool | None, + }, + Ccc: { + _T1: bool, + _T2: str, + _T3: float, + }, + } + + +_T1 = TypeVar("_T1") +_T2 = TypeVar("_T2") + + +class Xxx(Generic[_T1, _T2]): + pass + + +class Zzz(Generic[_T1]): + pass + + +_TInt = TypeVar("_TInt") +_TIntNone = TypeVar("_TIntNone") +_TStr = TypeVar("_TStr") +_TList = TypeVar("_TList") +_TDictIntTStr = TypeVar("_TDictIntTStr") + +GENERIC_MAP: dict[TypeVar, type[Any] | types.UnionType] = { + _TInt: int, + _TIntNone: int | None, + _TStr: str, + _TList: list, + _TDictIntTStr: dict[int, _TStr], # type: ignore +} + + +@pytest.mark.parametrize( + "t, expected", + [ + (_TIntNone, int | None), + (list[_TInt], list[int]), # type: ignore + (list[_TIntNone], list[int | None]), # type: ignore + (List[_TStr], List[str]), # type: ignore + (dict[_TStr, _TInt], dict[str, int]), # type: ignore + (dict[_TStr, list[_TInt]], dict[str, list[int]]), # type: ignore + (_TInt | None, int | None), + (bool | None, bool | None), + (Union[_TInt, float, bool, _TStr], int | float | bool | str), # type: ignore + (_TInt | float | bool | _TStr, int | float | bool | str), + (list[_TInt | None], list[int | None]), # type: ignore + (list[_TList], list[list[Any]]), # type: ignore + (dict[_TInt, _TDictIntTStr], dict[int, dict[int, str]]), # type: ignore + (Annotated[_TStr, "qwe", 123, None], Annotated[str, "qwe", 123, None]), # type: ignore + (Annotated[list[_TStr], "qwe", 123, None], Annotated[list[str], "qwe", 123, None]), # type: ignore + (list[Annotated[list[_TInt], "asd", "zxc"]], list[Annotated[list[int], "asd", "zxc"]]), # type: ignore + (Xxx[list[_TInt], Zzz[_TStr]], Xxx[list[int], Zzz[str]]), # type: ignore + ], +) +def test_build_subscripted_type(t: type, expected: type) -> None: + actual = build_subscripted_type(t, GENERIC_MAP) + assert actual == expected diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 3029cdf..b262249 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -3,7 +3,22 @@ import decimal import enum import uuid -from typing import Annotated, Any, Dict, FrozenSet, List, Set, Tuple +from contextlib import nullcontext as does_not_raise +from typing import ( + Annotated, + Any, + Callable, + ContextManager, + Dict, + FrozenSet, + Generic, + Iterable, + List, + Set, + Tuple, + TypeVar, + get_origin, +) import pytest @@ -627,3 +642,159 @@ class RootContainer: int_container: IntContainer = dataclasses.field(default_factory=IntContainer) assert mr.load(RootContainer, {}) == RootContainer() + + +@pytest.mark.parametrize( + "frozen, slots, get_type, context", + [ + (False, False, lambda x: None, does_not_raise()), + (False, False, lambda x: int, pytest.raises(ValueError, match=" is invalid but can be removed")), + (True, False, lambda x: None, pytest.raises(ValueError, match="Explicit cls required for unsubscripted type")), + (False, True, lambda x: None, pytest.raises(ValueError, match="Explicit cls required for unsubscripted type")), + (True, True, lambda x: None, pytest.raises(ValueError, match="Explicit cls required for unsubscripted type")), + (True, True, lambda x: get_origin(x), pytest.raises(ValueError, match=".Data'> is not subscripted version of")), + (True, True, lambda x: list[int], pytest.raises(ValueError, match="int] is not subscripted version of")), + (True, True, lambda x: int, pytest.raises(ValueError, match=" is not subscripted version of")), + (True, True, lambda x: x, does_not_raise()), + ], +) +def test_generic_extract_type_on_dump( + frozen: bool, slots: bool, get_type: Callable[[type], type | None], context: ContextManager +) -> None: + _TValue = TypeVar("_TValue") + + @dataclasses.dataclass(frozen=frozen, slots=slots) + class Data(Generic[_TValue]): + value: _TValue + + instance = Data[int](value=123) + with context: + dumped = mr.dump(instance, cls=get_type(Data[int])) + assert dumped == {"value": 123} + + instance_many = [Data[int](value=123), Data[int](value=456)] + with context: + dumped = mr.dump_many(instance_many, cls=get_type(Data[int])) + assert dumped == [{"value": 123}, {"value": 456}] + + +@pytest.mark.parametrize( + "frozen, slots, get_type, context", + [ + (False, False, lambda x: None, does_not_raise()), + (False, True, lambda x: x, does_not_raise()), + (True, False, lambda x: x, does_not_raise()), + (True, True, lambda x: x, does_not_raise()), + (False, False, lambda x: int, pytest.raises(ValueError, match=" is invalid but can be removed")), + (True, True, lambda x: int, pytest.raises(ValueError, match=" is invalid but can be removed")), + ], +) +def test_non_generic_extract_type_on_dump( + frozen: bool, slots: bool, get_type: Callable[[type], type | None], context: ContextManager +) -> None: + @dataclasses.dataclass(frozen=frozen, slots=slots) + class Data: + value: int + + instance = Data(value=123) + with context: + dumped = mr.dump(instance, cls=get_type(Data)) + assert dumped == {"value": 123} + + instance_many = [Data(value=123), Data(value=456)] + with context: + dumped = mr.dump_many(instance_many, cls=get_type(Data)) + assert dumped == [{"value": 123}, {"value": 456}] + + +def test_generic_in_parents() -> None: + _TXxx = TypeVar("_TXxx") + _TData = TypeVar("_TData") + + @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) + class Data(Generic[_TXxx]): + xxx: _TXxx + + @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) + class ParentClass(Generic[_TData]): + value: str + data: _TData + + @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) + class ChildClass(ParentClass[Data[int]]): + pass + + instance = ChildClass(value="vvv", data=Data(xxx=111)) + dumped = mr.dump(instance) + + assert dumped == {"value": "vvv", "data": {"xxx": 111}} + assert mr.load(ChildClass, dumped) == instance + + +def test_generic_type_var_with_reuse() -> None: + _T = TypeVar("_T") + + @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) + class T1(Generic[_T]): + t1: _T + + @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) + class T2(Generic[_T], T1[int]): + t2: _T + + instance = T2[str](t1=1, t2="2") + + dumped = mr.dump(instance, cls=T2[str]) + + assert dumped == {"t1": 1, "t2": "2"} + assert mr.load(T2[str], dumped) == instance + + +def test_generic_with_field_override() -> None: + @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) + class Value1: + v1: str + + @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) + class Value2(Value1): + v2: str + + _TValue = TypeVar("_TValue", bound=Value1) + _TItem = TypeVar("_TItem") + + @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) + class T1(Generic[_TItem]): + value: Value1 + iterable: Iterable[_TItem] + + @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) + class T2(Generic[_TValue, _TItem], T1[_TItem]): + value: _TValue + iterable: set[_TItem] + + instance = T2[Value2, int](value=Value2(v1="aaa", v2="bbb"), iterable=set([3, 4, 5])) + + dumped = mr.dump(instance, cls=T2[Value2, int]) + + assert dumped == {"value": {"v1": "aaa", "v2": "bbb"}, "iterable": [3, 4, 5]} + assert mr.load(T2[Value2, int], dumped) == instance + + +def test_generic_reuse_with_different_args() -> None: + _TItem = TypeVar("_TItem") + + @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) + class GenericContainer(Generic[_TItem]): + items: list[_TItem] + + container_int = GenericContainer[int](items=[1, 2, 3]) + dumped = mr.dump(container_int, cls=GenericContainer[int]) + + assert dumped == {"items": [1, 2, 3]} + assert mr.load(GenericContainer[int], dumped) == container_int + + container_str = GenericContainer[str](items=["q", "w", "e"]) + dumped = mr.dump(container_str, cls=GenericContainer[str]) + + assert dumped == {"items": ["q", "w", "e"]} + assert mr.load(GenericContainer[str], dumped) == container_str