Skip to content

Commit

Permalink
fix(signature): Fix #3593: Ensure signature model internal function s…
Browse files Browse the repository at this point in the history
…ignatures don't clash with model signature (#3605)

* Ensure signature model internal function signatures don't clash with model signature
  • Loading branch information
provinzkraut authored Jun 30, 2024
1 parent de8f4a7 commit 3a6a293
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 10 deletions.
2 changes: 1 addition & 1 deletion litestar/_kwargs/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async def resolve_dependency(
"""
signature_model = dependency.provide.signature_model
dependency_kwargs = (
signature_model.parse_values_from_connection_kwargs(connection=connection, **kwargs)
signature_model.parse_values_from_connection_kwargs(connection=connection, kwargs=kwargs)
if signature_model._fields
else {}
)
Expand Down
10 changes: 6 additions & 4 deletions litestar/_signature/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,9 @@ def _build_error_message(cls, keys: Sequence[str], exc_msg: str, connection: ASG
return message

@classmethod
def _collect_errors(cls, deserializer: Callable[[Any, Any], Any], **kwargs: Any) -> list[tuple[str, Exception]]:
def _collect_errors(
cls, deserializer: Callable[[Any, Any], Any], kwargs: dict[str, Any]
) -> list[tuple[str, Exception]]:
exceptions: list[tuple[str, Exception]] = []
for field_name in cls._fields:
try:
Expand All @@ -181,12 +183,12 @@ def _collect_errors(cls, deserializer: Callable[[Any, Any], Any], **kwargs: Any)
return exceptions

@classmethod
def parse_values_from_connection_kwargs(cls, connection: ASGIConnection, **kwargs: Any) -> dict[str, Any]:
def parse_values_from_connection_kwargs(cls, connection: ASGIConnection, kwargs: dict[str, Any]) -> dict[str, Any]:
"""Extract values from the connection instance and return a dict of parsed values.
Args:
connection: The ASGI connection instance.
**kwargs: A dictionary of kwargs.
kwargs: A dictionary of kwargs.
Raises:
ValidationException: If validation failed.
Expand All @@ -206,7 +208,7 @@ def parse_values_from_connection_kwargs(cls, connection: ASGIConnection, **kwarg
messages.append(message)
raise cls._create_exception(messages=messages, connection=connection) from e
except ValidationError as e:
for field_name, exc in cls._collect_errors(deserializer=deserializer, **kwargs): # type: ignore[assignment]
for field_name, exc in cls._collect_errors(deserializer=deserializer, kwargs=kwargs): # type: ignore[assignment]
match = ERR_RE.search(str(exc))
keys = [field_name, str(match.group(1))] if match else [field_name]
message = cls._build_error_message(keys=keys, exc_msg=str(exc), connection=connection)
Expand Down
2 changes: 1 addition & 1 deletion litestar/routes/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ async def _get_response_data(
cleanup_group = await parameter_model.resolve_dependencies(request, kwargs)

parsed_kwargs = route_handler.signature_model.parse_values_from_connection_kwargs(
connection=request, **kwargs
connection=request, kwargs=kwargs
)

if cleanup_group:
Expand Down
2 changes: 1 addition & 1 deletion litestar/routes/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ async def handle(self, scope: WebSocketScope, receive: Receive, send: Send) -> N
cleanup_group = await self.handler_parameter_model.resolve_dependencies(websocket, parsed_kwargs)

parsed_kwargs = self.route_handler.signature_model.parse_values_from_connection_kwargs(
connection=websocket, **parsed_kwargs
connection=websocket, kwargs=parsed_kwargs
)

if cleanup_group:
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,11 @@ exclude-classes = """
"""

[tool.ruff]
include = [
"{litestar,tests,docs,test_apps,tools}/**/*.{py,pyi}",
"pyproject.toml"
]

lint.select = [
"A", # flake8-builtins
"B", # flake8-bugbear
Expand Down
30 changes: 27 additions & 3 deletions tests/unit/test_signature/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def fn(a: int) -> None:
type_decoders=[],
)
with pytest.raises(ValidationException):
model.parse_values_from_connection_kwargs(connection=RequestFactory().get(), a="not an int")
model.parse_values_from_connection_kwargs(connection=RequestFactory().get(), kwargs={"a": "not an int"})


def test_create_signature_validation() -> None:
Expand Down Expand Up @@ -128,7 +128,7 @@ def handler(data: Parent) -> None:

with pytest.raises(ValidationException) as exc_info:
model.parse_values_from_connection_kwargs(
connection=RequestFactory().get(route_handler=handler), data={"child": {}, "other_child": {}}
connection=RequestFactory().get(route_handler=handler), kwargs={"data": {"child": {}, "other_child": {}}}
)

assert isinstance(exc_info.value.extra, list)
Expand Down Expand Up @@ -283,7 +283,7 @@ def fn(a: Annotated[int, Parameter(gt=5)], b: Annotated[int, Parameter(lt=5)]) -
type_decoders=[],
)
with pytest.raises(ValidationException) as exc:
model.parse_values_from_connection_kwargs(connection=RequestFactory().get(), a=0, b=9)
model.parse_values_from_connection_kwargs(connection=RequestFactory().get(), kwargs={"a": 0, "b": 9})

assert exc.value.extra == [
{"message": "Expected `int` >= 6", "key": "a", "source": ParamType.QUERY},
Expand All @@ -303,3 +303,27 @@ async def something(foo: Foo[str] = Foo()) -> None:

with create_test_client([something]) as client:
assert client.get("/").status_code == 200


def test_separate_model_namespace() -> None:
# https://github.com/litestar-org/litestar/issues/3593

async def provide_connection() -> str:
return "connection"

@get("/connection", dependencies={"connection": provide_connection})
async def get_connection(connection: str) -> str:
return connection

async def provide_deserializer() -> str:
return "deserializer"

@get("/deserializer", dependencies={"deserializer": provide_deserializer})
async def get_deserializer(deserializer: int) -> str:
return deserializer # type: ignore[return-value]

with create_test_client([get_connection, get_deserializer], raise_server_exceptions=True, debug=True) as client:
assert client.get("/connection").text == "connection"
res = client.get("/deserializer")
assert res.status_code == 500
assert "Expected `int`, got `str` - at `$.deserializer`" in res.text

0 comments on commit 3a6a293

Please sign in to comment.