Skip to content

Commit

Permalink
fix: apply api middleware correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
davemooreuws committed Apr 3, 2024
1 parent 46842c3 commit 73afdba
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 3 deletions.
4 changes: 2 additions & 2 deletions nitric/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,9 +439,9 @@ async def chained_middleware(ctx: C, nxt: Optional[Middleware[C]] = None) -> C:

return chained_middleware

middleware_chain = functools.reduce(reduce_chain, reversed(middlewares)) # type: ignore
middleware_chain = functools.reduce(reduce_chain, reversed(middlewares), last_middleware) # type: ignore
# type ignored because mypy appears to misidentify the correct return type
return await middleware_chain(ctx, last_middleware) # type: ignore
return await middleware_chain(ctx) # type: ignore

return composed

Expand Down
10 changes: 10 additions & 0 deletions nitric/resources/apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ def _route(self, match: str, opts: Optional[RouteOptions] = None) -> Route:
if opts is None:
opts = RouteOptions()

if self.middleware is not None:
opts.middleware = self.middleware + opts.middleware

r = Route(self, match, opts)
self.routes.append(r)
return r
Expand Down Expand Up @@ -339,6 +342,13 @@ def method(
self, methods: List[HttpMethod], *middleware: HttpMiddleware | HttpHandler, opts: Optional[MethodOptions] = None
) -> None:
"""Register middleware for multiple HTTP Methods."""

# ensure route/api middlewares are added
middleware = (
*self.middleware,
*middleware
)

Method(self, methods, *middleware, opts=opts if opts else MethodOptions())

def get(self, *middleware: HttpMiddleware | HttpHandler, opts: Optional[MethodOptions] = None) -> None:
Expand Down
19 changes: 18 additions & 1 deletion tests/resources/test_apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

# from nitric.faas import HttpMethod, MethodOptions, ApiWorkerOptions
from nitric.resources import api, ApiOptions, JwtSecurityDefinition
from nitric.resources.apis import MethodOptions, ScopedOidcOptions, oidc_rule
from nitric.resources.apis import MethodOptions, ScopedOidcOptions, oidc_rule, HttpMiddleware
from nitric.proto.resources.v1 import (
ApiOpenIdConnectionDefinition,
ApiSecurityDefinitionResource,
Expand All @@ -40,6 +40,7 @@
from nitric.proto.apis.v1 import ApiDetailsResponse, ApiDetailsRequest, ApiWorkerScopes

from nitric.context import (
HttpContext,
HttpMethod,
)

Expand Down Expand Up @@ -221,6 +222,22 @@ def test_api_route(self):
assert test_route.middleware == []
assert test_route.api.name == test_api.name

def test_api_route_middleware(self):
mock_declare = AsyncMock()
mock_response = Object()
mock_declare.return_value = mock_response

async def middleware_test(ctx: HttpContext, nxt: HttpMiddleware):
return nxt(ctx)

with patch("nitric.proto.resources.v1.ResourcesStub.declare", mock_declare):
test_api = api("test-api-route", ApiOptions(path="/api/v2/", middleware=[middleware_test]))

test_route = test_api._route("/hello")

assert len(test_route.middleware) == 1
assert len(test_api.middleware) == 1

def test_define_route(self):
mock_declare = AsyncMock()
mock_response = Object()
Expand Down

0 comments on commit 73afdba

Please sign in to comment.