diff --git a/cmd/server/main.go b/cmd/server/main.go index a0341eb..07cc9f9 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -48,9 +48,6 @@ func main() { // }, Scopes: []string{oidc.ScopeOpenID}, } - oidcConfig := oidc.Config{ - ClientID: requireStringFromEnv("DELBM_OIDC_CLIENT_ID"), - } // Get and init config config := domain.Configuration{ MaxContentDownloadAttempts: getIntFromEnv("DELBM_MAX_CONTENT_DOWNLOAD_ATTEMPTS", 3), @@ -65,8 +62,6 @@ func main() { ServerWriteTimeoutSeconds: getIntFromEnv("DELBM_SERVER_WRITE_TIMEOUT_SECONDS", 10), SessionCookieSecretKey: getStringFromEnv("DELBM_SESSION_COOKIE_SECRET_KEY", uuid.New().String()), ServerPort: getIntFromEnv("DELBM_SERVER_PORT", 1323), - Oauth2Config: oauth2Config, - OidcConfig: oidcConfig, } // Start the bookMarkCrawler quitChannel := make(chan struct{}) @@ -77,10 +72,9 @@ func main() { bookMarkCrawler.Run(quitChannel) // Start the server server.RunServer(server.Controller{ - Store: &store, - Config: config, - OidcProvider: oidcProvider, - }) + Store: &store, + Config: config, + }, oauth2Config, oidcProvider) } func requireStringFromEnv(s string) string { diff --git a/internal/domain/domain.go b/internal/domain/domain.go index aace177..034826f 100644 --- a/internal/domain/domain.go +++ b/internal/domain/domain.go @@ -2,9 +2,6 @@ package domain import ( "time" - - "github.com/coreos/go-oidc/v3/oidc" - "golang.org/x/oauth2" ) type BookmarkSlice struct { @@ -52,8 +49,6 @@ type Configuration struct { ServerWriteTimeoutSeconds int SessionCookieSecretKey string ServerPort int - Oauth2Config oauth2.Config - OidcConfig oidc.Config } const ( diff --git a/internal/middleware/oidc.go b/internal/middleware/oidc.go new file mode 100644 index 0000000..48ae8f9 --- /dev/null +++ b/internal/middleware/oidc.go @@ -0,0 +1,68 @@ +package middleware + +import ( + "encoding/base64" + "github.com/aggregat4/go-baselib/crypto" + "github.com/coreos/go-oidc/v3/oidc" + "github.com/labstack/echo/v4" + "golang.org/x/oauth2" + "log" + "net/http" + "time" +) + +func CreateOidcMiddleware(isAuthenticated func(c echo.Context) bool, oidcConfig oauth2.Config) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if !isAuthenticated(c) { + state, err := crypto.RandString(16) + if err != nil { + return c.Render(http.StatusUnauthorized, "error-unauthorized", nil) + } + // encode the original request URL into the state so we can redirect back to it after a successful login + // TODO: think about whether storing the original URL like this is generic or should be some sort of custom config + state = state + "|" + base64.StdEncoding.EncodeToString([]byte(c.Request().URL.String())) + c.SetCookie(&http.Cookie{ + Name: "oidc-callback-state-cookie", + Value: state, + Path: "/", // TODO: this path is not context path safe + Expires: time.Now().Add(time.Minute * 5), + HttpOnly: true, + }) + return c.Redirect(http.StatusFound, oidcConfig.AuthCodeURL(state)) + } else { + return next(c) + } + } + } +} + +func CreateOidcCallbackEndpoint(oidcConfig oauth2.Config, oidcProvider *oidc.Provider, delegate func(c echo.Context, idToken *oidc.IDToken, state string) error) echo.HandlerFunc { + verifier := oidcProvider.Verifier(&oidc.Config{ClientID: oidcConfig.ClientID}) + return func(c echo.Context) error { + // check state vs cookie + state, err := c.Cookie("oidc-callback-state-cookie") + if err != nil { + log.Println(err) + return c.Render(http.StatusUnauthorized, "error-unauthorized", nil) + } + if c.QueryParam("state") != state.Value { + return c.Render(http.StatusUnauthorized, "error-unauthorized", nil) + } + oauth2Token, err := oidcConfig.Exchange(c.Request().Context(), c.QueryParam("code")) + if err != nil { + log.Println(err) + return c.Render(http.StatusUnauthorized, "error-unauthorized", nil) + } + rawIDToken, ok := oauth2Token.Extra("id_token").(string) + if !ok { + return c.Render(http.StatusUnauthorized, "error-unauthorized", nil) + } + idToken, err := verifier.Verify(c.Request().Context(), rawIDToken) + if err != nil { + log.Println(err) + return c.Render(http.StatusUnauthorized, "error-unauthorized", nil) + } + return delegate(c, idToken, state.Value) + } +} diff --git a/internal/repository/repository.go b/internal/repository/repository.go index 437caa7..e0c7849 100644 --- a/internal/repository/repository.go +++ b/internal/repository/repository.go @@ -199,16 +199,11 @@ func (store *Store) FindExistingBookmarkId(url string, userid int) (int64, error } func (store *Store) DeleteBookmark(url string, userid int) error { - id, err := store.FindExistingBookmarkId(url, userid) - if err != nil { - return err - } - _, err = store.db.Exec("DELETE FROM bookmarks WHERE id = ?", id) + _, err := store.db.Exec("DELETE FROM bookmarks WHERE user_id = ? AND url = ?", userid, url) return err } func (store *Store) AddOrUpdateBookmark(bookmark domain.Bookmark, userid int) error { - // we perform an upsert because the URL may already be stored and we just want to update the other fields _, err := store.db.Exec(` INSERT INTO bookmarks (user_id, url, title, description, tags, private, readlater, created, updated) diff --git a/internal/server/server.go b/internal/server/server.go index 86c235a..71d9ec8 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "errors" "github.com/labstack/echo/v4" + "golang.org/x/oauth2" "html/template" "io" "log" @@ -15,8 +16,8 @@ import ( "time" "aggregat4/gobookmarks/internal/domain" + internalmiddleware "aggregat4/gobookmarks/internal/middleware" "aggregat4/gobookmarks/internal/repository" - "github.com/aggregat4/go-baselib/crypto" "github.com/aggregat4/go-baselib/lang" "github.com/coreos/go-oidc/v3/oidc" @@ -33,12 +34,11 @@ var viewTemplates embed.FS var images embed.FS type Controller struct { - Store *repository.Store - Config domain.Configuration - OidcProvider *oidc.Provider + Store *repository.Store + Config domain.Configuration } -func RunServer(controller Controller) { +func RunServer(controller Controller, oidcConfig oauth2.Config, oidcProvider *oidc.Provider) { e := echo.New() // Set server timeouts based on advice from https://blog.cloudflare.com/the-complete-guide-to-golang-net-http-timeouts/#1687428081 e.Server.ReadTimeout = time.Duration(controller.Config.ServerReadTimeoutSeconds) * time.Second @@ -53,7 +53,16 @@ func RunServer(controller Controller) { templates: template.Must(template.New("").Funcs(funcMap).ParseFS(viewTemplates, "public/views/*.html")), } e.Renderer = t - + // Configure all the middleware + e.Use(internalmiddleware.CreateOidcMiddleware(func(c echo.Context) bool { + userId, err := getUserIdFromSession(c) + if err != nil && userId != 0 { + return true + } else { + clearSessionCookie(c) + return false + } + }, oidcConfig)) e.Use(middleware.Logger()) e.Use(middleware.Recover()) sessionCookieSecretKey := controller.Config.SessionCookieSecretKey @@ -61,27 +70,43 @@ func RunServer(controller Controller) { e.Use(middleware.GzipWithConfig(middleware.GzipConfig{ Level: 5, })) + // TODO: replace form based CSRF with origin check e.Use(middleware.CSRFWithConfig(middleware.CSRFConfig{ TokenLookup: "form:csrf_token", })) - + // Endpoints // MustSubFS basically strips the prefix from the path that is automatically added by Go's embedFS imageFS := echo.MustSubFS(images, "public/images") e.StaticFS("/images", imageFS) - - e.GET("/oidccallback", controller.oidcCallback) - + e.GET("/oidccallback", internalmiddleware.CreateOidcCallbackEndpoint(oidcConfig, oidcProvider, controller.oidcDelegate)) e.GET("/bookmarks", controller.showBookmarks) e.POST("/bookmarks", controller.addBookmark) e.GET("/addbookmark", controller.showAddBookmark) e.POST("/deletebookmark", controller.deleteBookmark) e.GET("/feeds/:id", controller.showFeed) - + // Start the server port := controller.Config.ServerPort e.Logger.Fatal(e.Start(":" + strconv.Itoa(port))) // NO MORE CODE HERE, IT WILL NOT BE EXECUTED } +func handleInternalServerError(c echo.Context, err error) error { + log.Println(err) + return c.Render(http.StatusInternalServerError, "error-internalserver", nil) +} + +func getUserIdFromSession(c echo.Context) (int, error) { + sess, err := session.Get("delicious-bookmarks-session", c) + if err != nil { + return 0, err + } + if sess.Values["userid"] != nil { + return sess.Values["userid"].(int), nil + } else { + return 0, errors.New("no userid in session") + } +} + func highlight(text string) string { return strings.ReplaceAll(strings.ReplaceAll(text, "{{mark}}", ""), "{{endmark}}", "") } @@ -96,71 +121,7 @@ func clearSessionCookie(c echo.Context) { }) } -func setOidcCallbackCookie(c echo.Context, state string) { - c.SetCookie(&http.Cookie{ - Name: "delicious-bookmarks-oidc-callback", - Value: state, - Path: "/", // TODO: this path is not context path safe - Expires: time.Now().Add(time.Minute * 5), - HttpOnly: true, - }) -} - -func (controller *Controller) withValidSession(c echo.Context, delegate func(userid int) error) error { - sess, err := session.Get("delicious-bookmarks-session", c) - originalRequestUrlBase64 := base64.StdEncoding.EncodeToString([]byte(c.Request().URL.String())) - handleUnauthenticated := func() error { - clearSessionCookie(c) - state, err := crypto.RandString(16) - if err != nil { - return c.Render(http.StatusUnauthorized, "error-unauthorized", nil) - } - state = state + "|" + originalRequestUrlBase64 - setOidcCallbackCookie(c, state) - return c.Redirect(http.StatusFound, controller.Config.Oauth2Config.AuthCodeURL(state)) - } - if err != nil { - return handleUnauthenticated() - } else { - useridraw := sess.Values["userid"] - if useridraw == nil { - return handleUnauthenticated() - } - sessionUserid := useridraw.(int) - if sessionUserid == 0 { - return handleUnauthenticated() - } else { - return delegate(sessionUserid) - } - } -} - -func (controller *Controller) oidcCallback(c echo.Context) error { - // check state vs cookie - state, err := c.Cookie("delicious-bookmarks-oidc-callback") - if err != nil { - log.Println(err) - return c.Render(http.StatusUnauthorized, "error-unauthorized", nil) - } - if c.QueryParam("state") != state.Value { - return c.Render(http.StatusUnauthorized, "error-unauthorized", nil) - } - oauth2Token, err := controller.Config.Oauth2Config.Exchange(c.Request().Context(), c.QueryParam("code")) - if err != nil { - log.Println(err) - return c.Render(http.StatusUnauthorized, "error-unauthorized", nil) - } - rawIDToken, ok := oauth2Token.Extra("id_token").(string) - if !ok { - return c.Render(http.StatusUnauthorized, "error-unauthorized", nil) - } - // TODO: maybe initialize this verifier beforehand and reuse it here - verifier := controller.OidcProvider.Verifier(&controller.Config.OidcConfig) - idToken, err := verifier.Verify(c.Request().Context(), rawIDToken) - if err != nil { - log.Println(err) - return c.Render(http.StatusUnauthorized, "error-unauthorized", nil) - } +func (controller *Controller) oidcDelegate(c echo.Context, idToken *oidc.IDToken, state string) error { // we now have a valid ID token, to progress in the application we need to map this // to an existing user or create a new one on demand username := idToken.Subject @@ -177,7 +138,7 @@ func (controller *Controller) oidcCallback(c echo.Context) error { log.Println(err) return c.Render(http.StatusInternalServerError, "error-internal", nil) } - stateParts := strings.Split(state.Value, "|") + stateParts := strings.Split(state, "|") if len(stateParts) > 1 { originalRequestUrlBase64 := stateParts[1] decodedOriginalRequestUrl, err := base64.StdEncoding.DecodeString(originalRequestUrlBase64) @@ -200,151 +161,146 @@ type AddBookmarkPage struct { // TODO: continue here refactoring controller methods into this struct and moving db operations to the repository func (controller *Controller) showBookmarks(c echo.Context) error { - return controller.withValidSession(c, func(userid int) error { - handleError := func(err error) error { - log.Println(err) - return c.Render(http.StatusUnauthorized, "error-unauthorized", nil) - } - currentLastModifiedDateTime, err := controller.Store.GetLastModifiedDate(userid) - if err != nil { - return handleError(err) - } - if c.Request().Header.Get("If-Modified-Since") == currentLastModifiedDateTime.Format(http.TimeFormat) { - return c.NoContent(http.StatusNotModified) - } - var direction = domain.DirectionRight - if c.QueryParam("direction") != "" { - direction, err = strconv.Atoi(c.QueryParam("direction")) - if err != nil { - direction = domain.DirectionRight - } - if direction != 0 && direction != 1 { - direction = domain.DirectionRight - } - } - var offset int64 - if direction == domain.DirectionLeft { - offset = 0 - } else { - offset = math.MaxInt64 - } - if c.QueryParam("offset") != "" { - offset, _ = strconv.ParseInt(c.QueryParam("offset"), 10, 64) - // ignore error here, we'll just use the default value - } - var searchQuery = c.QueryParam("q") - bookmarks, err := controller.Store.GetBookmarks(searchQuery, direction, userid, offset, controller.Config.BookmarksPageSize) + userid, err := getUserIdFromSession(c) + if err != nil { + return handleInternalServerError(c, err) + } + currentLastModifiedDateTime, err := controller.Store.GetLastModifiedDate(userid) + if err != nil { + return handleInternalServerError(c, err) + } + if c.Request().Header.Get("If-Modified-Since") == currentLastModifiedDateTime.Format(http.TimeFormat) { + return c.NoContent(http.StatusNotModified) + } + var direction = domain.DirectionRight + if c.QueryParam("direction") != "" { + direction, err = strconv.Atoi(c.QueryParam("direction")) if err != nil { - return handleError(err) - } - moreResultsLeft := len(bookmarks) == (controller.Config.BookmarksPageSize + 1) - if moreResultsLeft { - bookmarks = bookmarks[:len(bookmarks)-1] - } - if direction == domain.DirectionLeft { - // if we are moving back in the list of bookmarks the query has given us an ascending list of them - // we need to reverse them to satisfy the invariant of having a descending list of bookmarks - for i, j := 0, len(bookmarks)-1; i < j; i, j = i+1, j-1 { - bookmarks[i], bookmarks[j] = bookmarks[j], bookmarks[i] - } - } - var HasLeft = true - if /*!(direction == right && offset != 0 && len(bookmarks) == config.BookmarksPageSize) && */ offset == math.MaxInt64 || (direction == domain.DirectionLeft && !moreResultsLeft) { - HasLeft = false - } - var LeftOffset int64 = 0 - if len(bookmarks) > 0 { - LeftOffset = bookmarks[0].Created.Unix() + direction = domain.DirectionRight } - var HasRight = true - if /* !(direction == left && offset != 0 && len(bookmarks) == config.BookmarksPageSize) && */ offset == 0 || (direction == domain.DirectionRight && !moreResultsLeft) { - HasRight = false + if direction != 0 && direction != 1 { + direction = domain.DirectionRight } - var RightOffset int64 = math.MaxInt64 - if len(bookmarks) >= controller.Config.BookmarksPageSize { - RightOffset = bookmarks[controller.Config.BookmarksPageSize-1].Created.Unix() + } + var offset int64 + if direction == domain.DirectionLeft { + offset = 0 + } else { + offset = math.MaxInt64 + } + if c.QueryParam("offset") != "" { + offset, _ = strconv.ParseInt(c.QueryParam("offset"), 10, 64) + // ignore error here, we'll just use the default value + } + var searchQuery = c.QueryParam("q") + bookmarks, err := controller.Store.GetBookmarks(searchQuery, direction, userid, offset, controller.Config.BookmarksPageSize) + if err != nil { + return handleInternalServerError(c, err) + } + moreResultsLeft := len(bookmarks) == (controller.Config.BookmarksPageSize + 1) + if moreResultsLeft { + bookmarks = bookmarks[:len(bookmarks)-1] + } + if direction == domain.DirectionLeft { + // if we are moving back in the list of bookmarks the query has given us an ascending list of them + // we need to reverse them to satisfy the invariant of having a descending list of bookmarks + for i, j := 0, len(bookmarks)-1; i < j; i, j = i+1, j-1 { + bookmarks[i], bookmarks[j] = bookmarks[j], bookmarks[i] } + } + var HasLeft = true + if /*!(direction == right && offset != 0 && len(bookmarks) == config.BookmarksPageSize) && */ offset == math.MaxInt64 || (direction == domain.DirectionLeft && !moreResultsLeft) { + HasLeft = false + } + var LeftOffset int64 = 0 + if len(bookmarks) > 0 { + LeftOffset = bookmarks[0].Created.Unix() + } + var HasRight = true + if /* !(direction == left && offset != 0 && len(bookmarks) == config.BookmarksPageSize) && */ offset == 0 || (direction == domain.DirectionRight && !moreResultsLeft) { + HasRight = false + } + var RightOffset int64 = math.MaxInt64 + if len(bookmarks) >= controller.Config.BookmarksPageSize { + RightOffset = bookmarks[controller.Config.BookmarksPageSize-1].Created.Unix() + } - feedId, err := controller.Store.GetOrCreateFeedIdForUser(userid) - if err != nil { - return handleError(err) - } + feedId, err := controller.Store.GetOrCreateFeedIdForUser(userid) + if err != nil { + return handleInternalServerError(c, err) + } - c.Response().Header().Set("Cache-Control", "no-cache") - c.Response().Header().Set("Last-Modified", currentLastModifiedDateTime.Format(http.TimeFormat)) - return c.Render(http.StatusOK, "bookmarks", domain.BookmarkSlice{ - Bookmarks: bookmarks, - HasLeft: HasLeft, - LeftOffset: LeftOffset, - HasRight: HasRight, - RightOffset: RightOffset, - SearchQuery: searchQuery, - CsrfToken: c.Get("csrf").(string), - RssFeedUrl: controller.Config.DeliciousBookmarksBaseUrl + "/feeds/" + feedId}) - }) + c.Response().Header().Set("Cache-Control", "no-cache") + c.Response().Header().Set("Last-Modified", currentLastModifiedDateTime.Format(http.TimeFormat)) + return c.Render(http.StatusOK, "bookmarks", domain.BookmarkSlice{ + Bookmarks: bookmarks, + HasLeft: HasLeft, + LeftOffset: LeftOffset, + HasRight: HasRight, + RightOffset: RightOffset, + SearchQuery: searchQuery, + CsrfToken: c.Get("csrf").(string), + RssFeedUrl: controller.Config.DeliciousBookmarksBaseUrl + "/feeds/" + feedId}) } func (controller *Controller) showAddBookmark(c echo.Context) error { - return controller.withValidSession(c, func(userid int) error { - handleError := func(err error) error { - log.Println(err) - return c.Render(http.StatusInternalServerError, "addbookmark", nil) + handleError := func(err error) error { + log.Println(err) + return c.Render(http.StatusInternalServerError, "error-internalserver", nil) + } + userid, err := getUserIdFromSession(c) + if err != nil { + return handleError(err) + } + url := c.QueryParam("url") + title := c.QueryParam("title") + description := c.QueryParam("description") + if url != "" { + existingBookmark, err := controller.Store.FindExistingBookmark(url, userid) + if err != nil { + return handleError(err) } - url := c.QueryParam("url") - title := c.QueryParam("title") - description := c.QueryParam("description") - if url != "" { - existingBookmark, err := controller.Store.FindExistingBookmark(url, userid) - if err != nil { - return handleError(err) - } - if existingBookmark != (domain.Bookmark{}) { - return c.Render(http.StatusOK, "addbookmark", AddBookmarkPage{Bookmark: existingBookmark, CsrfToken: c.Get("csrf").(string)}) - } + if existingBookmark != (domain.Bookmark{}) { + return c.Render(http.StatusOK, "addbookmark", AddBookmarkPage{Bookmark: existingBookmark, CsrfToken: c.Get("csrf").(string)}) } - return c.Render(http.StatusOK, "addbookmark", AddBookmarkPage{Bookmark: domain.Bookmark{URL: url, Title: title, Description: description}, CsrfToken: c.Get("csrf").(string)}) - }) + } + return c.Render(http.StatusOK, "addbookmark", AddBookmarkPage{Bookmark: domain.Bookmark{URL: url, Title: title, Description: description}, CsrfToken: c.Get("csrf").(string)}) } func (controller *Controller) deleteBookmark(c echo.Context) error { - return controller.withValidSession(c, func(userid int) error { - handleError := func(err error) error { - log.Println(err) - // TODO: add error toast or something based on URL parameter in redirect - return c.Redirect(http.StatusFound, "/bookmarks") - } - url := c.FormValue("url") - if url != "" { - err := controller.Store.DeleteBookmark(url, userid) - if err != nil { - return handleError(err) - } + userid, err := getUserIdFromSession(c) + if err != nil { + return handleInternalServerError(c, err) + } + url := c.FormValue("url") + if url != "" { + err := controller.Store.DeleteBookmark(url, userid) + if err != nil { + return handleInternalServerError(c, err) } - return c.Redirect(http.StatusFound, "/bookmarks") - }) + } + return c.Redirect(http.StatusFound, "/bookmarks") } func (controller *Controller) addBookmark(c echo.Context) error { - return controller.withValidSession(c, func(userid int) error { - handleError := func(err error) error { - log.Println("addBookmark error: ", err) - return c.Redirect(http.StatusFound, "/bookmarks") - } - url := c.FormValue("url") - if url == "" { - return handleError(errors.New("URL is required")) - } - title := c.FormValue("title") - description := c.FormValue("description") - tags := c.FormValue("tags") - private := c.FormValue("private") == "on" - readlater := c.FormValue("readlater") == "on" - err := controller.Store.AddOrUpdateBookmark(domain.Bookmark{URL: url, Title: title, Description: description, Tags: tags, Private: private, Readlater: readlater}, userid) - if err != nil { - return handleError(err) - } - return c.Redirect(http.StatusFound, "/bookmarks") - }) + userid, err := getUserIdFromSession(c) + if err != nil { + return handleInternalServerError(c, err) + } + url := c.FormValue("url") + if url == "" { + return c.Render(http.StatusBadRequest, "error-badrequest", "url parameter is required") + } + title := c.FormValue("title") + description := c.FormValue("description") + tags := c.FormValue("tags") + private := c.FormValue("private") == "on" + readlater := c.FormValue("readlater") == "on" + err = controller.Store.AddOrUpdateBookmark(domain.Bookmark{URL: url, Title: title, Description: description, Tags: tags, Private: private, Readlater: readlater}, userid) + if err != nil { + return handleInternalServerError(c, err) + } + return c.Redirect(http.StatusFound, "/bookmarks") } func (controller *Controller) showFeed(c echo.Context) error { @@ -362,8 +318,7 @@ func (controller *Controller) showFeed(c echo.Context) error { } readLaterBookmarks, err := controller.Store.FindReadLaterBookmarksWithContent(userId, controller.Config.MaxContentDownloadAttempts) if err != nil { - log.Println(err) - return c.String(http.StatusInternalServerError, "error retrieving read later bookmarks") + return handleInternalServerError(c, err) } feed := &feeds.Feed{ Title: "Delicious Read Later Bookmarks", @@ -392,7 +347,7 @@ func (controller *Controller) showFeed(c echo.Context) error { rss, err := feed.ToRss() if err != nil { - log.Fatal(err) + return handleInternalServerError(c, err) } c.Response().Header().Set("Content-Type", "application/rss+xml") return c.String(http.StatusOK, rss) @@ -402,6 +357,6 @@ type Template struct { templates *template.Template } -func (t *Template) Render(w io.Writer, name string, data interface{}, c echo.Context) error { +func (t *Template) Render(w io.Writer, name string, data interface{}, _ echo.Context) error { return t.templates.ExecuteTemplate(w, name, data) }