From 69be24c6eacc0981b676d3d3a7fa571e2aed4f91 Mon Sep 17 00:00:00 2001 From: Nathanael Liechti Date: Fri, 13 Jan 2023 16:49:48 +0100 Subject: [PATCH] feat(oidc): optionally query OIDC UserInfo to gather group claims Signed-off-by: Nathanael Liechti --- docs/operator-manual/user-management/index.md | 14 ++ server/server.go | 34 +++- server/server_test.go | 125 +++++++++++-- util/cache/inmemory.go | 4 + util/oidc/oidc.go | 177 +++++++++++++++++- util/oidc/oidc_test.go | 73 ++++++-- util/settings/settings.go | 70 +++++-- util/test/testutil.go | 10 + 8 files changed, 459 insertions(+), 48 deletions(-) diff --git a/docs/operator-manual/user-management/index.md b/docs/operator-manual/user-management/index.md index 8c3f2e169597c0..1682f0e6a917b7 100644 --- a/docs/operator-manual/user-management/index.md +++ b/docs/operator-manual/user-management/index.md @@ -381,6 +381,20 @@ For a simple case this can be: oidc.config: | requestedIDTokenClaims: {"groups": {"essential": true}} ``` + +### Retrieving group claims when not in the token + +Some OIDC providers don't return the group information for a user in the token, even if explicitly requested using the `requestedIDTokenClaims` setting (Okta for example). They instead provide the groups on the user info endpoint. With the following config, Argo CD queries the user info endpoint during login for groups information of a user: + +```yaml +oidc.config: | + enableUserInfoGroups: true + userInfoPath: /userinfo + userInfoCacheExpiration: "5m" +``` + +**Note: If you omit the `userInfoCacheExpiration` setting, the argocd-server will cache group information as long as the OIDC token is valid!** + ### Configuring a custom logout URL for your OIDC provider Optionally, if your OIDC provider exposes a logout API and you wish to configure a custom logout URL for the purposes of invalidating diff --git a/server/server.go b/server/server.go index a0fc5327e985cb..4bfd69aef65e8e 100644 --- a/server/server.go +++ b/server/server.go @@ -1122,7 +1122,7 @@ func (a *ArgoCDServer) registerDexHandlers(mux *http.ServeMux) { // Run dex OpenID Connect Identity Provider behind a reverse proxy (served at /api/dex) var err error mux.HandleFunc(common.DexAPIEndpoint+"/", dexutil.NewDexHTTPReverseProxy(a.DexServerAddr, a.BaseHRef, a.DexTLSConfig)) - a.ssoClientApp, err = oidc.NewClientApp(a.settings, a.DexServerAddr, a.DexTLSConfig, a.BaseHRef) + a.ssoClientApp, err = oidc.NewClientApp(a.settings, a.DexServerAddr, a.DexTLSConfig, a.BaseHRef, cacheutil.NewRedisCache(a.RedisClient, a.settings.UserInfoCacheExpiration(), cacheutil.RedisCompressionNone)) errorsutil.CheckError(err) mux.HandleFunc(common.LoginEndpoint, a.ssoClientApp.HandleLogin) mux.HandleFunc(common.CallbackEndpoint, a.ssoClientApp.HandleCallback) @@ -1316,7 +1316,37 @@ func (a *ArgoCDServer) getClaims(ctx context.Context) (jwt.Claims, string, error if err != nil { return claims, "", status.Errorf(codes.Unauthenticated, "invalid session: %v", err) } - return claims, newToken, nil + + // Some SSO implementations (Okta) require a call to + // the OIDC user info path to get attributes like groups + // we assume that everywhere in argocd jwt.MapClaims is used as type for interface jwt.Claims + // otherwise this would cause a panic + var groupClaims jwt.MapClaims + if groupClaims, ok = claims.(jwt.MapClaims); !ok { + if tmpClaims, ok := claims.(*jwt.MapClaims); ok { + groupClaims = *tmpClaims + + } + } + iss := jwtutil.StringField(groupClaims, "iss") + if iss != util_session.SessionManagerClaimsIssuer && a.settings.UserInfoGroupsEnabled() && a.settings.UserInfoPath() != "" { + sub := jwtutil.StringField(groupClaims, "sub") + userInfo, unauthorized, err := a.ssoClientApp.GetUserInfo(sub, a.settings.IssuerURL(), a.settings.UserInfoPath()) + if unauthorized { + log.Errorf("error while quering userinfo endpoint: %v", err) + return claims, "", status.Errorf(codes.Unauthenticated, "invalid session") + } + if err != nil { + log.Errorf("error fetching user info endpoint: %v", err) + return claims, "", status.Errorf(codes.Internal, "invalid userinfo response") + } + if groupClaims["sub"] != userInfo["sub"] { + return claims, "", status.Error(codes.Unknown, "subject of claims from user info endpoint didn't match subject of idToken, see https://openid.net/specs/openid-connect-core-1_0.html#UserInfo") + } + groupClaims["groups"] = userInfo["groups"] + } + + return groupClaims, newToken, nil } // getToken extracts the token from gRPC metadata or cookie headers diff --git a/server/server_test.go b/server/server_test.go index 303f938871f383..acfb32e57e5d4e 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -32,8 +32,10 @@ import ( "github.com/argoproj/argo-cd/v2/server/rbacpolicy" "github.com/argoproj/argo-cd/v2/test" "github.com/argoproj/argo-cd/v2/util/assets" + "github.com/argoproj/argo-cd/v2/util/cache" cacheutil "github.com/argoproj/argo-cd/v2/util/cache" appstatecache "github.com/argoproj/argo-cd/v2/util/cache/appstate" + "github.com/argoproj/argo-cd/v2/util/oidc" "github.com/argoproj/argo-cd/v2/util/rbac" settings_util "github.com/argoproj/argo-cd/v2/util/settings" testutil "github.com/argoproj/argo-cd/v2/util/test" @@ -533,7 +535,7 @@ func dexMockHandler(t *testing.T, url string) func(http.ResponseWriter, *http.Re } } -func getTestServer(t *testing.T, anonymousEnabled bool, withFakeSSO bool, useDexForSSO bool) (argocd *ArgoCDServer, oidcURL string) { +func getTestServer(t *testing.T, anonymousEnabled bool, withFakeSSO bool, useDexForSSO bool, additionalOIDCConfig settings_util.OIDCConfig) (argocd *ArgoCDServer, oidcURL string) { cm := test.NewFakeConfigMap() if anonymousEnabled { cm.Data["users.anonymous.enabled"] = "true" @@ -562,13 +564,12 @@ connectors: clientID: test-client clientSecret: $dex.oidc.clientSecret` } else { - oidcConfig := settings_util.OIDCConfig{ - Name: "Okta", - Issuer: oidcServer.URL, - ClientID: "argo-cd", - ClientSecret: "$oidc.okta.clientSecret", - } - oidcConfigString, err := yaml.Marshal(oidcConfig) + // override required oidc config fields but keep other configs as passed in + additionalOIDCConfig.Name = "Okta" + additionalOIDCConfig.Issuer = oidcServer.URL + additionalOIDCConfig.ClientID = "argo-cd" + additionalOIDCConfig.ClientSecret = "$oidc.okta.clientSecret" + oidcConfigString, err := yaml.Marshal(additionalOIDCConfig) require.NoError(t, err) cm.Data["oidc.config"] = string(oidcConfigString) // Avoid bothering with certs for local tests. @@ -589,9 +590,109 @@ connectors: argoCDOpts.DexServerAddr = ts.URL } argocd = NewServer(context.Background(), argoCDOpts) + var err error + argocd.ssoClientApp, err = oidc.NewClientApp(argocd.settings, argocd.DexServerAddr, argocd.DexTLSConfig, argocd.BaseHRef, cache.NewInMemoryCache(24*time.Hour)) + require.NoError(t, err) return argocd, oidcServer.URL } +func TestGetClaims(t *testing.T) { + + defaultExpiry := jwt.NewNumericDate(time.Now().Add(time.Hour * 24)) + defaultExpiryUnix := float64(defaultExpiry.Unix()) + + type testData struct { + test string + claims jwt.MapClaims + expectedErrorContains string + expectedClaims jwt.MapClaims + expectNewToken bool + additionalOIDCConfig settings_util.OIDCConfig + } + var tests = []testData{ + { + test: "GetClaims", + claims: jwt.MapClaims{ + "aud": "argo-cd", + "exp": defaultExpiry, + "sub": "randomUser", + }, + expectedErrorContains: "", + expectedClaims: jwt.MapClaims{ + "aud": "argo-cd", + "exp": defaultExpiryUnix, + "sub": "randomUser", + }, + expectNewToken: false, + additionalOIDCConfig: settings_util.OIDCConfig{}, + }, + { + // note: a passing test with user info groups can never be achieved since the user never logged in properly + // therefore the oidcClient's cache contains no accessToken for the user info endpoint + // and since the oidcClient cache is unexported (for good reasons) we can't mock this behaviour + test: "GetClaimsWithUserInfoGroupsEnabled", + claims: jwt.MapClaims{ + "aud": common.ArgoCDClientAppID, + "exp": defaultExpiry, + "sub": "randomUser", + }, + expectedErrorContains: "invalid session", + expectedClaims: jwt.MapClaims{ + "aud": common.ArgoCDClientAppID, + "exp": defaultExpiryUnix, + "sub": "randomUser", + }, + expectNewToken: false, + additionalOIDCConfig: settings_util.OIDCConfig{ + EnableUserInfoGroups: true, + UserInfoPath: "/userinfo", + UserInfoCacheExpiration: "5m", + }, + }, + } + + for _, testData := range tests { + testDataCopy := testData + + t.Run(testDataCopy.test, func(t *testing.T) { + t.Parallel() + + // Must be declared here to avoid race. + ctx := context.Background() //nolint:ineffassign,staticcheck + + argocd, oidcURL := getTestServer(t, false, true, false, testDataCopy.additionalOIDCConfig) + + // create new JWT and store it on the context to simulate an incoming request + testDataCopy.claims["iss"] = oidcURL + testDataCopy.expectedClaims["iss"] = oidcURL + token := jwt.NewWithClaims(jwt.SigningMethodRS512, testDataCopy.claims) + key, err := jwt.ParseRSAPrivateKeyFromPEM(testutil.PrivateKey) + require.NoError(t, err) + tokenString, err := token.SignedString(key) + require.NoError(t, err) + ctx = metadata.NewIncomingContext(context.Background(), metadata.Pairs(apiclient.MetaDataTokenKey, tokenString)) + + gotClaims, newToken, err := argocd.getClaims(ctx) + + // Note: testutil.oidcMockHandler currently doesn't implement reissuing expired tokens + // so newToken will always be empty + if testDataCopy.expectNewToken { + assert.NotEmpty(t, newToken) + } + if testDataCopy.expectedClaims == nil { + assert.Nil(t, gotClaims) + } else { + assert.Equal(t, testDataCopy.expectedClaims, gotClaims) + } + if testDataCopy.expectedErrorContains != "" { + assert.ErrorContains(t, err, testDataCopy.expectedErrorContains, "getClaims should have thrown an error and return an error") + } else { + assert.NoError(t, err) + } + }) + } +} + func TestAuthenticate_3rd_party_JWTs(t *testing.T) { // Marshaling single strings to strings is typical, so we test for this relatively common behavior. jwt.MarshalSingleStringAsArray = false @@ -723,7 +824,7 @@ func TestAuthenticate_3rd_party_JWTs(t *testing.T) { // Must be declared here to avoid race. ctx := context.Background() //nolint:ineffassign,staticcheck - argocd, oidcURL := getTestServer(t, testDataCopy.anonymousEnabled, true, testDataCopy.useDex) + argocd, oidcURL := getTestServer(t, testDataCopy.anonymousEnabled, true, testDataCopy.useDex, settings_util.OIDCConfig{}) if testDataCopy.useDex { testDataCopy.claims.Issuer = fmt.Sprintf("%s/api/dex", oidcURL) @@ -779,7 +880,7 @@ func TestAuthenticate_no_request_metadata(t *testing.T) { t.Run(testDataCopy.test, func(t *testing.T) { t.Parallel() - argocd, _ := getTestServer(t, testDataCopy.anonymousEnabled, true, true) + argocd, _ := getTestServer(t, testDataCopy.anonymousEnabled, true, true, settings_util.OIDCConfig{}) ctx := context.Background() ctx, err := argocd.Authenticate(ctx) @@ -825,7 +926,7 @@ func TestAuthenticate_no_SSO(t *testing.T) { // Must be declared here to avoid race. ctx := context.Background() //nolint:ineffassign,staticcheck - argocd, dexURL := getTestServer(t, testDataCopy.anonymousEnabled, false, true) + argocd, dexURL := getTestServer(t, testDataCopy.anonymousEnabled, false, true, settings_util.OIDCConfig{}) token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{Issuer: fmt.Sprintf("%s/api/dex", dexURL)}) tokenString, err := token.SignedString([]byte("key")) require.NoError(t, err) @@ -933,7 +1034,7 @@ func TestAuthenticate_bad_request_metadata(t *testing.T) { // Must be declared here to avoid race. ctx := context.Background() //nolint:ineffassign,staticcheck - argocd, _ := getTestServer(t, testDataCopy.anonymousEnabled, true, true) + argocd, _ := getTestServer(t, testDataCopy.anonymousEnabled, true, true, settings_util.OIDCConfig{}) ctx = metadata.NewIncomingContext(context.Background(), testDataCopy.metadata) ctx, err := argocd.Authenticate(ctx) diff --git a/util/cache/inmemory.go b/util/cache/inmemory.go index 53e690925d940b..f75688c2755467 100644 --- a/util/cache/inmemory.go +++ b/util/cache/inmemory.go @@ -16,6 +16,10 @@ func NewInMemoryCache(expiration time.Duration) *InMemoryCache { } } +func init() { + gob.Register([]interface{}{}) +} + // compile-time validation of adherance of the CacheClient contract var _ CacheClient = &InMemoryCache{} diff --git a/util/oidc/oidc.go b/util/oidc/oidc.go index 3df31664901729..7218c1d49a7826 100644 --- a/util/oidc/oidc.go +++ b/util/oidc/oidc.go @@ -6,6 +6,7 @@ import ( "fmt" "html" "html/template" + "io" "net" "net/http" "net/url" @@ -21,9 +22,12 @@ import ( "github.com/argoproj/argo-cd/v2/common" "github.com/argoproj/argo-cd/v2/server/settings/oidc" + "github.com/argoproj/argo-cd/v2/util/cache" "github.com/argoproj/argo-cd/v2/util/crypto" "github.com/argoproj/argo-cd/v2/util/dex" + httputil "github.com/argoproj/argo-cd/v2/util/http" + jwtutil "github.com/argoproj/argo-cd/v2/util/jwt" "github.com/argoproj/argo-cd/v2/util/rand" "github.com/argoproj/argo-cd/v2/util/settings" ) @@ -31,9 +35,11 @@ import ( var InvalidRedirectURLError = fmt.Errorf("invalid return URL") const ( - GrantTypeAuthorizationCode = "authorization_code" - GrantTypeImplicit = "implicit" - ResponseTypeCode = "code" + GrantTypeAuthorizationCode = "authorization_code" + GrantTypeImplicit = "implicit" + ResponseTypeCode = "code" + UserInfoResponseCachePrefix = "userinfo_response" + AccessTokenCachePrefix = "access_token" ) // OIDCConfiguration holds a subset of interested fields from the OIDC configuration spec @@ -57,6 +63,8 @@ type ClientApp struct { redirectURI string // URL of the issuer (e.g. https://argocd.example.com/api/dex) issuerURL string + // the path where the issuer providers user information (e.g /user-info for okta) + userInfoPath string // The URL endpoint at which the ArgoCD server is accessed. baseHRef string // client is the HTTP client which is used to query the IDp @@ -70,6 +78,8 @@ type ClientApp struct { encryptionKey []byte // provider is the OIDC provider provider Provider + // clientCache represent a cache of sso artifact + clientCache cache.CacheClient } func GetScopesOrDefault(scopes []string) []string { @@ -81,7 +91,7 @@ func GetScopesOrDefault(scopes []string) []string { // NewClientApp will register the Argo CD client app (either via Dex or external OIDC) and return an // object which has HTTP handlers for handling the HTTP responses for login and callback -func NewClientApp(settings *settings.ArgoCDSettings, dexServerAddr string, dexTlsConfig *dex.DexTLSConfig, baseHRef string) (*ClientApp, error) { +func NewClientApp(settings *settings.ArgoCDSettings, dexServerAddr string, dexTlsConfig *dex.DexTLSConfig, baseHRef string, cacheClient cache.CacheClient) (*ClientApp, error) { redirectURL, err := settings.RedirectURL() if err != nil { return nil, err @@ -95,8 +105,10 @@ func NewClientApp(settings *settings.ArgoCDSettings, dexServerAddr string, dexTl clientSecret: settings.OAuth2ClientSecret(), redirectURI: redirectURL, issuerURL: settings.IssuerURL(), + userInfoPath: settings.UserInfoPath(), baseHRef: baseHRef, encryptionKey: encryptionKey, + clientCache: cacheClient, } log.Infof("Creating client app (%s)", a.clientID) u, err := url.Parse(settings.URL) @@ -376,6 +388,25 @@ func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusInternalServerError) return } + // save the accessToken in memory for later use + encToken, err := crypto.Encrypt([]byte(token.AccessToken), a.encryptionKey) + if err != nil { + claimsJSON, _ := json.Marshal(claims) + http.Error(w, fmt.Sprintf("claims=%s, err=%v", claimsJSON, err), http.StatusInternalServerError) + return + } + sub := jwtutil.StringField(claims, "sub") + err = a.clientCache.Set(&cache.Item{ + Key: fmt.Sprintf("%s_%s", AccessTokenCachePrefix, sub), + Object: encToken, + Expiration: getTokenExpiration(claims, time.Millisecond), + }) + if err != nil { + claimsJSON, _ := json.Marshal(claims) + http.Error(w, fmt.Sprintf("claims=%s, err=%v", claimsJSON, err), http.StatusInternalServerError) + return + } + if idTokenRAW != "" { cookies, err := httputil.MakeCookieMetadata(common.AuthCookieName, idTokenRAW, flags...) if err != nil { @@ -509,3 +540,141 @@ func createClaimsAuthenticationRequestParameter(requestedClaims map[string]*oidc } return oauth2.SetAuthURLParam("claims", string(claimsRequestRAW)), nil } + +// GetUserInfo queries the IDP userinfo endpoint for claims +func (a *ClientApp) GetUserInfo(sub, issuerURL, userInfoPath string) (jwt.MapClaims, bool, error) { + // in case we got it in the cache, we just return the item + var claims jwt.MapClaims + var encClaims []byte + clientCacheKey := fmt.Sprintf("%s_%s", UserInfoResponseCachePrefix, sub) + if err := a.clientCache.Get(clientCacheKey, &encClaims); err == nil { + claimsRaw, err := crypto.Decrypt(encClaims, a.encryptionKey) + if err != nil { + log.Errorf("decrypting the cached claims failed (sub=%s): %s", sub, err) + } else { + err = json.Unmarshal(claimsRaw, &claims) + if err != nil { + log.Errorf("cannot unmarshal cached claims structure: %s", err) + } else { + // return the cached claims since they are not yet expired, were successfully decrypted and unmarshaled + return claims, false, err + } + } + } + + // check if the accessToken for the user is still present + var encAccessToken []byte + err := a.clientCache.Get(fmt.Sprintf("%s_%s", AccessTokenCachePrefix, sub), &encAccessToken) + // without an accessToken we can't query the user info endpoint + // thus the user needs to reauthenticate for argocd to get a new accessToken + if err == cache.ErrCacheMiss { + return claims, true, fmt.Errorf("no accessToken for %s: %w", sub, err) + } else if err != nil { + return claims, true, fmt.Errorf("couldn't read accessToken from cache for %s: %w", sub, err) + } + + accessToken, err := crypto.Decrypt(encAccessToken, a.encryptionKey) + if err != nil { + return claims, true, fmt.Errorf("couldn't decrypt accessToken for %s: %w", sub, err) + } + + url := issuerURL + userInfoPath + request, err := http.NewRequest("GET", url, nil) + + if err != nil { + err = fmt.Errorf("failed creating new http request: %w", err) + return claims, false, err + } + + bearer := fmt.Sprintf("Bearer %s", accessToken) + request.Header.Set("Authorization", bearer) + + response, err := a.client.Do(request) + if err != nil { + return claims, false, fmt.Errorf("failed to query userinfo endpoint of IDP: %w", err) + } + defer response.Body.Close() + if response.StatusCode == http.StatusUnauthorized { + return claims, true, err + } + + // according to https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponseValidation + // the response should be validated + header := response.Header.Get("content-type") + rawBody, err := io.ReadAll(response.Body) + if err != nil { + return claims, false, fmt.Errorf("got error reading response body: %w", err) + } + switch header { + case "application/jwt": + // if body is JWT, first validate it before extracting claims + idToken, err := a.provider.Verify(string(rawBody), a.settings) + if err != nil { + return claims, false, fmt.Errorf("user info response in jwt format not valid: %w", err) + } + err = idToken.Claims(claims) + if err != nil { + return claims, false, fmt.Errorf("cannot get claims from userinfo jwt: %w", err) + } + default: + // if body is json, unsigned and unencrypted claims can be deserialized + err = json.Unmarshal(rawBody, &claims) + if err != nil { + return claims, false, fmt.Errorf("failed to decode response body to struct: %w", err) + } + } + + // in case response was successfully validated and there was no error, put item in cache + // but first let's determine the expiry of the cache + var cacheExpiry time.Duration + settingExpiry := time.Duration(a.settings.UserInfoCacheExpiration().Microseconds()) + tokenExpiry := getTokenExpiration(claims, time.Microsecond) + + // only use configured expiry if the token lives longer and the expiry is configured + // otherwise use the expiry of the token + if settingExpiry < tokenExpiry && settingExpiry != 0 { + cacheExpiry = settingExpiry + } else { + cacheExpiry = tokenExpiry + } + + rawClaims, err := json.Marshal(claims) + if err != nil { + return claims, false, fmt.Errorf("couldn't marshal claim to json: %w", err) + } + encClaims, err = crypto.Encrypt(rawClaims, a.encryptionKey) + if err != nil { + return claims, false, fmt.Errorf("couldn't encrypt user info response: %w", err) + } + + err = a.clientCache.Set(&cache.Item{ + Key: clientCacheKey, + Object: encClaims, + Expiration: cacheExpiry, + }) + if err != nil { + return claims, false, fmt.Errorf("couldn't put item to cache: %w", err) + } + + return claims, false, nil +} + +// getClaimsExpiration returns a time.Duration in the given format until the token expires +func getTokenExpiration(claims jwt.MapClaims, format time.Duration) time.Duration { + // get duration until token expires + exp := jwtutil.Float64Field(claims, "exp") + tokenExpiry := time.Duration(time.Until(time.Unix(int64(exp), 0)).Seconds()) + switch format { + case time.Microsecond: + return time.Duration(tokenExpiry.Microseconds()) + case time.Nanosecond: + return time.Duration(tokenExpiry.Nanoseconds()) + case time.Millisecond: + return time.Duration(tokenExpiry.Milliseconds()) + case time.Minute: + return time.Duration(tokenExpiry.Minutes()) + default: + // default is second + return tokenExpiry + } +} diff --git a/util/oidc/oidc_test.go b/util/oidc/oidc_test.go index fe5fa77eed3b5a..8614445b1c59e0 100644 --- a/util/oidc/oidc_test.go +++ b/util/oidc/oidc_test.go @@ -11,6 +11,7 @@ import ( "os" "strings" "testing" + "time" gooidc "github.com/coreos/go-oidc/v3/oidc" "github.com/stretchr/testify/assert" @@ -20,6 +21,7 @@ import ( "github.com/argoproj/argo-cd/v2/common" "github.com/argoproj/argo-cd/v2/server/settings/oidc" "github.com/argoproj/argo-cd/v2/util" + "github.com/argoproj/argo-cd/v2/util/cache" "github.com/argoproj/argo-cd/v2/util/crypto" "github.com/argoproj/argo-cd/v2/util/dex" "github.com/argoproj/argo-cd/v2/util/settings" @@ -126,7 +128,7 @@ clientID: xxx clientSecret: yyy requestedScopes: ["oidc"]`, oidcTestServer.URL), } - app, err := NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com") + app, err := NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) req := httptest.NewRequest(http.MethodGet, "https://argocd.example.com/auth/login", nil) @@ -141,7 +143,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL), cdSettings.OIDCTLSInsecureSkipVerify = true - app, err = NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com") + app, err = NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) w = httptest.NewRecorder() @@ -166,7 +168,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL), require.NoError(t, err) cdSettings.Certificate = &cert - app, err := NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com") + app, err := NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) req := httptest.NewRequest(http.MethodGet, "https://argocd.example.com/auth/login", nil) @@ -179,7 +181,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL), t.Fatal("did not receive expected certificate verification failure error") } - app, err = NewClientApp(cdSettings, dexTestServer.URL, &dex.DexTLSConfig{StrictValidation: false}, "https://argocd.example.com") + app, err = NewClientApp(cdSettings, dexTestServer.URL, &dex.DexTLSConfig{StrictValidation: false}, "https://argocd.example.com", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) w = httptest.NewRecorder() @@ -211,7 +213,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL), // The base href (the last argument for NewClientApp) is what HandleLogin will fall back to when no explicit // redirect URL is given. - app, err := NewClientApp(cdSettings, "", nil, "/") + app, err := NewClientApp(cdSettings, "", nil, "/", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) w := httptest.NewRecorder() @@ -254,7 +256,7 @@ clientID: xxx clientSecret: yyy requestedScopes: ["oidc"]`, oidcTestServer.URL), } - app, err := NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com") + app, err := NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) req := httptest.NewRequest(http.MethodGet, "https://argocd.example.com/auth/callback", nil) @@ -269,7 +271,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL), cdSettings.OIDCTLSInsecureSkipVerify = true - app, err = NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com") + app, err = NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) w = httptest.NewRecorder() @@ -294,7 +296,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL), require.NoError(t, err) cdSettings.Certificate = &cert - app, err := NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com") + app, err := NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) req := httptest.NewRequest(http.MethodGet, "https://argocd.example.com/auth/callback", nil) @@ -307,7 +309,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL), t.Fatal("did not receive expected certificate verification failure error") } - app, err = NewClientApp(cdSettings, dexTestServer.URL, &dex.DexTLSConfig{StrictValidation: false}, "https://argocd.example.com") + app, err = NewClientApp(cdSettings, dexTestServer.URL, &dex.DexTLSConfig{StrictValidation: false}, "https://argocd.example.com", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) w = httptest.NewRecorder() @@ -406,7 +408,7 @@ func TestGenerateAppState(t *testing.T) { signature, err := util.MakeSignature(32) require.NoError(t, err) expectedReturnURL := "http://argocd.example.com/" - app, err := NewClientApp(&settings.ArgoCDSettings{ServerSignature: signature, URL: expectedReturnURL}, "", nil, "") + app, err := NewClientApp(&settings.ArgoCDSettings{ServerSignature: signature, URL: expectedReturnURL}, "", nil, "", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) generateResponse := httptest.NewRecorder() state, err := app.generateAppState(expectedReturnURL, generateResponse) @@ -443,7 +445,7 @@ func TestGenerateAppState_XSS(t *testing.T) { URL: "https://argocd.example.com", ServerSignature: signature, }, - "", nil, "", + "", nil, "", cache.NewInMemoryCache(24*time.Hour), ) require.NoError(t, err) @@ -495,7 +497,7 @@ func TestGenerateAppState_NoReturnURL(t *testing.T) { encrypted, err := crypto.Encrypt([]byte("123"), key) require.NoError(t, err) - app, err := NewClientApp(cdSettings, "", nil, "/argo-cd") + app, err := NewClientApp(cdSettings, "", nil, "/argo-cd", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) req.AddCookie(&http.Cookie{Name: common.StateCookieName, Value: hex.EncodeToString(encrypted)}) @@ -503,3 +505,50 @@ func TestGenerateAppState_NoReturnURL(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "/argo-cd", returnURL) } + +func TestGetUserInfo(t *testing.T) { + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + userInfoBytes := ` + { + "groups":["githubOrg:engineers"] + }` + w.Header().Set("content-type", "application/json") + _, err := w.Write([]byte(userInfoBytes)) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + signature, err := util.MakeSignature(32) + require.NoError(t, err) + cdSettings := &settings.ArgoCDSettings{ServerSignature: signature} + encryptionKey, err := cdSettings.GetServerEncryptionKey() + assert.NoError(t, err) + a, _ := NewClientApp(cdSettings, "", nil, "/argo-cd", cache.NewInMemoryCache(24*time.Hour)) + + t.Run("call UserInfo without accessToken in cache", func(t *testing.T) { + got, _, err := a.GetUserInfo("randomUser", ts.URL, "/user-info") + assert.Nil(t, got) + assert.Error(t, err) + }) + t.Run("call UserInfo with valid accessToken in cache", func(t *testing.T) { + obj := "FakeAccessToken" + assert.NoError(t, err) + encObj, err := crypto.Encrypt([]byte(obj), encryptionKey) + assert.NoError(t, err) + err = a.clientCache.Set(&cache.Item{ + Key: fmt.Sprintf("%s_randomUser", AccessTokenCachePrefix), + Object: encObj, + }) + require.NoError(t, err) + + got, _, err := a.GetUserInfo("randomUser", ts.URL, "/user-info") + assert.NotNil(t, got["groups"]) + assert.Equal(t, nil, err) + }) + +} diff --git a/util/settings/settings.go b/util/settings/settings.go index b992a32ed164d8..cae348bd7e953b 100644 --- a/util/settings/settings.go +++ b/util/settings/settings.go @@ -155,28 +155,34 @@ func (o *oidcConfig) toExported() *OIDCConfig { return nil } return &OIDCConfig{ - Name: o.Name, - Issuer: o.Issuer, - ClientID: o.ClientID, - ClientSecret: o.ClientSecret, - CLIClientID: o.CLIClientID, - RequestedScopes: o.RequestedScopes, - RequestedIDTokenClaims: o.RequestedIDTokenClaims, - LogoutURL: o.LogoutURL, - RootCA: o.RootCA, + Name: o.Name, + Issuer: o.Issuer, + ClientID: o.ClientID, + ClientSecret: o.ClientSecret, + CLIClientID: o.CLIClientID, + UserInfoPath: o.UserInfoPath, + EnableUserInfoGroups: o.EnableUserInfoGroups, + UserInfoCacheExpiration: o.UserInfoCacheExpiration, + RequestedScopes: o.RequestedScopes, + RequestedIDTokenClaims: o.RequestedIDTokenClaims, + LogoutURL: o.LogoutURL, + RootCA: o.RootCA, } } type OIDCConfig struct { - Name string `json:"name,omitempty"` - Issuer string `json:"issuer,omitempty"` - ClientID string `json:"clientID,omitempty"` - ClientSecret string `json:"clientSecret,omitempty"` - CLIClientID string `json:"cliClientID,omitempty"` - RequestedScopes []string `json:"requestedScopes,omitempty"` - RequestedIDTokenClaims map[string]*oidc.Claim `json:"requestedIDTokenClaims,omitempty"` - LogoutURL string `json:"logoutURL,omitempty"` - RootCA string `json:"rootCA,omitempty"` + Name string `json:"name,omitempty"` + Issuer string `json:"issuer,omitempty"` + EnableUserInfoGroups bool `json:"enableUserInfoGroups,omitempty"` + UserInfoPath string `json:"userInfoPath,omitempty"` + UserInfoCacheExpiration string `json:"userInfoCacheExpiration,omitempty"` + ClientID string `json:"clientID,omitempty"` + ClientSecret string `json:"clientSecret,omitempty"` + CLIClientID string `json:"cliClientID,omitempty"` + RequestedScopes []string `json:"requestedScopes,omitempty"` + RequestedIDTokenClaims map[string]*oidc.Claim `json:"requestedIDTokenClaims,omitempty"` + LogoutURL string `json:"logoutURL,omitempty"` + RootCA string `json:"rootCA,omitempty"` } // DEPRECATED. Helm repository credentials are now managed using RepoCredentials @@ -1834,6 +1840,34 @@ func (a *ArgoCDSettings) IssuerURL() string { return "" } +// UserInfoGroupsEnabled returns whether group claims should be fetch from UserInfo endpoint +func (a *ArgoCDSettings) UserInfoGroupsEnabled() bool { + if oidcConfig := a.OIDCConfig(); oidcConfig != nil { + return oidcConfig.EnableUserInfoGroups + } + return false +} + +// UserInfoPath returns the sub-path on which the IDP exposes the UserInfo endpoint +func (a *ArgoCDSettings) UserInfoPath() string { + if oidcConfig := a.OIDCConfig(); oidcConfig != nil { + return oidcConfig.UserInfoPath + } + return "" +} + +// UserInfoCacheExpiration returns the expiry time of the UserInfo cache +func (a *ArgoCDSettings) UserInfoCacheExpiration() time.Duration { + if oidcConfig := a.OIDCConfig(); oidcConfig != nil && oidcConfig.UserInfoCacheExpiration != "" { + userInfoCacheExpiration, err := time.ParseDuration(oidcConfig.UserInfoCacheExpiration) + if err != nil { + log.Warnf("Failed to parse 'oidc.config.userInfoCacheExpiration' key: %v", err) + } + return userInfoCacheExpiration + } + return 0 +} + func (a *ArgoCDSettings) OAuth2ClientID() string { if oidcConfig := a.OIDCConfig(); oidcConfig != nil { return oidcConfig.ClientID diff --git a/util/test/testutil.go b/util/test/testutil.go index 6fdbd4151d82cf..1cb23bc08bb3e8 100644 --- a/util/test/testutil.go +++ b/util/test/testutil.go @@ -168,6 +168,16 @@ func oidcMockHandler(t *testing.T, url string) func(http.ResponseWriter, *http.R "token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"], "claims_supported": ["sub", "aud", "exp"] }`, url)) + require.NoError(t, err) + case "/userinfo": + w.Header().Set("content-type", "application/json") + _, err := io.WriteString(w, fmt.Sprintf(` +{ + "groups":["githubOrg:engineers"], + "iss": "%[1]s", + "sub": "randomUser" +}`, url)) + require.NoError(t, err) case "/keys": pubKey, err := jwt.ParseRSAPublicKeyFromPEM(Cert)