Skip to content

Commit

Permalink
Merge pull request #50 from mirumee/fix-43-add-headers-for-remote-schema
Browse files Browse the repository at this point in the history
Improve headers configuration
  • Loading branch information
rafalp authored Feb 14, 2024
2 parents d9084ff + 85fedf2 commit ff8698a
Show file tree
Hide file tree
Showing 11 changed files with 222 additions and 45 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

- Added `CacheSerializer`, `NoopCacheSerializer` and `JSONCacheSerializer`. Changed `CacheBackend`, `InMemoryCache`, `CloudflareCacheBackend` and `DynamoDBCacheBackend` to accept `serializer` initialization option.
- Fixed schema proxy returning an error when variable defined in an operation is missing from its variables.
- Improved custom headers handling in `ProxyResolver` and `ProxySchema`.


## 0.2.0 (2023-09-25)
Expand Down
40 changes: 27 additions & 13 deletions GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ It takes following optional arguments:
- `cache_key`: `Union[str, Callable[[GraphQLResolveInfo], str]]`
- `cache_ttl`: `int`

`proxy_headers` option is documented in "Proxying headers" section of this guide.
`proxy_headers` option is documented in "Configuring headers" section of this guide.

`cache`, `cache_key` and `cache_ttl` arguments are documented in cache section of this guide.

