From f9eeffd4889c8813133bb558c26d1add5ddcc5c2 Mon Sep 17 00:00:00 2001 From: Rodrigo Bernardi Date: Mon, 12 Aug 2019 10:16:48 -0300 Subject: [PATCH] Separated commands --- dialect.go | 75 +++++++++++++++++++++++++++++++++++++++------------- migrate.go | 12 +++++++++ migration.go | 2 +- 3 files changed, 70 insertions(+), 19 deletions(-) diff --git a/dialect.go b/dialect.go index fc383a323..29561993b 100644 --- a/dialect.go +++ b/dialect.go @@ -11,6 +11,7 @@ type SQLDialect interface { createVersionTableSQL() string // sql string to create the db version table insertVersionSQL() string // sql string to insert the initial version table row dbVersionQuery(db *sql.DB) (*sql.Rows, error) + dbRunAux(db *sql.Tx) error } var dialect SQLDialect = &PostgresDialect{} @@ -49,6 +50,10 @@ func SetDialect(d string) error { // PostgresDialect struct. type PostgresDialect struct{} +func (pg PostgresDialect) dbRunAux(db *sql.Tx) error { + return nil +} + func (pg PostgresDialect) createVersionTableSQL() string { return fmt.Sprintf(`CREATE TABLE %s ( id serial NOT NULL, @@ -79,6 +84,10 @@ func (pg PostgresDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { // MySQLDialect struct. type MySQLDialect struct{} +func (m MySQLDialect) dbRunAux(db *sql.Tx) error { + return nil +} + func (m MySQLDialect) createVersionTableSQL() string { return fmt.Sprintf(`CREATE TABLE %s ( id serial NOT NULL, @@ -109,6 +118,10 @@ func (m MySQLDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { // Sqlite3Dialect struct. type Sqlite3Dialect struct{} +func (m Sqlite3Dialect) dbRunAux(db *sql.Tx) error { + return nil +} + func (m Sqlite3Dialect) createVersionTableSQL() string { return fmt.Sprintf(`CREATE TABLE %s ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -138,6 +151,10 @@ func (m Sqlite3Dialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { // RedshiftDialect struct. type RedshiftDialect struct{} +func (rs RedshiftDialect) dbRunAux(db *sql.Tx) error { + return nil +} + func (rs RedshiftDialect) createVersionTableSQL() string { return fmt.Sprintf(`CREATE TABLE %s ( id integer NOT NULL identity(1, 1), @@ -168,6 +185,10 @@ func (rs RedshiftDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { // TiDBDialect struct. type TiDBDialect struct{} +func (m TiDBDialect) dbRunAux(db *sql.Tx) error { + return nil +} + func (m TiDBDialect) createVersionTableSQL() string { return fmt.Sprintf(`CREATE TABLE %s ( id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT UNIQUE, @@ -198,34 +219,52 @@ func (m TiDBDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { // OracleDialect struct. type OracleDialect struct{} +func (OracleDialect) dbRunAux(db *sql.Tx) error { + _, err := db.Exec(fmt.Sprintf(`ALTER TABLE "%s" ADD PRIMARY KEY ("ID")`, TableName())) + if err != nil { + println("error on create PK: %s", err.Error()) + } + _, err = db.Exec(fmt.Sprintf(`CREATE SEQUENCE %s_id_seq`, TableName())) + if err != nil { + println("error on create SEQ: %s", err.Error()) + } + + var trigger = fmt.Sprintf(` + create or replace trigger %[1]s_BI + before insert on "%[1]s" + for each row + begin + if inserting then + if :NEW."ID" is null then + select %[1]s_id_seq.nextval into :NEW."ID" from dual; + end if; + end if; + end`, TableName()) + + _, err = db.Exec(trigger) + if err != nil { + println("error on create Trigger: %s", err.Error()) + } + return nil +} + func (OracleDialect) createVersionTableSQL() string { - return fmt.Sprintf(` - CREATE TABLE "%[1]s" ( - id NUMBER, + var command = fmt.Sprintf(` + CREATE TABLE %[1]s ( + id NUMBER(19), version_id NUMBER(19) NOT NULL, is_applied char(1) NOT NULL, tstamp TIMESTAMP(6) default SYS_EXTRACT_UTC(SYSTIMESTAMP) - ); - ALTER TABLE "%[1]s" ADD PRIMARY KEY ("ID"); - CREATE SEQUENCE %[1]s_id_seq; - create or replace trigger %[1]s_BI - before insert on "%[1]s" - for each row - begin - if inserting then - if :NEW."ID" is null then - select %[1]s_id_seq.nextval into :NEW."ID" from dual; - end if; - end if; - end;`, TableName()) + )`, TableName()) + return command } func (OracleDialect) insertVersionSQL() string { - return fmt.Sprintf(`INSERT INTO "%s" (version_id, is_applied) VALUES (?, ?);`, TableName()) + return fmt.Sprintf(`INSERT INTO %s (version_id, is_applied) VALUES (?, ?)`, TableName()) } func (OracleDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { - rows, err := db.Query(fmt.Sprintf(`SELECT version_id, is_applied from "%s" ORDER BY id DESC`, TableName())) + rows, err := db.Query(fmt.Sprintf(`SELECT version_id, is_applied from %s ORDER BY id DESC`, TableName())) if err != nil { return nil, err } diff --git a/migrate.go b/migrate.go index 4774af022..de01d2a8e 100644 --- a/migrate.go +++ b/migrate.go @@ -255,9 +255,21 @@ func createVersionTable(db *sql.DB) error { txn.Rollback() return err } + return txn.Commit() + + txn, err = db.Begin() + if err != nil { + return err + } version := 0 applied := true + + if err := d.dbRunAux(txn); err != nil { + txn.Rollback() + return err + + } if _, err := txn.Exec(d.insertVersionSQL(), version, applied); err != nil { txn.Rollback() return err diff --git a/migration.go b/migration.go index 595b54a27..283e54f30 100644 --- a/migration.go +++ b/migration.go @@ -54,7 +54,7 @@ func (m *Migration) run(db *sql.DB, direction bool) error { switch filepath.Ext(m.Source) { case ".sql": if err := runSQLMigration(db, m.Source, m.Version, direction); err != nil { - return fmt.Errorf("FAIL %v, quitting migration", err) + return fmt.Errorf("FAIL %v, quitting migration %s", err, m.Source) } case ".go":