diff --git a/ariadne_graphql_proxy/query_filter.py b/ariadne_graphql_proxy/query_filter.py index 3751cac..2a6d3d9 100644 --- a/ariadne_graphql_proxy/query_filter.py +++ b/ariadne_graphql_proxy/query_filter.py @@ -7,7 +7,9 @@ FragmentSpreadNode, GraphQLSchema, InlineFragmentNode, + ListValueNode, NameNode, + ObjectValueNode, OperationDefinitionNode, SelectionNode, SelectionSetNode, @@ -147,18 +149,33 @@ def filter_operation_node( ), ) + def extract_variables( + self, + value: VariableNode | ListValueNode | ObjectValueNode, + context: QueryFilterContext, + ): + if isinstance(value, VariableNode): + context.variables.add(value.name.value) + elif isinstance(value, ObjectValueNode): + for field in value.fields: + self.extract_variables(field.value, context) # type: ignore + elif isinstance(value, ListValueNode): + for item in value.values: + self.extract_variables(item, context) # type: ignore + + def update_context_variables( + self, field_node: FieldNode, context: QueryFilterContext + ): + for argument in field_node.arguments: + self.extract_variables(argument.value, context) # type: ignore + def filter_field_node( self, field_node: FieldNode, schema_obj: str, context: QueryFilterContext, ) -> Optional[FieldNode]: - context.variables.update( - argument.value.name.value - for argument in field_node.arguments - if isinstance(argument.value, VariableNode) - ) - + self.update_context_variables(field_node, context) if not field_node.selection_set: return field_node diff --git a/tests/conftest.py b/tests/conftest.py index 9f251f3..d39a621 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -322,3 +322,52 @@ def search_root_value(): @pytest.fixture def gql(): return lambda x: x + + +@pytest.fixture +def car_schema(): + return make_executable_schema( + """ + type Query { + carsByIds(ids: [ID!]!): [Car!]! + carsByCriteria(input: SearchInput!): [Car!]! + } + + type Car { + id: ID! + make: String! + model: String! + year: Int! + } + + input SearchInput { + search: SearchCriteria + } + + input SearchCriteria { + make: String + model: String + year: Int + } + """ + ) + + +@pytest.fixture +def car_schema_json(car_schema): + schema_data = graphql_sync(car_schema, get_introspection_query()).data + return {"data": schema_data} + + +@pytest.fixture +def car_root_value(): + return { + "carsByIds": [ + {"id": "car1", "make": "Toyota", "model": "Corolla", "year": 2020}, + {"id": "car2", "make": "Honda", "model": "Civic", "year": 2019}, + ], + "carsByCriteria": [ + {"id": "car3", "make": "Ford", "model": "Mustang", "year": 2018}, + {"id": "car4", "make": "Chevrolet", "model": "Camaro", "year": 2017}, + ], + } diff --git a/tests/test_proxy_schema.py b/tests/test_proxy_schema.py index 068b282..7a46cc1 100644 --- a/tests/test_proxy_schema.py +++ b/tests/test_proxy_schema.py @@ -743,6 +743,118 @@ async def test_proxy_schema_splits_variables_between_schemas( } +@pytest.mark.asyncio +async def test_proxy_schema_handles_object_variables_correctly( + httpx_mock, + car_schema_json, + car_root_value, +): + httpx_mock.add_response( + url="http://graphql.example.com/cars/", json=car_schema_json + ) + httpx_mock.add_response( + url="http://graphql.example.com/cars/", + json={"data": car_root_value["carsByCriteria"]}, + ) + + proxy_schema = ProxySchema() + proxy_schema.add_remote_schema("http://graphql.example.com/cars/") + proxy_schema.get_final_schema() + + await proxy_schema.root_resolver( + {}, + "CarsByCriteriaQuery", + {"criteria": {"make": "Toyota", "model": "Corolla", "year": 2020}}, + parse( + """ + query CarsByCriteriaQuery($criteria: SearchCriteria!) { + carsByCriteria(input: { criteria: $criteria }) { + id + make + model + year + } + } + """ + ), + ) + + cars_request = httpx_mock.get_requests(url="http://graphql.example.com/cars/")[-1] + + assert json.loads(cars_request.content) == { + "operationName": "CarsByCriteriaQuery", + "variables": {"criteria": {"make": "Toyota", "model": "Corolla", "year": 2020}}, + "query": dedent( + """ + query CarsByCriteriaQuery($criteria: SearchCriteria!) { + carsByCriteria(input: {criteria: $criteria}) { + id + make + model + year + } + } + """ + ).strip(), + } + + +@pytest.mark.asyncio +async def test_proxy_schema_handles_list_variables_correctly( + httpx_mock, + car_schema_json, + car_root_value, +): + httpx_mock.add_response( + url="http://graphql.example.com/cars/", json=car_schema_json + ) + httpx_mock.add_response( + url="http://graphql.example.com/cars/", + json={"data": car_root_value["carsByIds"]}, + ) + + proxy_schema = ProxySchema() + proxy_schema.add_remote_schema("http://graphql.example.com/cars/") + proxy_schema.get_final_schema() + + await proxy_schema.root_resolver( + {}, + "CarsQuery", + {"id": "car2"}, + parse( + """ + query CarsQuery($id: ID!) { + carsByIds(ids: [$id]) { + id + make + model + year + } + } + """ + ), + ) + + cars_request = httpx_mock.get_requests(url="http://graphql.example.com/cars/")[-1] + + assert json.loads(cars_request.content) == { + "operationName": "CarsQuery", + "variables": {"id": "car2"}, + "query": dedent( + """ + query CarsQuery($id: ID!) { + carsByIds(ids: [$id]) { + id + make + model + year + } + } + """ + ).strip(), + } + + @pytest.mark.asyncio async def test_proxy_schema_splits_variables_from_fragments_between_schemas( httpx_mock,