Skip to content

Commit

Permalink
Merge pull request #49 from mirumee/fix-46-union-query-fails
Browse files Browse the repository at this point in the history
Change fragment spreads for union field selection to inline fragments
  • Loading branch information
rafalp authored Feb 9, 2024
2 parents ba59564 + b9b6ac7 commit d9084ff
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 7 deletions.
8 changes: 8 additions & 0 deletions ariadne_graphql_proxy/proxy_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
GraphQLInterfaceType,
GraphQLObjectType,
GraphQLSchema,
GraphQLUnionType,
GraphQLWrappingType,
print_ast,
)
Expand All @@ -31,6 +32,7 @@ def __init__(self, root_value: Optional[RootValue] = None):
self.urls: List[Optional[str]] = []
self.fields_map: Dict[str, Dict[str, Set[int]]] = {}
self.fields_types: Dict[str, Dict[str, str]] = {}
self.unions: Dict[str, List[str]] = {}
self.foreign_keys: Dict[str, Dict[str, List[str]]] = {}

self.schema: Optional[GraphQLSchema] = None
Expand Down Expand Up @@ -97,6 +99,11 @@ def add_schema(
if type_name in STANDARD_TYPES:
continue

if isinstance(type_def, GraphQLUnionType):
self.unions[type_name] = [
object_type.name for object_type in type_def.types
]

if not isinstance(type_def, (GraphQLInterfaceType, GraphQLObjectType)):
continue

Expand Down Expand Up @@ -180,6 +187,7 @@ def get_final_schema(self) -> GraphQLSchema:
self.schemas,
self.fields_map,
self.fields_types,
self.unions,
self.foreign_keys,
)

Expand Down
47 changes: 43 additions & 4 deletions ariadne_graphql_proxy/query_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ def __init__(
schemas: List[GraphQLSchema],
fields_map: Dict[str, Dict[str, Set[int]]],
fields_types: Dict[str, Dict[str, str]],
unions: Dict[str, List[str]],
foreign_keys: Dict[str, Dict[str, List[str]]],
):
self.schema = schema
self.schemas = schemas
self.fields_map = fields_map
self.fields_types = fields_types
self.unions = unions
self.foreign_keys = foreign_keys

