Skip to content

Commit

Permalink
Spec Refactor:
Browse files Browse the repository at this point in the history
* Spec.build method is removed. Specs are now fully declarative and the DbType is responsible for implementation.
* Add Spec.prefix method
* Change Spec.seek nullOrder, Spec.orderBy direction and nullOrder to be database Default
* Tests are improved and now shared between database specs insteads of duplicated. The increased coverage exposed some tangential issues which are also resolved in this MR (see below). Note: switched from munit assertEquals -> assert until fix for scalameta/munit#855 (comment) is released.

Additional changes:

* No longer need to handle null when using DbCodec.biMap, or implementing DbCodec.readSingle
* When implementing DbCodec, new method readSingleOption must be defined
* Support optional products in outer-join queries (see test OptionalProductTests)
* Switched to latest scalafmt version to prevent OOM in OracleTests
* Make Frag, Query, Update, Returning into regular classes
* Frag.returningKeys method which uses ResultSet.getGeneratedKeys
* Make MySq insertReturning throw, since we shouldn't define 2-query repository methods.
  • Loading branch information
AugustNagro committed Nov 29, 2024
1 parent 7555a22 commit 77b0121
Show file tree
Hide file tree
Showing 73 changed files with 2,091 additions and 3,107 deletions.
2 changes: 1 addition & 1 deletion .scalafmt.conf
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
version = 3.8.3
version = 3.8.4-RC3
runner.dialect = scala3
rewrite.scala3.insertEndMarkerMinLines = 20
rewrite.scala3.removeEndMarkerMaxLines = 19
Expand Down
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ ThisBuild / publishTo := {
}
ThisBuild / publish / skip := true

Global / onChangedBuildSource := ReloadOnSourceChanges
addCommandAlias("fmt", "scalafmtAll")

val testcontainersVersion = "0.41.4"
val circeVersion = "0.14.10"
Expand Down
95 changes: 94 additions & 1 deletion magnum-pg/src/main/scala/com/augustnagro/magnum/pg/PgCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import org.postgresql.util.PGInterval
import java.sql
import java.sql.{JDBCType, PreparedStatement, ResultSet, Types}
import scala.reflect.ClassTag
import scala.collection.mutable as m
import scala.collection.{mutable as m}
import scala.compiletime.*

object PgCodec:
Expand Down Expand Up @@ -64,6 +64,15 @@ object PgCodec:
val arr = aArrayCodec.readArray(jdbcArray.getArray)
IArray.unsafeFromArray(arr)
finally jdbcArray.free()
def readSingleOption(resultSet: ResultSet, pos: Int): Option[IArray[A]] =
val jdbcArray = resultSet.getArray(pos)
if resultSet.wasNull then None
else
try
val arr = aArrayCodec.readArray(jdbcArray.getArray)
Some(IArray.unsafeFromArray(arr))
finally jdbcArray.free()

def writeSingle(entity: IArray[A], ps: PreparedStatement, pos: Int): Unit =
ps.setObject(pos, entity)

Expand All @@ -81,6 +90,14 @@ object PgCodec:
val arr = aArrayCodec.readArray(jdbcArray.getArray)
IArray.unsafeFromArray(arr)
finally jdbcArray.free()
def readSingleOption(resultSet: ResultSet, pos: Int): Option[IArray[A]] =
val jdbcArray = resultSet.getArray(pos)
if resultSet.wasNull then None
else
try
val arr = aArrayCodec.readArray(jdbcArray.getArray)
Some(IArray.unsafeFromArray(arr))
finally jdbcArray.free()
def writeSingle(entity: IArray[A], ps: PreparedStatement, pos: Int): Unit =
val arr = entity.iterator.map(aArrayCodec.toArrayObj).toArray
val jdbcArr =
Expand All @@ -99,6 +116,12 @@ object PgCodec:
val jdbcArray = resultSet.getArray(pos)
try aArrayCodec.readArray(jdbcArray.getArray)
finally jdbcArray.free()
def readSingleOption(resultSet: ResultSet, pos: Int): Option[Array[A]] =
val jdbcArray = resultSet.getArray(pos)
if resultSet.wasNull then None
else
try Some(aArrayCodec.readArray(jdbcArray.getArray))
finally jdbcArray.free()
def writeSingle(entity: Array[A], ps: PreparedStatement, pos: Int): Unit =
ps.setObject(pos, entity)

