diff --git a/src/main/kotlin/graphql/kickstart/tools/SchemaClassScanner.kt b/src/main/kotlin/graphql/kickstart/tools/SchemaClassScanner.kt index 239311bb..e91e3e2c 100644 --- a/src/main/kotlin/graphql/kickstart/tools/SchemaClassScanner.kt +++ b/src/main/kotlin/graphql/kickstart/tools/SchemaClassScanner.kt @@ -149,10 +149,11 @@ internal class SchemaClassScanner( ?: error("No ${TypeDefinition::class.java.simpleName} for type name $inputTypeName") when (typeDefinition) { is ScalarTypeDefinition -> handleFoundScalarType(typeDefinition) - is InputObjectTypeDefinition -> { - for (input in typeDefinition.inputValueDefinitions) { - handleDirectiveInput(input.type) - } + is EnumTypeDefinition -> handleDictionaryTypes(listOf(typeDefinition)) { + "Enum type '${it.name}' is used in a directive, but no class could be found for that type name. Please pass a class for type '${it.name}' in the parser's dictionary." + } + is InputObjectTypeDefinition -> handleDictionaryTypes(listOf(typeDefinition)) { + "Input object type '${it.name}' is used in a directive, but no class could be found for that type name. Please pass a class for type '${it.name}' in the parser's dictionary." } } } @@ -209,9 +210,9 @@ internal class SchemaClassScanner( log.warn("Schema type was defined but can never be accessed, and can be safely deleted: ${definition.name}") } - val fieldResolvers = fieldResolversByType.flatMap { it.value.map { it.value } } - val observedNormalResolverInfos = fieldResolvers.map { it.resolverInfo }.distinct().filterIsInstance() - val observedMultiResolverInfos = fieldResolvers.map { it.resolverInfo }.distinct().filterIsInstance().flatMap { it.resolverInfoList } + val fieldResolvers = fieldResolversByType.flatMap { entry -> entry.value.map { it.value } } + val observedNormalResolverInfos = fieldResolvers.map { it.resolverInfo }.filterIsInstance().toSet() + val observedMultiResolverInfos = fieldResolvers.map { it.resolverInfo }.filterIsInstance().flatMap { it.resolverInfoList }.toSet() (resolverInfos - observedNormalResolverInfos - observedMultiResolverInfos).forEach { resolverInfo -> log.warn("Resolver was provided but no methods on it were used in data fetchers, and can be safely deleted: ${resolverInfo.resolver}") @@ -255,7 +256,7 @@ internal class SchemaClassScanner( }.flatten().distinct() } - private fun handleDictionaryTypes(types: List, failureMessage: (ObjectTypeDefinition) -> String) { + private fun handleDictionaryTypes(types: List>, failureMessage: (TypeDefinition<*>) -> String) { types.forEach { type -> val dictionaryContainsType = dictionary.filter { it.key.name == type.name }.isNotEmpty() if (!unvalidatedTypes.contains(type) && !dictionaryContainsType) { diff --git a/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt b/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt index f48ced77..c26efc81 100644 --- a/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt +++ b/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt @@ -1,6 +1,5 @@ package graphql.kickstart.tools -import graphql.Scalars import graphql.introspection.Introspection import graphql.introspection.Introspection.DirectiveLocation.INPUT_FIELD_DEFINITION import graphql.kickstart.tools.directive.DirectiveWiringHelper @@ -9,6 +8,7 @@ import graphql.kickstart.tools.util.getExtendedFieldDefinitions import graphql.kickstart.tools.util.unwrap import graphql.language.* import graphql.schema.* +import graphql.schema.idl.DirectiveInfo import graphql.schema.idl.RuntimeWiring import graphql.schema.idl.ScalarInfo import graphql.schema.visibility.NoIntrospectionGraphqlFieldVisibility @@ -60,6 +60,8 @@ class SchemaParser internal constructor( private val codeRegistryBuilder = GraphQLCodeRegistry.newCodeRegistry() private val directiveWiringHelper = DirectiveWiringHelper(options, runtimeWiring, codeRegistryBuilder, directiveDefinitions) + private lateinit var schemaDirectives : Set + /** * Parses the given schema with respect to the given dictionary and returns GraphQL objects. */ @@ -72,6 +74,7 @@ class SchemaParser internal constructor( // Create GraphQL objects val inputObjects: MutableList = mutableListOf() + createDirectives(inputObjects) inputObjectDefinitions.forEach { if (inputObjects.none { io -> io.name == it.name }) { inputObjects.add(createInputObject(it, inputObjects, mutableSetOf())) @@ -82,8 +85,6 @@ class SchemaParser internal constructor( val unions = unionDefinitions.map { createUnionObject(it, objects) } val enums = enumDefinitions.map { createEnumObject(it) } - val directives = directiveDefinitions.map { createDirective(it, inputObjects) }.toSet() - // Assign type resolver to interfaces now that we know all of the object types interfaces.forEach { codeRegistryBuilder.typeResolver(it, InterfaceTypeResolver(dictionary.inverse(), it)) } unions.forEach { codeRegistryBuilder.typeResolver(it, UnionTypeResolver(dictionary.inverse(), it)) } @@ -103,7 +104,7 @@ class SchemaParser internal constructor( val additionalObjects = objects.filter { o -> o != query && o != subscription && o != mutation } val types = (additionalObjects.toSet() as Set) + inputObjects + enums + interfaces + unions - return SchemaObjects(query, mutation, subscription, types, directives, codeRegistryBuilder, rootInfo.getDescription()) + return SchemaObjects(query, mutation, subscription, types, schemaDirectives, codeRegistryBuilder, rootInfo.getDescription()) } /** @@ -300,7 +301,7 @@ class SchemaParser internal constructor( .name(definition.name) .definition(definition) .description(getDocumentation(definition, options)) - .type(determineInputType(definition.type, inputObjects, setOf())) + .type(determineInputType(definition.type, inputObjects, mutableSetOf())) .apply { getDeprecated(definition.directives)?.let { deprecate(it) } } .apply { definition.defaultValue?.let { defaultValueLiteral(it) } } .withAppliedDirectives(*buildAppliedDirectives(definition.directives)) @@ -308,36 +309,67 @@ class SchemaParser internal constructor( .build() } - private fun createDirective(definition: DirectiveDefinition, inputObjects: List): GraphQLDirective { - val locations = definition.directiveLocations.map { Introspection.DirectiveLocation.valueOf(it.name) }.toTypedArray() + private fun createDirectives(inputObjects: MutableList) { + schemaDirectives = directiveDefinitions.map { definition -> + val locations = definition.directiveLocations.map { Introspection.DirectiveLocation.valueOf(it.name) }.toTypedArray() + + GraphQLDirective.newDirective() + .name(definition.name) + .description(getDocumentation(definition, options)) + .definition(definition) + .comparatorRegistry(runtimeWiring.comparatorRegistry) + .validLocations(*locations) + .repeatable(definition.isRepeatable) + .apply { + definition.inputValueDefinitions.forEach { argumentDefinition -> + argument(createDirectiveArgument(argumentDefinition, inputObjects)) + } + } + .build() + }.toSet() + // because the arguments can have directives too, we attach them only after the directives themselves are created + schemaDirectives = schemaDirectives.map { d -> + val arguments = d.arguments.map { a -> a.transform { + it.withAppliedDirectives(*buildAppliedDirectives(a.definition!!.directives)) + .withDirectives(*buildDirectives(a.definition!!.directives, Introspection.DirectiveLocation.OBJECT)) + } } + d.transform { it.replaceArguments(arguments) } + }.toSet() + } - return GraphQLDirective.newDirective() + private fun createDirectiveArgument(definition: InputValueDefinition, inputObjects: List): GraphQLArgument { + return GraphQLArgument.newArgument() .name(definition.name) - .description(getDocumentation(definition, options)) .definition(definition) - .comparatorRegistry(runtimeWiring.comparatorRegistry) - .validLocations(*locations) - .repeatable(definition.isRepeatable) - .apply { - definition.inputValueDefinitions.forEach { argumentDefinition -> - argument(createArgument(argumentDefinition, inputObjects)) - } - } + .description(getDocumentation(definition, options)) + .type(determineInputType(definition.type, inputObjects, mutableSetOf())) + .apply { getDeprecated(definition.directives)?.let { deprecate(it) } } + .apply { definition.defaultValue?.let { defaultValueLiteral(it) } } .build() } private fun buildAppliedDirectives(directives: List): Array { - return directives.map { + return directives.map { directive -> + val graphQLDirective = schemaDirectives.find { d -> d.name == directive.name } + ?: DirectiveInfo.GRAPHQL_SPECIFICATION_DIRECTIVE_MAP[directive.name] + ?: throw SchemaError("Found applied directive ${directive.name} without corresponding directive definition.") + val graphQLArguments = graphQLDirective.arguments.associateBy { it.name } + GraphQLAppliedDirective.newDirective() - .name(it.name) - .description(getDocumentation(it, options)) + .name(directive.name) + .description(getDocumentation(directive, options)) + .definition(directive) .comparatorRegistry(runtimeWiring.comparatorRegistry) .apply { - it.arguments.forEach { arg -> + directive.arguments.forEach { arg -> + val graphQLArgument = graphQLArguments[arg.name] + ?: throw SchemaError("Found an unexpected directive argument ${directive.name}#${arg.name} .") argument(GraphQLAppliedDirectiveArgument.newArgument() .name(arg.name) - .type(buildDirectiveInputType(arg.value)) + // TODO instead of guessing the type from its value, lookup the directive definition + .type(graphQLArgument.type) .valueLiteral(arg.value) + .description(graphQLArgument.description) .build() ) } @@ -358,6 +390,10 @@ class SchemaParser internal constructor( val repeatable = directiveDefinitions.find { it.name.equals(directive.name) }?.isRepeatable ?: false if (repeatable || !names.contains(directive.name)) { names.add(directive.name) + val graphQLDirective = this.schemaDirectives.find { d -> d.name == directive.name } + ?: DirectiveInfo.GRAPHQL_SPECIFICATION_DIRECTIVE_MAP[directive.name] + ?: throw SchemaError("Found applied directive ${directive.name} without corresponding directive definition.") + val graphQLArguments = graphQLDirective.arguments.associateBy { it.name } output.add( GraphQLDirective.newDirective() .name(directive.name) @@ -367,9 +403,11 @@ class SchemaParser internal constructor( .repeatable(repeatable) .apply { directive.arguments.forEach { arg -> + val graphQLArgument = graphQLArguments[arg.name] + ?: throw SchemaError("Found an unexpected directive argument ${directive.name}#${arg.name}.") argument(GraphQLArgument.newArgument() .name(arg.name) - .type(buildDirectiveInputType(arg.value)) + .type(graphQLArgument.type) // TODO remove this once directives are fully replaced with applied directives .valueLiteral(arg.value) .build()) @@ -383,46 +421,6 @@ class SchemaParser internal constructor( return output.toTypedArray() } - private fun buildDirectiveInputType(value: Value<*>): GraphQLInputType? { - return when (value) { - is NullValue -> Scalars.GraphQLString - is FloatValue -> Scalars.GraphQLFloat - is StringValue -> Scalars.GraphQLString - is IntValue -> Scalars.GraphQLInt - is BooleanValue -> Scalars.GraphQLBoolean - is ArrayValue -> GraphQLList.list(buildDirectiveInputType(getArrayValueWrappedType(value))) - // TODO to implement this we'll need to "observe" directive's input types + match them here based on their fields(?) - else -> throw SchemaError("Directive values of type '${value::class.simpleName}' are not supported yet.") - } - } - - private fun getArrayValueWrappedType(value: ArrayValue): Value<*> { - // empty array [] is equivalent to [null] - if (value.values.isEmpty()) { - return NullValue.newNullValue().build() - } - - // get rid of null values - val nonNullValueList = value.values.filter { v -> v !is NullValue } - - // [null, null, ...] unwrapped is null - if (nonNullValueList.isEmpty()) { - return NullValue.newNullValue().build() - } - - // make sure the array isn't polymorphic - val distinctTypes = nonNullValueList - .map { it::class.java } - .distinct() - - if (distinctTypes.size > 1) { - throw SchemaError("Arrays containing multiple types of values are not supported yet.") - } - - // peek at first value, value exists and is assured to be non-null - return nonNullValueList[0] - } - private fun determineOutputType(typeDefinition: Type<*>, inputObjects: List) = determineType(GraphQLOutputType::class, typeDefinition, permittedTypesForObject, inputObjects) as GraphQLOutputType @@ -455,13 +453,15 @@ class SchemaParser internal constructor( else -> throw SchemaError("Unknown type: $typeDefinition") } - private fun determineInputType(typeDefinition: Type<*>, inputObjects: List, referencingInputObjects: Set) = + private fun determineInputType(typeDefinition: Type<*>, inputObjects: List, referencingInputObjects: MutableSet) = determineInputType(GraphQLInputType::class, typeDefinition, permittedTypesForInputObject, inputObjects, referencingInputObjects) - private fun determineInputType(expectedType: KClass, - typeDefinition: Type<*>, allowedTypeReferences: Set, - inputObjects: List, - referencingInputObjects: Set): GraphQLInputType = + private fun determineInputType( + expectedType: KClass, + typeDefinition: Type<*>, + allowedTypeReferences: Set, + inputObjects: List, + referencingInputObjects: MutableSet): GraphQLInputType = when (typeDefinition) { is ListType -> GraphQLList(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects)) is NonNullType -> GraphQLNonNull(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects)) @@ -489,7 +489,7 @@ class SchemaParser internal constructor( if (referencingInputObject != null) { GraphQLTypeReference(referencingInputObject) } else { - val inputObject = createInputObject(filteredDefinitions[0], inputObjects, referencingInputObjects as MutableSet) + val inputObject = createInputObject(filteredDefinitions[0], inputObjects, referencingInputObjects) (inputObjects as MutableList).add(inputObject) inputObject } diff --git a/src/test/kotlin/graphql/kickstart/tools/SchemaClassScannerDirectiveTest.kt b/src/test/kotlin/graphql/kickstart/tools/SchemaClassScannerDirectiveTest.kt new file mode 100644 index 00000000..10573c72 --- /dev/null +++ b/src/test/kotlin/graphql/kickstart/tools/SchemaClassScannerDirectiveTest.kt @@ -0,0 +1,136 @@ +package graphql.kickstart.tools + +import graphql.GraphQLContext +import graphql.execution.CoercedVariables +import graphql.kickstart.tools.SchemaClassScannerDirectiveTest.CustomEnum.ONE +import graphql.language.StringValue +import graphql.language.Value +import graphql.schema.Coercing +import graphql.schema.GraphQLScalarType +import org.junit.Test +import java.util.* + +class SchemaClassScannerDirectiveTest { + + @Test + fun `scanner should handle directives with scalar input value`() { + val schema = SchemaParser.newParser() + .schemaString( + """ + scalar CustomValue + directive @doSomething(value: CustomValue) on FIELD_DEFINITION + + type Query { + string: String @doSomething(value: "some thing") + } + """) + .resolvers(object : GraphQLQueryResolver { fun string(): String = "hello" }) + .scalars(customValueScalar) + .build() + .makeExecutableSchema() + + val value = schema.queryType.getFieldDefinition("string") + .getAppliedDirective("doSomething") + .getArgument("value") + .getValue() + + assertEquals(value.value, "some thing") + } + + data class CustomValue(val value: String) + private val customValueScalar: GraphQLScalarType = GraphQLScalarType.newScalar() + .name("CustomValue") + .coercing(object : Coercing { + override fun serialize(input: Any, context: GraphQLContext, locale: Locale) = input.toString() + override fun parseValue(input: Any, context: GraphQLContext, locale: Locale) = + CustomValue(input.toString()) + override fun parseLiteral(input: Value<*>, variables: CoercedVariables, context: GraphQLContext, locale: Locale) = + CustomValue((input as StringValue).value) + }) + .build() + + @Test + fun `scanner should handle directives with enum input value`() { + val schema = SchemaParser.newParser() + .schemaString( + """ + enum CustomEnum { ONE TWO THREE } + directive @doSomething(value: CustomEnum) on FIELD_DEFINITION + + type Query { + string: String @doSomething(value: ONE) + another: CustomEnum + } + """) + .resolvers(object : GraphQLQueryResolver { + fun string(): String = "hello" + fun another(): CustomEnum = ONE + }) + .scalars(customValueScalar) + .build() + .makeExecutableSchema() + + val value = schema.queryType.getFieldDefinition("string") + .getAppliedDirective("doSomething") + .getArgument("value") + .getValue() + + assertEquals(value, ONE) + } + + enum class CustomEnum { ONE, TWO, THREE} + + @Test + fun `scanner should handle directives with input object input value`() { + val schema = SchemaParser.newParser() + .schemaString( + """ + input CustomInput { value: String } + directive @doSomething(input: CustomInput) on FIELD_DEFINITION + + type Query { + string: String @doSomething(input: { value: "some value" }) + another(input: CustomInput): String + } + """) + .resolvers(object : GraphQLQueryResolver { + fun string(): String = "hello" + fun another(input: CustomInput): String = input.value + }) + .scalars(customValueScalar) + .build() + .makeExecutableSchema() + + val value = schema.queryType.getFieldDefinition("string") + .getAppliedDirective("doSomething") + .getArgument("input") + .getValue>()["value"] + + assertEquals(value, "some value") + } + + data class CustomInput(val value: String) + + @Test + fun `scanner should handle directives with arguments with directives`() { + val schema = SchemaParser.newParser() + .schemaString( + """ + directive @doSomething(one: String @somethingElse) on FIELD_DEFINITION | ARGUMENT_DEFINITION + directive @somethingElse(two: String @doSomething) on FIELD_DEFINITION | ARGUMENT_DEFINITION + + type Query { + string: String @doSomething(one: "sss") + } + """) + .resolvers(object : GraphQLQueryResolver { + fun string(): String = "hello" + }) + .scalars(customValueScalar) + .build() + .makeExecutableSchema() + + assertNotNull(schema.directivesByName["doSomething"]?.getArgument("one")?.directivesByName?.get("somethingElse")) + assertNotNull(schema.directivesByName["somethingElse"]?.getArgument("two")?.directivesByName?.get("doSomething")) + } +} diff --git a/src/test/kotlin/graphql/kickstart/tools/SchemaClassScannerTest.kt b/src/test/kotlin/graphql/kickstart/tools/SchemaClassScannerTest.kt index 24ec19b2..9dff7011 100644 --- a/src/test/kotlin/graphql/kickstart/tools/SchemaClassScannerTest.kt +++ b/src/test/kotlin/graphql/kickstart/tools/SchemaClassScannerTest.kt @@ -431,9 +431,14 @@ class SchemaClassScannerTest { # these directives are defined in the Apollo Federation Specification: # https://www.apollographql.com/docs/apollo-server/federation/federation-spec/ scalar FieldSet + scalar link__Import + enum link__Purpose { SECURITY EXECUTION } directive @key(fields: FieldSet!, resolvable: Boolean = true) repeatable on OBJECT | INTERFACE directive @extends on OBJECT | INTERFACE directive @external on FIELD_DEFINITION | OBJECT + directive @link(url: String!, as: String, for: link__Purpose) repeatable on SCHEMA + + extend schema @link(url: "https://specs.apollo.dev/federation/v2.0", import: ["@key", "@shareable"]) # Let's say this is the Products service from Apollo Federation Introduction type Query { @@ -459,6 +464,7 @@ class SchemaClassScannerTest { }) .options(SchemaParserOptions.newOptions().includeUnusedTypes(true).build()) .dictionary(User::class) + .dictionary("link__Purpose", LinkPurpose::class) .scalars(fieldSetScalar) .build() .makeExecutableSchema() @@ -469,6 +475,7 @@ class SchemaClassScannerTest { } data class FieldSet(val value: String) + enum class LinkPurpose { SECURITY, EXECUTION } private val fieldSetScalar: GraphQLScalarType = GraphQLScalarType.newScalar() .name("FieldSet")