diff --git a/handler/session.go b/handler/session.go index b660d9ca..86be792c 100644 --- a/handler/session.go +++ b/handler/session.go @@ -8,7 +8,11 @@ import ( "github.com/gorilla/sessions" "github.com/labstack/echo-contrib/session" "github.com/labstack/echo/v4" + "github.com/labstack/gommon/log" + "github.com/ngoduykhanh/wireguard-ui/model" + "github.com/ngoduykhanh/wireguard-ui/store/jsondb" "github.com/ngoduykhanh/wireguard-ui/util" + "github.com/rs/xid" ) func ValidSession(next echo.HandlerFunc) echo.HandlerFunc { @@ -43,6 +47,86 @@ func NeedsAdmin(next echo.HandlerFunc) echo.HandlerFunc { } } +// SSOauth uses external authentication (usually by reverseproxy) in the form of HTTP header REMOTE_USER +func SSOauth(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if !util.RemoteUser { + return next(c) + } + if !isValidSession(c) { + remoteUser := c.Request().Header.Get("REMOTE_USER") + if remoteUser == "" { + // TODO: Better error handling + log.Infof("No REMOTE_USER in reqest. Bailing out.") + return c.Redirect(http.StatusTemporaryRedirect, util.BasePath+"/") + } + log.Debugf("No valid session for REMOTE_USER: %s", remoteUser) + + db := c.Get("db").(*jsondb.JsonDB) + dbuser, err := db.GetUserByName(remoteUser) + if err != nil { + log.Infof("User %s not in database, creating user", remoteUser) + newUser := model.User{ + Username: remoteUser, + Admin: false, + } + err = db.SaveUser(newUser) + if err != nil { + // TODO: Better error handling + return c.Redirect(http.StatusTemporaryRedirect, util.BasePath+"/") + } + // Update dbuser from database + dbuser, err = db.GetUserByName(remoteUser) + if err != nil { + // TODO: Better error handling + return c.Redirect(http.StatusTemporaryRedirect, util.BasePath+"/") + } + + } else { + log.Debugf("Got user from db: %s", dbuser.Username) + } + + // Set session for REMOTE_USER + ageMax := 0 + + cookiePath := util.GetCookiePath() + + sess, _ := session.Get("session", c) + sess.Options = &sessions.Options{ + Path: cookiePath, + MaxAge: ageMax, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + } + + // set session_token + tokenUID := xid.New().String() + now := time.Now().UTC().Unix() + sess.Values["username"] = dbuser.Username + sess.Values["user_hash"] = util.GetDBUserCRC32(dbuser) + sess.Values["admin"] = dbuser.Admin + sess.Values["session_token"] = tokenUID + sess.Values["max_age"] = ageMax + sess.Values["created_at"] = now + sess.Values["updated_at"] = now + sess.Save(c.Request(), c.Response()) + + // set session_token in cookie + cookie := new(http.Cookie) + cookie.Name = "session_token" + cookie.Path = cookiePath + cookie.Value = tokenUID + cookie.MaxAge = ageMax + cookie.HttpOnly = true + cookie.SameSite = http.SameSiteLaxMode + c.SetCookie(cookie) + + return c.Redirect(http.StatusTemporaryRedirect, util.BasePath) + } + return next(c) + } +} + func isValidSession(c echo.Context) bool { if util.DisableLogin { return true diff --git a/main.go b/main.go index 1125746f..b3aee92a 100644 --- a/main.go +++ b/main.go @@ -33,6 +33,7 @@ var ( buildTime = fmt.Sprintf(time.Now().UTC().Format("01-02-2006 15:04:05")) // configuration variables flagDisableLogin = false + flagRemoteUser = false flagBindAddress = "0.0.0.0:5000" flagSmtpHostname = "127.0.0.1" flagSmtpPort = 25 @@ -77,6 +78,7 @@ var embeddedAssets embed.FS func init() { // command-line flags and env variables flag.BoolVar(&flagDisableLogin, "disable-login", util.LookupEnvOrBool("DISABLE_LOGIN", flagDisableLogin), "Disable authentication on the app. This is potentially dangerous.") + flag.BoolVar(&flagRemoteUser, "remote_user", util.LookupEnvOrBool("REMOTE_USER", flagRemoteUser), "Use HTTP header REMOTE_USER for auth. Commonly used with SSO and a proxy funcion.") flag.StringVar(&flagBindAddress, "bind-address", util.LookupEnvOrString("BIND_ADDRESS", flagBindAddress), "Address:Port to which the app will be bound.") flag.StringVar(&flagSmtpHostname, "smtp-hostname", util.LookupEnvOrString("SMTP_HOSTNAME", flagSmtpHostname), "SMTP Hostname") flag.IntVar(&flagSmtpPort, "smtp-port", util.LookupEnvOrInt("SMTP_PORT", flagSmtpPort), "SMTP Port") @@ -126,6 +128,7 @@ func init() { // update runtime config util.DisableLogin = flagDisableLogin + util.RemoteUser = flagRemoteUser util.BindAddress = flagBindAddress util.SmtpHostname = flagSmtpHostname util.SmtpPort = flagSmtpPort @@ -161,6 +164,7 @@ func init() { fmt.Println("Build Time\t:", buildTime) fmt.Println("Git Repo\t:", "https://github.com/ngoduykhanh/wireguard-ui") fmt.Println("Authentication\t:", !util.DisableLogin) + fmt.Println("Remote_user\t:", util.RemoteUser) fmt.Println("Bind address\t:", util.BindAddress) //fmt.Println("Sendgrid key\t:", util.SendgridApiKey) fmt.Println("Email from\t:", util.EmailFrom) @@ -206,9 +210,9 @@ func main() { } // register routes - app := router.New(tmplDir, extraData, util.SessionSecret) + app := router.New(tmplDir, extraData, util.SessionSecret, db) - app.GET(util.BasePath, handler.WireGuardClients(db), handler.ValidSession, handler.RefreshSession) + app.GET(util.BasePath, handler.WireGuardClients(db), handler.SSOauth, handler.ValidSession, handler.RefreshSession) // Important: Make sure that all non-GET routes check the request content type using handler.ContentTypeJson to // mitigate CSRF attacks. This is effective, because browsers don't allow setting the Content-Type header on diff --git a/router/router.go b/router/router.go index 59d352eb..8e97c796 100644 --- a/router/router.go +++ b/router/router.go @@ -13,6 +13,7 @@ import ( "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" "github.com/labstack/gommon/log" + "github.com/ngoduykhanh/wireguard-ui/store/jsondb" "github.com/ngoduykhanh/wireguard-ui/util" ) @@ -48,7 +49,7 @@ func (t *TemplateRegistry) Render(w io.Writer, name string, data interface{}, c } // New function -func New(tmplDir fs.FS, extraData map[string]interface{}, secret [64]byte) *echo.Echo { +func New(tmplDir fs.FS, extraData map[string]interface{}, secret [64]byte, db *jsondb.JsonDB) *echo.Echo { e := echo.New() cookiePath := util.GetCookiePath() @@ -60,6 +61,14 @@ func New(tmplDir fs.FS, extraData map[string]interface{}, secret [64]byte) *echo e.Use(session.Middleware(cookieStore)) + // Add db to context so middlewares can use it. + e.Use(func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + c.Set("db", db) + return next(c) + } + }) + // read html template file to string tmplBaseString, err := util.StringFromEmbedFile(tmplDir, "base.html") if err != nil { diff --git a/util/config.go b/util/config.go index 4af6bd2b..e05b81c3 100644 --- a/util/config.go +++ b/util/config.go @@ -10,6 +10,7 @@ import ( // Runtime config var ( DisableLogin bool + RemoteUser bool BindAddress string SmtpHostname string SmtpPort int