Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce DB#InsertObtainID() function #64

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions database/contracts.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
package database

import (
"context"
"github.com/jmoiron/sqlx"
)

// Entity is implemented by each type that works with the database package.
type Entity interface {
Fingerprinter
Expand Down Expand Up @@ -54,3 +59,10 @@ type PgsqlOnConflictConstrainter interface {
// PgsqlOnConflictConstraint returns the primary or unique key constraint name of the PostgreSQL table.
PgsqlOnConflictConstraint() string
}

// TxOrDB is just a helper interface that can represent a *[sqlx.Tx] or *[DB] instance.
type TxOrDB interface {
sqlx.ExtContext

PrepareNamedContext(ctx context.Context, query string) (*sqlx.NamedStmt, error)
}
6 changes: 6 additions & 0 deletions database/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -836,3 +836,9 @@ func (db *DB) Log(ctx context.Context, query string, counter *com.Counter) perio
db.logger.Debugf("Finished executing %q with %d rows in %s", query, counter.Total(), tick.Elapsed)
}))
}

var (
// Assert TxOrDB interface compliance of the DB and sqlx.Tx types.
_ TxOrDB = (*DB)(nil)
_ TxOrDB = (*sqlx.Tx)(nil)
)
37 changes: 37 additions & 0 deletions database/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/icinga/icinga-go-library/com"
"github.com/icinga/icinga-go-library/strcase"
"github.com/icinga/icinga-go-library/types"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
)

Expand Down Expand Up @@ -44,6 +45,42 @@ func SplitOnDupId[T IDer]() com.BulkChunkSplitPolicy[T] {
}
}

// InsertObtainID executes the given query and fetches the last inserted ID.
//
// Using this method for database tables that don't define an auto-incrementing ID, or none at all,
// will not work. The only supported column that can be retrieved with this method is id.
//
// This function expects [TxOrDB] as an executor of the provided query, and is usually a *[sqlx.Tx] or *[DB] instance.
//
// Returns the retrieved ID on success and error on any database inserting/retrieving failure.
func InsertObtainID(ctx context.Context, conn TxOrDB, stmt string, arg any) (int64, error) {
var resultID int64
switch conn.DriverName() {
case PostgreSQL:
stmt = stmt + " RETURNING id"
query, args, err := conn.BindNamed(stmt, arg)
if err != nil {
return 0, errors.Wrapf(err, "can't bind named query %q", stmt)
}

if err := sqlx.GetContext(ctx, conn, &resultID, query, args...); err != nil {
return 0, CantPerformQuery(err, query)
}
default:
oxzi marked this conversation as resolved.
Show resolved Hide resolved
result, err := sqlx.NamedExecContext(ctx, conn, stmt, arg)
if err != nil {
return 0, CantPerformQuery(err, stmt)
}

resultID, err = result.LastInsertId()
if err != nil {
return 0, errors.Wrap(err, "can't retrieve last inserted ID")
}
}

return resultID, nil
}

// unsafeSetSessionVariableIfExists sets the given MySQL/MariaDB system variable for the specified database session.
//
// NOTE: It is unsafe to use this function with untrusted/user supplied inputs and poses an SQL injection,
Expand Down
51 changes: 44 additions & 7 deletions database/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,48 @@ import (
"time"
)

func TestSetMysqlSessionVars(t *testing.T) {
func TestDatabaseUtils(t *testing.T) {
t.Parallel()

ctx := context.Background()
db := GetTestDB(ctx, t, "ICINGAGOLIBRARY")

t.Run("SetMySQLSessionVars", func(t *testing.T) {
t.Parallel()
if db.DriverName() != MySQL {
t.Skipf("skipping set session vars test for %q driver", db.DriverName())
}

setMysqlSessionVars(ctx, db, t)
})

t.Run("InsertObtainID", func(t *testing.T) {
t.Parallel()

defer func() {
_, err := db.ExecContext(ctx, "DROP TABLE IF EXISTS igl_test_insert_obtain")
assert.NoError(t, err, "dropping test database table should not fail")
}()

var err error
if db.DriverName() == PostgreSQL {
_, err = db.ExecContext(ctx, "CREATE TABLE igl_test_insert_obtain (id SERIAL PRIMARY KEY, name VARCHAR(255))")
} else {
_, err = db.ExecContext(ctx, "CREATE TABLE igl_test_insert_obtain (id INT AUTO_INCREMENT PRIMARY KEY, name VARCHAR(255))")
}
require.NoError(t, err, "creating test database table should not fail")

id, err := InsertObtainID(ctx, db, "INSERT INTO igl_test_insert_obtain (name) VALUES (:name)", map[string]any{"name": "test1"})
require.NoError(t, err, "inserting new row into test database table should not fail")
assert.Equal(t, id, int64(1))

id, err = InsertObtainID(ctx, db, "INSERT INTO igl_test_insert_obtain (name) VALUES (:name)", map[string]any{"name": "test2"})
require.NoError(t, err, "inserting new row into test database table should not fail")
assert.Equal(t, id, int64(2))
})
}

func setMysqlSessionVars(ctx context.Context, db *DB, t *testing.T) {
vars := map[string][]struct {
name string
value string
Expand Down Expand Up @@ -45,14 +86,10 @@ func TestSetMysqlSessionVars(t *testing.T) {
},
}

ctx := context.Background()
db := GetTestDB(ctx, t, "ICINGAGOLIBRARY")
if db.DriverName() != MySQL {
t.Skipf("skipping set session vars test for %q driver", db.DriverName())
}

for name, vs := range vars {
t.Run(name, func(t *testing.T) {
t.Parallel()

for _, v := range vs {
conn, err := db.DB.Conn(ctx)
require.NoError(t, err, "connecting to MySQL/MariaDB database should not fail")
Expand Down
Loading