Skip to content

Commit

Permalink
fix issue with discriminated transform
Browse files Browse the repository at this point in the history
  • Loading branch information
lewisjkl committed Jul 30, 2024
1 parent 7745181 commit 1513fa8
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,56 +54,61 @@ class DiscriminatedUnionMemberComponents() extends OpenApiMapper {
.map(s => ShapeId.from(s.getValue) -> schema)
}
}
unions.asScala.foreach { union =>
val unionMixinName = union.getId().getName() + "Mixin"
val unionMixinId =
ShapeId.fromParts(union.getId().getNamespace(), unionMixinName)
val discriminatorField =
union.expectTrait(classOf[DiscriminatedUnionTrait]).getValue()
unions.asScala
.filter(u => componentSchemas.contains(u.toShapeId))
.foreach { union =>
val unionMixinName = union.getId().getName() + "Mixin"
val unionMixinId =
ShapeId.fromParts(union.getId().getNamespace(), unionMixinName)
val discriminatorField =
union.expectTrait(classOf[DiscriminatedUnionTrait]).getValue()

val unionMixinSchema = Schema
.builder()
.`type`("object")
.properties(
Map(
discriminatorField -> Schema
.builder()
.`type`("string")
.build()
).asJava
)
.required(List(discriminatorField).asJava)
.build()
val unionMixinSchema = Schema
.builder()
.`type`("object")
.properties(
Map(
discriminatorField -> Schema
.builder()
.`type`("string")
.build()
).asJava
)
.required(List(discriminatorField).asJava)
.build()

val unionMixinRef = context.createRef(unionMixinId)
val unionMixinRef = context.createRef(unionMixinId)

componentBuilder.putSchema(unionMixinName, unionMixinSchema)
componentBuilder.putSchema(unionMixinName, unionMixinSchema)

union.members().asScala.foreach { memberShape =>
val syntheticMemberName =
union.getId().getName() + memberShape.getMemberName.capitalize
context.getPointer(union).split('/').last + memberShape
.getMemberName()
.capitalize
val targetRef = context.createRef(memberShape.getTarget())
val syntheticUnionMember =
Schema.builder().allOf(List(targetRef, unionMixinRef).asJava).build()
componentBuilder.putSchema(syntheticMemberName, syntheticUnionMember)
}
union.members().asScala.foreach { memberShape =>
val syntheticMemberName =
union.getId().getName() + memberShape.getMemberName.capitalize
context.getPointer(union).split('/').last + memberShape
.getMemberName()
.capitalize
val targetRef = context.createRef(memberShape.getTarget())
val syntheticUnionMember =
Schema
.builder()
.allOf(List(targetRef, unionMixinRef).asJava)
.build()
componentBuilder.putSchema(syntheticMemberName, syntheticUnionMember)
}

componentSchemas.get(union.toShapeId).foreach { sch =>
componentBuilder.putSchema(
union.toShapeId.getName,
updateDiscriminatedUnion(
union,
sch.toBuilder(),
discriminatorField
componentSchemas.get(union.toShapeId).foreach { sch =>
componentBuilder.putSchema(
union.toShapeId.getName,
updateDiscriminatedUnion(
union,
sch.toBuilder(),
discriminatorField
)
.build()
)
.build()
)
}
}

}
}
openapi.toBuilder.components(componentBuilder.build()).build()
}

Expand Down
7 changes: 7 additions & 0 deletions modules/openapi/test/resources/bar.smithy
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,10 @@ union CatOrDog {
one: String
two: Integer
}

structure ProblemSomething {}

@alloy#discriminated("type")
union Problem {
something: ProblemSomething
}
24 changes: 24 additions & 0 deletions modules/openapi/test/src/alloy/openapi/OpenApiConversionSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,30 @@ final class OpenApiConversionSpec extends munit.FunSuite {
assertEquals(result, expected)
}

test(
"OpenAPI conversion with one namespace excluded and one included"
) {
val model = Model
.assembler()
.addImport(getClass().getClassLoader().getResource("baz.smithy"))
.addImport(getClass().getClassLoader().getResource("bar.smithy"))
.discoverModels()
.assemble()
.unwrap()

val result = convert(model, Some(Set("baz")))
.map(_.contents)
.mkString
.filterNot(_.isWhitespace)

val expected = Using
.resource(Source.fromResource("baz.json"))(
_.getLines().mkString.filterNot(_.isWhitespace)
)

assertEquals(result, expected)
}

test("OpenAPI conversion from testJson protocol") {
val model = Model
.assembler()
Expand Down

0 comments on commit 1513fa8

Please sign in to comment.