diff --git a/cmd/octillery/main.go b/cmd/octillery/main.go index a83c4ef..94d4af1 100644 --- a/cmd/octillery/main.go +++ b/cmd/octillery/main.go @@ -2,12 +2,10 @@ package main import ( "bufio" - "bytes" coresql "database/sql" "encoding/csv" "encoding/json" "fmt" - "io" "io/ioutil" "log" "os" @@ -17,19 +15,17 @@ import ( "strconv" "strings" "time" - "unicode" flags "github.com/jessevdk/go-flags" vtparser "github.com/knocknote/vitess-sqlparser/sqlparser" "github.com/pkg/errors" - "github.com/schemalex/schemalex" - "github.com/schemalex/schemalex/diff" "go.knocknote.io/octillery" "go.knocknote.io/octillery/algorithm" "go.knocknote.io/octillery/config" "go.knocknote.io/octillery/connection" _ "go.knocknote.io/octillery/connection/adapter/plugin" "go.knocknote.io/octillery/database/sql" + "go.knocknote.io/octillery/migrator" "go.knocknote.io/octillery/printer" "go.knocknote.io/octillery/sqlparser" "go.knocknote.io/octillery/transposer" @@ -117,98 +113,7 @@ func (cmd *TransposeCommand) Execute(args []string) error { return errors.WithStack(transposer.New().Transpose(pattern, searchPath, cmd.Ignore, transposeClosure)) } -type schemaTextSource string - -func (s schemaTextSource) WriteSchema(dst io.Writer) error { - if _, err := io.WriteString(dst, string(s)); err != nil { - return errors.Wrap(err, `failed to copy text contents to dst`) - } - return nil -} - -type serverSource struct { - conn *coresql.DB -} - -// WriteSchema get normalized schema from mysql server and write it to dst. -// This method's original source code is `schemalex/source.go` -func (s *serverSource) WriteSchema(dst io.Writer) error { - db := s.conn - tableRows, err := db.Query("SHOW TABLES") - if err != nil { - return errors.Wrap(err, `failed to execute 'SHOW TABLES'`) - } - defer tableRows.Close() - parser, err := sqlparser.New() - if err != nil { - return errors.WithStack(err) - } - var table string - var tableSchema string - var buf bytes.Buffer - for tableRows.Next() { - if err = tableRows.Scan(&table); err != nil { - return errors.Wrap(err, `failed to scan tables`) - } - - if err = db.QueryRow("SHOW CREATE TABLE `"+table+"`").Scan(&table, &tableSchema); err != nil { - return errors.Wrapf(err, `failed to execute 'SHOW CREATE TABLE "%s"'`, table) - } - if buf.Len() > 0 { - buf.WriteString("\n\n") - } - query, err := parser.Parse(tableSchema) - if err != nil { - return errors.WithStack(err) - } - // normalize DDL because schemalex cannot parse PARTITION option - normalizedSchema := vtparser.String(query.(*sqlparser.QueryBase).Stmt) - buf.WriteString(normalizedSchema) - buf.WriteByte(';') - } - - return errors.WithStack(schemalex.NewReaderSource(&buf).WriteSchema(dst)) -} - -func (cmd *MigrateCommand) compareSchema(from schemalex.SchemaSource, to schemalex.SchemaSource) (string, error) { - var buf bytes.Buffer - p := schemalex.New() - if err := diff.Sources( - &buf, - from, - to, - diff.WithTransaction(false), diff.WithParser(p), - ); err != nil { - return "", errors.WithStack(err) - } - return buf.String(), nil -} - -// CompareResult type for results of comparing schema -type CompareResult struct { - diff string - dsn string - conn *coresql.DB -} - -// CombinedQuery has all `sqlparser.Query` for a DNS -type CombinedQuery struct { - queries []sqlparser.Query - conn *coresql.DB -} - -func (c *CombinedQuery) allDDL() string { - allDDL := []string{} - for _, query := range c.queries { - // normalize DDL because schemalex cannot parse PARTITION option - normalizedDDL := vtparser.String(query.(*sqlparser.QueryBase).Stmt) - allDDL = append(allDDL, normalizedDDL) - } - return strings.Join(allDDL, ";\n") -} - // Execute executes migrate command -// nolint: gocyclo func (cmd *MigrateCommand) Execute(args []string) error { if len(args) == 0 { return errors.New("argument is required. it is path to directory includes schema file or direct path to schema file") @@ -217,151 +122,12 @@ func (cmd *MigrateCommand) Execute(args []string) error { return errors.WithStack(err) } - schamePath := args[0] - parser, err := sqlparser.New() + schemaPath := args[0] + migrator, err := migrator.NewMigrator("mysql", cmd.DryRun, cmd.Quiet) if err != nil { return errors.WithStack(err) } - tableNameToOriginalQueryMap := map[string]sqlparser.Query{} - queries := []sqlparser.Query{} - if err := filepath.Walk(schamePath, func(path string, info os.FileInfo, err error) error { - if err != nil { - return errors.WithStack(err) - } - if info.IsDir() { - return nil - } - schema, err := ioutil.ReadFile(path) - if err != nil { - return errors.WithStack(err) - } - query, err := parser.Parse(string(schema)) - if err != nil { - return errors.WithStack(err) - } - tableNameToOriginalQueryMap[query.Table()] = query - queries = append(queries, query) - return nil - }); err != nil { - return errors.WithStack(err) - } - - mgr, err := connection.NewConnectionManager() - if err != nil { - return errors.WithStack(err) - } - dsnToQueryMap := map[string]*CombinedQuery{} - for _, query := range queries { - conn, err := mgr.ConnectionByTableName(query.Table()) - if err != nil { - return errors.WithStack(err) - } - if conn.IsShard { - for _, shard := range conn.ShardConnections.AllShard() { - cfg := conn.Config.ShardConfigByName(shard.ShardName) - dsn := fmt.Sprintf("%s/%s", cfg.Masters[0], cfg.NameOrPath) - if _, exists := dsnToQueryMap[dsn]; exists { - dsnToQueryMap[dsn].queries = append(dsnToQueryMap[dsn].queries, query) - } else { - dsnToQueryMap[dsn] = &CombinedQuery{ - queries: []sqlparser.Query{query}, - conn: shard.Connection, - } - } - } - } else { - cfg := conn.Config - dsn := fmt.Sprintf("%s/%s", cfg.Masters[0], cfg.NameOrPath) - if _, exists := dsnToQueryMap[dsn]; exists { - dsnToQueryMap[dsn].queries = append(dsnToQueryMap[dsn].queries, query) - } else { - dsnToQueryMap[dsn] = &CombinedQuery{ - queries: []sqlparser.Query{query}, - conn: conn.Connection, - } - } - } - } - results := []*CompareResult{} - for dsn, combinedQuery := range dsnToQueryMap { - allDDL := combinedQuery.allDDL() - fromSource := &serverSource{ - conn: combinedQuery.conn, - } - diff, err := cmd.compareSchema(fromSource, schemaTextSource(allDDL)) - if err != nil { - return errors.WithStack(err) - } - if len(diff) == 0 { - continue - } - - replacedDDL := []string{} - splittedDDL := strings.Split(diff, ";") - for _, ddl := range splittedDDL { - if ddl == "" || ddl == "\n" { - continue - } - if !strings.HasPrefix(ddl, "CREATE TABLE") { - replacedDDL = append(replacedDDL, ddl+";") - continue - } - - // If diff is `CREATE TABLE` statement, use original DDL ( not eliminated PARTITION option ) - stmt, err := parser.Parse(ddl) - if err != nil { - return errors.WithStack(err) - } - tableName := stmt.Table() - query := tableNameToOriginalQueryMap[tableName] - replacedDDL = append(replacedDDL, query.(*sqlparser.QueryBase).Text) - } - results = append(results, &CompareResult{ - diff: strings.Join(replacedDDL, "\n"), - dsn: dsn, - conn: combinedQuery.conn, - }) - } - if cmd.DryRun { - if len(results) > 0 { - for _, result := range results { - if result.diff == "" || result.diff == "\n" { - continue - } - fmt.Printf("[ %s ]\n\n", result.dsn) - for _, diff := range strings.Split(result.diff, ";") { - trimmedDiff := strings.TrimFunc(diff, func(r rune) bool { - return unicode.IsSpace(r) - }) - if trimmedDiff == "" { - continue - } - fmt.Printf("%s\n\n", trimmedDiff) - } - } - } - } else { - for _, result := range results { - if !cmd.Quiet { - fmt.Printf("[ %s ]\n\n", result.dsn) - } - for _, diff := range strings.Split(result.diff, ";") { - trimmedDiff := strings.TrimFunc(diff, func(r rune) bool { - return unicode.IsSpace(r) - }) - if trimmedDiff == "" { - continue - } - if !cmd.Quiet { - fmt.Printf("%s\n\n", trimmedDiff) - } - if _, err := result.conn.Exec(trimmedDiff); err != nil { - return errors.WithStack(err) - } - } - } - } - return nil + return errors.WithStack(migrator.Migrate(schemaPath)) } func (cmd *ImportCommand) schemaFromTableName(tableName string) (vtparser.Statement, error) { diff --git a/migrator/migrator.go b/migrator/migrator.go new file mode 100644 index 0000000..9419bad --- /dev/null +++ b/migrator/migrator.go @@ -0,0 +1,190 @@ +package migrator + +import ( + "database/sql" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "sync" + + vtparser "github.com/knocknote/vitess-sqlparser/sqlparser" + "github.com/pkg/errors" + "go.knocknote.io/octillery/connection" + "go.knocknote.io/octillery/sqlparser" +) + +// DBMigratorPlugin interface for migration +type DBMigratorPlugin interface { + Init([]sqlparser.Query) + CompareSchema(*sql.DB, []string) ([]string, error) +} + +var ( + migratorPluginsMu sync.RWMutex + migratorPlugins = make(map[string]func() DBMigratorPlugin) +) + +// Migrator migrates database schema +type Migrator struct { + DryRun bool + Quiet bool + Plugin DBMigratorPlugin +} + +type dsnWithConnection struct { + dsn string + conn *sql.DB +} + +type combinedQuery struct { + queries []sqlparser.Query + conn *sql.DB +} + +// Register register DBMigratorPlugin with adapter name +func Register(name string, pluginCreator func() DBMigratorPlugin) { + migratorPluginsMu.Lock() + defer migratorPluginsMu.Unlock() + if pluginCreator == nil { + panic("plugin creator is nil") + } + if _, dup := migratorPlugins[name]; dup { + panic("register called twice for migrator plugin " + name) + } + migratorPlugins[name] = pluginCreator +} + +// NewMigrator creates instance of Migrator +func NewMigrator(adapter string, dryRun bool, isQuiet bool) (*Migrator, error) { + plugin := migratorPlugins[adapter] + if plugin == nil { + return nil, errors.Errorf("cannot find migrator plugin for %s", adapter) + } + return &Migrator{ + DryRun: dryRun, + Quiet: !dryRun && isQuiet, + Plugin: plugin(), + }, nil +} + +// Migrate executes migrate +func (m *Migrator) Migrate(schemaPath string) error { + queries, err := m.queries(schemaPath) + if err != nil { + return errors.WithStack(err) + } + m.Plugin.Init(queries) + dsnToQueryMap := map[string]*combinedQuery{} + for _, query := range queries { + dsnConns, err := m.dsnWithConnections(query) + if err != nil { + return errors.WithStack(err) + } + for _, dsnConn := range dsnConns { + dsn := dsnConn.dsn + if _, exists := dsnToQueryMap[dsn]; exists { + dsnToQueryMap[dsn].queries = append(dsnToQueryMap[dsn].queries, query) + } else { + dsnToQueryMap[dsn] = &combinedQuery{ + queries: []sqlparser.Query{query}, + conn: dsnConn.conn, + } + } + } + } + for dsn, combinedQuery := range dsnToQueryMap { + allDDL := combinedQuery.allDDL() + diff, err := m.Plugin.CompareSchema(combinedQuery.conn, allDDL) + if err != nil { + return errors.WithStack(err) + } + if len(diff) == 0 { + continue + } + if !m.Quiet { + fmt.Printf("[ %s ]\n\n", dsn) + } + for _, diff := range diff { + if !m.Quiet { + fmt.Printf("%s\n\n", diff) + } + if m.DryRun { + continue + } + if _, err := combinedQuery.conn.Exec(diff); err != nil { + return errors.WithStack(err) + } + } + } + return nil +} + +func (m *Migrator) queries(schemaPath string) ([]sqlparser.Query, error) { + parser, err := sqlparser.New() + if err != nil { + return nil, errors.WithStack(err) + } + queries := []sqlparser.Query{} + if err := filepath.Walk(schemaPath, func(path string, info os.FileInfo, err error) error { + if err != nil { + return errors.WithStack(err) + } + if info.IsDir() { + return nil + } + schema, err := ioutil.ReadFile(path) + if err != nil { + return errors.WithStack(err) + } + query, err := parser.Parse(string(schema)) + if err != nil { + return errors.WithStack(err) + } + queries = append(queries, query) + return nil + }); err != nil { + return nil, errors.WithStack(err) + } + return queries, nil +} + +func (c *combinedQuery) allDDL() []string { + allDDL := []string{} + for _, query := range c.queries { + // normalize DDL because schemalex cannot parse PARTITION option + normalizedDDL := vtparser.String(query.(*sqlparser.QueryBase).Stmt) + allDDL = append(allDDL, normalizedDDL) + } + return allDDL +} + +func (m *Migrator) dsnWithConnections(query sqlparser.Query) ([]*dsnWithConnection, error) { + mgr, err := connection.NewConnectionManager() + if err != nil { + return nil, errors.WithStack(err) + } + conn, err := mgr.ConnectionByTableName(query.Table()) + if err != nil { + return nil, errors.WithStack(err) + } + dsnConns := []*dsnWithConnection{} + if conn.IsShard { + for _, shard := range conn.ShardConnections.AllShard() { + cfg := conn.Config.ShardConfigByName(shard.ShardName) + dsn := fmt.Sprintf("%s/%s", cfg.Masters[0], cfg.NameOrPath) + dsnConns = append(dsnConns, &dsnWithConnection{ + dsn: dsn, + conn: shard.Connection, + }) + } + } else { + cfg := conn.Config + dsn := fmt.Sprintf("%s/%s", cfg.Masters[0], cfg.NameOrPath) + dsnConns = append(dsnConns, &dsnWithConnection{ + dsn: dsn, + conn: conn.Connection, + }) + } + return dsnConns, nil +} diff --git a/migrator/mysql.go b/migrator/mysql.go new file mode 100644 index 0000000..7e9f068 --- /dev/null +++ b/migrator/mysql.go @@ -0,0 +1,139 @@ +package migrator + +import ( + "bytes" + "database/sql" + "io" + "strings" + "unicode" + + vtparser "github.com/knocknote/vitess-sqlparser/sqlparser" + "github.com/pkg/errors" + "github.com/schemalex/schemalex" + "github.com/schemalex/schemalex/diff" + "go.knocknote.io/octillery/sqlparser" +) + +type schemaTextSource string + +func (s schemaTextSource) WriteSchema(dst io.Writer) error { + if _, err := io.WriteString(dst, string(s)); err != nil { + return errors.Wrap(err, `failed to copy text contents to dst`) + } + return nil +} + +type serverSource struct { + conn *sql.DB +} + +// WriteSchema get normalized schema from mysql server and write it to dst. +// This method's original source code is `schemalex/source.go` +func (s *serverSource) WriteSchema(dst io.Writer) error { + db := s.conn + tableRows, err := db.Query("SHOW TABLES") + if err != nil { + return errors.Wrap(err, `failed to execute 'SHOW TABLES'`) + } + defer tableRows.Close() + parser, err := sqlparser.New() + if err != nil { + return errors.WithStack(err) + } + var ( + table string + tableSchema string + buf bytes.Buffer + ) + for tableRows.Next() { + if err := tableRows.Scan(&table); err != nil { + return errors.Wrap(err, `failed to scan tables`) + } + + if err := db.QueryRow("SHOW CREATE TABLE `"+table+"`").Scan(&table, &tableSchema); err != nil { + return errors.Wrapf(err, `failed to execute 'SHOW CREATE TABLE "%s"'`, table) + } + if buf.Len() > 0 { + buf.WriteString("\n\n") + } + query, err := parser.Parse(tableSchema) + if err != nil { + return errors.WithStack(err) + } + // normalize DDL because schemalex cannot parse PARTITION option + normalizedSchema := vtparser.String(query.(*sqlparser.QueryBase).Stmt) + buf.WriteString(normalizedSchema) + buf.WriteByte(';') + } + + return errors.WithStack(schemalex.NewReaderSource(&buf).WriteSchema(dst)) +} + +// MySQLMigrator implements DBMigratorPlugin +type MySQLMigrator struct { + tableNameToQueryMap map[string]sqlparser.Query +} + +// Init create mapping from table name to sqlparser.Query +func (m *MySQLMigrator) Init(queries []sqlparser.Query) { + m.tableNameToQueryMap = map[string]sqlparser.Query{} + for _, query := range queries { + m.tableNameToQueryMap[query.Table()] = query + } +} + +// CompareSchema compare schema on mysql server with local schema +func (m *MySQLMigrator) CompareSchema(conn *sql.DB, allDDL []string) ([]string, error) { + from := &serverSource{conn: conn} + to := schemaTextSource(strings.Join(allDDL, ";\n")) + var buf bytes.Buffer + p := schemalex.New() + if err := diff.Sources( + &buf, + from, + to, + diff.WithTransaction(false), diff.WithParser(p), + ); err != nil { + return nil, errors.WithStack(err) + } + schemaDiff := buf.String() + if len(schemaDiff) == 0 { + return nil, nil + } + parser, err := sqlparser.New() + if err != nil { + return nil, errors.WithStack(err) + } + replacedDDL := []string{} + splittedDDL := strings.Split(schemaDiff, ";") + for _, ddl := range splittedDDL { + trimmedDDL := strings.TrimFunc(ddl, func(r rune) bool { + return unicode.IsSpace(r) + }) + if trimmedDDL == "" { + continue + } + if !strings.HasPrefix(trimmedDDL, "CREATE TABLE") { + replacedDDL = append(replacedDDL, trimmedDDL) + continue + } + + // If diff is `CREATE TABLE` statement, use original DDL ( not eliminated PARTITION option ) + stmt, err := parser.Parse(trimmedDDL) + if err != nil { + return nil, errors.WithStack(err) + } + tableName := stmt.Table() + query := m.tableNameToQueryMap[tableName] + replacedDDL = append(replacedDDL, strings.TrimFunc(query.(*sqlparser.QueryBase).Text, func(r rune) bool { + return unicode.IsSpace(r) || string(r) == ";" + })) + } + return replacedDDL, nil +} + +func init() { + Register("mysql", func() DBMigratorPlugin { + return &MySQLMigrator{} + }) +}