diff --git a/litestar/_kwargs/dependencies.py b/litestar/_kwargs/dependencies.py index 88ffb07b1e..bd3eb1b33c 100644 --- a/litestar/_kwargs/dependencies.py +++ b/litestar/_kwargs/dependencies.py @@ -58,7 +58,7 @@ async def resolve_dependency( """ signature_model = dependency.provide.signature_model dependency_kwargs = ( - signature_model.parse_values_from_connection_kwargs(connection=connection, **kwargs) + signature_model.parse_values_from_connection_kwargs(connection=connection, kwargs=kwargs) if signature_model._fields else {} ) diff --git a/litestar/_signature/model.py b/litestar/_signature/model.py index 33653ed548..3faec5de59 100644 --- a/litestar/_signature/model.py +++ b/litestar/_signature/model.py @@ -168,7 +168,9 @@ def _build_error_message(cls, keys: Sequence[str], exc_msg: str, connection: ASG return message @classmethod - def _collect_errors(cls, deserializer: Callable[[Any, Any], Any], **kwargs: Any) -> list[tuple[str, Exception]]: + def _collect_errors( + cls, deserializer: Callable[[Any, Any], Any], kwargs: dict[str, Any] + ) -> list[tuple[str, Exception]]: exceptions: list[tuple[str, Exception]] = [] for field_name in cls._fields: try: @@ -181,12 +183,12 @@ def _collect_errors(cls, deserializer: Callable[[Any, Any], Any], **kwargs: Any) return exceptions @classmethod - def parse_values_from_connection_kwargs(cls, connection: ASGIConnection, **kwargs: Any) -> dict[str, Any]: + def parse_values_from_connection_kwargs(cls, connection: ASGIConnection, kwargs: dict[str, Any]) -> dict[str, Any]: """Extract values from the connection instance and return a dict of parsed values. Args: connection: The ASGI connection instance. - **kwargs: A dictionary of kwargs. + kwargs: A dictionary of kwargs. Raises: ValidationException: If validation failed. @@ -206,7 +208,7 @@ def parse_values_from_connection_kwargs(cls, connection: ASGIConnection, **kwarg messages.append(message) raise cls._create_exception(messages=messages, connection=connection) from e except ValidationError as e: - for field_name, exc in cls._collect_errors(deserializer=deserializer, **kwargs): # type: ignore[assignment] + for field_name, exc in cls._collect_errors(deserializer=deserializer, kwargs=kwargs): # type: ignore[assignment] match = ERR_RE.search(str(exc)) keys = [field_name, str(match.group(1))] if match else [field_name] message = cls._build_error_message(keys=keys, exc_msg=str(exc), connection=connection) diff --git a/litestar/routes/http.py b/litestar/routes/http.py index 95595edc1a..99ef4afe78 100644 --- a/litestar/routes/http.py +++ b/litestar/routes/http.py @@ -189,7 +189,7 @@ async def _get_response_data( cleanup_group = await parameter_model.resolve_dependencies(request, kwargs) parsed_kwargs = route_handler.signature_model.parse_values_from_connection_kwargs( - connection=request, **kwargs + connection=request, kwargs=kwargs ) if cleanup_group: diff --git a/litestar/routes/websocket.py b/litestar/routes/websocket.py index ebf4959d46..3248e2a83c 100644 --- a/litestar/routes/websocket.py +++ b/litestar/routes/websocket.py @@ -75,7 +75,7 @@ async def handle(self, scope: WebSocketScope, receive: Receive, send: Send) -> N cleanup_group = await self.handler_parameter_model.resolve_dependencies(websocket, parsed_kwargs) parsed_kwargs = self.route_handler.signature_model.parse_values_from_connection_kwargs( - connection=websocket, **parsed_kwargs + connection=websocket, kwargs=parsed_kwargs ) if cleanup_group: diff --git a/pyproject.toml b/pyproject.toml index b595342374..e116e9f087 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -310,6 +310,11 @@ exclude-classes = """ """ [tool.ruff] +include = [ + "{litestar,tests,docs,test_apps,tools}/**/*.{py,pyi}", + "pyproject.toml" +] + lint.select = [ "A", # flake8-builtins "B", # flake8-bugbear diff --git a/tests/unit/test_signature/test_validation.py b/tests/unit/test_signature/test_validation.py index bc39616723..acc8f97cae 100644 --- a/tests/unit/test_signature/test_validation.py +++ b/tests/unit/test_signature/test_validation.py @@ -28,7 +28,7 @@ def fn(a: int) -> None: type_decoders=[], ) with pytest.raises(ValidationException): - model.parse_values_from_connection_kwargs(connection=RequestFactory().get(), a="not an int") + model.parse_values_from_connection_kwargs(connection=RequestFactory().get(), kwargs={"a": "not an int"}) def test_create_signature_validation() -> None: @@ -128,7 +128,7 @@ def handler(data: Parent) -> None: with pytest.raises(ValidationException) as exc_info: model.parse_values_from_connection_kwargs( - connection=RequestFactory().get(route_handler=handler), data={"child": {}, "other_child": {}} + connection=RequestFactory().get(route_handler=handler), kwargs={"data": {"child": {}, "other_child": {}}} ) assert isinstance(exc_info.value.extra, list) @@ -283,7 +283,7 @@ def fn(a: Annotated[int, Parameter(gt=5)], b: Annotated[int, Parameter(lt=5)]) - type_decoders=[], ) with pytest.raises(ValidationException) as exc: - model.parse_values_from_connection_kwargs(connection=RequestFactory().get(), a=0, b=9) + model.parse_values_from_connection_kwargs(connection=RequestFactory().get(), kwargs={"a": 0, "b": 9}) assert exc.value.extra == [ {"message": "Expected `int` >= 6", "key": "a", "source": ParamType.QUERY}, @@ -303,3 +303,27 @@ async def something(foo: Foo[str] = Foo()) -> None: with create_test_client([something]) as client: assert client.get("/").status_code == 200 + + +def test_separate_model_namespace() -> None: + # https://github.com/litestar-org/litestar/issues/3593 + + async def provide_connection() -> str: + return "connection" + + @get("/connection", dependencies={"connection": provide_connection}) + async def get_connection(connection: str) -> str: + return connection + + async def provide_deserializer() -> str: + return "deserializer" + + @get("/deserializer", dependencies={"deserializer": provide_deserializer}) + async def get_deserializer(deserializer: int) -> str: + return deserializer # type: ignore[return-value] + + with create_test_client([get_connection, get_deserializer], raise_server_exceptions=True, debug=True) as client: + assert client.get("/connection").text == "connection" + res = client.get("/deserializer") + assert res.status_code == 500 + assert "Expected `int`, got `str` - at `$.deserializer`" in res.text