From 45b3ac6edec67418941462ef29df9bfcf3d06e89 Mon Sep 17 00:00:00 2001 From: rafalp Date: Wed, 9 Oct 2024 21:50:07 +0200 Subject: [PATCH] Add input type --- example/mutations/__init__.py | 8 ++++++-- example/mutations/calc.py | 32 ++++++++++++++++++++++++++++++++ tests/test_mutation.py | 12 ++++++++++++ 3 files changed, 50 insertions(+), 2 deletions(-) create mode 100644 example/mutations/calc.py diff --git a/example/mutations/__init__.py b/example/mutations/__init__.py index ab61dac..aead3b4 100644 --- a/example/mutations/__init__.py +++ b/example/mutations/__init__.py @@ -1,9 +1,13 @@ from typing import Any -from . import compare_roles -from . import dates_delta +from . import ( + calc, + compare_roles, + dates_delta, +) mutations: list[Any] = [ + calc.Mutation, compare_roles.Mutation, dates_delta.Mutation, ] diff --git a/example/mutations/calc.py b/example/mutations/calc.py new file mode 100644 index 0000000..d8519f4 --- /dev/null +++ b/example/mutations/calc.py @@ -0,0 +1,32 @@ +from enum import StrEnum + +from ariadne_graphql_modules import GraphQLInput, GraphQLObject +from graphql import GraphQLResolveInfo + +from ..scalars.date import DateScalar + + +class CalcOperation(StrEnum): + ADD = "add" + SUB = "sub" + MUL = "MUL" + + +class CalcInput(GraphQLInput): + a: int + b: int + op: CalcOperation + + +class Mutation(GraphQLObject): + @GraphQLObject.field(name="calc") + @staticmethod + def resolve_calc(obj, info: GraphQLResolveInfo, *, input: CalcInput) -> int: + if input.op == CalcOperation.ADD: + return input.a + input.b + if input.op == CalcOperation.SUB: + return input.a - input.b + if input.op == CalcOperation.MUL: + return input.a * input.b + + return 0 diff --git a/tests/test_mutation.py b/tests/test_mutation.py index b410a22..32e2b8d 100644 --- a/tests/test_mutation.py +++ b/tests/test_mutation.py @@ -1,6 +1,18 @@ import pytest +@pytest.mark.asyncio +async def test_query_calc_mutation(exec_query): + result = await exec_query( + """ + mutation { + calc(input: {a: 4, b: 3, op: MUL}) + } + """ + ) + assert result.data == {"calc": 12} + + @pytest.mark.asyncio async def test_query_compare_roles_mutation(exec_query): result = await exec_query(