Expand All @@ -114,6 +137,12 @@ object PgCodec:
val jdbcArray = resultSet.getArray(pos)
try aArrayCodec.readArray(jdbcArray.getArray)
finally jdbcArray.free()
def readSingleOption(resultSet: ResultSet, pos: Int): Option[Array[A]] =
val jdbcArray = resultSet.getArray(pos)
if resultSet.wasNull then None
else
try Some(aArrayCodec.readArray(jdbcArray.getArray))
finally jdbcArray.free()
def writeSingle(entity: Array[A], ps: PreparedStatement, pos: Int): Unit =
val arr = entity.iterator.map(aArrayCodec.toArrayObj).toArray
val jdbcArr =
Expand All @@ -133,6 +162,14 @@ object PgCodec:
val arr = aArrayCodec.readArray(jdbcArray.getArray)
List.from(arr)
finally jdbcArray.free()
def readSingleOption(resultSet: ResultSet, pos: Int): Option[Seq[A]] =
val jdbcArray = resultSet.getArray(pos)
if resultSet.wasNull then None
else
try
val arr = aArrayCodec.readArray(jdbcArray.getArray)
Some(List.from(arr))
finally jdbcArray.free()
def writeSingle(entity: Seq[A], ps: PreparedStatement, pos: Int): Unit =
val arr = entity.iterator.map(aArrayCodec.toArrayObj).toArray
val jdbcArr =
Expand All @@ -152,6 +189,14 @@ object PgCodec:
val arr = aArrayCodec.readArray(jdbcArray.getArray)
List.from(arr)
finally jdbcArray.free()
def readSingleOption(resultSet: ResultSet, pos: Int): Option[List[A]] =
val jdbcArray = resultSet.getArray(pos)
if resultSet.wasNull then None
else
try
val arr = aArrayCodec.readArray(jdbcArray.getArray)
Some(List.from(arr))
finally jdbcArray.free()
def writeSingle(entity: List[A], ps: PreparedStatement, pos: Int): Unit =
val arr = entity.iterator.map(aArrayCodec.toArrayObj).toArray
val jdbcArr =
Expand All @@ -171,6 +216,14 @@ object PgCodec:
val arr = aArrayCodec.readArray(jdbcArray.getArray)
Vector.from(arr)
finally jdbcArray.free()
def readSingleOption(resultSet: ResultSet, pos: Int): Option[Vector[A]] =
val jdbcArray = resultSet.getArray(pos)
if resultSet.wasNull then None
else
try
val arr = aArrayCodec.readArray(jdbcArray.getArray)
Some(Vector.from(arr))
finally jdbcArray.free()
def writeSingle(entity: Vector[A], ps: PreparedStatement, pos: Int): Unit =
val arr = entity.iterator.map(aArrayCodec.toArrayObj).toArray
val jdbcArr =
Expand All @@ -191,6 +244,14 @@ object PgCodec:
val arr = aArrayCodec.readArray(jdbcArray.getArray)
m.Buffer.from(arr)
finally jdbcArray.free()
def readSingleOption(resultSet: ResultSet, pos: Int): Option[m.Buffer[A]] =
val jdbcArray = resultSet.getArray(pos)
if resultSet.wasNull then None
else
try
val arr = aArrayCodec.readArray(jdbcArray.getArray)
Some(m.Buffer.from(arr))
finally jdbcArray.free()
def writeSingle(
entity: m.Buffer[A],
ps: PreparedStatement,
Expand All @@ -205,6 +266,10 @@ object PgCodec:
val cols: IArray[Int] = IArray(Types.JAVA_OBJECT)
def readSingle(resultSet: ResultSet, pos: Int): PGbox =
resultSet.getObject(pos, classOf[PGbox])
def readSingleOption(resultSet: ResultSet, pos: Int): Option[PGbox] =
val res = resultSet.getObject(pos, classOf[PGbox])
if resultSet.wasNull then None
else Some(res)
def writeSingle(entity: PGbox, ps: PreparedStatement, pos: Int): Unit =
ps.setObject(pos, entity)

Expand All @@ -213,6 +278,10 @@ object PgCodec:
val cols: IArray[Int] = IArray(Types.JAVA_OBJECT)
def readSingle(resultSet: ResultSet, pos: Int): PGcircle =
resultSet.getObject(pos, classOf[PGcircle])
def readSingleOption(resultSet: ResultSet, pos: Int): Option[PGcircle] =
val res = resultSet.getObject(pos, classOf[PGcircle])
if resultSet.wasNull then None
else Some(res)
def writeSingle(entity: PGcircle, ps: PreparedStatement, pos: Int): Unit =
ps.setObject(pos, entity)

Expand All @@ -221,6 +290,10 @@ object PgCodec:
val cols: IArray[Int] = IArray(Types.JAVA_OBJECT)
def readSingle(resultSet: ResultSet, pos: Int): PGInterval =
resultSet.getObject(pos, classOf[PGInterval])
def readSingleOption(resultSet: ResultSet, pos: Int): Option[PGInterval] =
val res = resultSet.getObject(pos, classOf[PGInterval])
if resultSet.wasNull then None
else Some(res)
def writeSingle(entity: PGInterval, ps: PreparedStatement, pos: Int): Unit =
ps.setObject(pos, entity)

Expand All @@ -229,6 +302,10 @@ object PgCodec:
val cols: IArray[Int] = IArray(Types.JAVA_OBJECT)
def readSingle(resultSet: ResultSet, pos: Int): PGline =
resultSet.getObject(pos, classOf[PGline])
def readSingleOption(resultSet: ResultSet, pos: Int): Option[PGline] =
val res = resultSet.getObject(pos, classOf[PGline])
if resultSet.wasNull then None
else Some(res)
def writeSingle(entity: PGline, ps: PreparedStatement, pos: Int): Unit =
ps.setObject(pos, entity)

Expand All @@ -237,6 +314,10 @@ object PgCodec:
val cols: IArray[Int] = IArray(Types.JAVA_OBJECT)
def readSingle(resultSet: ResultSet, pos: Int): PGlseg =
resultSet.getObject(pos, classOf[PGlseg])
def readSingleOption(resultSet: ResultSet, pos: Int): Option[PGlseg] =
val res = resultSet.getObject(pos, classOf[PGlseg])
if resultSet.wasNull then None
else Some(res)
def writeSingle(entity: PGlseg, ps: PreparedStatement, pos: Int): Unit =
ps.setObject(pos, entity)

Expand All @@ -245,6 +326,10 @@ object PgCodec:
val cols: IArray[Int] = IArray(Types.JAVA_OBJECT)
def readSingle(resultSet: ResultSet, pos: Int): PGpath =
resultSet.getObject(pos, classOf[PGpath])
def readSingleOption(resultSet: ResultSet, pos: Int): Option[PGpath] =
val res = resultSet.getObject(pos, classOf[PGpath])
if resultSet.wasNull then None
else Some(res)
def writeSingle(entity: PGpath, ps: PreparedStatement, pos: Int): Unit =
ps.setObject(pos, entity)

Expand All @@ -253,6 +338,10 @@ object PgCodec:
val cols: IArray[Int] = IArray(Types.JAVA_OBJECT)
def readSingle(resultSet: ResultSet, pos: Int): PGpoint =
resultSet.getObject(pos, classOf[PGpoint])
def readSingleOption(resultSet: ResultSet, pos: Int): Option[PGpoint] =
val res = resultSet.getObject(pos, classOf[PGpoint])
if resultSet.wasNull then None
else Some(res)
def writeSingle(entity: PGpoint, ps: PreparedStatement, pos: Int): Unit =
ps.setObject(pos, entity)

Expand All @@ -261,6 +350,10 @@ object PgCodec:
val cols: IArray[Int] = IArray(Types.JAVA_OBJECT)
def readSingle(resultSet: ResultSet, pos: Int): PGpolygon =
resultSet.getObject(pos, classOf[PGpolygon])
def readSingleOption(resultSet: ResultSet, pos: Int): Option[PGpolygon] =
val res = resultSet.getObject(pos, classOf[PGpolygon])
if resultSet.wasNull then None
else Some(res)
def writeSingle(entity: PGpolygon, ps: PreparedStatement, pos: Int): Unit =
ps.setObject(pos, entity)
end PgCodec
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@ trait JsonBDbCodec[A] extends DbCodec[A]:
override val cols: IArray[Int] = IArray(Types.OTHER)

override def readSingle(resultSet: ResultSet, pos: Int): A =
decode(resultSet.getString(pos))

override def readSingleOption(resultSet: ResultSet, pos: Int): Option[A] =
val rawJson = resultSet.getString(pos)
if rawJson eq null then null.asInstanceOf[A]
else decode(rawJson)
if rawJson == null then None
else Some(decode(rawJson))

override def writeSingle(entity: A, ps: PreparedStatement, pos: Int): Unit =
val jsonObject = PGobject()
jsonObject.setType("jsonb")
val encoded = if entity == null then null else encode(entity)
jsonObject.setValue(encoded)
jsonObject.setValue(encode(entity))
ps.setObject(pos, jsonObject)

end JsonBDbCodec
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@ trait JsonDbCodec[A] extends DbCodec[A]:
override val cols: IArray[Int] = IArray(Types.OTHER)

override def readSingle(resultSet: ResultSet, pos: Int): A =
decode(resultSet.getString(pos))

override def readSingleOption(resultSet: ResultSet, pos: Int): Option[A] =
val rawJson = resultSet.getString(pos)
if rawJson eq null then null.asInstanceOf[A]
else decode(rawJson)
if rawJson == null then None
else Some(decode(rawJson))

override def writeSingle(entity: A, ps: PreparedStatement, pos: Int): Unit =
val jsonObject = PGobject()
jsonObject.setType("json")
val encoded = if entity == null then null else encode(entity)
jsonObject.setValue(encoded)
jsonObject.setValue(encode(entity))
ps.setObject(pos, jsonObject)

end JsonDbCodec
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@ trait XmlDbCodec[A] extends DbCodec[A]:
override val cols: IArray[Int] = IArray(Types.SQLXML)

override def readSingle(resultSet: ResultSet, pos: Int): A =
decode(resultSet.getString(pos))

override def readSingleOption(resultSet: ResultSet, pos: Int): Option[A] =
val xmlString = resultSet.getString(pos)
if xmlString == null then null.asInstanceOf[A]
else decode(xmlString)
if xmlString == null then None
else Some(decode(xmlString))

override def writeSingle(entity: A, ps: PreparedStatement, pos: Int): Unit =
val xmlObject = PGobject()
xmlObject.setType("xml")
val encoded = if entity == null then null else encode(entity)
xmlObject.setValue(encoded)
xmlObject.setValue(encode(entity))
ps.setObject(pos, xmlObject)

end XmlDbCodec
19 changes: 13 additions & 6 deletions magnum-pg/src/test/scala/PgCodecTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@ class PgCodecTests extends FunSuite, TestContainersFixtures:

test("select all MagUser"):
connect(ds()):
assertEquals(userRepo.findAll, allUsers)
assert(userRepo.findAll == allUsers)

test("select all MagCar"):
connect(ds()):
assertEquals(carRepo.findAll, allCars)
assert(carRepo.findAll == allCars)

test("insert MagUser"):
connect(ds()):
Expand All @@ -124,7 +124,7 @@ class PgCodecTests extends FunSuite, TestContainersFixtures:
)
userRepo.insert(u)
val dbU = userRepo.findById(3L).get
assertEquals(dbU, u)
assert(dbU == u)