def split_query(
Expand Down Expand Up @@ -180,7 +182,12 @@ def filter_field_node(
)

type_name = self.fields_types[schema_obj][field_name]
type_fields = self.fields_map[type_name]
type_is_union = type_name in self.unions

if type_is_union:
type_fields = {}
else:
type_fields = self.fields_map[type_name]

new_selections: List[SelectionNode] = []
for selection in field_node.selection_set.selections:
Expand All @@ -205,9 +212,16 @@ def filter_field_node(
new_selections.append(inline_fragment_selection)

if isinstance(selection, FragmentSpreadNode):
new_selections += self.filter_fragment_spread_node(
selection, schema_obj, context
)
if type_is_union:
inline_fragment = self.inline_fragment_spread_node(
selection, schema_obj, context
)
if inline_fragment:
new_selections.append(inline_fragment)
else:
new_selections += self.filter_fragment_spread_node(
selection, schema_obj, context
)

if not new_selections:
return None
Expand Down Expand Up @@ -308,3 +322,28 @@ def filter_fragment_spread_node(
)

return new_selections

def inline_fragment_spread_node(
self,
fragment_node: FragmentSpreadNode,
schema_obj: str,
context: QueryFilterContext,
) -> Optional[InlineFragmentNode]:
fragment_name = fragment_node.name.value
fragment = context.fragments.get(fragment_name)
if not fragment:
return None

selections = self.filter_fragment_spread_node(
fragment_node, schema_obj, context
)

if not selections:
return None

return InlineFragmentNode(
type_condition=fragment.type_condition,
selection_set=SelectionSetNode(
selections=tuple(selections),
),
)
25 changes: 25 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def schema():
type Query {
basic(arg: Generic, other: Generic): String
complex(arg: Generic, other: Generic): Complex
unionField: [DeliveryMethod!]!
}
type Complex {
Expand All @@ -34,6 +35,18 @@ def schema():
arg4: ID!
arg5: Int!
}
union DeliveryMethod = Shipping | Warehouse
type Shipping {
id: ID!
name: String!
}
type Warehouse {
id: ID!
address: String!
}
"""
)

Expand Down Expand Up @@ -114,6 +127,18 @@ def root_value():
"rank": 9001,
},
},
"deliveryMethod": [
{
"__typename": "Shipping",
"id": "SHIP:1",
"name": "Test Shipping",
},
{
"__typename": "Warehouse",
"id": "WAREHOUSE:13",
"address": "Warehouse #13",
},
],
"other": "Dolor Met",
"otherComplex": {
"id": 123,
Expand Down
141 changes: 138 additions & 3 deletions tests/test_proxy_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ def test_local_and_proxy_schemas_are_added_to_proxy(

final_schema = proxy_schema.get_final_schema()
assert print_schema(final_schema)
assert len(final_schema.type_map["Query"].fields) == 4
assert len(final_schema.type_map["Query"].fields) == 5
assert "basic" in final_schema.type_map["Query"].fields
assert "complex" in final_schema.type_map["Query"].fields
assert "unionField" in final_schema.type_map["Query"].fields
assert "other" in final_schema.type_map["Query"].fields
assert "otherComplex" in final_schema.type_map["Query"].fields

Expand All @@ -50,9 +51,10 @@ def test_multiple_local_schemas_are_added_to_proxy(schema, complex_schema):

final_schema = proxy_schema.get_final_schema()
assert print_schema(final_schema)
assert len(final_schema.type_map["Query"].fields) == 2
assert len(final_schema.type_map["Query"].fields) == 3
assert "basic" in final_schema.type_map["Query"].fields
assert "complex" in final_schema.type_map["Query"].fields
assert "unionField" in final_schema.type_map["Query"].fields


def test_multiple_proxy_schemas_are_added_to_proxy(
Expand All @@ -67,9 +69,10 @@ def test_multiple_proxy_schemas_are_added_to_proxy(

final_schema = proxy_schema.get_final_schema()
assert print_schema(final_schema)
assert len(final_schema.type_map["Query"].fields) == 2
assert len(final_schema.type_map["Query"].fields) == 3
assert "basic" in final_schema.type_map["Query"].fields
assert "complex" in final_schema.type_map["Query"].fields
assert "unionField" in final_schema.type_map["Query"].fields


def test_local_schema_can_be_retrieved_from_proxy(schema):
Expand Down Expand Up @@ -771,6 +774,138 @@ async def test_proxy_schema_splits_variables_from_fragments_between_schemas(
}


@pytest.mark.asyncio
async def test_proxy_schema_handles_union_queries(
httpx_mock,
schema_json,
root_value,
):
httpx_mock.add_response(
url="http://graphql.example.com/",
json=schema_json,
)
httpx_mock.add_response(
url="http://graphql.example.com/",
json={"data": root_value},
)

proxy_schema = ProxySchema()
proxy_schema.add_remote_schema("http://graphql.example.com/")
proxy_schema.get_final_schema()

await proxy_schema.root_resolver(
{},
"TestQuery",
None,
parse(
"""
query TestQuery {
unionField {
... on Shipping {
id
name
}
... on Warehouse {
id
address
}
}
}
"""
),
)

request = httpx_mock.get_requests(url="http://graphql.example.com/")[-1]
assert json.loads(request.content) == {
"operationName": "TestQuery",
"variables": None,
"query": dedent(
"""
query TestQuery {
unionField {
... on Shipping {
id
name
}
... on Warehouse {
id
address
}
}
}
"""
).strip(),
}


@pytest.mark.asyncio
async def test_proxy_schema_handles_union_queries_with_fragments(
httpx_mock,
schema_json,
root_value,
):
httpx_mock.add_response(
url="http://graphql.example.com/",
json=schema_json,
)
httpx_mock.add_response(
url="http://graphql.example.com/",
json={"data": root_value},
)

proxy_schema = ProxySchema()
proxy_schema.add_remote_schema("http://graphql.example.com/")
proxy_schema.get_final_schema()

await proxy_schema.root_resolver(
{},
"TestQuery",
None,
parse(
"""
fragment ShippingFields on Shipping {
id
name
}
fragment WarehouseFields on Warehouse {
id
address
}
query TestQuery {
unionField {
... ShippingFields
... WarehouseFields
}
}
"""
),
)

request = httpx_mock.get_requests(url="http://graphql.example.com/")[-1]
assert json.loads(request.content) == {
"operationName": "TestQuery",
"variables": None,
"query": dedent(
"""
query TestQuery {
unionField {
... on Shipping {
id
name
}
... on Warehouse {
id
address
}
}
}
"""
).strip(),
}


@pytest.mark.asyncio
async def test_proxy_schema_handles_omitted_optional_variables(
httpx_mock,
Expand Down

0 comments on commit d9084ff

Please sign in to comment.