Expand Down Expand Up @@ -525,7 +525,7 @@ query getProduct {
```


## Proxying headers
## Configuring headers

Ariadne GraphQL Proxy requires that `GraphQLResolveInfo.context` attribute is a dictionary containing `headers` key, which in itself is a `Dict[str, str]` dictionary.

Expand All @@ -534,26 +534,40 @@ Ariadne GraphQL Proxy requires that `GraphQLResolveInfo.context` attribute is a

### `ProxySchema`

Proxy schema doesn't include any custom headers in requests used to introspect remote schemas.
It is possible to configure headers per schema with second positional argument of the `add_remote_schema` method:

`root_resolver` includes `authorization` in header proxied requests if it's present in context's `headers` dictionary.
```python
schema.add_remote_schema("https://example.com/graphql", {"Authorization": "Bearer T0K3N"})
```

Configured headers will be included in all HTTP requests to `https://example.com/graphql` made by the `ProxySchema`. This excludes requests made by `ForeignKeyResolver` and `ProxyResolver` which require headers to be configured on them separately.

### `ForeignKeyResolver` and `ProxyResolver`
If you need to create headers from `context` (eg. to proxy authorization header), you can use a function instead of a `dict`:

```python
def get_proxy_schema_headers(context):
if not context:
# Context is not available when `ProxySchema` retrieves remote schema for the first time
return {"Authorization": "Bearer T0K3N"}

return context.get("headers")

Both foreign key and proxy resolvers constructors take `proxy_headers` as second option. This option controls which headers from `context["headers"]` are proxied to services and which aren't.

If this option is not set, only `authorization` header is proxied, if it was sent to the proxy.
schema.add_remote_schema("https://example.com/graphql", get_proxy_schema_headers)
```


### `ForeignKeyResolver` and `ProxyResolver`

Both foreign key and proxy resolvers constructors take `proxy_headers` as second option. This option controls headers proxying:

If `proxy_headers` is a `List[str]`, its assumed to be a list of names of headers that should be proxied if sent by client.
If this option is not set, no headers are set on proxied queries.

If `proxy_headers` is a callable, it will be called with three arguments:
If `proxy_headers` is `True` and `context["headers"]` dictionary exists, its `authorization` value will be proxied.

- `obj`: `Any` value that was passed to resolved field's first argument.
- `info`: a `GraphQLResolveInfo` object for field with proxy or foreign key resolver.
- `payload`: a `dict` with GraphQL JSON payload that will be sent to a proxy server (`operationName`, `query`, `variables`).
If `proxy_headers` is a `List[str]`, its assumed to be a list of names of headers that should be proxied from `context["headers"]`.

Callable should return `None` or `Dict[str, str]` with headers to send to other server.
If `proxy_headers` is a callable, it will be called with single argument (`context`) and should return either `None` or `Dict[str, str]` with headers to send to the other server.

If `proxy_headers` is `None` or `False`, no headers are proxied to the other service.

Expand Down
16 changes: 7 additions & 9 deletions ariadne_graphql_proxy/proxy_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class ProxyResolver:
def __init__(
self,
url: str,
proxy_headers: Union[bool, Callable, List[str]] = True,
proxy_headers: Union[bool, Callable, List[str]] = False,
cache: Optional[CacheBackend] = None,
cache_key: Optional[Union[str, Callable[[GraphQLResolveInfo], str]]] = None,
cache_ttl: Optional[int] = None,
Expand Down Expand Up @@ -97,22 +97,20 @@ async def proxy_query_with_cache(
async def proxy_query(
self, obj: Any, info: GraphQLResolveInfo, payload: dict
) -> Any:
proxy_headers = None
if self._proxy_headers is True:
authorization = info.context["headers"].get("authorization")
if authorization:
proxy_headers = {"authorization": authorization}
else:
proxy_headers = None
if "headers" in info.context:
authorization = info.context["headers"].get("authorization")
if authorization:
proxy_headers = {"authorization": authorization}
elif callable(self._proxy_headers):
proxy_headers = self._proxy_headers(obj, info, payload)
proxy_headers = self._proxy_headers(info.context)
elif self._proxy_headers:
proxy_headers = {
header: value
for header, value in info.context["headers"].items()
if header in self._proxy_headers
}
else:
proxy_headers = None

async with AsyncClient() as client:
r = await client.post(
Expand Down
31 changes: 21 additions & 10 deletions ariadne_graphql_proxy/proxy_schema.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from asyncio import gather
from functools import reduce
from inspect import isawaitable
from typing import Dict, List, Optional, Set, Union
from typing import Any, Callable, Dict, List, Optional, Set, Union

from ariadne.types import RootValue
from graphql import (
Expand All @@ -26,10 +26,14 @@
)


ProxyHeaders = Union[dict, Callable[[Any], dict]]


class ProxySchema:
def __init__(self, root_value: Optional[RootValue] = None):
self.schemas: List[GraphQLSchema] = []
self.urls: List[Optional[str]] = []
self.headers: List[Optional[ProxyHeaders]] = []
self.fields_map: Dict[str, Dict[str, Set[int]]] = {}
self.fields_types: Dict[str, Dict[str, str]] = {}
self.unions: Dict[str, List[str]] = {}
Expand All @@ -42,6 +46,7 @@ def __init__(self, root_value: Optional[RootValue] = None):
def add_remote_schema(
self,
url: str,
headers: Optional[ProxyHeaders] = None,
*,
exclude_types: Optional[List[str]] = None,
exclude_args: Optional[Dict[str, Dict[str, List[str]]]] = None,
Expand All @@ -50,9 +55,15 @@ def add_remote_schema(
exclude_directives_args: Optional[Dict[str, List[str]]] = None,
extra_fields: Optional[Dict[str, List[str]]] = None,
) -> int:
if callable(headers):
remote_schema = get_remote_schema(url, headers(None))
else:
remote_schema = get_remote_schema(url, headers)

return self.add_schema(
get_remote_schema(url),
remote_schema,
url,
headers,
exclude_types=exclude_types,
exclude_args=exclude_args,
exclude_fields=exclude_fields,
Expand All @@ -65,6 +76,7 @@ def add_schema(
self,
schema: GraphQLSchema,
url: Optional[str] = None,
headers: Optional[ProxyHeaders] = None,
*,
exclude_types: Optional[List[str]] = None,
exclude_args: Optional[Dict[str, Dict[str, List[str]]]] = None,
Expand Down Expand Up @@ -93,6 +105,7 @@ def add_schema(

self.schemas.append(schema)
self.urls.append(url)
self.headers.append(headers)

schema_id = len(self.schemas) - 1
for type_name, type_def in schema.type_map.items():
Expand Down Expand Up @@ -228,17 +241,12 @@ async def root_resolver(
if not queries:
return root_value

headers = {}
if context_value.get("request"):
authorization = context_value["request"].headers.get("authorization")
if authorization:
headers["Authorization"] = authorization

subqueries_data = await gather(
*[
self.fetch_data(
context_value,
self.urls[schema_id],
headers,
self.headers[schema_id],
{
"operationName": operation_name,
"query": print_ast(query_document),
Expand All @@ -264,8 +272,11 @@ async def root_resolver(

return root_value or None

async def fetch_data(self, url, headers, json):
async def fetch_data(self, context, url, headers, json):
async with AsyncClient() as client:
if callable(headers):
headers = headers(context)

r = await client.post(
url,
headers=headers,
Expand Down
7 changes: 6 additions & 1 deletion ariadne_graphql_proxy/remote_schema.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import httpx
from graphql import (
GraphQLSchema,
Expand All @@ -6,9 +8,12 @@
)


def get_remote_schema(graphql_url: str) -> GraphQLSchema:
def get_remote_schema(
graphql_url: str, headers: Optional[dict] = None
) -> GraphQLSchema:
response = httpx.post(
graphql_url,
headers=headers,
json={
"operationName": "IntrospectionQuery",
"query": get_introspection_query(),
Expand Down
8 changes: 4 additions & 4 deletions tests/cache/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,23 @@ def test_json_serialize_calls_json_dumps_with_decode_if_orjson_is_available(
):
JSONCacheSerializer().serialize("test value")

assert mocked_orjson.dumps.called_with("test value")
mocked_orjson.dumps.assert_called_with("test value")
assert mocked_orjson.dumps().decode.called


def test_json_serialize_calls_json_dumps_if_orjson_is_not_available(mocked_json):
JSONCacheSerializer().serialize("test value")

assert mocked_json.dumps.called_with("test value")
mocked_json.dumps.assert_called_with("test value")


def test_json_deserialize_calls_orjson_loads_if_orjson_is_available(mocked_orjson):
JSONCacheSerializer().deserialize("test value")

assert mocked_orjson.loads.called_with("test value")
mocked_orjson.loads.assert_called_with("test value")


def test_json_deserialize_calls_json_loads_if_orjson_is_not_available(mocked_json):
JSONCacheSerializer().deserialize("test value")

assert mocked_json.loads.called_with("test value")
mocked_json.loads.assert_called_with("test value")
10 changes: 8 additions & 2 deletions tests/test_copy_directives.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest.mock import call

from graphql import (
DirectiveLocation,
GraphQLArgument,
Expand All @@ -20,8 +22,12 @@ def test_copy_directives_calls_copy_directive_for_each_object(mocker):
copy_directives({}, (directive1, directive2))

assert mocked_copy_directive.call_count == 2
assert mocked_copy_directive.called_with({}, directive1)
assert mocked_copy_directive.called_with({}, directive2)
mocked_copy_directive.assert_has_calls(
[
call({}, directive1, directive_exclude_args=None),
call({}, directive2, directive_exclude_args=None),
]
)


def test_copy_directives_returns_tuple_without_excluded_directive():
Expand Down
12 changes: 12 additions & 0 deletions tests/test_get_remote_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,18 @@ def test_remote_schema_document_is_returned(httpx_mock, schema, schema_json):
assert print_schema(remote_schema) == print_schema(schema)


def test_remote_schema_document_is_retrieved_with_headers_dict(
httpx_mock, schema, schema_json
):
httpx_mock.add_response(json=schema_json)
remote_schema = get_remote_schema("http://graphql.example.com/", {"auth": "ok"})
assert remote_schema
assert print_schema(remote_schema) == print_schema(schema)

request = httpx_mock.get_requests(url="http://graphql.example.com/")[0]
assert request.headers["auth"] == "ok"


def test_remote_schema_fetch_raises_http_response_error(httpx_mock):
httpx_mock.add_response(status_code=404, text="Not found")

Expand Down
13 changes: 9 additions & 4 deletions tests/test_merge_types_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ def test_merge_type_maps_calls_copy_schema_type_if_object_is_not_present_in_one_
merge_type_maps(type_map1={"TypeName": type_}, type_map2={})

assert mocked_copy_schema_type.called
assert mocked_copy_schema_type.called_with(new_types={}, graphql_type=type_)
assert mocked_copy_schema_type(new_types={}, graphql_type=type_)


def test_merge_type_maps_calls_merge_types_if_object_is_present_in_both_maps(
mocker,
):
mocked_merge_types = mocker.patch("ariadne_graphql_proxy.merge.merge_types")
mocked_merge_types = mocker.patch(
"ariadne_graphql_proxy.merge.merge_types", return_value=True
)
type1 = GraphQLObjectType(
name="TypeName",
fields={"fieldA": GraphQLField(type_=GraphQLString)},
Expand All @@ -34,5 +36,8 @@ def test_merge_type_maps_calls_merge_types_if_object_is_present_in_both_maps(
)
merge_type_maps(type_map1={"TypeName": type1}, type_map2={"TypeName": type2})

assert mocked_merge_types.called
assert mocked_merge_types.called_with(merge_types={}, type1=type1, type2=type2)
mocked_merge_types.assert_called_with(
merged_types={"TypeName": True},
type1=type1,
type2=type2,
)
4 changes: 2 additions & 2 deletions tests/test_proxy_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ async def test_proxy_resolver_proxies_headers_via_callable(
schema,
root_value,
):
def proxy_headers(_, info, payload):
def proxy_headers(context):
return {
"x-auth": info.context["headers"].get("authorization"),
"x-auth": context["headers"].get("authorization"),
}

resolver = ProxyResolver(
Expand Down
Loading

0 comments on commit ff8698a

Please sign in to comment.