diff --git a/docs/changelog.md b/docs/changelog.md index af19479..938d9bd 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,10 @@ ## 1.* +### 1.3.2 - 24-08-12 - Allow subclasses of dtypes + +(also when using objects for dtypes, subclasses of that object are allowed to validate) + ### 1.3.1 - 24-08-12 - Allow arbitrary dtypes, pydantic models as dtypes Previously we would only allow dtypes if we knew for sure that there was some diff --git a/pyproject.toml b/pyproject.toml index 3f14e28..e1ffa3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "numpydantic" -version = "1.3.1" +version = "1.3.2" description = "Type and shape validation and serialization for numpy arrays in pydantic models" authors = [ {name = "sneakers-the-rat", email = "sneakers-the-rat@protonmail.com"}, diff --git a/src/numpydantic/interface/interface.py b/src/numpydantic/interface/interface.py index 3030220..3dc3fdc 100644 --- a/src/numpydantic/interface/interface.py +++ b/src/numpydantic/interface/interface.py @@ -128,7 +128,13 @@ def validate_dtype(self, dtype: DtypeType) -> bool: elif self.dtype is np.str_: valid = getattr(dtype, "type", None) is np.str_ or dtype is np.str_ else: - valid = dtype == self.dtype + # try to match as any subclass, if self.dtype is a class + try: + valid = issubclass(dtype, self.dtype) + except TypeError: + # expected, if dtype or self.dtype is not a class + valid = dtype == self.dtype + return valid def raise_for_dtype(self, valid: bool, dtype: DtypeType) -> None: diff --git a/tests/conftest.py b/tests/conftest.py index 0655362..0467f25 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -66,6 +66,10 @@ class BadModel(BaseModel): x: int +class SubClass(BasicModel): + pass + + RGB_UNION: TypeAlias = Union[ NDArray[Shape["* x, * y"], Number], NDArray[Shape["* x, * y, 3 r_g_b"], Number], @@ -143,6 +147,7 @@ def shape_cases(request) -> ValidationCase: ValidationCase(annotation=MODEL, dtype=BasicModel, passes=True), ValidationCase(annotation=MODEL, dtype=BadModel, passes=False), ValidationCase(annotation=MODEL, dtype=int, passes=False), + ValidationCase(annotation=MODEL, dtype=SubClass, passes=True), ], ids=[ "float", @@ -169,6 +174,7 @@ def shape_cases(request) -> ValidationCase: "model-model", "model-badmodel", "model-int", + "model-subclass", ], ) def dtype_cases(request) -> ValidationCase: