Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue 527 - Add to StackReference, methods converting to typed Outputs #528

Merged
merged 6 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion besom-json/src/main/scala/besom/json/JsonFormat.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ object JsonReader {
def read(json: JsValue) = f(json)
}

inline def derived[T <: Product](using JsonProtocol): JsonReader[T] = summon[JsonProtocol].jsonFormatN[T]
inline def derived[T <: Product](using JsonProtocol): JsonReader[T] = summon[JsonProtocol].jsonReaderN[T]
}

/** Provides the JSON serialization for type T.
Expand Down
89 changes: 72 additions & 17 deletions besom-json/src/main/scala/besom/json/ProductFormats.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ trait ProductFormats:
def requireNullsForOptions: Boolean = false

inline def jsonFormatN[T <: Product]: RootJsonFormat[T] = ${ ProductFormatsMacro.jsonFormatImpl[T]('self) }
inline def jsonReaderN[T <: Product]: RootJsonReader[T] = ${ ProductFormatsMacro.jsonReaderImpl[T]('self) }

object ProductFormatsMacro:
import scala.deriving.*
Expand Down Expand Up @@ -57,29 +58,45 @@ object ProductFormatsMacro:
'{ $namesExpr.zip($identsExpr).toMap }
catch case cce: ClassCastException => '{ Map.empty[String, Any] } // TODO drop after https://github.com/lampepfl/dotty/issues/19732

private def prepareFormatInstances(elemLabels: Type[?], elemTypes: Type[?])(using Quotes): List[Expr[(String, JsonFormat[?], Boolean)]] =
(elemLabels, elemTypes) match
case ('[EmptyTuple], '[EmptyTuple]) => Nil
case ('[label *: labelsTail], '[tpe *: tpesTail]) =>
val label = Type.valueOfConstant[label].get.asInstanceOf[String]
val isOption = Type.of[tpe] match
case '[Option[?]] => Expr(true)
case _ => Expr(false)

val fieldName = Expr(label)
val fieldFormat = Expr.summon[JsonFormat[tpe]].getOrElse {
quotes.reflect.report.errorAndAbort("Missing given instance of JsonFormat[" ++ Type.show[tpe] ++ "]")
} // TODO: Handle missing instance
val namedInstance = '{ (${ fieldName }, $fieldFormat, ${ isOption }) }
namedInstance :: prepareFormatInstances(Type.of[labelsTail], Type.of[tpesTail])

private def prepareReaderInstances(elemLabels: Type[?], elemTypes: Type[?])(using Quotes): List[Expr[(String, JsonReader[?], Boolean)]] =
(elemLabels, elemTypes) match
case ('[EmptyTuple], '[EmptyTuple]) => Nil
case ('[label *: labelsTail], '[tpe *: tpesTail]) =>
val label = Type.valueOfConstant[label].get.asInstanceOf[String]
val isOption = Type.of[tpe] match
case '[Option[?]] => Expr(true)
case _ => Expr(false)

val fieldName = Expr(label)
val fieldFormat = Expr.summon[JsonReader[tpe]].getOrElse {
quotes.reflect.report.errorAndAbort("Missing given instance of JsonFormat[" ++ Type.show[tpe] ++ "]")
} // TODO: Handle missing instance
val namedInstance = '{ (${ fieldName }, $fieldFormat, ${ isOption }) }
namedInstance :: prepareReaderInstances(Type.of[labelsTail], Type.of[tpesTail])

def jsonFormatImpl[T <: Product: Type](prodFormats: Expr[ProductFormats])(using Quotes): Expr[RootJsonFormat[T]] =
Expr.summon[Mirror.Of[T]].get match
case '{
$m: Mirror.ProductOf[T] { type MirroredElemLabels = elementLabels; type MirroredElemTypes = elementTypes }
} =>
def prepareInstances(elemLabels: Type[?], elemTypes: Type[?]): List[Expr[(String, JsonFormat[?], Boolean)]] =
(elemLabels, elemTypes) match
case ('[EmptyTuple], '[EmptyTuple]) => Nil
case ('[label *: labelsTail], '[tpe *: tpesTail]) =>
val label = Type.valueOfConstant[label].get.asInstanceOf[String]
val isOption = Type.of[tpe] match
case '[Option[?]] => Expr(true)
case _ => Expr(false)

val fieldName = Expr(label)
val fieldFormat = Expr.summon[JsonFormat[tpe]].getOrElse {
quotes.reflect.report.errorAndAbort("Missing given instance of JsonFormat[" ++ Type.show[tpe] ++ "]")
} // TODO: Handle missing instance
val namedInstance = '{ (${ fieldName }, $fieldFormat, ${ isOption }) }
namedInstance :: prepareInstances(Type.of[labelsTail], Type.of[tpesTail])

// instances are in correct order of fields of the product
val allInstancesExpr = Expr.ofList(prepareInstances(Type.of[elementLabels], Type.of[elementTypes]))
val allInstancesExpr = Expr.ofList(prepareFormatInstances(Type.of[elementLabels], Type.of[elementTypes]))
val defaultArguments = findDefaultParams[T]

'{
Expand Down Expand Up @@ -121,6 +138,44 @@ object ProductFormatsMacro:

JsObject(fields.toMap)
}

def jsonReaderImpl[T <: Product: Type](prodFormats: Expr[ProductFormats])(using Quotes): Expr[RootJsonReader[T]] =
Expr.summon[Mirror.Of[T]].get match
case '{
$m: Mirror.ProductOf[T] { type MirroredElemLabels = elementLabels; type MirroredElemTypes = elementTypes }
} =>
// instances are in correct order of fields of the product
val allInstancesExpr = Expr.ofList(prepareReaderInstances(Type.of[elementLabels], Type.of[elementTypes]))
val defaultArguments = findDefaultParams[T]

'{
new RootJsonReader[T]:
private val allInstances = ${ allInstancesExpr }
private val fmts = ${ prodFormats }
private val defaultArgs = ${ defaultArguments }

def read(json: JsValue): T = json match
case JsObject(fields) =>
val values = allInstances.map { case (fieldName, fieldFormat, isOption) =>
try fieldFormat.read(fields(fieldName))
catch
case e: NoSuchElementException =>
// if field has a default value, use it, we didn't find anything in the JSON
if defaultArgs.contains(fieldName) then defaultArgs(fieldName)
// if field is optional and requireNullsForOptions is disabled, return None
// otherwise we require an explicit null value
else if isOption && !fmts.requireNullsForOptions then None
// it's missing so we throw an exception
else throw DeserializationException("Object is missing required member '" ++ fieldName ++ "'", null, fieldName :: Nil)
case DeserializationException(msg, cause, fieldNames) =>
throw DeserializationException(msg, cause, fieldName :: fieldNames)
}
$m.fromProduct(Tuple.fromArray(values.toArray))

case _ => throw DeserializationException("Object expected", null, allInstances.map(_._1))

}

end ProductFormatsMacro

/** This trait supplies an alternative rendering mode for optional case class members. Normally optional members that are undefined (`None`)
Expand Down
2 changes: 2 additions & 0 deletions core/src/main/scala/besom/aliases.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ object aliases:
type ComponentResource = besom.internal.ComponentResource
type RegistersOutputs[A <: ComponentResource & Product] = besom.internal.RegistersOutputs[A]
type StackReference = besom.internal.StackReference
type TypedStackReference[A] = besom.internal.TypedStackReference[A]
object StackReference extends besom.internal.StackReferenceFactory
type StackReferenceArgs = besom.internal.StackReferenceArgs
object StackReferenceArgs extends besom.internal.StackReferenceArgsFactory
Expand All @@ -52,4 +53,5 @@ object aliases:
object CustomTimeouts extends besom.internal.CustomTimeoutsFactory

export besom.internal.InvokeOptions
export besom.util.JsonReaderInstances.*
end aliases
51 changes: 49 additions & 2 deletions core/src/main/scala/besom/internal/StackReference.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,44 @@ case class StackReference(
end StackReference

trait StackReferenceFactory:
def apply(using Context)(
sealed trait StackReferenceType[T]:
type Out[T]
def transform(stackReference: StackReference)(using Context): Output[Out[T]]

object StackReferenceType:
given untyped: UntypedStackReferenceType = UntypedStackReferenceType()

given typed[T: JsonReader]: TypedStackReferenceType[T] = TypedStackReferenceType[T]

class TypedStackReferenceType[T](using JsonReader[T]) extends StackReferenceType[T]:
type Out[T] = TypedStackReference[T]
def transform(stackReference: StackReference)(using Context): Output[Out[T]] =
val objectOutput: Output[T] =
requireObject(stackReference.outputs, stackReference.secretOutputNames)

objectOutput.map(t =>
TypedStackReference(
urn = stackReference.urn,
id = stackReference.id,
name = stackReference.name,
outputs = t,
secretOutputNames = stackReference.secretOutputNames
)
)

class UntypedStackReferenceType extends StackReferenceType[Any]:
type Out[T] = StackReference
def transform(stackReference: StackReference)(using Context): Output[StackReference] = Output(stackReference)

def untypedStackReference(using Context): StackReferenceType[Any] = UntypedStackReferenceType()

def typedStackReference[T: JsonReader]: TypedStackReferenceType[T] = TypedStackReferenceType()

def apply[T](using stackRefType: StackReferenceType[T], ctx: Context)(
name: NonEmptyString,
args: Input.Optional[StackReferenceArgs] = None,
opts: StackReferenceResourceOptions = StackReferenceResourceOptions()
): Output[StackReference] =
): Output[stackRefType.Out[T]] =
args
.asOptionOutput(false)
.flatMap {
Expand All @@ -76,3 +109,17 @@ trait StackReferenceFactory:

Context().readOrRegisterResource[StackReference, StackReferenceArgs]("pulumi:pulumi:StackReference", name, stackRefArgs, mergedOpts)
}
.flatMap(stackRefType.transform)

private[internal] def requireObject[T: JsonReader](
outputs: Output[Map[String, JsValue]],
secretOutputNames: Output[Set[String]]
): Output[T] =
outputs
.map(JsObject(_).convertTo[T])
.withIsSecret(
secretOutputNames
.map(_.nonEmpty)
.getValueOrElse(false)
)
end StackReferenceFactory
11 changes: 11 additions & 0 deletions core/src/main/scala/besom/internal/TypedStackReference.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package besom.internal

import besom.types.*

case class TypedStackReference[T](
urn: Output[URN],
id: Output[ResourceId],
name: Output[String],
outputs: T,
secretOutputNames: Output[Set[String]]
) extends CustomResource
2 changes: 1 addition & 1 deletion core/src/main/scala/besom/internal/codecs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ object Decoder extends DecoderInstancesLowPrio1:
.foldLeft[ValidatedResult[DecodingError, Vector[OutputData[A]]]](ValidatedResult.valid(Vector.empty))(
accumulatedOutputDataOrErrors(_, _, "iterable", label)
)
.map(_.toIterable)
.map(_.toVector)
.map(OutputData.sequence)
end if
}
Expand Down
33 changes: 33 additions & 0 deletions core/src/main/scala/besom/util/JsonReaderInstances.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package besom.util

import besom.json.*
import besom.internal.{Output, Context}
import besom.internal.Constants, Constants.SpecialSig

object JsonReaderInstances:
implicit def outputJsonReader[A](using jsonReader: JsonReader[A], ctx: Context): JsonReader[Output[A]] =
new JsonReader[Output[A]]:
def read(json: JsValue): Output[A] = json match
case JsObject(fields) =>
fields.get(SpecialSig.Key) match
case Some(JsString(sig)) if SpecialSig.fromString(sig) == Some(SpecialSig.OutputSig) =>
val maybeInnerValue = fields.get(Constants.ValueName)
maybeInnerValue
.map { innerValue =>
try Output(jsonReader.read(innerValue))
catch case e: Throwable => Output.fail(e)
}
.getOrElse(Output.fail(Exception("Invalid JSON")))

case Some(JsString(sig)) if SpecialSig.fromString(sig) == Some(SpecialSig.SecretSig) =>
val maybeInnerValue = fields.get(Constants.ValueName)
maybeInnerValue
.map { innerValue =>
try Output.secret(jsonReader.read(innerValue))
catch case e: Throwable => Output.fail(e)
}
.getOrElse(Output.fail(Exception("Invalid JSON")))

case _ => Output.fail(Exception("Invalid JSON"))

case _ => Output.fail(Exception("Invalid JSON"))
61 changes: 61 additions & 0 deletions core/src/test/scala/besom/internal/StackReferenceTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package besom.internal

import besom.*
import besom.json.*
import RunResult.{*, given}

class StackReferenceTest extends munit.FunSuite:

test("convert stack reference to case class") {
given Context = DummyContext().unsafeRunSync()
case class Test(s: String, i: Int) derives JsonReader
val expected = Test("value1", 2)
val outputs = Map("s" -> JsString("value1"), "i" -> JsNumber(2))

val requireObject = StackReference.requireObject[Test](Output(outputs), Output(Set.empty))
assertEquals(requireObject.getData.unsafeRunSync(), OutputData(expected))
}

test("fail when convert stack reference to case class with missing data") {
given Context = DummyContext().unsafeRunSync()
case class Test(s: String, i: Int) derives JsonReader
val outputs = Map("s" -> JsString("value1"))

val requireObject = StackReference.requireObject[Test](Output(outputs), Output(Set.empty))
intercept[besom.json.DeserializationException](requireObject.getData.unsafeRunSync())
}

test("convert stack reference to case class with secret field") {
given Context = DummyContext().unsafeRunSync()
case class Test(s: String, i: Int) derives JsonReader
val expected = Test("value1", 2)
val outputs = Map("s" -> JsString("value1"), "i" -> JsNumber(2))
val secretOutputNames = Set("i")

val requireObject = StackReference.requireObject[Test](Output(outputs), Output(secretOutputNames))
assertEquals(requireObject.getData.unsafeRunSync(), OutputData(expected).withIsSecret(true))
}

test("propagate secret field to whole typed stack reference") {
given Context = DummyContext().unsafeRunSync()

case class Test(s: String, i: Int) derives JsonReader
val outputs = Map("s" -> JsString("value1"), "i" -> JsNumber(2))
val secretOutputNames = Set("i")

val typedStackReference =
StackReference
.requireObject[Test](Output(outputs), Output(secretOutputNames))
.map(test =>
TypedStackReference(
urn = Output(URN.empty),
id = Output(ResourceId.empty),
name = Output(""),
outputs = test,
secretOutputNames = Output(secretOutputNames)
)
)

assertEquals(typedStackReference.getData.unsafeRunSync().secret, true)
}
end StackReferenceTest
Binary file removed cs
Binary file not shown.
12 changes: 6 additions & 6 deletions integration-tests/CoreTests.test.scala
Original file line number Diff line number Diff line change
Expand Up @@ -140,19 +140,19 @@ class CoreTests extends munit.FunSuite {
case pulumi.FixtureMultiContext(ctx, Vector(ctx1, ctx2)) =>
println(s"Source stack name: ${ctx1.stackName}, pulumi home: ${ctx.home}")
pulumi.up(ctx1.stackName).call(cwd = ctx1.programDir, env = ctx1.env)
val outputs1 = upickle.default.read[Map[String, ujson.Value]](
val expected = upickle.default.read[Map[String, ujson.Value]](
pulumi.outputs(ctx1.stackName, "--show-secrets").call(cwd = ctx1.programDir, env = ctx1.env).out.text()
)

println(s"Target stack name: ${ctx2.stackName}, pulumi home: ${ctx.home}")
pulumi
.up(ctx2.stackName, "--config", s"sourceStack=organization/source-stack-test/${ctx1.stackName}")
.call(cwd = ctx2.programDir, env = ctx2.env)
val outputs2 = upickle.default.read[Map[String, ujson.Value]](
val obtained = upickle.default.read[Map[String, ujson.Value]](
pulumi.outputs(ctx2.stackName, "--show-secrets").call(cwd = ctx2.programDir, env = ctx2.env).out.text()
)

assertEquals(outputs1, outputs2)
assertEquals(obtained, expected)

case _ => throw new Exception("Invalid number of contexts")
}
Expand Down Expand Up @@ -182,19 +182,19 @@ class CoreTests extends munit.FunSuite {
case pulumi.FixtureMultiContext(ctx, Vector(ctx1, ctx2)) =>
println(s"Source stack name: ${ctx1.stackName}, pulumi home: ${ctx.home}")
pulumi.up(ctx1.stackName).call(cwd = ctx1.programDir, env = ctx1.env)
val outputs1 = upickle.default.read[Map[String, ujson.Value]](
val expected = upickle.default.read[Map[String, ujson.Value]](
pulumi.outputs(ctx1.stackName, "--show-secrets").call(cwd = ctx1.programDir, env = ctx1.env).out.text()
)

println(s"Target stack name: ${ctx2.stackName}, pulumi home: ${ctx.home}")
pulumi
.up(ctx2.stackName, "--config", s"sourceStack=organization/source-stack-test/${ctx1.stackName}")
.call(cwd = ctx2.programDir, env = ctx2.env)
val outputs2 = upickle.default.read[Map[String, ujson.Value]](
val obtained = upickle.default.read[Map[String, ujson.Value]](
pulumi.outputs(ctx2.stackName, "--show-secrets").call(cwd = ctx2.programDir, env = ctx2.env).out.text()
)

assertEquals(outputs1, outputs2)
assertEquals(obtained, expected)

case _ => throw new Exception("Invalid number of contexts")
}
Expand Down
Loading
Loading