Skip to content

Commit

Permalink
Merge pull request #87 from labthings/pydantic-2-10
Browse files Browse the repository at this point in the history
Work with pydantic 2.10
  • Loading branch information
rwb27 authored Nov 28, 2024
2 parents 7ac72fa + 307845b commit b446f8d
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 10 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "labthings-fastapi"
version = "0.0.6"
version = "0.0.7"
authors = [
{ name="Richard Bowman", email="[email protected]" },
]
Expand Down
13 changes: 9 additions & 4 deletions src/labthings_fastapi/thing_description/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

from __future__ import annotations
from collections.abc import Mapping, Sequence
from typing import Any, Optional, Union
from typing import Any, Optional
import json

from pydantic import TypeAdapter, ValidationError, BaseModel
from pydantic import TypeAdapter, ValidationError
from .model import DataSchema


Expand Down Expand Up @@ -192,7 +192,7 @@ def jsonschema_to_dataschema(
return output


def type_to_dataschema(t: Union[type, BaseModel], **kwargs) -> DataSchema:
def type_to_dataschema(t: type, **kwargs) -> DataSchema:
"""Convert a Python type to a Thing Description DataSchema
This makes use of pydantic's `schema_of` function to create a
Expand All @@ -205,9 +205,14 @@ def type_to_dataschema(t: Union[type, BaseModel], **kwargs) -> DataSchema:
is passed in. Typically you'll want to use this for the
`title` field.
"""
if isinstance(t, BaseModel):
if hasattr(t, "model_json_schema"):
# The input should be a `BaseModel` subclass, in which case this works:
json_schema = t.model_json_schema()
else:
# In principle, the below should work for any type, though some
# deferred annotations can go wrong.
# Some attempt at looking up the environment of functions might help
# here.
json_schema = TypeAdapter(t).json_schema()
schema_dict = jsonschema_to_dataschema(json_schema)
# Definitions of referenced ($ref) schemas are put in a
Expand Down
11 changes: 7 additions & 4 deletions src/labthings_fastapi/types/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,15 @@ def double(arr: NDArray) -> NDArray:
WrapSerializer,
)
from typing import Annotated, Any, List, Union
from typing_extensions import TypeAlias
from collections.abc import Mapping, Sequence


# Define a nested list of floats with 0-6 dimensions
# This would be most elegantly defined as a recursive type
# but the below gets the job done for now.
Number = Union[int, float]
NestedListOfNumbers = Union[
Number: TypeAlias = Union[int, float]
NestedListOfNumbers: TypeAlias = Union[
Number,
List[Number],
List[List[Number]],
Expand Down Expand Up @@ -68,10 +69,12 @@ def listoflists_to_np(lol: Union[NestedListOfNumbers, np.ndarray]) -> np.ndarray


# Define an annotated type so Pydantic can cope with numpy
NDArray = Annotated[
NDArray: TypeAlias = Annotated[
np.ndarray,
PlainValidator(listoflists_to_np),
PlainSerializer(np_to_listoflists, when_used="json-unless-none"),
PlainSerializer(
np_to_listoflists, when_used="json-unless-none", return_type=NestedListOfNumbers
),
WithJsonSchema(NestedListOfNumbersModel.model_json_schema(), mode="validation"),
]

Expand Down
13 changes: 12 additions & 1 deletion tests/test_numpy_type.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from __future__ import annotations

from pydantic import BaseModel
from pydantic import BaseModel, RootModel
import numpy as np

from labthings_fastapi.types.numpy import NDArray, DenumpifyingDict
from labthings_fastapi.thing import Thing
from labthings_fastapi.decorators import thing_action


class ArrayModel(RootModel):
root: NDArray


def check_field_works_with_list(data):
class Model(BaseModel):
a: NDArray
Expand Down Expand Up @@ -86,3 +90,10 @@ def test_denumpifying_dict():
assert dump["e"] is None
assert dump["f"] == 1
d.model_dump_json()


def test_rootmodel():
for input in [[0, 1, 2], np.arange(3)]:
m = ArrayModel(root=input)
assert isinstance(m.root, np.ndarray)
assert (m.model_dump() == [0, 1, 2]).all()

0 comments on commit b446f8d

Please sign in to comment.