From be41face0d31c8b967cee6eb96b28f74ce7bf03f Mon Sep 17 00:00:00 2001 From: Nathanael Liechti Date: Fri, 13 Jan 2023 16:49:48 +0100 Subject: [PATCH] feat(oidc): optionally query OIDC UserInfo to gather group claims Signed-off-by: Nathanael Liechti --- docs/operator-manual/user-management/index.md | 14 + server/server.go | 34 +- server/server_test.go | 125 +++++++- util/cache/inmemory.go | 4 + util/oidc/oidc.go | 183 ++++++++++- util/oidc/oidc_test.go | 294 +++++++++++++++++- util/settings/settings.go | 34 ++ util/test/testutil.go | 10 + 8 files changed, 668 insertions(+), 30 deletions(-) diff --git a/docs/operator-manual/user-management/index.md b/docs/operator-manual/user-management/index.md index 8e459202456ebf..8c10f8de8c26a4 100644 --- a/docs/operator-manual/user-management/index.md +++ b/docs/operator-manual/user-management/index.md @@ -387,6 +387,20 @@ For a simple case this can be: oidc.config: | requestedIDTokenClaims: {"groups": {"essential": true}} ``` + +### Retrieving group claims when not in the token + +Some OIDC providers don't return the group information for a user in the token, even if explicitly requested using the `requestedIDTokenClaims` setting (Okta for example). They instead provide the groups on the user info endpoint. With the following config, Argo CD queries the user info endpoint during login for groups information of a user: + +```yaml +oidc.config: | + enableUserInfoGroups: true + userInfoPath: /userinfo + userInfoCacheExpiration: "5m" +``` + +**Note: If you omit the `userInfoCacheExpiration` setting, the argocd-server will cache group information as long as the OIDC token is valid!** + ### Configuring a custom logout URL for your OIDC provider Optionally, if your OIDC provider exposes a logout API and you wish to configure a custom logout URL for the purposes of invalidating diff --git a/server/server.go b/server/server.go index e52416927143b9..ed623c3fd0a060 100644 --- a/server/server.go +++ b/server/server.go @@ -1121,7 +1121,7 @@ func (a *ArgoCDServer) registerDexHandlers(mux *http.ServeMux) { // Run dex OpenID Connect Identity Provider behind a reverse proxy (served at /api/dex) var err error mux.HandleFunc(common.DexAPIEndpoint+"/", dexutil.NewDexHTTPReverseProxy(a.DexServerAddr, a.BaseHRef, a.DexTLSConfig)) - a.ssoClientApp, err = oidc.NewClientApp(a.settings, a.DexServerAddr, a.DexTLSConfig, a.BaseHRef) + a.ssoClientApp, err = oidc.NewClientApp(a.settings, a.DexServerAddr, a.DexTLSConfig, a.BaseHRef, cacheutil.NewRedisCache(a.RedisClient, a.settings.UserInfoCacheExpiration(), cacheutil.RedisCompressionNone)) errorsutil.CheckError(err) mux.HandleFunc(common.LoginEndpoint, a.ssoClientApp.HandleLogin) mux.HandleFunc(common.CallbackEndpoint, a.ssoClientApp.HandleCallback) @@ -1315,7 +1315,37 @@ func (a *ArgoCDServer) getClaims(ctx context.Context) (jwt.Claims, string, error if err != nil { return claims, "", status.Errorf(codes.Unauthenticated, "invalid session: %v", err) } - return claims, newToken, nil + + // Some SSO implementations (Okta) require a call to + // the OIDC user info path to get attributes like groups + // we assume that everywhere in argocd jwt.MapClaims is used as type for interface jwt.Claims + // otherwise this would cause a panic + var groupClaims jwt.MapClaims + if groupClaims, ok = claims.(jwt.MapClaims); !ok { + if tmpClaims, ok := claims.(*jwt.MapClaims); ok { + groupClaims = *tmpClaims + + } + } + iss := jwtutil.StringField(groupClaims, "iss") + if iss != util_session.SessionManagerClaimsIssuer && a.settings.UserInfoGroupsEnabled() && a.settings.UserInfoPath() != "" { + sub := jwtutil.StringField(groupClaims, "sub") + userInfo, unauthorized, err := a.ssoClientApp.GetUserInfo(sub, a.settings.IssuerURL(), a.settings.UserInfoPath()) + if unauthorized { + log.Errorf("error while quering userinfo endpoint: %v", err) + return claims, "", status.Errorf(codes.Unauthenticated, "invalid session") + } + if err != nil { + log.Errorf("error fetching user info endpoint: %v", err) + return claims, "", status.Errorf(codes.Internal, "invalid userinfo response") + } + if groupClaims["sub"] != userInfo["sub"] { + return claims, "", status.Error(codes.Unknown, "subject of claims from user info endpoint didn't match subject of idToken, see https://openid.net/specs/openid-connect-core-1_0.html#UserInfo") + } + groupClaims["groups"] = userInfo["groups"] + } + + return groupClaims, newToken, nil } // getToken extracts the token from gRPC metadata or cookie headers diff --git a/server/server_test.go b/server/server_test.go index 303f938871f383..acfb32e57e5d4e 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -32,8 +32,10 @@ import ( "github.com/argoproj/argo-cd/v2/server/rbacpolicy" "github.com/argoproj/argo-cd/v2/test" "github.com/argoproj/argo-cd/v2/util/assets" + "github.com/argoproj/argo-cd/v2/util/cache" cacheutil "github.com/argoproj/argo-cd/v2/util/cache" appstatecache "github.com/argoproj/argo-cd/v2/util/cache/appstate" + "github.com/argoproj/argo-cd/v2/util/oidc" "github.com/argoproj/argo-cd/v2/util/rbac" settings_util "github.com/argoproj/argo-cd/v2/util/settings" testutil "github.com/argoproj/argo-cd/v2/util/test" @@ -533,7 +535,7 @@ func dexMockHandler(t *testing.T, url string) func(http.ResponseWriter, *http.Re } } -func getTestServer(t *testing.T, anonymousEnabled bool, withFakeSSO bool, useDexForSSO bool) (argocd *ArgoCDServer, oidcURL string) { +func getTestServer(t *testing.T, anonymousEnabled bool, withFakeSSO bool, useDexForSSO bool, additionalOIDCConfig settings_util.OIDCConfig) (argocd *ArgoCDServer, oidcURL string) { cm := test.NewFakeConfigMap() if anonymousEnabled { cm.Data["users.anonymous.enabled"] = "true" @@ -562,13 +564,12 @@ connectors: clientID: test-client clientSecret: $dex.oidc.clientSecret` } else { - oidcConfig := settings_util.OIDCConfig{ - Name: "Okta", - Issuer: oidcServer.URL, - ClientID: "argo-cd", - ClientSecret: "$oidc.okta.clientSecret", - } - oidcConfigString, err := yaml.Marshal(oidcConfig) + // override required oidc config fields but keep other configs as passed in + additionalOIDCConfig.Name = "Okta" + additionalOIDCConfig.Issuer = oidcServer.URL + additionalOIDCConfig.ClientID = "argo-cd" + additionalOIDCConfig.ClientSecret = "$oidc.okta.clientSecret" + oidcConfigString, err := yaml.Marshal(additionalOIDCConfig) require.NoError(t, err) cm.Data["oidc.config"] = string(oidcConfigString) // Avoid bothering with certs for local tests. @@ -589,9 +590,109 @@ connectors: argoCDOpts.DexServerAddr = ts.URL } argocd = NewServer(context.Background(), argoCDOpts) + var err error + argocd.ssoClientApp, err = oidc.NewClientApp(argocd.settings, argocd.DexServerAddr, argocd.DexTLSConfig, argocd.BaseHRef, cache.NewInMemoryCache(24*time.Hour)) + require.NoError(t, err) return argocd, oidcServer.URL } +func TestGetClaims(t *testing.T) { + + defaultExpiry := jwt.NewNumericDate(time.Now().Add(time.Hour * 24)) + defaultExpiryUnix := float64(defaultExpiry.Unix()) + + type testData struct { + test string + claims jwt.MapClaims + expectedErrorContains string + expectedClaims jwt.MapClaims + expectNewToken bool + additionalOIDCConfig settings_util.OIDCConfig + } + var tests = []testData{ + { + test: "GetClaims", + claims: jwt.MapClaims{ + "aud": "argo-cd", + "exp": defaultExpiry, + "sub": "randomUser", + }, + expectedErrorContains: "", + expectedClaims: jwt.MapClaims{ + "aud": "argo-cd", + "exp": defaultExpiryUnix, + "sub": "randomUser", + }, + expectNewToken: false, + additionalOIDCConfig: settings_util.OIDCConfig{}, + }, + { + // note: a passing test with user info groups can never be achieved since the user never logged in properly + // therefore the oidcClient's cache contains no accessToken for the user info endpoint + // and since the oidcClient cache is unexported (for good reasons) we can't mock this behaviour + test: "GetClaimsWithUserInfoGroupsEnabled", + claims: jwt.MapClaims{ + "aud": common.ArgoCDClientAppID, + "exp": defaultExpiry, + "sub": "randomUser", + }, + expectedErrorContains: "invalid session", + expectedClaims: jwt.MapClaims{ + "aud": common.ArgoCDClientAppID, + "exp": defaultExpiryUnix, + "sub": "randomUser", + }, + expectNewToken: false, + additionalOIDCConfig: settings_util.OIDCConfig{ + EnableUserInfoGroups: true, + UserInfoPath: "/userinfo", + UserInfoCacheExpiration: "5m", + }, + }, + } + + for _, testData := range tests { + testDataCopy := testData + + t.Run(testDataCopy.test, func(t *testing.T) { + t.Parallel() + + // Must be declared here to avoid race. + ctx := context.Background() //nolint:ineffassign,staticcheck + + argocd, oidcURL := getTestServer(t, false, true, false, testDataCopy.additionalOIDCConfig) + + // create new JWT and store it on the context to simulate an incoming request + testDataCopy.claims["iss"] = oidcURL + testDataCopy.expectedClaims["iss"] = oidcURL + token := jwt.NewWithClaims(jwt.SigningMethodRS512, testDataCopy.claims) + key, err := jwt.ParseRSAPrivateKeyFromPEM(testutil.PrivateKey) + require.NoError(t, err) + tokenString, err := token.SignedString(key) + require.NoError(t, err) + ctx = metadata.NewIncomingContext(context.Background(), metadata.Pairs(apiclient.MetaDataTokenKey, tokenString)) + + gotClaims, newToken, err := argocd.getClaims(ctx) + + // Note: testutil.oidcMockHandler currently doesn't implement reissuing expired tokens + // so newToken will always be empty + if testDataCopy.expectNewToken { + assert.NotEmpty(t, newToken) + } + if testDataCopy.expectedClaims == nil { + assert.Nil(t, gotClaims) + } else { + assert.Equal(t, testDataCopy.expectedClaims, gotClaims) + } + if testDataCopy.expectedErrorContains != "" { + assert.ErrorContains(t, err, testDataCopy.expectedErrorContains, "getClaims should have thrown an error and return an error") + } else { + assert.NoError(t, err) + } + }) + } +} + func TestAuthenticate_3rd_party_JWTs(t *testing.T) { // Marshaling single strings to strings is typical, so we test for this relatively common behavior. jwt.MarshalSingleStringAsArray = false @@ -723,7 +824,7 @@ func TestAuthenticate_3rd_party_JWTs(t *testing.T) { // Must be declared here to avoid race. ctx := context.Background() //nolint:ineffassign,staticcheck - argocd, oidcURL := getTestServer(t, testDataCopy.anonymousEnabled, true, testDataCopy.useDex) + argocd, oidcURL := getTestServer(t, testDataCopy.anonymousEnabled, true, testDataCopy.useDex, settings_util.OIDCConfig{}) if testDataCopy.useDex { testDataCopy.claims.Issuer = fmt.Sprintf("%s/api/dex", oidcURL) @@ -779,7 +880,7 @@ func TestAuthenticate_no_request_metadata(t *testing.T) { t.Run(testDataCopy.test, func(t *testing.T) { t.Parallel() - argocd, _ := getTestServer(t, testDataCopy.anonymousEnabled, true, true) + argocd, _ := getTestServer(t, testDataCopy.anonymousEnabled, true, true, settings_util.OIDCConfig{}) ctx := context.Background() ctx, err := argocd.Authenticate(ctx) @@ -825,7 +926,7 @@ func TestAuthenticate_no_SSO(t *testing.T) { // Must be declared here to avoid race. ctx := context.Background() //nolint:ineffassign,staticcheck - argocd, dexURL := getTestServer(t, testDataCopy.anonymousEnabled, false, true) + argocd, dexURL := getTestServer(t, testDataCopy.anonymousEnabled, false, true, settings_util.OIDCConfig{}) token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{Issuer: fmt.Sprintf("%s/api/dex", dexURL)}) tokenString, err := token.SignedString([]byte("key")) require.NoError(t, err) @@ -933,7 +1034,7 @@ func TestAuthenticate_bad_request_metadata(t *testing.T) { // Must be declared here to avoid race. ctx := context.Background() //nolint:ineffassign,staticcheck - argocd, _ := getTestServer(t, testDataCopy.anonymousEnabled, true, true) + argocd, _ := getTestServer(t, testDataCopy.anonymousEnabled, true, true, settings_util.OIDCConfig{}) ctx = metadata.NewIncomingContext(context.Background(), testDataCopy.metadata) ctx, err := argocd.Authenticate(ctx) diff --git a/util/cache/inmemory.go b/util/cache/inmemory.go index 53e690925d940b..f75688c2755467 100644 --- a/util/cache/inmemory.go +++ b/util/cache/inmemory.go @@ -16,6 +16,10 @@ func NewInMemoryCache(expiration time.Duration) *InMemoryCache { } } +func init() { + gob.Register([]interface{}{}) +} + // compile-time validation of adherance of the CacheClient contract var _ CacheClient = &InMemoryCache{} diff --git a/util/oidc/oidc.go b/util/oidc/oidc.go index 3df31664901729..c9138ba4a79bbc 100644 --- a/util/oidc/oidc.go +++ b/util/oidc/oidc.go @@ -6,6 +6,7 @@ import ( "fmt" "html" "html/template" + "io" "net" "net/http" "net/url" @@ -21,9 +22,12 @@ import ( "github.com/argoproj/argo-cd/v2/common" "github.com/argoproj/argo-cd/v2/server/settings/oidc" + "github.com/argoproj/argo-cd/v2/util/cache" "github.com/argoproj/argo-cd/v2/util/crypto" "github.com/argoproj/argo-cd/v2/util/dex" + httputil "github.com/argoproj/argo-cd/v2/util/http" + jwtutil "github.com/argoproj/argo-cd/v2/util/jwt" "github.com/argoproj/argo-cd/v2/util/rand" "github.com/argoproj/argo-cd/v2/util/settings" ) @@ -31,9 +35,11 @@ import ( var InvalidRedirectURLError = fmt.Errorf("invalid return URL") const ( - GrantTypeAuthorizationCode = "authorization_code" - GrantTypeImplicit = "implicit" - ResponseTypeCode = "code" + GrantTypeAuthorizationCode = "authorization_code" + GrantTypeImplicit = "implicit" + ResponseTypeCode = "code" + UserInfoResponseCachePrefix = "userinfo_response" + AccessTokenCachePrefix = "access_token" ) // OIDCConfiguration holds a subset of interested fields from the OIDC configuration spec @@ -57,6 +63,8 @@ type ClientApp struct { redirectURI string // URL of the issuer (e.g. https://argocd.example.com/api/dex) issuerURL string + // the path where the issuer providers user information (e.g /user-info for okta) + userInfoPath string // The URL endpoint at which the ArgoCD server is accessed. baseHRef string // client is the HTTP client which is used to query the IDp @@ -70,6 +78,8 @@ type ClientApp struct { encryptionKey []byte // provider is the OIDC provider provider Provider + // clientCache represent a cache of sso artifact + clientCache cache.CacheClient } func GetScopesOrDefault(scopes []string) []string { @@ -81,7 +91,7 @@ func GetScopesOrDefault(scopes []string) []string { // NewClientApp will register the Argo CD client app (either via Dex or external OIDC) and return an // object which has HTTP handlers for handling the HTTP responses for login and callback -func NewClientApp(settings *settings.ArgoCDSettings, dexServerAddr string, dexTlsConfig *dex.DexTLSConfig, baseHRef string) (*ClientApp, error) { +func NewClientApp(settings *settings.ArgoCDSettings, dexServerAddr string, dexTlsConfig *dex.DexTLSConfig, baseHRef string, cacheClient cache.CacheClient) (*ClientApp, error) { redirectURL, err := settings.RedirectURL() if err != nil { return nil, err @@ -95,8 +105,10 @@ func NewClientApp(settings *settings.ArgoCDSettings, dexServerAddr string, dexTl clientSecret: settings.OAuth2ClientSecret(), redirectURI: redirectURL, issuerURL: settings.IssuerURL(), + userInfoPath: settings.UserInfoPath(), baseHRef: baseHRef, encryptionKey: encryptionKey, + clientCache: cacheClient, } log.Infof("Creating client app (%s)", a.clientID) u, err := url.Parse(settings.URL) @@ -376,6 +388,26 @@ func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusInternalServerError) return } + // save the accessToken in memory for later use + encToken, err := crypto.Encrypt([]byte(token.AccessToken), a.encryptionKey) + if err != nil { + claimsJSON, _ := json.Marshal(claims) + http.Error(w, "failed encrypting token", http.StatusInternalServerError) + log.Errorf("cannot encrypt accessToken: %v (claims=%s)", err, claimsJSON) + return + } + sub := jwtutil.StringField(claims, "sub") + err = a.clientCache.Set(&cache.Item{ + Key: fmt.Sprintf("%s_%s", AccessTokenCachePrefix, sub), + Object: encToken, + Expiration: getTokenExpiration(claims, time.Millisecond), + }) + if err != nil { + claimsJSON, _ := json.Marshal(claims) + http.Error(w, fmt.Sprintf("claims=%s, err=%v", claimsJSON, err), http.StatusInternalServerError) + return + } + if idTokenRAW != "" { cookies, err := httputil.MakeCookieMetadata(common.AuthCookieName, idTokenRAW, flags...) if err != nil { @@ -509,3 +541,146 @@ func createClaimsAuthenticationRequestParameter(requestedClaims map[string]*oidc } return oauth2.SetAuthURLParam("claims", string(claimsRequestRAW)), nil } + +// GetUserInfo queries the IDP userinfo endpoint for claims +func (a *ClientApp) GetUserInfo(sub, issuerURL, userInfoPath string) (jwt.MapClaims, bool, error) { + // in case we got it in the cache, we just return the item + var claims jwt.MapClaims + var encClaims []byte + clientCacheKey := formatUserInfoResponseCacheKey(UserInfoResponseCachePrefix, sub) + if err := a.clientCache.Get(clientCacheKey, &encClaims); err == nil { + claimsRaw, err := crypto.Decrypt(encClaims, a.encryptionKey) + if err != nil { + log.Errorf("decrypting the cached claims failed (sub=%s): %s", sub, err) + } else { + err = json.Unmarshal(claimsRaw, &claims) + if err != nil { + log.Errorf("cannot unmarshal cached claims structure: %s", err) + } else { + // return the cached claims since they are not yet expired, were successfully decrypted and unmarshaled + return claims, false, err + } + } + } + + // check if the accessToken for the user is still present + var encAccessToken []byte + err := a.clientCache.Get(formatUserInfoResponseCacheKey(UserInfoResponseCachePrefix, sub), &encAccessToken) + // without an accessToken we can't query the user info endpoint + // thus the user needs to reauthenticate for argocd to get a new accessToken + if err == cache.ErrCacheMiss { + return claims, true, fmt.Errorf("no accessToken for %s: %w", sub, err) + } else if err != nil { + return claims, true, fmt.Errorf("couldn't read accessToken from cache for %s: %w", sub, err) + } + + accessToken, err := crypto.Decrypt(encAccessToken, a.encryptionKey) + if err != nil { + return claims, true, fmt.Errorf("couldn't decrypt accessToken for %s: %w", sub, err) + } + + url := issuerURL + userInfoPath + request, err := http.NewRequest("GET", url, nil) + + if err != nil { + err = fmt.Errorf("failed creating new http request: %w", err) + return claims, false, err + } + + bearer := fmt.Sprintf("Bearer %s", accessToken) + request.Header.Set("Authorization", bearer) + + response, err := a.client.Do(request) + if err != nil { + return claims, false, fmt.Errorf("failed to query userinfo endpoint of IDP: %w", err) + } + defer response.Body.Close() + if response.StatusCode == http.StatusUnauthorized { + return claims, true, err + } + + // according to https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponseValidation + // the response should be validated + header := response.Header.Get("content-type") + rawBody, err := io.ReadAll(response.Body) + if err != nil { + return claims, false, fmt.Errorf("got error reading response body: %w", err) + } + switch header { + case "application/jwt": + // if body is JWT, first validate it before extracting claims + idToken, err := a.provider.Verify(string(rawBody), a.settings) + if err != nil { + return claims, false, fmt.Errorf("user info response in jwt format not valid: %w", err) + } + err = idToken.Claims(claims) + if err != nil { + return claims, false, fmt.Errorf("cannot get claims from userinfo jwt: %w", err) + } + default: + // if body is json, unsigned and unencrypted claims can be deserialized + err = json.Unmarshal(rawBody, &claims) + if err != nil { + return claims, false, fmt.Errorf("failed to decode response body to struct: %w", err) + } + } + + // in case response was successfully validated and there was no error, put item in cache + // but first let's determine the expiry of the cache + var cacheExpiry time.Duration + settingExpiry := time.Duration(a.settings.UserInfoCacheExpiration().Microseconds()) + tokenExpiry := getTokenExpiration(claims, time.Microsecond) + + // only use configured expiry if the token lives longer and the expiry is configured + // otherwise use the expiry of the token + if settingExpiry < tokenExpiry && settingExpiry != 0 { + cacheExpiry = settingExpiry + } else { + cacheExpiry = tokenExpiry + } + + rawClaims, err := json.Marshal(claims) + if err != nil { + return claims, false, fmt.Errorf("couldn't marshal claim to json: %w", err) + } + encClaims, err = crypto.Encrypt(rawClaims, a.encryptionKey) + if err != nil { + return claims, false, fmt.Errorf("couldn't encrypt user info response: %w", err) + } + + err = a.clientCache.Set(&cache.Item{ + Key: clientCacheKey, + Object: encClaims, + Expiration: cacheExpiry, + }) + if err != nil { + return claims, false, fmt.Errorf("couldn't put item to cache: %w", err) + } + + return claims, false, nil +} + +// getTokenExpiration returns a time.Duration in the given format until the token expires +func getTokenExpiration(claims jwt.MapClaims, format time.Duration) time.Duration { + // get duration until token expires + exp := jwtutil.Float64Field(claims, "exp") + tokenExpiry := time.Duration(time.Until(time.Unix(int64(exp), 0)).Seconds()) + switch format { + case time.Microsecond: + return time.Duration(tokenExpiry.Microseconds()) + case time.Nanosecond: + return time.Duration(tokenExpiry.Nanoseconds()) + case time.Millisecond: + return time.Duration(tokenExpiry.Milliseconds()) + case time.Minute: + return time.Duration(tokenExpiry.Minutes()) + default: + // default is second + return tokenExpiry + } +} + +// formatUserInfoResponseCacheKey returns the key which is used to store userinfo of user in cache +func formatUserInfoResponseCacheKey(prefix, sub string) string { + return fmt.Sprintf("%s_%s", UserInfoResponseCachePrefix, sub) +} diff --git a/util/oidc/oidc_test.go b/util/oidc/oidc_test.go index fe5fa77eed3b5a..c6e1454a5d6f36 100644 --- a/util/oidc/oidc_test.go +++ b/util/oidc/oidc_test.go @@ -11,8 +11,10 @@ import ( "os" "strings" "testing" + "time" gooidc "github.com/coreos/go-oidc/v3/oidc" + "github.com/golang-jwt/jwt/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" @@ -20,6 +22,7 @@ import ( "github.com/argoproj/argo-cd/v2/common" "github.com/argoproj/argo-cd/v2/server/settings/oidc" "github.com/argoproj/argo-cd/v2/util" + "github.com/argoproj/argo-cd/v2/util/cache" "github.com/argoproj/argo-cd/v2/util/crypto" "github.com/argoproj/argo-cd/v2/util/dex" "github.com/argoproj/argo-cd/v2/util/settings" @@ -126,7 +129,7 @@ clientID: xxx clientSecret: yyy requestedScopes: ["oidc"]`, oidcTestServer.URL), } - app, err := NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com") + app, err := NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) req := httptest.NewRequest(http.MethodGet, "https://argocd.example.com/auth/login", nil) @@ -141,7 +144,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL), cdSettings.OIDCTLSInsecureSkipVerify = true - app, err = NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com") + app, err = NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) w = httptest.NewRecorder() @@ -166,7 +169,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL), require.NoError(t, err) cdSettings.Certificate = &cert - app, err := NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com") + app, err := NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) req := httptest.NewRequest(http.MethodGet, "https://argocd.example.com/auth/login", nil) @@ -179,7 +182,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL), t.Fatal("did not receive expected certificate verification failure error") } - app, err = NewClientApp(cdSettings, dexTestServer.URL, &dex.DexTLSConfig{StrictValidation: false}, "https://argocd.example.com") + app, err = NewClientApp(cdSettings, dexTestServer.URL, &dex.DexTLSConfig{StrictValidation: false}, "https://argocd.example.com", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) w = httptest.NewRecorder() @@ -211,7 +214,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL), // The base href (the last argument for NewClientApp) is what HandleLogin will fall back to when no explicit // redirect URL is given. - app, err := NewClientApp(cdSettings, "", nil, "/") + app, err := NewClientApp(cdSettings, "", nil, "/", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) w := httptest.NewRecorder() @@ -254,7 +257,7 @@ clientID: xxx clientSecret: yyy requestedScopes: ["oidc"]`, oidcTestServer.URL), } - app, err := NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com") + app, err := NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) req := httptest.NewRequest(http.MethodGet, "https://argocd.example.com/auth/callback", nil) @@ -269,7 +272,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL), cdSettings.OIDCTLSInsecureSkipVerify = true - app, err = NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com") + app, err = NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) w = httptest.NewRecorder() @@ -294,7 +297,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL), require.NoError(t, err) cdSettings.Certificate = &cert - app, err := NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com") + app, err := NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) req := httptest.NewRequest(http.MethodGet, "https://argocd.example.com/auth/callback", nil) @@ -307,7 +310,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL), t.Fatal("did not receive expected certificate verification failure error") } - app, err = NewClientApp(cdSettings, dexTestServer.URL, &dex.DexTLSConfig{StrictValidation: false}, "https://argocd.example.com") + app, err = NewClientApp(cdSettings, dexTestServer.URL, &dex.DexTLSConfig{StrictValidation: false}, "https://argocd.example.com", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) w = httptest.NewRecorder() @@ -406,7 +409,7 @@ func TestGenerateAppState(t *testing.T) { signature, err := util.MakeSignature(32) require.NoError(t, err) expectedReturnURL := "http://argocd.example.com/" - app, err := NewClientApp(&settings.ArgoCDSettings{ServerSignature: signature, URL: expectedReturnURL}, "", nil, "") + app, err := NewClientApp(&settings.ArgoCDSettings{ServerSignature: signature, URL: expectedReturnURL}, "", nil, "", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) generateResponse := httptest.NewRecorder() state, err := app.generateAppState(expectedReturnURL, generateResponse) @@ -443,7 +446,7 @@ func TestGenerateAppState_XSS(t *testing.T) { URL: "https://argocd.example.com", ServerSignature: signature, }, - "", nil, "", + "", nil, "", cache.NewInMemoryCache(24*time.Hour), ) require.NoError(t, err) @@ -495,7 +498,7 @@ func TestGenerateAppState_NoReturnURL(t *testing.T) { encrypted, err := crypto.Encrypt([]byte("123"), key) require.NoError(t, err) - app, err := NewClientApp(cdSettings, "", nil, "/argo-cd") + app, err := NewClientApp(cdSettings, "", nil, "/argo-cd", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) req.AddCookie(&http.Cookie{Name: common.StateCookieName, Value: hex.EncodeToString(encrypted)}) @@ -503,3 +506,270 @@ func TestGenerateAppState_NoReturnURL(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "/argo-cd", returnURL) } + +func TestGetUserInfo(t *testing.T) { + + var tests = []struct { + name string + userInfoPath string + expectedOutput interface{} + expectError bool + expectUnauthenticated bool + expectedCacheItems []struct { // items to check in cache after function call + key string + value string + expectEncrypted bool + expectError bool + } + idpHandler func(w http.ResponseWriter, r *http.Request) + idpUser string + cache cache.CacheClient + cacheItems []struct { // items to put in cache before execution + key string + value string + encrypt bool + } + }{ + { + name: "call UserInfo with wrong userInfoPath", + userInfoPath: "/user", + expectedOutput: jwt.MapClaims(nil), + expectError: true, + expectUnauthenticated: false, + expectedCacheItems: []struct { + key string + value string + expectEncrypted bool + expectError bool + }{ + { + key: fmt.Sprintf("%s_randomUser", UserInfoResponseCachePrefix), + expectError: true, + }, + }, + idpUser: "randomUser", + idpHandler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + }, + cache: cache.NewInMemoryCache(24 * time.Hour), + cacheItems: []struct { + key string + value string + encrypt bool + }{ + { + key: fmt.Sprintf("%s_randomUser", AccessTokenCachePrefix), + value: "FakeAccessToken", + encrypt: true, + }, + }, + }, + { + name: "call UserInfo with bad accessToken", + userInfoPath: "/user-info", + expectedOutput: jwt.MapClaims(nil), + expectError: false, + expectUnauthenticated: true, + expectedCacheItems: []struct { + key string + value string + expectEncrypted bool + expectError bool + }{ + { + key: fmt.Sprintf("%s_randomUser", UserInfoResponseCachePrefix), + expectError: true, + }, + }, + idpUser: "randomUser", + idpHandler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + }, + cache: cache.NewInMemoryCache(24 * time.Hour), + cacheItems: []struct { + key string + value string + encrypt bool + }{ + { + key: fmt.Sprintf("%s_randomUser", AccessTokenCachePrefix), + value: "FakeAccessToken", + encrypt: true, + }, + }, + }, + { + name: "call UserInfo with garbage returned", + userInfoPath: "/user-info", + expectedOutput: jwt.MapClaims(nil), + expectError: true, + expectUnauthenticated: false, + expectedCacheItems: []struct { + key string + value string + expectEncrypted bool + expectError bool + }{ + { + key: fmt.Sprintf("%s_randomUser", UserInfoResponseCachePrefix), + expectError: true, + }, + }, + idpUser: "randomUser", + idpHandler: func(w http.ResponseWriter, r *http.Request) { + userInfoBytes := ` + notevenJsongarbage + ` + _, err := w.Write([]byte(userInfoBytes)) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusTeapot) + }, + cache: cache.NewInMemoryCache(24 * time.Hour), + cacheItems: []struct { + key string + value string + encrypt bool + }{ + { + key: fmt.Sprintf("%s_randomUser", AccessTokenCachePrefix), + value: "FakeAccessToken", + encrypt: true, + }, + }, + }, + { + name: "call UserInfo without accessToken in cache", + userInfoPath: "/user-info", + expectedOutput: jwt.MapClaims(nil), + expectError: true, + expectUnauthenticated: true, + expectedCacheItems: []struct { + key string + value string + expectEncrypted bool + expectError bool + }{ + { + key: fmt.Sprintf("%s_randomUser", UserInfoResponseCachePrefix), + expectError: true, + }, + }, + idpUser: "randomUser", + idpHandler: func(w http.ResponseWriter, r *http.Request) { + userInfoBytes := ` + { + "groups":["githubOrg:engineers"] + }` + w.Header().Set("content-type", "application/json") + _, err := w.Write([]byte(userInfoBytes)) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + }, + cache: cache.NewInMemoryCache(24 * time.Hour), + }, + { + name: "call UserInfo with valid accessToken in cache", + userInfoPath: "/user-info", + expectedOutput: jwt.MapClaims{"groups": []interface{}{"githubOrg:engineers"}}, + expectError: false, + expectUnauthenticated: false, + expectedCacheItems: []struct { + key string + value string + expectEncrypted bool + expectError bool + }{ + { + key: fmt.Sprintf("%s_randomUser", UserInfoResponseCachePrefix), + value: "{\"groups\":[\"githubOrg:engineers\"]}", + expectEncrypted: true, + expectError: false, + }, + }, + idpUser: "randomUser", + idpHandler: func(w http.ResponseWriter, r *http.Request) { + userInfoBytes := ` + { + "groups":["githubOrg:engineers"] + }` + w.Header().Set("content-type", "application/json") + _, err := w.Write([]byte(userInfoBytes)) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + }, + cache: cache.NewInMemoryCache(24 * time.Hour), + cacheItems: []struct { + key string + value string + encrypt bool + }{ + { + key: fmt.Sprintf("%s_randomUser", AccessTokenCachePrefix), + value: "FakeAccessToken", + encrypt: true, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(tt.idpHandler)) + defer ts.Close() + + signature, err := util.MakeSignature(32) + require.NoError(t, err) + cdSettings := &settings.ArgoCDSettings{ServerSignature: signature} + encryptionKey, err := cdSettings.GetServerEncryptionKey() + assert.NoError(t, err) + a, _ := NewClientApp(cdSettings, "", nil, "/argo-cd", tt.cache) + + for _, item := range tt.cacheItems { + var newValue []byte + newValue = []byte(item.value) + if item.encrypt { + newValue, err = crypto.Encrypt([]byte(item.value), encryptionKey) + assert.NoError(t, err) + } + err := a.clientCache.Set(&cache.Item{ + Key: item.key, + Object: newValue, + }) + require.NoError(t, err) + } + + got, unauthenticated, err := a.GetUserInfo(tt.idpUser, ts.URL, tt.userInfoPath) + assert.Equal(t, tt.expectedOutput, got) + assert.Equal(t, tt.expectUnauthenticated, unauthenticated) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + for _, item := range tt.expectedCacheItems { + var tmpValue []byte + err := a.clientCache.Get(item.key, &tmpValue) + if item.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + if item.expectEncrypted { + tmpValue, err = crypto.Decrypt(tmpValue, encryptionKey) + require.NoError(t, err) + } + assert.Equal(t, item.value, string(tmpValue)) + } + } + }) + } + +} diff --git a/util/settings/settings.go b/util/settings/settings.go index bc091e8b818ecc..baff450aa817ee 100644 --- a/util/settings/settings.go +++ b/util/settings/settings.go @@ -161,6 +161,9 @@ func (o *oidcConfig) toExported() *OIDCConfig { ClientID: o.ClientID, ClientSecret: o.ClientSecret, CLIClientID: o.CLIClientID, + UserInfoPath: o.UserInfoPath, + EnableUserInfoGroups: o.EnableUserInfoGroups, + UserInfoCacheExpiration: o.UserInfoCacheExpiration, RequestedScopes: o.RequestedScopes, RequestedIDTokenClaims: o.RequestedIDTokenClaims, LogoutURL: o.LogoutURL, @@ -175,6 +178,9 @@ type OIDCConfig struct { ClientID string `json:"clientID,omitempty"` ClientSecret string `json:"clientSecret,omitempty"` CLIClientID string `json:"cliClientID,omitempty"` + EnableUserInfoGroups bool `json:"enableUserInfoGroups,omitempty"` + UserInfoPath string `json:"userInfoPath,omitempty"` + UserInfoCacheExpiration string `json:"userInfoCacheExpiration,omitempty"` RequestedScopes []string `json:"requestedScopes,omitempty"` RequestedIDTokenClaims map[string]*oidc.Claim `json:"requestedIDTokenClaims,omitempty"` LogoutURL string `json:"logoutURL,omitempty"` @@ -1850,6 +1856,34 @@ func (a *ArgoCDSettings) IssuerURL() string { return "" } +// UserInfoGroupsEnabled returns whether group claims should be fetch from UserInfo endpoint +func (a *ArgoCDSettings) UserInfoGroupsEnabled() bool { + if oidcConfig := a.OIDCConfig(); oidcConfig != nil { + return oidcConfig.EnableUserInfoGroups + } + return false +} + +// UserInfoPath returns the sub-path on which the IDP exposes the UserInfo endpoint +func (a *ArgoCDSettings) UserInfoPath() string { + if oidcConfig := a.OIDCConfig(); oidcConfig != nil { + return oidcConfig.UserInfoPath + } + return "" +} + +// UserInfoCacheExpiration returns the expiry time of the UserInfo cache +func (a *ArgoCDSettings) UserInfoCacheExpiration() time.Duration { + if oidcConfig := a.OIDCConfig(); oidcConfig != nil && oidcConfig.UserInfoCacheExpiration != "" { + userInfoCacheExpiration, err := time.ParseDuration(oidcConfig.UserInfoCacheExpiration) + if err != nil { + log.Warnf("Failed to parse 'oidc.config.userInfoCacheExpiration' key: %v", err) + } + return userInfoCacheExpiration + } + return 0 +} + func (a *ArgoCDSettings) OAuth2ClientID() string { if oidcConfig := a.OIDCConfig(); oidcConfig != nil { return oidcConfig.ClientID diff --git a/util/test/testutil.go b/util/test/testutil.go index 6fdbd4151d82cf..1cb23bc08bb3e8 100644 --- a/util/test/testutil.go +++ b/util/test/testutil.go @@ -168,6 +168,16 @@ func oidcMockHandler(t *testing.T, url string) func(http.ResponseWriter, *http.R "token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"], "claims_supported": ["sub", "aud", "exp"] }`, url)) + require.NoError(t, err) + case "/userinfo": + w.Header().Set("content-type", "application/json") + _, err := io.WriteString(w, fmt.Sprintf(` +{ + "groups":["githubOrg:engineers"], + "iss": "%[1]s", + "sub": "randomUser" +}`, url)) + require.NoError(t, err) case "/keys": pubKey, err := jwt.ParseRSAPublicKeyFromPEM(Cert)