Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add a factory supporting pydantic dataclasses #605

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/usage/library_factories/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,13 @@ These include:

:class:`TypedDictFactory <polyfactory.factories.typed_dict_factory.TypedDictFactory>`
a base factory for typed-dicts

:class:`ModelFactory <polyfactory.factories.pydantic_factory.ModelFactory>`
a base factory for `pydantic <https://docs.pydantic.dev/>`_ models

:class:`PydanticDataclassFactory <polyfactory.factories.pydantic_factory.PydanticDataclassFactory>`
a base factory for `pydantic <https://docs.pydantic.dev/latest/concepts/dataclasses/>`_ dataclasses

:class:`BeanieDocumentFactory <polyfactory.factories.beanie_odm_factory.BeanieDocumentFactory>`
a base factory for `beanie <https://beanie-odm.dev/>`_ documents

Expand Down
76 changes: 69 additions & 7 deletions polyfactory/factories/pydantic_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import copy
from contextlib import suppress
from dataclasses import is_dataclass
from datetime import timezone
from functools import partial
from os.path import realpath
Expand Down Expand Up @@ -60,9 +61,15 @@
# is installed.
from pydantic import PyObject

# prevent unbound variable warnings
# Prevent unbound variable warnings
BaseModelV2 = BaseModelV1
UndefinedV2 = Undefined

if TYPE_CHECKING:
from pydantic.dataclasses import Dataclass as PydanticDataclassV1 # pyright: ignore[reportPrivateImportUsage]

# Prevent unbound variable warnings
PydanticDataclassV2 = PydanticDataclassV1
except ImportError:
# pydantic v2

Expand Down Expand Up @@ -91,6 +98,8 @@
from pydantic.v1.color import Color # type: ignore[assignment]
from pydantic.v1.fields import DeferredType, ModelField, Undefined

if TYPE_CHECKING:
from pydantic.dataclasses import PydanticDataclass as PydanticDataclassV2 # pyright: ignore[reportPrivateImportUsage]

if TYPE_CHECKING:
from collections import abc
Expand All @@ -99,7 +108,9 @@

from typing_extensions import NotRequired, TypeGuard

T = TypeVar("T", bound="BaseModelV1 | BaseModelV2") # pyright: ignore[reportInvalidTypeForm]

ModelT = TypeVar("ModelT", bound="BaseModelV1 | BaseModelV2") # pyright: ignore[reportInvalidTypeForm]
T = TypeVar("T")

_IS_PYDANTIC_V1 = VERSION.startswith("1")

Expand Down Expand Up @@ -370,7 +381,7 @@ def get_constraints_metadata(cls, annotation: Any) -> Sequence[Any]:
return metadata


class ModelFactory(Generic[T], BaseFactory[T]):
class ModelFactory(Generic[ModelT], BaseFactory[ModelT]):
"""Base factory for pydantic models"""

__forward_ref_resolution_type_mapping__: ClassVar[Mapping[str, type]] = {}
Expand All @@ -388,7 +399,7 @@ def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None:
cls.__model__.update_forward_refs(**cls.__forward_ref_resolution_type_mapping__) # type: ignore[attr-defined]

@classmethod
def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]:
def is_supported_type(cls, value: Any) -> TypeGuard[type[ModelT]]:
"""Determine whether the given value is supported by the factory.

:param value: An arbitrary value.
Expand Down Expand Up @@ -454,7 +465,7 @@ def build(
cls,
factory_use_construct: bool = False,
**kwargs: Any,
) -> T:
) -> ModelT:
"""Build an instance of the factory's __model__

:param factory_use_construct: A boolean that determines whether validations will be made when instantiating the
Expand Down Expand Up @@ -492,7 +503,7 @@ def _get_build_context(cls, build_context: BaseBuildContext | PydanticBuildConte
}

@classmethod
def _create_model(cls, _build_context: PydanticBuildContext, **kwargs: Any) -> T:
def _create_model(cls, _build_context: PydanticBuildContext, **kwargs: Any) -> ModelT:
"""Create an instance of the factory's __model__

:param _build_context: BuildContext instance.
Expand All @@ -508,7 +519,7 @@ def _create_model(cls, _build_context: PydanticBuildContext, **kwargs: Any) -> T
return cls.__model__(**kwargs) # type: ignore[return-value]

@classmethod
def coverage(cls, factory_use_construct: bool = False, **kwargs: Any) -> abc.Iterator[T]:
def coverage(cls, factory_use_construct: bool = False, **kwargs: Any) -> abc.Iterator[ModelT]:
"""Build a batch of the factory's Meta.model will full coverage of the sub-types of the model.

