diff --git a/database/contracts.go b/database/contracts.go index bf55d320..6da723a8 100644 --- a/database/contracts.go +++ b/database/contracts.go @@ -1,5 +1,10 @@ package database +import ( + "context" + "github.com/jmoiron/sqlx" +) + // Entity is implemented by each type that works with the database package. type Entity interface { Fingerprinter @@ -54,3 +59,10 @@ type PgsqlOnConflictConstrainter interface { // PgsqlOnConflictConstraint returns the primary or unique key constraint name of the PostgreSQL table. PgsqlOnConflictConstraint() string } + +// TxOrDB is just a helper interface that can represent a *[sqlx.Tx] or *[DB] instance. +type TxOrDB interface { + sqlx.ExtContext + + PrepareNamedContext(ctx context.Context, query string) (*sqlx.NamedStmt, error) +} diff --git a/database/db.go b/database/db.go index 26e50e1e..ce33ac36 100644 --- a/database/db.go +++ b/database/db.go @@ -836,3 +836,9 @@ func (db *DB) Log(ctx context.Context, query string, counter *com.Counter) perio db.logger.Debugf("Finished executing %q with %d rows in %s", query, counter.Total(), tick.Elapsed) })) } + +var ( + // Assert TxOrDB interface compliance of the DB and sqlx.Tx types. + _ TxOrDB = (*DB)(nil) + _ TxOrDB = (*sqlx.Tx)(nil) +) diff --git a/database/utils.go b/database/utils.go index 067b156b..43ed574d 100644 --- a/database/utils.go +++ b/database/utils.go @@ -8,6 +8,7 @@ import ( "github.com/icinga/icinga-go-library/com" "github.com/icinga/icinga-go-library/strcase" "github.com/icinga/icinga-go-library/types" + "github.com/jmoiron/sqlx" "github.com/pkg/errors" ) @@ -44,6 +45,42 @@ func SplitOnDupId[T IDer]() com.BulkChunkSplitPolicy[T] { } } +// InsertObtainID executes the given query and fetches the last inserted ID. +// +// Using this method for database tables that don't define an auto-incrementing ID, or none at all, +// will not work. The only supported column that can be retrieved with this method is id. +// +// This function expects [TxOrDB] as an executor of the provided query, and is usually a *[sqlx.Tx] or *[DB] instance. +// +// Returns the retrieved ID on success and error on any database inserting/retrieving failure. +func InsertObtainID(ctx context.Context, conn TxOrDB, stmt string, arg any) (int64, error) { + var resultID int64 + switch conn.DriverName() { + case PostgreSQL: + stmt = stmt + " RETURNING id" + query, args, err := conn.BindNamed(stmt, arg) + if err != nil { + return 0, errors.Wrapf(err, "can't bind named query %q", stmt) + } + + if err := sqlx.GetContext(ctx, conn, &resultID, query, args...); err != nil { + return 0, CantPerformQuery(err, query) + } + default: + result, err := sqlx.NamedExecContext(ctx, conn, stmt, arg) + if err != nil { + return 0, CantPerformQuery(err, stmt) + } + + resultID, err = result.LastInsertId() + if err != nil { + return 0, errors.Wrap(err, "can't retrieve last inserted ID") + } + } + + return resultID, nil +} + // unsafeSetSessionVariableIfExists sets the given MySQL/MariaDB system variable for the specified database session. // // NOTE: It is unsafe to use this function with untrusted/user supplied inputs and poses an SQL injection, diff --git a/database/utils_test.go b/database/utils_test.go index ba746310..285f43b7 100644 --- a/database/utils_test.go +++ b/database/utils_test.go @@ -16,7 +16,48 @@ import ( "time" ) -func TestSetMysqlSessionVars(t *testing.T) { +func TestDatabaseUtils(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db := GetTestDB(ctx, t, "ICINGAGOLIBRARY") + + t.Run("SetMySQLSessionVars", func(t *testing.T) { + t.Parallel() + if db.DriverName() != MySQL { + t.Skipf("skipping set session vars test for %q driver", db.DriverName()) + } + + setMysqlSessionVars(ctx, db, t) + }) + + t.Run("InsertObtainID", func(t *testing.T) { + t.Parallel() + + defer func() { + _, err := db.ExecContext(ctx, "DROP TABLE IF EXISTS igl_test_insert_obtain") + assert.NoError(t, err, "dropping test database table should not fail") + }() + + var err error + if db.DriverName() == PostgreSQL { + _, err = db.ExecContext(ctx, "CREATE TABLE igl_test_insert_obtain (id SERIAL PRIMARY KEY, name VARCHAR(255))") + } else { + _, err = db.ExecContext(ctx, "CREATE TABLE igl_test_insert_obtain (id INT AUTO_INCREMENT PRIMARY KEY, name VARCHAR(255))") + } + require.NoError(t, err, "creating test database table should not fail") + + id, err := InsertObtainID(ctx, db, "INSERT INTO igl_test_insert_obtain (name) VALUES (:name)", map[string]any{"name": "test1"}) + require.NoError(t, err, "inserting new row into test database table should not fail") + assert.Equal(t, id, int64(1)) + + id, err = InsertObtainID(ctx, db, "INSERT INTO igl_test_insert_obtain (name) VALUES (:name)", map[string]any{"name": "test2"}) + require.NoError(t, err, "inserting new row into test database table should not fail") + assert.Equal(t, id, int64(2)) + }) +} + +func setMysqlSessionVars(ctx context.Context, db *DB, t *testing.T) { vars := map[string][]struct { name string value string @@ -45,14 +86,10 @@ func TestSetMysqlSessionVars(t *testing.T) { }, } - ctx := context.Background() - db := GetTestDB(ctx, t, "ICINGAGOLIBRARY") - if db.DriverName() != MySQL { - t.Skipf("skipping set session vars test for %q driver", db.DriverName()) - } - for name, vs := range vars { t.Run(name, func(t *testing.T) { + t.Parallel() + for _, v := range vs { conn, err := db.DB.Conn(ctx) require.NoError(t, err, "connecting to MySQL/MariaDB database should not fail")