Skip to content

Commit

Permalink
refactor: Moving the OIDC related middleware and callback endpoint in…
Browse files Browse the repository at this point in the history
…to its own package as a preparation to extract this as a library and to clean up the code
  • Loading branch information
aggregat4 committed Apr 5, 2024
1 parent f4a439f commit e09219e
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 228 deletions.
12 changes: 3 additions & 9 deletions cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,6 @@ func main() {
// },
Scopes: []string{oidc.ScopeOpenID},
}
oidcConfig := oidc.Config{
ClientID: requireStringFromEnv("DELBM_OIDC_CLIENT_ID"),
}
// Get and init config
config := domain.Configuration{
MaxContentDownloadAttempts: getIntFromEnv("DELBM_MAX_CONTENT_DOWNLOAD_ATTEMPTS", 3),
Expand All @@ -65,8 +62,6 @@ func main() {
ServerWriteTimeoutSeconds: getIntFromEnv("DELBM_SERVER_WRITE_TIMEOUT_SECONDS", 10),
SessionCookieSecretKey: getStringFromEnv("DELBM_SESSION_COOKIE_SECRET_KEY", uuid.New().String()),
ServerPort: getIntFromEnv("DELBM_SERVER_PORT", 1323),
Oauth2Config: oauth2Config,
OidcConfig: oidcConfig,
}
// Start the bookMarkCrawler
quitChannel := make(chan struct{})
Expand All @@ -77,10 +72,9 @@ func main() {
bookMarkCrawler.Run(quitChannel)
// Start the server
server.RunServer(server.Controller{
Store: &store,
Config: config,
OidcProvider: oidcProvider,
})
Store: &store,
Config: config,
}, oauth2Config, oidcProvider)
}

func requireStringFromEnv(s string) string {
Expand Down
5 changes: 0 additions & 5 deletions internal/domain/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@ package domain

import (
"time"

"github.com/coreos/go-oidc/v3/oidc"
"golang.org/x/oauth2"
)

type BookmarkSlice struct {
Expand Down Expand Up @@ -52,8 +49,6 @@ type Configuration struct {
ServerWriteTimeoutSeconds int
SessionCookieSecretKey string
ServerPort int
Oauth2Config oauth2.Config
OidcConfig oidc.Config
}

const (
Expand Down
68 changes: 68 additions & 0 deletions internal/middleware/oidc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package middleware

import (
"encoding/base64"
"github.com/aggregat4/go-baselib/crypto"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/labstack/echo/v4"
"golang.org/x/oauth2"
"log"
"net/http"
"time"
)

func CreateOidcMiddleware(isAuthenticated func(c echo.Context) bool, oidcConfig oauth2.Config) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if !isAuthenticated(c) {
state, err := crypto.RandString(16)
if err != nil {
return c.Render(http.StatusUnauthorized, "error-unauthorized", nil)
}
// encode the original request URL into the state so we can redirect back to it after a successful login
// TODO: think about whether storing the original URL like this is generic or should be some sort of custom config
state = state + "|" + base64.StdEncoding.EncodeToString([]byte(c.Request().URL.String()))
c.SetCookie(&http.Cookie{
Name: "oidc-callback-state-cookie",
Value: state,
Path: "/", // TODO: this path is not context path safe
Expires: time.Now().Add(time.Minute * 5),
HttpOnly: true,
})
return c.Redirect(http.StatusFound, oidcConfig.AuthCodeURL(state))
} else {
return next(c)
}
}
}
}

func CreateOidcCallbackEndpoint(oidcConfig oauth2.Config, oidcProvider *oidc.Provider, delegate func(c echo.Context, idToken *oidc.IDToken, state string) error) echo.HandlerFunc {
verifier := oidcProvider.Verifier(&oidc.Config{ClientID: oidcConfig.ClientID})
return func(c echo.Context) error {
// check state vs cookie
state, err := c.Cookie("oidc-callback-state-cookie")
if err != nil {
log.Println(err)
return c.Render(http.StatusUnauthorized, "error-unauthorized", nil)
}
if c.QueryParam("state") != state.Value {
return c.Render(http.StatusUnauthorized, "error-unauthorized", nil)
}
oauth2Token, err := oidcConfig.Exchange(c.Request().Context(), c.QueryParam("code"))
if err != nil {
log.Println(err)
return c.Render(http.StatusUnauthorized, "error-unauthorized", nil)
}
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
return c.Render(http.StatusUnauthorized, "error-unauthorized", nil)
}
idToken, err := verifier.Verify(c.Request().Context(), rawIDToken)
if err != nil {
log.Println(err)
return c.Render(http.StatusUnauthorized, "error-unauthorized", nil)
}
return delegate(c, idToken, state.Value)
}
}
7 changes: 1 addition & 6 deletions internal/repository/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,16 +199,11 @@ func (store *Store) FindExistingBookmarkId(url string, userid int) (int64, error
}

func (store *Store) DeleteBookmark(url string, userid int) error {
id, err := store.FindExistingBookmarkId(url, userid)
if err != nil {
return err
}
_, err = store.db.Exec("DELETE FROM bookmarks WHERE id = ?", id)
_, err := store.db.Exec("DELETE FROM bookmarks WHERE user_id = ? AND url = ?", userid, url)
return err
}

func (store *Store) AddOrUpdateBookmark(bookmark domain.Bookmark, userid int) error {

// we perform an upsert because the URL may already be stored and we just want to update the other fields
_, err := store.db.Exec(`
INSERT INTO bookmarks (user_id, url, title, description, tags, private, readlater, created, updated)
Expand Down
Loading

0 comments on commit e09219e

Please sign in to comment.