diff --git a/modules/core/shared/src/main/scala-3/syntax/StringContextOps.scala b/modules/core/shared/src/main/scala-3/syntax/StringContextOps.scala index 235279b43..7567d6296 100644 --- a/modules/core/shared/src/main/scala-3/syntax/StringContextOps.scala +++ b/modules/core/shared/src/main/scala-3/syntax/StringContextOps.scala @@ -49,7 +49,7 @@ object StringContextOps { def yell(s: String) = println(s"${Console.RED}$s${Console.RESET}") - def sqlImpl(sc: Expr[StringContext], argsExpr: Expr[Seq[Any]])(using qc:Quotes): Expr[Any] = { + def sqlImpl(sc: Expr[StringContext], argsExpr: Expr[Seq[Any]])(using qc: Quotes): Expr[Any] = { import qc.reflect.report // Ok we want to construct an Origin here @@ -78,23 +78,17 @@ object StringContextOps { val partsEncoders: Either[Expr[Any], (List[Expr[Part]], List[Expr[Any]])] = strings.flatMap { strings => val lastPart: Expr[Part] = '{Str(${Expr(strings.last)})} (strings zip args).reverse.foldLeftM((List[Expr[Part]](lastPart), List.empty[Expr[Any]])) { - case ((parts, es), (str, arg)) => - - if (str.endsWith("#")) { - + if (str.endsWith("#")) then { // Interpolations like "...#$foo ..." require `foo` to be a String. arg match { case '{ $s: String } => Right(('{Str(${Expr(str.dropRight(1))})} :: '{Str($s)} :: parts, es)) case '{ $a: t } => - report.error(s"Found ${Type.show[t]}, expected String.}", a) - Left('{ compiletime.error("Expected String") }) /// - } - + report.error(s"Found ${Type.show[t]}, expected String.}", a) + Left('{ compiletime.error("Expected String") }) /// + } } else { - arg match { - // The interpolated thing is an Encoder. case '{ $e: Encoder[t] } => val newParts = '{Str(${Expr(str)})} :: '{Par($e.sql)} :: parts @@ -113,13 +107,10 @@ object StringContextOps { Right((newParts, newEncoders)) case '{ $a: t } => - report.error(s"Found ${Type.show[t]}, expected String, Encoder, or Fragment.", a) - Left('{compiletime.error("Expected String, Encoder, or Fragment.")}) - + report.error(s"Found ${Type.show[t]}, expected String, Encoder, or Fragment.", a) + Left('{compiletime.error("Expected String, Encoder, or Fragment.")}) } - } - } } @@ -127,18 +118,22 @@ object StringContextOps { partsEncoders.map { (parts, encoders) => val finalEnc: Expr[Any] = if encoders.isEmpty then '{ Void.codec } - else if legacyCommandSyntax then encoders.reduceLeft { - case ('{$a : Encoder[a]}, '{ $b : Encoder[b] }) => '{$a ~ $b} - } else encoders.reduceRight { - case ('{$a : Encoder[a]}, '{ $b : Encoder[bh *: bt] }) => '{$a *: $b} - case ('{$a : Encoder[a]}, '{ $b : Encoder[b] }) => '{$a *: $b} + else if legacyCommandSyntax then + encoders.reduceLeft { + case ('{$a : Encoder[a]}, '{ $b : Encoder[b] }) => '{$a ~ $b} + } + else if encoders.size == 1 then encoders.head + else { + val last: Expr[Any] = encoders.last match { + case '{$a: Encoder[a]} => '{$a.imap(_ *: EmptyTuple)(_.head)} + } + encoders.init.foldRight(last) { case ('{$a: Encoder[a]}, '{$acc: Encoder[t & Tuple]}) => '{$a *: $acc} } } finalEnc match { case '{ $e : Encoder[t] } => '{ fragmentFromParts[t](${Expr.ofList(parts)}, $e, $origin) } } }.merge - } def idImpl(sc: Expr[StringContext])(using qc: Quotes): Expr[Identifier] = diff --git a/modules/tests/shared/src/test/scala/issue/990.scala b/modules/tests/shared/src/test/scala/issue/990.scala new file mode 100644 index 000000000..d73c9b56a --- /dev/null +++ b/modules/tests/shared/src/test/scala/issue/990.scala @@ -0,0 +1,15 @@ +// Copyright (c) 2018-2021 by Rob Norris +// This software is licensed under the MIT License (MIT). +// For more information see LICENSE or https://opensource.org/licenses/MIT + +package tests.issue + +import skunk._ +import skunk.codec.all._ +import skunk.syntax.all._ + + +object Issue990 { + def updateBy[A](where: Fragment[A]): Command[String *: A *: EmptyTuple] = + sql"""UPDATE foo SET bar = $text WHERE $where;""".command +}