Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add field dependencies #54

Merged
merged 6 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,6 @@ jobs:
pytest
- name: Linters
run: |
ruff ariadne_graphql_proxy tests
ruff check ariadne_graphql_proxy tests
mypy ariadne_graphql_proxy --ignore-missing-imports --check-untyped-defs
black --check ariadne_graphql_proxy tests
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# CHANGELOG

## UNRELEASED
## 0.3.0 (UNRELEASED)

- Added `CacheSerializer`, `NoopCacheSerializer` and `JSONCacheSerializer`. Changed `CacheBackend`, `InMemoryCache`, `CloudflareCacheBackend` and `DynamoDBCacheBackend` to accept `serializer` initialization option.
- Fixed schema proxy returning an error when variable defined in an operation is missing from its variables.
- Improved custom headers handling in `ProxyResolver` and `ProxySchema`.
- Added fields dependencies configuration option to `ProxySchema`.


## 0.2.0 (2023-09-25)
Expand Down
93 changes: 93 additions & 0 deletions GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,71 @@ If `proxy_headers` is a callable, it will be called with single argument (`conte
If `proxy_headers` is `None` or `False`, no headers are proxied to the other service.


## Fields dependencies

In situations where field depends on data from sibling fields in order to be resolved, `ProxySchema` can be configured to include those additional fields in root value query sent to remote schema.

Below example pulls a remote schema that defines `Product` type, extends this type with `image: String` field, and then uses `ProxySchema.add_field_dependencies` to configure `{ metadata { thumb} }` as additional fields to retrieve when `image` field is queried. It also includes custom resolver for `image` field that uses this additional data:


```python
from ariadne.asgi import GraphQL
from ariadne_graphql_proxy import (
ProxySchema,
get_context_value,
set_resolver,
)
from graphql import build_ast_schema, parse


proxy_schema = ProxySchema()

# Store schema ID for remote schema
remote_schema_id = proxy_schema.add_remote_schema(
"https://example.com/graphql/",
)

# Extend Product type with additional image field
proxy_schema.add_schema(
build_ast_schema(
parse(
"""
type Product {
image: String
}
"""
)
)
)

# Configure proxy schema to retrieve thumb from metadata
# from remote schema when image is queried
proxy_schema.add_field_dependencies(
remote_schema_id, "Product", "image", "{ metadata { thumb } }"
)

# Create schema instance
final_schema = proxy_schema.get_final_schema()


# Add product image resolver
def resolve_product_image(obj, info):
return obj["metadata"]["thumb"]


set_resolver(final_schema, "Product", "image", resolve_product_image)


# Setup Ariadne ASGI GraphQL application
app = GraphQL(
final_schema,
context_value=get_context_value,
root_value=proxy_schema.root_resolver,
debug=True,
)
```


## Cache framework

Ariadne GraphQL Proxy implements basic cache framework that enables of caching parts of GraphQL queries.
Expand Down Expand Up @@ -855,6 +920,34 @@ def add_delayed_fields(self, delayed_fields: Dict[str, List[str]]):
Sets specific fields in schema as delayed. Delayed fields are excluded from queries ran by `root_resolver` against the remote GraphQL APIs.


#### `delayed_fields`

This is a dict of type name and fields names lists:

```python
{"Type": ["field", "otherField"], "OtherType": ["field"]}
```


### `add_field_dependencies`

```python
def add_field_dependencies(
self, schema_id: int, type_name: str, field_name: str, query: str
):
```

Adds fields specified in `query` as dependencies for `field_name` of `type_name` that should be retrieved from schema with `schema_id`.


#### Required arguments

- `schema_id`: an `int` with ID of schema returned by `add_remote_schema` or `add_schema`.
- `type_name`: a `str` with name of type for which dependencies will be set.
- `field_name`: a `str` with name of field which dependencies will be set.
- `query`: a `str` with additional fields to fetch when `field_name` is included, eg. `{ metadata { key value} }`.


### `add_foreign_key`

```python
Expand Down
3 changes: 3 additions & 0 deletions ariadne_graphql_proxy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from .query_filter import QueryFilter, QueryFilterContext
from .remote_schema import get_remote_schema
from .resolvers import set_resolver, unset_resolver
from .selections import merge_selection_sets, merge_selections

__all__ = [
"ForeignKeyResolver",
Expand Down Expand Up @@ -84,6 +85,8 @@
"merge_objects",
"merge_scalars",
"merge_schemas",
"merge_selection_sets",
"merge_selections",
"merge_type_maps",
"merge_types",
"merge_unions",
Expand Down
104 changes: 104 additions & 0 deletions ariadne_graphql_proxy/proxy_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
GraphQLSchema,
GraphQLUnionType,
GraphQLWrappingType,
OperationDefinitionNode,
OperationType,
SelectionSetNode,
parse,
print_ast,
)
from httpx import AsyncClient
Expand All @@ -20,6 +24,7 @@
from .proxy_root_value import ProxyRootValue
from .query_filter import QueryFilter
from .remote_schema import get_remote_schema
from .selections import merge_selection_sets
from .standard_types import STANDARD_TYPES, add_missing_scalar_types
from .str_to_field import (
get_field_definition_from_str,
Expand All @@ -46,6 +51,7 @@ def __init__(
self.fields_types: Dict[str, Dict[str, str]] = {}
self.unions: Dict[str, List[str]] = {}
self.foreign_keys: Dict[str, Dict[str, List[str]]] = {}
self.dependencies: Dict[int, Dict[str, Dict[str, SelectionSetNode]]] = {}

self.proxy_root_value = proxy_root_value

Expand Down Expand Up @@ -176,8 +182,105 @@ def add_foreign_key(
if field_name in self.foreign_keys[type_name]:
raise ValueError(f"Foreign key already exists on {type_name}.{field_name}")

for schema_dependencies in self.dependencies.values():
if (
type_name in schema_dependencies
and field_name in schema_dependencies[type_name]
):
raise ValueError(
f"Foreign key can't be created for {type_name}.{field_name} because "
"field dependencies were previously defined for it."
)

self.foreign_keys[type_name][field_name] = [on] if isinstance(on, str) else on

def add_field_dependencies(
self, schema_id: int, type_name: str, field_name: str, query: str
):
if type_name in ("Query", "Mutation", "Subscription"):
raise ValueError(
f"Defining field dependencies for {type_name} fields is not allowed."
)

if (
type_name in self.foreign_keys
and field_name in self.foreign_keys[type_name]
):
raise ValueError(
f"Dependencies can't be created for {type_name}.{field_name} because "
"foreign key was previously defined for it."
)

if schema_id < 0 or schema_id + 1 > len(self.urls):
raise ValueError(f"Schema with ID '{schema_id}' doesn't exist.")
if not self.urls[schema_id]:
raise ValueError(f"Schema with ID '{schema_id}' is not a remote schema.")

schema = self.schemas[schema_id]
if type_name not in schema.type_map:
raise ValueError(
f"Type '{type_name}' doesn't exist in schema with ID '{schema_id}'."
)

schema_type = schema.type_map[type_name]
if not isinstance(schema_type, GraphQLObjectType):
raise ValueError(
f"Type '{type_name}' in schema with ID '{schema_id}' is not "
"an object type."
)

self.validate_field_with_dependencies(type_name, field_name)

if schema_id not in self.dependencies:
self.dependencies[schema_id] = {}
if type_name not in self.dependencies[schema_id]:
self.dependencies[schema_id][type_name] = {}

selection_set = self.parse_field_dependencies(field_name, query)

type_dependencies = self.dependencies[schema_id][type_name]
if not type_dependencies.get(field_name):
type_dependencies[field_name] = selection_set
else:
type_dependencies[field_name] = merge_selection_sets(
type_dependencies[field_name], selection_set
)

def parse_field_dependencies(self, field_name: str, query: str) -> SelectionSetNode:
clean_query = query.strip()
if not clean_query.startswith("{") or not clean_query.endswith("}"):
raise ValueError(
f"'{field_name}' field dependencies should be defined as a single "
"GraphQL operation, e.g.: '{ field other { subfield } }'."
)

ast = parse(clean_query)

if (
not len(ast.definitions) == 1
or not isinstance(ast.definitions[0], OperationDefinitionNode)
or ast.definitions[0].operation != OperationType.QUERY
):
raise ValueError(
f"'{field_name}' field dependencies should be defined as a single "
"GraphQL operation, e.g.: '{ field other { subfield } }'."
)

return ast.definitions[0].selection_set

def validate_field_with_dependencies(self, type_name: str, field_name: str) -> None:
for schema in self.schemas:
if (
type_name in schema.type_map
and isinstance(schema.type_map[type_name], GraphQLObjectType)
and field_name in schema.type_map[type_name].fields # type: ignore
):
return

raise ValueError(
f"Type '{type_name}' doesn't define the '{field_name}' field in any of schemas."
)

def add_delayed_fields(self, delayed_fields: Dict[str, List[str]]):
for type_name, type_fields in delayed_fields.items():
if type_name not in self.fields_map:
Expand Down Expand Up @@ -227,6 +330,7 @@ def get_final_schema(self) -> GraphQLSchema:
self.fields_types,
self.unions,
self.foreign_keys,
self.dependencies,
)

return self.schema
Expand Down
56 changes: 50 additions & 6 deletions ariadne_graphql_proxy/query_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
VariableNode,
)

from .selections import merge_selections


class QueryFilterContext:
schema_id: int
Expand All @@ -35,13 +37,15 @@ def __init__(
fields_types: Dict[str, Dict[str, str]],
unions: Dict[str, List[str]],
foreign_keys: Dict[str, Dict[str, List[str]]],
dependencies: Dict[int, Dict[str, Dict[str, SelectionSetNode]]],
):
self.schema = schema
self.schemas = schemas
self.fields_map = fields_map
self.fields_types = fields_types
self.unions = unions
self.foreign_keys = foreign_keys
self.dependencies = dependencies

def split_query(
self, document: DocumentNode
Expand Down Expand Up @@ -189,12 +193,22 @@ def filter_field_node(
else:
type_fields = self.fields_map[type_name]

fields_dependencies = self.get_type_fields_dependencies(
context.schema_id, type_name
)

new_selections: List[SelectionNode] = []
for selection in field_node.selection_set.selections:
if isinstance(selection, FieldNode):
field_name = selection.name.value
if fields_dependencies and field_name in fields_dependencies:
new_selections = merge_selections(
new_selections, fields_dependencies[field_name].selections
)

if (
selection.name.value not in type_fields
or context.schema_id not in type_fields[selection.name.value]
field_name not in type_fields
or context.schema_id not in type_fields[field_name]
):
continue

Expand Down Expand Up @@ -244,12 +258,22 @@ def filter_inline_fragment_node(
type_name = fragment_node.type_condition.name.value
type_fields = self.fields_map[type_name]

fields_dependencies = self.get_type_fields_dependencies(
context.schema_id, type_name
)

new_selections: List[SelectionNode] = []
for selection in fragment_node.selection_set.selections:
if isinstance(selection, FieldNode):
field_name = selection.name.value
if fields_dependencies and field_name in fields_dependencies:
new_selections = merge_selections(
new_selections, fields_dependencies[field_name].selections
)

if (
selection.name.value not in type_fields
or context.schema_id not in type_fields[selection.name.value]
field_name not in type_fields
or context.schema_id not in type_fields[field_name]
):
continue

Expand Down Expand Up @@ -294,12 +318,22 @@ def filter_fragment_spread_node(
type_name = fragment.type_condition.name.value
type_fields = self.fields_map[type_name]

fields_dependencies = self.get_type_fields_dependencies(
context.schema_id, type_name
)

new_selections: List[SelectionNode] = []
for selection in fragment.selection_set.selections:
if isinstance(selection, FieldNode):
field_name = selection.name.value
if fields_dependencies and field_name in fields_dependencies:
new_selections = merge_selections(
new_selections, fields_dependencies[field_name].selections
)

if (
selection.name.value not in type_fields
or context.schema_id not in type_fields[selection.name.value]
field_name not in type_fields
or context.schema_id not in type_fields[field_name]
):
continue

Expand Down Expand Up @@ -347,3 +381,13 @@ def inline_fragment_spread_node(
selections=tuple(selections),
),
)

def get_type_fields_dependencies(
self,
schema_id: int,
type_name: str,
) -> Optional[Dict[str, SelectionSetNode]]:
if schema_id in self.dependencies and type_name in self.dependencies[schema_id]:
return self.dependencies[schema_id][type_name]

return None
Loading
Loading