Skip to content

Commit

Permalink
Proxy errors and extensions in ProxySchema
Browse files Browse the repository at this point in the history
  • Loading branch information
rafalp committed Mar 1, 2024
1 parent ff8698a commit 67ea420
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 10 deletions.
30 changes: 30 additions & 0 deletions ariadne_graphql_proxy/proxy_root_value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import List, Optional, Tuple

from ariadne.types import BaseProxyRootValue, GraphQLResult


class ProxyRootValue(BaseProxyRootValue):
__slots__ = ("root_value", "errors", "extensions")

def __init__(
self,
root_value: Optional[dict] = None,
errors: Optional[List[dict]] = None,
extensions: Optional[dict] = None,
):
super().__init__(root_value)
self.errors = errors
self.extensions = extensions

def update_result(self, result: Tuple[bool, dict]) -> Tuple[bool, dict]:
success, data = super().update_result(result)

if self.errors:
data.setdefault("errors", [])
data["errors"] += self.errors

if self.extensions:
data.setdefault("extensions", {})
data["extensions"].update(self.extensions)

return success, data
73 changes: 64 additions & 9 deletions ariadne_graphql_proxy/proxy_schema.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from asyncio import gather
from functools import reduce
from inspect import isawaitable
from typing import Any, Callable, Dict, List, Optional, Set, Union
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union

from ariadne.types import RootValue
from ariadne.types import BaseProxyRootValue, RootValue
from graphql import (
DocumentNode,
GraphQLInterfaceType,
Expand All @@ -17,6 +17,7 @@

from .copy import copy_schema
from .merge import merge_schemas
from .proxy_root_value import ProxyRootValue
from .query_filter import QueryFilter
from .remote_schema import get_remote_schema
from .standard_types import STANDARD_TYPES, add_missing_scalar_types
Expand All @@ -30,15 +31,24 @@


class ProxySchema:
def __init__(self, root_value: Optional[RootValue] = None):
def __init__(
self,
root_value: Optional[RootValue] = None,
proxy_root_value: Type[ProxyRootValue] = ProxyRootValue,
):
self.schemas: List[GraphQLSchema] = []
self.urls: List[Optional[str]] = []
self.headers: List[Optional[ProxyHeaders]] = []
self.proxy_errors: List[bool] = []
self.proxy_extensions: List[bool] = []
self.labels: List[str] = []
self.fields_map: Dict[str, Dict[str, Set[int]]] = {}
self.fields_types: Dict[str, Dict[str, str]] = {}
self.unions: Dict[str, List[str]] = {}
self.foreign_keys: Dict[str, Dict[str, List[str]]] = {}

self.proxy_root_value = proxy_root_value

self.schema: Optional[GraphQLSchema] = None
self.query_filter: Optional[QueryFilter] = None
self.root_value: Optional[RootValue] = root_value
Expand All @@ -54,12 +64,17 @@ def add_remote_schema(
exclude_directives: Optional[List[str]] = None,
exclude_directives_args: Optional[Dict[str, List[str]]] = None,
extra_fields: Optional[Dict[str, List[str]]] = None,
label: Optional[str] = None,
proxy_errors: bool = True,
proxy_extensions: bool = True,
) -> int:
if callable(headers):
remote_schema = get_remote_schema(url, headers(None))
else:
remote_schema = get_remote_schema(url, headers)

schema_id = len(self.schemas)

return self.add_schema(
remote_schema,
url,
Expand All @@ -70,6 +85,9 @@ def add_remote_schema(
exclude_directives=exclude_directives,
exclude_directives_args=exclude_directives_args,
extra_fields=extra_fields,
label=label or f"remote_{schema_id}",
proxy_errors=proxy_errors,
proxy_extensions=proxy_extensions,
)

def add_schema(
Expand All @@ -84,6 +102,9 @@ def add_schema(
exclude_directives: Optional[List[str]] = None,
exclude_directives_args: Optional[Dict[str, List[str]]] = None,
extra_fields: Optional[Dict[str, List[str]]] = None,
label: Optional[str] = None,
proxy_errors: bool = True,
proxy_extensions: bool = True,
) -> int:
if (
exclude_types
Expand All @@ -103,11 +124,15 @@ def add_schema(

schema.type_map = add_missing_scalar_types(schema.type_map)

schema_id = len(self.schemas)

self.schemas.append(schema)
self.urls.append(url)
self.headers.append(headers)
self.labels.append(label or f"schema_{schema_id}")
self.proxy_errors.append(proxy_errors)
self.proxy_extensions.append(proxy_extensions)

schema_id = len(self.schemas) - 1
for type_name, type_def in schema.type_map.items():
if type_name in STANDARD_TYPES:
continue
Expand Down Expand Up @@ -241,9 +266,13 @@ async def root_resolver(
if not queries:
return root_value

root_errors: List[dict] = []
root_extensions: dict = {}

subqueries_data = await gather(
*[
self.fetch_data(
schema_id,
context_value,
self.urls[schema_id],
self.headers[schema_id],
Expand All @@ -266,13 +295,31 @@ async def root_resolver(
]
)

for subquery_data in subqueries_data:
if subquery_data:
root_value.update(subquery_data)
for schema_id, subquery_data in subqueries_data:
label = self.labels[schema_id]
if isinstance(subquery_data.get("data"), dict):
root_value.update(subquery_data["data"])
if (
isinstance(subquery_data.get("errors"), list)
and self.proxy_errors[schema_id]
):
root_errors += self.clean_errors(label, subquery_data["errors"])
if (
isinstance(subquery_data.get("extensions"), dict)
and self.proxy_extensions[schema_id]
):
root_extensions[label] = subquery_data["extensions"]

if root_errors or root_extensions:
return self.proxy_root_value(
root_value,
root_errors or None,
root_extensions or None,
)

return root_value or None

async def fetch_data(self, context, url, headers, json):
async def fetch_data(self, schema_id, context, url, headers, json):
async with AsyncClient() as client:
if callable(headers):
headers = headers(context)
Expand All @@ -284,4 +331,12 @@ async def fetch_data(self, context, url, headers, json):
)

query_data = r.json()
return query_data.get("data")
return (schema_id, query_data)

def clean_errors(self, label: str, errors: List[dict]) -> List[dict]:
clean_errors: List[dict] = []
for error in errors:
if isinstance(error, dict) and isinstance(error.get("path"), list):
error["path"].insert(0, label)
clean_errors.append(error)
return clean_errors
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ classifiers = [
"Topic :: Software Development :: Libraries :: Python Modules",
]
version = "0.2.0"
dependencies = ["graphql-core>=3.2.0,<3.3", "httpx", "ariadne"]
dependencies = [
"graphql-core>=3.2.0,<3.3",
"httpx",
"ariadne>=0.23",
]

[project.optional-dependencies]
test = [
Expand Down

0 comments on commit 67ea420

Please sign in to comment.