Skip to content

Commit

Permalink
Merge pull request #991 from typelevel/bugfix/990
Browse files Browse the repository at this point in the history
Fix for #990 - nested fragment interpolation
  • Loading branch information
mpilquist authored Oct 10, 2023
2 parents 40c6ede + 4193787 commit 372f697
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 22 deletions.
39 changes: 17 additions & 22 deletions modules/core/shared/src/main/scala-3/syntax/StringContextOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ object StringContextOps {

def yell(s: String) = println(s"${Console.RED}$s${Console.RESET}")

def sqlImpl(sc: Expr[StringContext], argsExpr: Expr[Seq[Any]])(using qc:Quotes): Expr[Any] = {
def sqlImpl(sc: Expr[StringContext], argsExpr: Expr[Seq[Any]])(using qc: Quotes): Expr[Any] = {
import qc.reflect.report

// Ok we want to construct an Origin here
Expand Down Expand Up @@ -78,23 +78,17 @@ object StringContextOps {
val partsEncoders: Either[Expr[Any], (List[Expr[Part]], List[Expr[Any]])] = strings.flatMap { strings =>
val lastPart: Expr[Part] = '{Str(${Expr(strings.last)})}
(strings zip args).reverse.foldLeftM((List[Expr[Part]](lastPart), List.empty[Expr[Any]])) {

case ((parts, es), (str, arg)) =>

if (str.endsWith("#")) {

if (str.endsWith("#")) then {
// Interpolations like "...#$foo ..." require `foo` to be a String.
arg match {
case '{ $s: String } => Right(('{Str(${Expr(str.dropRight(1))})} :: '{Str($s)} :: parts, es))
case '{ $a: t } =>
report.error(s"Found ${Type.show[t]}, expected String.}", a)
Left('{ compiletime.error("Expected String") }) ///
}

report.error(s"Found ${Type.show[t]}, expected String.}", a)
Left('{ compiletime.error("Expected String") }) ///
}
} else {

arg match {

// The interpolated thing is an Encoder.
case '{ $e: Encoder[t] } =>
val newParts = '{Str(${Expr(str)})} :: '{Par($e.sql)} :: parts
Expand All @@ -113,32 +107,33 @@ object StringContextOps {
Right((newParts, newEncoders))

case '{ $a: t } =>
report.error(s"Found ${Type.show[t]}, expected String, Encoder, or Fragment.", a)
Left('{compiletime.error("Expected String, Encoder, or Fragment.")})

report.error(s"Found ${Type.show[t]}, expected String, Encoder, or Fragment.", a)
Left('{compiletime.error("Expected String, Encoder, or Fragment.")})
}

}

}
}

val legacyCommandSyntax = Expr.summon[skunk.featureFlags.legacyCommandSyntax].isDefined
partsEncoders.map { (parts, encoders) =>
val finalEnc: Expr[Any] =
if encoders.isEmpty then '{ Void.codec }
else if legacyCommandSyntax then encoders.reduceLeft {
case ('{$a : Encoder[a]}, '{ $b : Encoder[b] }) => '{$a ~ $b}
} else encoders.reduceRight {
case ('{$a : Encoder[a]}, '{ $b : Encoder[bh *: bt] }) => '{$a *: $b}
case ('{$a : Encoder[a]}, '{ $b : Encoder[b] }) => '{$a *: $b}
else if legacyCommandSyntax then
encoders.reduceLeft {
case ('{$a : Encoder[a]}, '{ $b : Encoder[b] }) => '{$a ~ $b}
}
else if encoders.size == 1 then encoders.head
else {
val last: Expr[Any] = encoders.last match {
case '{$a: Encoder[a]} => '{$a.imap(_ *: EmptyTuple)(_.head)}
}
encoders.init.foldRight(last) { case ('{$a: Encoder[a]}, '{$acc: Encoder[t & Tuple]}) => '{$a *: $acc} }
}

finalEnc match {
case '{ $e : Encoder[t] } => '{ fragmentFromParts[t](${Expr.ofList(parts)}, $e, $origin) }
}
}.merge

}

def idImpl(sc: Expr[StringContext])(using qc: Quotes): Expr[Identifier] =
Expand Down
15 changes: 15 additions & 0 deletions modules/tests/shared/src/test/scala/issue/990.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (c) 2018-2021 by Rob Norris
// This software is licensed under the MIT License (MIT).
// For more information see LICENSE or https://opensource.org/licenses/MIT

package tests.issue

import skunk._
import skunk.codec.all._
import skunk.syntax.all._


object Issue990 {
def updateBy[A](where: Fragment[A]): Command[String *: A *: EmptyTuple] =
sql"""UPDATE foo SET bar = $text WHERE $where;""".command
}

0 comments on commit 372f697

Please sign in to comment.