:param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used.
Expand Down Expand Up @@ -629,3 +640,54 @@ def _is_pydantic_v1_model(model: Any) -> TypeGuard[BaseModelV1]:

def _is_pydantic_v2_model(model: Any) -> TypeGuard[BaseModelV2]: # pyright: ignore[reportInvalidTypeForm]
return not _IS_PYDANTIC_V1 and is_safe_subclass(model, BaseModelV2)


def _is_pydantic_v1_dataclass(cls: type[Any]) -> TypeGuard[PydanticDataclassV1]:
return is_dataclass(cls) and "__pydantic_model__" in cls.__dict__


def _is_pydantic_v2_dataclass(cls: type[Any]) -> TypeGuard[PydanticDataclassV2]:
return is_dataclass(cls) and "__pydantic_validator__" in cls.__dict__


class PydanticDataclassFactory(ModelFactory[T]): # type: ignore[type-var]
"""Base factory for pydantic dataclasses"""

__is_base_factory__ = True

@classmethod
def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]:
return _is_pydantic_v1_dataclass(value) or _is_pydantic_v2_dataclass(value)

@classmethod
def get_model_fields(cls) -> list[FieldMeta]:
if _is_pydantic_v1_dataclass(cls.__model__):
pydantic_model = cls.__model__.__pydantic_model__
cls._fields_metadata = [
PydanticFieldMeta.from_model_field(
field,
use_alias=not pydantic_model.__config__.allow_population_by_field_name, # type: ignore[attr-defined]
random=cls.__random__,
)
for field in pydantic_model.__fields__.values()
]
elif _is_pydantic_v2_dataclass(cls.__model__):
pydantic_fields = cls.__model__.__pydantic_fields__
pydantic_config = cls.__model__.__pydantic_config__
cls._fields_metadata = [
PydanticFieldMeta.from_field_info(
field_info=field_info,
field_name=field_name,
random=cls.__random__,
use_alias=not pydantic_config.get(
"populate_by_name",
False,
),
)
for field_name, field_info in pydantic_fields.items()
]
else:
# This should be unreachable
return []

return cls._fields_metadata
50 changes: 49 additions & 1 deletion tests/test_pydantic_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,11 @@
constr,
validator,
)
from pydantic.dataclasses import dataclass as pydantic_dataclass

from polyfactory.exceptions import ConfigurationException
from polyfactory.factories import DataclassFactory
from polyfactory.factories.pydantic_factory import _IS_PYDANTIC_V1, ModelFactory
from polyfactory.factories.pydantic_factory import _IS_PYDANTIC_V1, ModelFactory, PydanticDataclassFactory
from tests.models import Person, PetFactory

IS_PYDANTIC_V1 = _IS_PYDANTIC_V1
Expand Down Expand Up @@ -1038,3 +1040,49 @@ class A(BaseModel):
AFactory = ModelFactory.create_factory(A)

assert AFactory.build()


def test_simple_pydantic_dataclass() -> None:
@pydantic_dataclass
class DataclassModel:
a: int
b: Annotated[str, MinLen(1)]

class DataclassModelFactory(PydanticDataclassFactory[DataclassModel]):
__model__ = DataclassModel

instance = DataclassModelFactory.build()
assert isinstance(instance, DataclassModel)
assert isinstance(instance.a, int)
assert isinstance(instance.b, str)
assert len(instance.b) >= 1


def test_nested_pydantic_dataclass() -> None:
@pydantic_dataclass
class FooDataclass:
content: int

@pydantic_dataclass
class NestedDataclassModel:
foo: FooDataclass

class DataclassModelFactory(PydanticDataclassFactory[NestedDataclassModel]):
__model__ = NestedDataclassModel

instance = DataclassModelFactory.build()
assert isinstance(instance, NestedDataclassModel)
assert isinstance(instance.foo, FooDataclass)
assert isinstance(instance.foo.content, int)


def test_pydantic_dataclass_factory_raises_for_std_dataclasses() -> None:
@dataclass
class DataclassModel:
a: int
b: str

with pytest.raises(ConfigurationException):

class DataclassModelFactory(PydanticDataclassFactory[DataclassModel]):
__model__ = DataclassModel
Loading