Skip to content

Commit

Permalink
Separated commands
Browse files Browse the repository at this point in the history
  • Loading branch information
snakeice committed Aug 12, 2019
1 parent 89262ff commit f9eeffd
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 19 deletions.
75 changes: 57 additions & 18 deletions dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ type SQLDialect interface {
createVersionTableSQL() string // sql string to create the db version table
insertVersionSQL() string // sql string to insert the initial version table row
dbVersionQuery(db *sql.DB) (*sql.Rows, error)
dbRunAux(db *sql.Tx) error
}

var dialect SQLDialect = &PostgresDialect{}
Expand Down Expand Up @@ -49,6 +50,10 @@ func SetDialect(d string) error {
// PostgresDialect struct.
type PostgresDialect struct{}

func (pg PostgresDialect) dbRunAux(db *sql.Tx) error {
return nil
}

func (pg PostgresDialect) createVersionTableSQL() string {
return fmt.Sprintf(`CREATE TABLE %s (
id serial NOT NULL,
Expand Down Expand Up @@ -79,6 +84,10 @@ func (pg PostgresDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
// MySQLDialect struct.
type MySQLDialect struct{}

func (m MySQLDialect) dbRunAux(db *sql.Tx) error {
return nil
}

func (m MySQLDialect) createVersionTableSQL() string {
return fmt.Sprintf(`CREATE TABLE %s (
id serial NOT NULL,
Expand Down Expand Up @@ -109,6 +118,10 @@ func (m MySQLDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
// Sqlite3Dialect struct.
type Sqlite3Dialect struct{}

func (m Sqlite3Dialect) dbRunAux(db *sql.Tx) error {
return nil
}

func (m Sqlite3Dialect) createVersionTableSQL() string {
return fmt.Sprintf(`CREATE TABLE %s (
id INTEGER PRIMARY KEY AUTOINCREMENT,
Expand Down Expand Up @@ -138,6 +151,10 @@ func (m Sqlite3Dialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
// RedshiftDialect struct.
type RedshiftDialect struct{}

func (rs RedshiftDialect) dbRunAux(db *sql.Tx) error {
return nil
}

func (rs RedshiftDialect) createVersionTableSQL() string {
return fmt.Sprintf(`CREATE TABLE %s (
id integer NOT NULL identity(1, 1),
Expand Down Expand Up @@ -168,6 +185,10 @@ func (rs RedshiftDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
// TiDBDialect struct.
type TiDBDialect struct{}

func (m TiDBDialect) dbRunAux(db *sql.Tx) error {
return nil
}

func (m TiDBDialect) createVersionTableSQL() string {
return fmt.Sprintf(`CREATE TABLE %s (
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT UNIQUE,
Expand Down Expand Up @@ -198,34 +219,52 @@ func (m TiDBDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
// OracleDialect struct.
type OracleDialect struct{}

func (OracleDialect) dbRunAux(db *sql.Tx) error {
_, err := db.Exec(fmt.Sprintf(`ALTER TABLE "%s" ADD PRIMARY KEY ("ID")`, TableName()))
if err != nil {
println("error on create PK: %s", err.Error())
}
_, err = db.Exec(fmt.Sprintf(`CREATE SEQUENCE %s_id_seq`, TableName()))
if err != nil {
println("error on create SEQ: %s", err.Error())
}

var trigger = fmt.Sprintf(`
create or replace trigger %[1]s_BI
before insert on "%[1]s"
for each row
begin
if inserting then
if :NEW."ID" is null then
select %[1]s_id_seq.nextval into :NEW."ID" from dual;
end if;
end if;
end`, TableName())

_, err = db.Exec(trigger)
if err != nil {
println("error on create Trigger: %s", err.Error())
}
return nil
}

func (OracleDialect) createVersionTableSQL() string {
return fmt.Sprintf(`
CREATE TABLE "%[1]s" (
id NUMBER,
var command = fmt.Sprintf(`
CREATE TABLE %[1]s (
id NUMBER(19),
version_id NUMBER(19) NOT NULL,
is_applied char(1) NOT NULL,
tstamp TIMESTAMP(6) default SYS_EXTRACT_UTC(SYSTIMESTAMP)
);
ALTER TABLE "%[1]s" ADD PRIMARY KEY ("ID");
CREATE SEQUENCE %[1]s_id_seq;
create or replace trigger %[1]s_BI
before insert on "%[1]s"
for each row
begin
if inserting then
if :NEW."ID" is null then
select %[1]s_id_seq.nextval into :NEW."ID" from dual;
end if;
end if;
end;`, TableName())
)`, TableName())
return command
}

func (OracleDialect) insertVersionSQL() string {
return fmt.Sprintf(`INSERT INTO "%s" (version_id, is_applied) VALUES (?, ?);`, TableName())
return fmt.Sprintf(`INSERT INTO %s (version_id, is_applied) VALUES (?, ?)`, TableName())
}

func (OracleDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
rows, err := db.Query(fmt.Sprintf(`SELECT version_id, is_applied from "%s" ORDER BY id DESC`, TableName()))
rows, err := db.Query(fmt.Sprintf(`SELECT version_id, is_applied from %s ORDER BY id DESC`, TableName()))
if err != nil {
return nil, err
}
Expand Down
12 changes: 12 additions & 0 deletions migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,21 @@ func createVersionTable(db *sql.DB) error {
txn.Rollback()
return err
}
return txn.Commit()

txn, err = db.Begin()
if err != nil {
return err
}

version := 0
applied := true

if err := d.dbRunAux(txn); err != nil {
txn.Rollback()
return err

}
if _, err := txn.Exec(d.insertVersionSQL(), version, applied); err != nil {
txn.Rollback()
return err
Expand Down
2 changes: 1 addition & 1 deletion migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (m *Migration) run(db *sql.DB, direction bool) error {
switch filepath.Ext(m.Source) {
case ".sql":
if err := runSQLMigration(db, m.Source, m.Version, direction); err != nil {
return fmt.Errorf("FAIL %v, quitting migration", err)
return fmt.Errorf("FAIL %v, quitting migration %s", err, m.Source)
}

case ".go":
Expand Down

0 comments on commit f9eeffd

Please sign in to comment.