diff --git a/internal/models/model.go b/internal/models/model.go index 558fe228e..8ae44bc9e 100644 --- a/internal/models/model.go +++ b/internal/models/model.go @@ -5,9 +5,13 @@ import ( "strings" "time" + "crypto/tls" + "crypto/x509" + "io/ioutil" + macaron "gopkg.in/macaron.v1" - _ "github.com/go-sql-driver/mysql" + "github.com/go-sql-driver/mysql" "github.com/go-xorm/core" "github.com/go-xorm/xorm" _ "github.com/lib/pq" @@ -72,7 +76,10 @@ func (model *BaseModel) pageLimitOffset() int { // 创建Db func CreateDb() *xorm.Engine { - dsn := getDbEngineDSN(app.Setting) + dsn, err := getDbEngineDSN(app.Setting) + if err != nil { + logger.Fatal("创建xorm引擎失败", err) + } engine, err := xorm.NewEngine(app.Setting.Db.Engine, dsn) if err != nil { logger.Fatal("创建xorm引擎失败", err) @@ -100,13 +107,15 @@ func CreateDb() *xorm.Engine { // 创建临时数据库连接 func CreateTmpDb(setting *setting.Setting) (*xorm.Engine, error) { - dsn := getDbEngineDSN(setting) - + dsn, err := getDbEngineDSN(setting) + if err != nil { + return nil, err + } return xorm.NewEngine(setting.Db.Engine, dsn) } // 获取数据库引擎DSN mysql,sqlite,postgres -func getDbEngineDSN(setting *setting.Setting) string { +func getDbEngineDSN(setting *setting.Setting) (string, error) { engine := strings.ToLower(setting.Db.Engine) dsn := "" switch engine { @@ -118,6 +127,14 @@ func getDbEngineDSN(setting *setting.Setting) string { setting.Db.Port, setting.Db.Database, setting.Db.Charset) + if setting.Db.Sslmode == "true" || setting.Db.Sslmode == "skip-verify" { + tlsConfig, err := getTlsConfig(setting) + if err != nil { + return dsn, err + } + mysql.RegisterTLSConfig("custom", tlsConfig) + dsn += "&tls=custom" + } case "postgres": dsn = fmt.Sprintf("user=%s password=%s host=%s port=%d dbname=%s sslmode=disable", setting.Db.User, @@ -127,7 +144,59 @@ func getDbEngineDSN(setting *setting.Setting) string { setting.Db.Database) } - return dsn + return dsn, nil +} + +func getTlsConfig(setting *setting.Setting) (*tls.Config, error) { + // https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig + rootCertPool := x509.NewCertPool() + pem, err := ioutil.ReadFile(setting.Db.SslCaFile) + if err != nil { + return nil, err + } + if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { + return nil, fmt.Errorf("Failed to append PEM.") + } + clientCert := make([]tls.Certificate, 0, 1) + certs, err := tls.LoadX509KeyPair(setting.Db.SslCertFile, setting.Db.SslKeyFile) + if err != nil { + return nil, err + } + clientCert = append(clientCert, certs) + cfg := &tls.Config{ + MinVersion: tls.VersionTLS12, + RootCAs: rootCertPool, + Certificates: clientCert, + } + if setting.Db.Sslmode == "skip-verify" { + cfg.InsecureSkipVerify = true + } + sn := setting.Db.SslServerName + if sn != "" { + cfg.ServerName = sn + // Solve gcp invalid hostname in CN: https://github.com/golang/go/issues/40748#issuecomment-673599371 + if strings.Contains(sn, ":") { + if cfg.InsecureSkipVerify != true { + cfg.InsecureSkipVerify = true + cfg.VerifyConnection = func(cs tls.ConnectionState) error { + commonName := cs.PeerCertificates[0].Subject.CommonName + if commonName != cs.ServerName { + return fmt.Errorf("invalid certificate name %q, expected %q", commonName, cs.ServerName) + } + opts := x509.VerifyOptions{ + Roots: rootCertPool, + Intermediates: x509.NewCertPool(), + } + for _, cert := range cs.PeerCertificates[1:] { + opts.Intermediates.AddCert(cert) + } + _, err := cs.PeerCertificates[0].Verify(opts) + return err + } + } + } + } + return cfg, nil } func keepDbAlived(engine *xorm.Engine) { diff --git a/internal/modules/setting/setting.go b/internal/modules/setting/setting.go index 46d53d843..8d2aa8a18 100644 --- a/internal/modules/setting/setting.go +++ b/internal/modules/setting/setting.go @@ -22,6 +22,12 @@ type Setting struct { Charset string MaxIdleConns int MaxOpenConns int + + Sslmode string + SslCaFile string + SslCertFile string + SslKeyFile string + SslServerName string } AllowIps string AppName string @@ -59,6 +65,12 @@ func Read(filename string) (*Setting, error) { s.Db.MaxIdleConns = section.Key("db.max.idle.conns").MustInt(30) s.Db.MaxOpenConns = section.Key("db.max.open.conns").MustInt(100) + s.Db.Sslmode = section.Key("db.sslmode").MustString("") + s.Db.SslCaFile = section.Key("db.ssl_ca_file").MustString("") + s.Db.SslCertFile = section.Key("db.ssl_cert_file").MustString("") + s.Db.SslKeyFile = section.Key("db.ssl_key_file").MustString("") + s.Db.SslServerName = section.Key("db.ssl_server_name").MustString("") + s.AllowIps = section.Key("allow_ips").MustString("") s.AppName = section.Key("app.name").MustString("定时任务管理系统") s.ApiKey = section.Key("api.key").MustString("") diff --git a/internal/routers/install/install.go b/internal/routers/install/install.go index 22e0f8cb8..a0a37879a 100644 --- a/internal/routers/install/install.go +++ b/internal/routers/install/install.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "strconv" + "strings" macaron "gopkg.in/macaron.v1" @@ -27,6 +28,11 @@ type InstallForm struct { DbPassword string `binding:"Required;MaxSize(30)"` DbName string `binding:"Required;MaxSize(50)"` DbTablePrefix string `binding:"MaxSize(20)"` + DbSslmode string `binding:"In(,false,true,skip-verify)"` + DbSslCaFile string `binding:"MaxSize(255)"` + DbSslCertFile string `binding:"MaxSize(255)"` + DbSslKeyFile string `binding:"MaxSize(255)"` + DbSslServerName string `binding:"MaxSize(255)"` AdminUsername string `binding:"Required;MinSize(3)"` AdminPassword string `binding:"Required;MinSize(6)"` ConfirmAdminPassword string `binding:"Required;MinSize(6)"` @@ -38,7 +44,11 @@ func (f InstallForm) Error(ctx *macaron.Context, errs binding.Errors) { return } json := utils.JsonResponse{} - content := json.CommonFailure("表单验证失败, 请检测输入") + newErrs := make([]error, len(errs)) + for i, e := range errs { + newErrs[i] = fmt.Errorf("表单验证失败-Fields: %s, Kind: %s, Error: %s", strings.Join(e.Fields(), ", "), e.Kind(), e.Error()) + } + content := json.CommonFailure("表单验证失败, 请检测输入", newErrs...) ctx.Write([]byte(content)) } @@ -108,6 +118,11 @@ func writeConfig(form InstallForm) error { "db.database", form.DbName, "db.prefix", form.DbTablePrefix, "db.charset", "utf8", + "db.sslmode", form.DbSslmode, + "db.ssl_ca_file", form.DbSslCaFile, + "db.ssl_cert_file", form.DbSslCertFile, + "db.ssl_key_file", form.DbSslKeyFile, + "db.ssl_server_name", form.DbSslServerName, "db.max.idle.conns", "5", "db.max.open.conns", "100", "allow_ips", "", @@ -146,6 +161,11 @@ func testDbConnection(form InstallForm) error { s.Db.User = form.DbUsername s.Db.Password = form.DbPassword s.Db.Database = form.DbName + s.Db.Sslmode = form.DbSslmode + s.Db.SslCaFile = form.DbSslCaFile + s.Db.SslCertFile = form.DbSslCertFile + s.Db.SslKeyFile = form.DbSslKeyFile + s.Db.SslServerName = form.DbSslServerName s.Db.Charset = "utf8" db, err := models.CreateTmpDb(&s) if err != nil { diff --git a/internal/routers/routers.go b/internal/routers/routers.go index c2693feb6..fc01e4e35 100644 --- a/internal/routers/routers.go +++ b/internal/routers/routers.go @@ -50,7 +50,7 @@ func Register(m *macaron.Macaron) { m.Get("/", func(ctx *macaron.Context) { file, err := statikFS.Open("/index.html") if err != nil { - logger.Error("读取首页文件失败: %s", err) + logger.Errorf("读取首页文件失败: %s", err) ctx.WriteHeader(http.StatusInternalServerError) return } diff --git a/web/vue/src/pages/install/index.vue b/web/vue/src/pages/install/index.vue index 663f8621b..468df22f2 100644 --- a/web/vue/src/pages/install/index.vue +++ b/web/vue/src/pages/install/index.vue @@ -1,7 +1,13 @@