test("insert MagCar"):
connect(ds()):
Expand All @@ -141,7 +141,7 @@ class PgCodecTests extends FunSuite, TestContainersFixtures:
)
carRepo.insert(c)
val dbC = carRepo.findById(3L).get
assertEquals(dbC, c)
assert(dbC == c)

test("update MagUser arrays"):
connect(ds()):
Expand All @@ -158,7 +158,7 @@ class PgCodecTests extends FunSuite, TestContainersFixtures:
sql"UPDATE mag_car SET text_color_map = $newTextColorMap WHERE id = 2".update
.run()
val newCar = carRepo.findById(2L).get
assertEquals(newCar.textColorMap, newTextColorMap)
assert(newCar.textColorMap == newTextColorMap)

test("MagCar xml string values"):
connect(ds()):
Expand All @@ -170,7 +170,14 @@ class PgCodecTests extends FunSuite, TestContainersFixtures:
.map(_.elem.toString)
val expected = allCars.flatMap(_.myXml).map(_.elem.toString)
println(found)
assertEquals(found, expected)
assert(found == expected)

test("where = ANY()"):
connect(ds()):
val ids = Vector(1L, 2L)
val cars =
sql"SELECT * FROM mag_car WHERE id = ANY($ids)".query[MagCar].run()
assert(cars == allCars)

