From 9fe676f2a9a3af5288864658a7ae0a7e50aeb1c3 Mon Sep 17 00:00:00 2001 From: slyces Date: Fri, 15 Nov 2024 17:36:28 +0100 Subject: [PATCH 1/3] feat: add a factory supporting pydantic dataclasses --- docs/usage/library_factories/index.rst | 4 ++ polyfactory/factories/pydantic_factory.py | 48 +++++++++++++++++++--- tests/test_pydantic_factory.py | 50 ++++++++++++++++++++++- 3 files changed, 95 insertions(+), 7 deletions(-) diff --git a/docs/usage/library_factories/index.rst b/docs/usage/library_factories/index.rst index 20911acf..8ccfebb4 100644 --- a/docs/usage/library_factories/index.rst +++ b/docs/usage/library_factories/index.rst @@ -11,9 +11,13 @@ These include: :class:`TypedDictFactory ` a base factory for typed-dicts + :class:`ModelFactory ` a base factory for `pydantic `_ models +:class:`PydanticDataclassFactory ` + a base factory for `pydantic `_ dataclasses + :class:`BeanieDocumentFactory ` a base factory for `beanie `_ documents diff --git a/polyfactory/factories/pydantic_factory.py b/polyfactory/factories/pydantic_factory.py index 0d53f341..e29348f6 100644 --- a/polyfactory/factories/pydantic_factory.py +++ b/polyfactory/factories/pydantic_factory.py @@ -55,6 +55,7 @@ ModelField, # pyright: ignore[attr-defined,reportAttributeAccessIssue] Undefined, # pyright: ignore[attr-defined,reportAttributeAccessIssue] ) + from pydantic.dataclasses import is_pydantic_dataclass # Keep this import last to prevent warnings from pydantic if pydantic v2 # is installed. @@ -68,6 +69,7 @@ # v2 specific imports from pydantic import BaseModel as BaseModelV2 + from pydantic.dataclasses import is_pydantic_dataclass from pydantic_core import PydanticUndefined as UndefinedV2 from pydantic_core import to_json @@ -99,7 +101,8 @@ from typing_extensions import NotRequired, TypeGuard -T = TypeVar("T", bound="BaseModelV1 | BaseModelV2") # pyright: ignore[reportInvalidTypeForm] +ModelT = TypeVar("ModelT", bound="BaseModelV1 | BaseModelV2") # pyright: ignore[reportInvalidTypeForm] +T = TypeVar("T") _IS_PYDANTIC_V1 = VERSION.startswith("1") @@ -370,7 +373,7 @@ def get_constraints_metadata(cls, annotation: Any) -> Sequence[Any]: return metadata -class ModelFactory(Generic[T], BaseFactory[T]): +class ModelFactory(Generic[ModelT], BaseFactory[ModelT]): """Base factory for pydantic models""" __forward_ref_resolution_type_mapping__: ClassVar[Mapping[str, type]] = {} @@ -388,7 +391,7 @@ def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None: cls.__model__.update_forward_refs(**cls.__forward_ref_resolution_type_mapping__) # type: ignore[attr-defined] @classmethod - def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: + def is_supported_type(cls, value: Any) -> TypeGuard[type[ModelT]]: """Determine whether the given value is supported by the factory. :param value: An arbitrary value. @@ -454,7 +457,7 @@ def build( cls, factory_use_construct: bool = False, **kwargs: Any, - ) -> T: + ) -> ModelT: """Build an instance of the factory's __model__ :param factory_use_construct: A boolean that determines whether validations will be made when instantiating the @@ -492,7 +495,7 @@ def _get_build_context(cls, build_context: BaseBuildContext | PydanticBuildConte } @classmethod - def _create_model(cls, _build_context: PydanticBuildContext, **kwargs: Any) -> T: + def _create_model(cls, _build_context: PydanticBuildContext, **kwargs: Any) -> ModelT: """Create an instance of the factory's __model__ :param _build_context: BuildContext instance. @@ -508,7 +511,7 @@ def _create_model(cls, _build_context: PydanticBuildContext, **kwargs: Any) -> T return cls.__model__(**kwargs) # type: ignore[return-value] @classmethod - def coverage(cls, factory_use_construct: bool = False, **kwargs: Any) -> abc.Iterator[T]: + def coverage(cls, factory_use_construct: bool = False, **kwargs: Any) -> abc.Iterator[ModelT]: """Build a batch of the factory's Meta.model will full coverage of the sub-types of the model. :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. @@ -629,3 +632,36 @@ def _is_pydantic_v1_model(model: Any) -> TypeGuard[BaseModelV1]: def _is_pydantic_v2_model(model: Any) -> TypeGuard[BaseModelV2]: # pyright: ignore[reportInvalidTypeForm] return not _IS_PYDANTIC_V1 and is_safe_subclass(model, BaseModelV2) + + +class PydanticDataclassFactory(ModelFactory[T]): # type: ignore[type-var] + """Base factory for pydantic dataclasses""" + + __is_base_factory__ = True + + @classmethod + def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: + return is_pydantic_dataclass(value) + + @classmethod + def get_model_fields(cls) -> list[FieldMeta]: + if not is_pydantic_dataclass(cls.__model__): + # This should be unreachable + return [] + + pydantic_fields = cls.__model__.__pydantic_fields__ + pydantic_config = cls.__model__.__pydantic_config__ + cls._fields_metadata = [ + PydanticFieldMeta.from_field_info( + field_info=field_info, + field_name=field_name, + random=cls.__random__, + use_alias=not pydantic_config.get( + "populate_by_name", + False, + ), + ) + for field_name, field_info in pydantic_fields.items() + ] + + return cls._fields_metadata diff --git a/tests/test_pydantic_factory.py b/tests/test_pydantic_factory.py index 2c24db78..1f859566 100644 --- a/tests/test_pydantic_factory.py +++ b/tests/test_pydantic_factory.py @@ -63,9 +63,11 @@ constr, validator, ) +from pydantic.dataclasses import dataclass as pydantic_dataclass +from polyfactory.exceptions import ConfigurationException from polyfactory.factories import DataclassFactory -from polyfactory.factories.pydantic_factory import _IS_PYDANTIC_V1, ModelFactory +from polyfactory.factories.pydantic_factory import _IS_PYDANTIC_V1, ModelFactory, PydanticDataclassFactory from tests.models import Person, PetFactory IS_PYDANTIC_V1 = _IS_PYDANTIC_V1 @@ -1038,3 +1040,49 @@ class A(BaseModel): AFactory = ModelFactory.create_factory(A) assert AFactory.build() + + +def test_simple_pydantic_dataclass() -> None: + @pydantic_dataclass + class DataclassModel: + a: int + b: Annotated[str, MinLen(1)] + + class DataclassModelFactory(PydanticDataclassFactory[DataclassModel]): + __model__ = DataclassModel + + instance = DataclassModelFactory.build() + assert isinstance(instance, DataclassModel) + assert isinstance(instance.a, int) + assert isinstance(instance.b, str) + assert len(instance.b) >= 1 + + +def test_nested_pydantic_dataclass() -> None: + @pydantic_dataclass + class FooDataclass: + content: int + + @pydantic_dataclass + class NestedDataclassModel: + foo: FooDataclass + + class DataclassModelFactory(PydanticDataclassFactory[NestedDataclassModel]): + __model__ = NestedDataclassModel + + instance = DataclassModelFactory.build() + assert isinstance(instance, NestedDataclassModel) + assert isinstance(instance.foo, FooDataclass) + assert isinstance(instance.foo.content, int) + + +def test_pydantic_dataclass_factory_raises_for_std_dataclasses() -> None: + @dataclass + class DataclassModel: + a: int + b: str + + with pytest.raises(ConfigurationException): + + class DataclassModelFactory(PydanticDataclassFactory[DataclassModel]): + __model__ = DataclassModel From 04112aeec89b94f4915911611d6df7a776538a36 Mon Sep 17 00:00:00 2001 From: slyces Date: Fri, 15 Nov 2024 23:05:08 +0100 Subject: [PATCH 2/3] fixup! feat: add a factory supporting pydantic dataclasses --- polyfactory/factories/pydantic_factory.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/polyfactory/factories/pydantic_factory.py b/polyfactory/factories/pydantic_factory.py index e29348f6..ab1ebf70 100644 --- a/polyfactory/factories/pydantic_factory.py +++ b/polyfactory/factories/pydantic_factory.py @@ -2,6 +2,7 @@ import copy from contextlib import suppress +from dataclasses import is_dataclass from datetime import timezone from functools import partial from os.path import realpath @@ -55,7 +56,6 @@ ModelField, # pyright: ignore[attr-defined,reportAttributeAccessIssue] Undefined, # pyright: ignore[attr-defined,reportAttributeAccessIssue] ) - from pydantic.dataclasses import is_pydantic_dataclass # Keep this import last to prevent warnings from pydantic if pydantic v2 # is installed. @@ -69,7 +69,6 @@ # v2 specific imports from pydantic import BaseModel as BaseModelV2 - from pydantic.dataclasses import is_pydantic_dataclass from pydantic_core import PydanticUndefined as UndefinedV2 from pydantic_core import to_json @@ -101,6 +100,8 @@ from typing_extensions import NotRequired, TypeGuard + from pydantic.dataclasses import PydanticDataclass # pyright: ignore[reportPrivateImportUsage] + ModelT = TypeVar("ModelT", bound="BaseModelV1 | BaseModelV2") # pyright: ignore[reportInvalidTypeForm] T = TypeVar("T") @@ -634,6 +635,11 @@ def _is_pydantic_v2_model(model: Any) -> TypeGuard[BaseModelV2]: # pyright: ign return not _IS_PYDANTIC_V1 and is_safe_subclass(model, BaseModelV2) +def is_pydantic_dataclass(cls: type[Any]) -> TypeGuard[PydanticDataclass]: + # This method is available in the `pydantic.dataclasses` module for python >= 3.9 + return is_dataclass(cls) and "__pydantic_validator__" in cls.__dict__ + + class PydanticDataclassFactory(ModelFactory[T]): # type: ignore[type-var] """Base factory for pydantic dataclasses""" From 58735ebfe2ed49bc8490ff72f92651bec4d0c148 Mon Sep 17 00:00:00 2001 From: slyces Date: Fri, 6 Dec 2024 12:56:38 +0100 Subject: [PATCH 3/3] fix: introduce pydantic v1/v2 code to hanble v1 dataclasses --- polyfactory/factories/pydantic_factory.py | 62 +++++++++++++++-------- 1 file changed, 41 insertions(+), 21 deletions(-) diff --git a/polyfactory/factories/pydantic_factory.py b/polyfactory/factories/pydantic_factory.py index ab1ebf70..3c4b26d2 100644 --- a/polyfactory/factories/pydantic_factory.py +++ b/polyfactory/factories/pydantic_factory.py @@ -61,9 +61,15 @@ # is installed. from pydantic import PyObject - # prevent unbound variable warnings + # Prevent unbound variable warnings BaseModelV2 = BaseModelV1 UndefinedV2 = Undefined + + if TYPE_CHECKING: + from pydantic.dataclasses import Dataclass as PydanticDataclassV1 # pyright: ignore[reportPrivateImportUsage] + + # Prevent unbound variable warnings + PydanticDataclassV2 = PydanticDataclassV1 except ImportError: # pydantic v2 @@ -92,6 +98,8 @@ from pydantic.v1.color import Color # type: ignore[assignment] from pydantic.v1.fields import DeferredType, ModelField, Undefined + if TYPE_CHECKING: + from pydantic.dataclasses import PydanticDataclass as PydanticDataclassV2 # pyright: ignore[reportPrivateImportUsage] if TYPE_CHECKING: from collections import abc @@ -100,7 +108,6 @@ from typing_extensions import NotRequired, TypeGuard - from pydantic.dataclasses import PydanticDataclass # pyright: ignore[reportPrivateImportUsage] ModelT = TypeVar("ModelT", bound="BaseModelV1 | BaseModelV2") # pyright: ignore[reportInvalidTypeForm] T = TypeVar("T") @@ -635,8 +642,11 @@ def _is_pydantic_v2_model(model: Any) -> TypeGuard[BaseModelV2]: # pyright: ign return not _IS_PYDANTIC_V1 and is_safe_subclass(model, BaseModelV2) -def is_pydantic_dataclass(cls: type[Any]) -> TypeGuard[PydanticDataclass]: - # This method is available in the `pydantic.dataclasses` module for python >= 3.9 +def _is_pydantic_v1_dataclass(cls: type[Any]) -> TypeGuard[PydanticDataclassV1]: + return is_dataclass(cls) and "__pydantic_model__" in cls.__dict__ + + +def _is_pydantic_v2_dataclass(cls: type[Any]) -> TypeGuard[PydanticDataclassV2]: return is_dataclass(cls) and "__pydantic_validator__" in cls.__dict__ @@ -647,27 +657,37 @@ class PydanticDataclassFactory(ModelFactory[T]): # type: ignore[type-var] @classmethod def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: - return is_pydantic_dataclass(value) + return _is_pydantic_v1_dataclass(value) or _is_pydantic_v2_dataclass(value) @classmethod def get_model_fields(cls) -> list[FieldMeta]: - if not is_pydantic_dataclass(cls.__model__): + if _is_pydantic_v1_dataclass(cls.__model__): + pydantic_model = cls.__model__.__pydantic_model__ + cls._fields_metadata = [ + PydanticFieldMeta.from_model_field( + field, + use_alias=not pydantic_model.__config__.allow_population_by_field_name, # type: ignore[attr-defined] + random=cls.__random__, + ) + for field in pydantic_model.__fields__.values() + ] + elif _is_pydantic_v2_dataclass(cls.__model__): + pydantic_fields = cls.__model__.__pydantic_fields__ + pydantic_config = cls.__model__.__pydantic_config__ + cls._fields_metadata = [ + PydanticFieldMeta.from_field_info( + field_info=field_info, + field_name=field_name, + random=cls.__random__, + use_alias=not pydantic_config.get( + "populate_by_name", + False, + ), + ) + for field_name, field_info in pydantic_fields.items() + ] + else: # This should be unreachable return [] - pydantic_fields = cls.__model__.__pydantic_fields__ - pydantic_config = cls.__model__.__pydantic_config__ - cls._fields_metadata = [ - PydanticFieldMeta.from_field_info( - field_info=field_info, - field_name=field_name, - random=cls.__random__, - use_alias=not pydantic_config.get( - "populate_by_name", - False, - ), - ) - for field_name, field_info in pydantic_fields.items() - ] - return cls._fields_metadata