diff --git a/ariadne_graphql_proxy/__init__.py b/ariadne_graphql_proxy/__init__.py index 9abcc68..b46767a 100644 --- a/ariadne_graphql_proxy/__init__.py +++ b/ariadne_graphql_proxy/__init__.py @@ -39,6 +39,7 @@ ) from .narrow_graphql_query import narrow_graphql_query from .proxy_resolver import ProxyResolver +from .proxy_root_value import ProxyRootValue from .proxy_schema import ProxySchema from .query_filter import QueryFilter, QueryFilterContext from .remote_schema import get_remote_schema @@ -47,6 +48,7 @@ __all__ = [ "ForeignKeyResolver", "ProxyResolver", + "ProxyRootValue", "ProxySchema", "QueryFilter", "QueryFilterContext", diff --git a/ariadne_graphql_proxy/proxy_root_value.py b/ariadne_graphql_proxy/proxy_root_value.py new file mode 100644 index 0000000..c7d14c6 --- /dev/null +++ b/ariadne_graphql_proxy/proxy_root_value.py @@ -0,0 +1,30 @@ +from typing import List, Optional + +from ariadne.types import BaseProxyRootValue, GraphQLResult + + +class ProxyRootValue(BaseProxyRootValue): + __slots__ = ("root_value", "errors", "extensions") + + def __init__( + self, + root_value: Optional[dict] = None, + errors: Optional[List[dict]] = None, + extensions: Optional[dict] = None, + ): + super().__init__(root_value) + self.errors = errors + self.extensions = extensions + + def update_result(self, result: GraphQLResult) -> GraphQLResult: + success, data = super().update_result(result) + + if self.errors: + data.setdefault("errors", []) + data["errors"] += self.errors + + if self.extensions: + data.setdefault("extensions", {}) + data["extensions"].update(self.extensions) + + return success, data diff --git a/ariadne_graphql_proxy/proxy_schema.py b/ariadne_graphql_proxy/proxy_schema.py index 313d31e..827035c 100644 --- a/ariadne_graphql_proxy/proxy_schema.py +++ b/ariadne_graphql_proxy/proxy_schema.py @@ -1,9 +1,9 @@ from asyncio import gather from functools import reduce from inspect import isawaitable -from typing import Any, Callable, Dict, List, Optional, Set, Union +from typing import Any, Callable, Dict, List, Optional, Set, Type, Union -from ariadne.types import RootValue +from ariadne.types import BaseProxyRootValue, RootValue from graphql import ( DocumentNode, GraphQLInterfaceType, @@ -17,6 +17,7 @@ from .copy import copy_schema from .merge import merge_schemas +from .proxy_root_value import ProxyRootValue from .query_filter import QueryFilter from .remote_schema import get_remote_schema from .standard_types import STANDARD_TYPES, add_missing_scalar_types @@ -30,15 +31,24 @@ class ProxySchema: - def __init__(self, root_value: Optional[RootValue] = None): + def __init__( + self, + root_value: Optional[RootValue] = None, + proxy_root_value: Type[ProxyRootValue] = ProxyRootValue, + ): self.schemas: List[GraphQLSchema] = [] self.urls: List[Optional[str]] = [] self.headers: List[Optional[ProxyHeaders]] = [] + self.proxy_errors: List[bool] = [] + self.proxy_extensions: List[bool] = [] + self.labels: List[str] = [] self.fields_map: Dict[str, Dict[str, Set[int]]] = {} self.fields_types: Dict[str, Dict[str, str]] = {} self.unions: Dict[str, List[str]] = {} self.foreign_keys: Dict[str, Dict[str, List[str]]] = {} + self.proxy_root_value = proxy_root_value + self.schema: Optional[GraphQLSchema] = None self.query_filter: Optional[QueryFilter] = None self.root_value: Optional[RootValue] = root_value @@ -54,12 +64,17 @@ def add_remote_schema( exclude_directives: Optional[List[str]] = None, exclude_directives_args: Optional[Dict[str, List[str]]] = None, extra_fields: Optional[Dict[str, List[str]]] = None, + label: Optional[str] = None, + proxy_errors: bool = True, + proxy_extensions: bool = True, ) -> int: if callable(headers): remote_schema = get_remote_schema(url, headers(None)) else: remote_schema = get_remote_schema(url, headers) + schema_id = len(self.schemas) + return self.add_schema( remote_schema, url, @@ -70,6 +85,9 @@ def add_remote_schema( exclude_directives=exclude_directives, exclude_directives_args=exclude_directives_args, extra_fields=extra_fields, + label=label or f"remote_{schema_id}", + proxy_errors=proxy_errors, + proxy_extensions=proxy_extensions, ) def add_schema( @@ -84,6 +102,9 @@ def add_schema( exclude_directives: Optional[List[str]] = None, exclude_directives_args: Optional[Dict[str, List[str]]] = None, extra_fields: Optional[Dict[str, List[str]]] = None, + label: Optional[str] = None, + proxy_errors: bool = True, + proxy_extensions: bool = True, ) -> int: if ( exclude_types @@ -103,11 +124,15 @@ def add_schema( schema.type_map = add_missing_scalar_types(schema.type_map) + schema_id = len(self.schemas) + self.schemas.append(schema) self.urls.append(url) self.headers.append(headers) + self.labels.append(label or f"schema_{schema_id}") + self.proxy_errors.append(proxy_errors) + self.proxy_extensions.append(proxy_extensions) - schema_id = len(self.schemas) - 1 for type_name, type_def in schema.type_map.items(): if type_name in STANDARD_TYPES: continue @@ -212,7 +237,7 @@ async def root_resolver( operation_name: Optional[str], variables: Optional[dict], document: DocumentNode, - ) -> Optional[dict]: + ) -> Optional[Union[dict, BaseProxyRootValue]]: if not self.query_filter: raise RuntimeError( "'get_final_schema' needs to be called to build final schema " @@ -241,9 +266,13 @@ async def root_resolver( if not queries: return root_value + root_errors: List[dict] = [] + root_extensions: dict = {} + subqueries_data = await gather( *[ self.fetch_data( + schema_id, context_value, self.urls[schema_id], self.headers[schema_id], @@ -266,13 +295,32 @@ async def root_resolver( ] ) - for subquery_data in subqueries_data: - if subquery_data: - root_value.update(subquery_data) + for schema_id, subquery_data in subqueries_data: + label = self.labels[schema_id] + if isinstance(subquery_data.get("data"), dict): + root_value.update(subquery_data["data"]) + if ( + isinstance(subquery_data.get("errors"), list) + and self.proxy_errors[schema_id] + ): + root_errors += self.clean_errors(label, subquery_data["errors"]) + if ( + isinstance(subquery_data.get("extensions"), dict) + and self.proxy_extensions[schema_id] + ): + print("HERE") + root_extensions[label] = subquery_data["extensions"] + + if root_errors or root_extensions: + return self.proxy_root_value( + root_value, + root_errors or None, + root_extensions or None, + ) return root_value or None - async def fetch_data(self, context, url, headers, json): + async def fetch_data(self, schema_id, context, url, headers, json): async with AsyncClient() as client: if callable(headers): headers = headers(context) @@ -284,4 +332,12 @@ async def fetch_data(self, context, url, headers, json): ) query_data = r.json() - return query_data.get("data") + return (schema_id, query_data) + + def clean_errors(self, label: str, errors: List[dict]) -> List[dict]: + clean_errors: List[dict] = [] + for error in errors: + if isinstance(error, dict) and isinstance(error.get("path"), list): + error["path"].insert(0, label) + clean_errors.append(error) + return clean_errors diff --git a/pyproject.toml b/pyproject.toml index e0b2fd6..a3c61d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,11 @@ classifiers = [ "Topic :: Software Development :: Libraries :: Python Modules", ] version = "0.2.0" -dependencies = ["graphql-core>=3.2.0,<3.3", "httpx", "ariadne"] +dependencies = [ + "graphql-core>=3.2.0,<3.3", + "httpx", + "ariadne==0.23.0.b1", +] [project.optional-dependencies] test = [ diff --git a/tests/test_proxy_root_value.py b/tests/test_proxy_root_value.py new file mode 100644 index 0000000..0179583 --- /dev/null +++ b/tests/test_proxy_root_value.py @@ -0,0 +1,128 @@ +from ariadne_graphql_proxy import ProxyRootValue + + +def test_proxy_root_value_without_errors_or_extensions_skips_result_update(): + result = False, {"data": "ok"} + root_value = ProxyRootValue() + assert root_value.update_result(result) == result + + +def test_proxy_root_value_with_errors_extends_result(): + result = False, {"data": "ok"} + root_value = ProxyRootValue(errors=[{"message": "Test"}]) + assert root_value.update_result(result) == ( + False, + { + "data": "ok", + "errors": [ + { + "message": "Test", + }, + ], + }, + ) + + +def test_proxy_root_value_with_extensions_extends_result(): + result = False, {"data": "ok"} + root_value = ProxyRootValue(extensions={"score": "100"}) + assert root_value.update_result(result) == ( + False, + { + "data": "ok", + "extensions": { + "score": "100", + }, + }, + ) + + +def test_proxy_root_value_with_errors_and_extensions_extends_result(): + result = False, {"data": "ok"} + root_value = ProxyRootValue( + errors=[{"message": "Test"}], + extensions={"score": "100"}, + ) + assert root_value.update_result(result) == ( + False, + { + "data": "ok", + "errors": [ + { + "message": "Test", + }, + ], + "extensions": { + "score": "100", + }, + }, + ) + + +def test_proxy_root_value_with_errors_updates_result(): + result = False, {"data": "ok", "errors": [{"message": "Org"}]} + + root_value = ProxyRootValue(errors=[{"message": "Test"}]) + assert root_value.update_result(result) == ( + False, + { + "data": "ok", + "errors": [ + { + "message": "Org", + }, + { + "message": "Test", + }, + ], + }, + ) + + +def test_proxy_root_value_with_extensions_updates_result(): + result = False, {"data": "ok", "extensions": {"core": True}} + root_value = ProxyRootValue(extensions={"score": "100"}) + assert root_value.update_result(result) == ( + False, + { + "data": "ok", + "extensions": { + "core": True, + "score": "100", + }, + }, + ) + + +def test_proxy_root_value_with_errors_and_extensions_updates_result(): + result = ( + False, + { + "data": "ok", + "errors": [{"message": "Org"}], + "extensions": {"core": True}, + }, + ) + + root_value = ProxyRootValue( + errors=[{"message": "Test"}], + extensions={"score": "100"}, + ) + assert root_value.update_result(result) == ( + False, + { + "data": "ok", + "errors": [ + { + "message": "Org", + }, + { + "message": "Test", + }, + ], + "extensions": { + "core": True, + "score": "100", + }, + }, + ) diff --git a/tests/test_proxy_schema.py b/tests/test_proxy_schema.py index 6d53a07..3d24bd7 100644 --- a/tests/test_proxy_schema.py +++ b/tests/test_proxy_schema.py @@ -5,7 +5,7 @@ import pytest from graphql import parse, print_schema -from ariadne_graphql_proxy import ProxySchema +from ariadne_graphql_proxy import ProxyRootValue, ProxySchema def test_local_schema_is_added_to_proxy(schema): @@ -1113,3 +1113,153 @@ async def test_proxy_schema_includes_headers_from_callable_in_requests( call(ANY), ] ) + + +@pytest.mark.asyncio +async def test_root_value_for_remote_schema_includes_proxied_errors( + httpx_mock, schema_json +): + httpx_mock.add_response(json=schema_json) + httpx_mock.add_response( + json={ + "errors": [ + { + "message": "Something bad has happened!", + "path": ["complex", "id"], + }, + ], + } + ) + + proxy_schema = ProxySchema() + proxy_schema.add_remote_schema("http://graphql.example.com/") + + proxy_schema.get_final_schema() + root_value = await proxy_schema.root_resolver( + {}, + "Query", + None, + parse("query Query { complex { id name } }"), + ) + + assert isinstance(root_value, ProxyRootValue) + assert root_value.errors == [ + { + "message": "Something bad has happened!", + "path": ["remote_0", "complex", "id"], + }, + ] + + +@pytest.mark.asyncio +async def test_root_value_for_remote_schema_excludes_errors(httpx_mock, schema_json): + httpx_mock.add_response(json=schema_json) + httpx_mock.add_response( + json={ + "errors": [ + { + "message": "Something bad has happened!", + "path": ["complex", "id"], + }, + ], + } + ) + + proxy_schema = ProxySchema() + proxy_schema.add_remote_schema("http://graphql.example.com/", proxy_errors=False) + + proxy_schema.get_final_schema() + root_value = await proxy_schema.root_resolver( + {}, + "Query", + None, + parse("query Query { complex { id name } }"), + ) + + assert not isinstance(root_value, ProxyRootValue) + assert root_value is None + + +@pytest.mark.asyncio +async def test_root_value_for_remote_schema_includes_proxied_extensions( + httpx_mock, schema_json +): + httpx_mock.add_response(json=schema_json) + httpx_mock.add_response( + json={ + "data": { + "complex": { + "id": "123", + "name": "Test", + }, + }, + "extensions": { + "score": 100, + }, + } + ) + + proxy_schema = ProxySchema() + proxy_schema.add_remote_schema("http://graphql.example.com/") + + proxy_schema.get_final_schema() + root_value = await proxy_schema.root_resolver( + {}, + "Query", + None, + parse("query Query { complex { id name } }"), + ) + + assert isinstance(root_value, ProxyRootValue) + assert root_value.root_value == { + "complex": { + "id": "123", + "name": "Test", + }, + } + assert root_value.extensions == { + "remote_0": { + "score": 100, + }, + } + + +@pytest.mark.asyncio +async def test_root_value_for_remote_schema_excludes_extensions( + httpx_mock, schema_json +): + httpx_mock.add_response(json=schema_json) + httpx_mock.add_response( + json={ + "data": { + "complex": { + "id": "123", + "name": "Test", + }, + }, + "extensions": { + "score": 100, + }, + } + ) + + proxy_schema = ProxySchema() + proxy_schema.add_remote_schema( + "http://graphql.example.com/", proxy_extensions=False + ) + + proxy_schema.get_final_schema() + root_value = await proxy_schema.root_resolver( + {}, + "Query", + None, + parse("query Query { complex { id name } }"), + ) + + assert not isinstance(root_value, ProxyRootValue) + assert root_value == { + "complex": { + "id": "123", + "name": "Test", + }, + }