val pgContainer = ForAllContainerFixture(
PostgreSQLContainer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package com.augustnagro.magnum

import java.sql.{Connection, PreparedStatement, ResultSet, Statement}
import java.time.OffsetDateTime
import java.util.StringJoiner
import scala.collection.View
import scala.deriving.Mirror
import scala.reflect.ClassTag
Expand Down Expand Up @@ -33,17 +34,18 @@ object ClickhouseDbType extends DbType:
val ecInsertKeys = ecElemNamesSql.mkString("(", ", ", ")")

val countSql = s"SELECT count(*) FROM $tableNameSql"
val countQuery = Frag(countSql).query[Long]
val countQuery = Frag(countSql, Vector.empty, FragWriter.empty).query[Long]
val existsByIdSql =
s"SELECT 1 FROM $tableNameSql WHERE $idName = ${idCodec.queryRepr}"
val findAllSql = s"SELECT $selectKeys FROM $tableNameSql"
val findAllQuery = Frag(findAllSql).query[E]
val findAllQuery = Frag(findAllSql, Vector.empty, FragWriter.empty).query[E]
val findByIdSql =
s"SELECT $selectKeys FROM $tableNameSql WHERE $idName = ${idCodec.queryRepr}"
val deleteByIdSql =
s"DELETE FROM $tableNameSql WHERE $idName = ${idCodec.queryRepr}"
val truncateSql = s"TRUNCATE TABLE $tableNameSql"
val truncateUpdate = Frag(truncateSql).update
val truncateUpdate =
Frag(truncateSql, Vector.empty, FragWriter.empty).update
val insertSql =
s"INSERT INTO $tableNameSql $ecInsertKeys VALUES (${ecCodec.queryRepr})"

Expand All @@ -63,10 +65,7 @@ object ClickhouseDbType extends DbType:
def findAll(using DbCon): Vector[E] = findAllQuery.run()

def findAll(spec: Spec[E])(using DbCon): Vector[E] =
val f = spec.build
Frag(s"SELECT * FROM $tableNameSql ${f.sqlString}", f.params, f.writer)
.query[E]
.run()
SpecImpl.Default.findAll(spec, tableNameSql)

def findById(id: ID)(using DbCon): Option[E] =
Frag(findByIdSql, IArray(id), idWriter(id))
Expand Down
Loading

0 comments on commit 77b0121

Please sign in to comment.