From 7552877b95a3881fb2ad74381a7d7d5eec16a223 Mon Sep 17 00:00:00 2001 From: Alex Lovell-Troy Date: Mon, 19 Aug 2024 14:37:13 -0400 Subject: [PATCH] Add NewKeySet method to JWTAuth This commit adds support for KeySets through a new method `NewKeySet` to the `JWTAuth` struct. It includes tests and comments that seek to explain how it works inline. There's also an example in the _example directory that shows how to use and rotate a KeySet. --- _example/main.go | 49 ++++++++++- jwtauth.go | 29 +++++++ jwtauth_test.go | 217 ++++++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 284 insertions(+), 11 deletions(-) diff --git a/_example/main.go b/_example/main.go index a6d7559..082beb0 100644 --- a/_example/main.go +++ b/_example/main.go @@ -68,6 +68,18 @@ import ( "github.com/lestrrat-go/jwx/v2/jwt" ) +type dynamicTokenAuth struct { + keySet []byte +} + +func (d *dynamicTokenAuth) JWTAuth() (*jwtauth.JWTAuth, error) { + keySet, err := jwtauth.NewKeySet(d.keySet) + if err != nil { + return nil, err + } + return keySet, nil +} + var tokenAuth *jwtauth.JWTAuth func init() { @@ -76,7 +88,8 @@ func init() { // For debugging/example purposes, we generate and print // a sample jwt token with claims `user_id:123` here: _, tokenString, _ := tokenAuth.Encode(map[string]interface{}{"user_id": 123}) - fmt.Printf("DEBUG: a sample jwt is %s\n\n", tokenString) + fmt.Printf("DEBUG: a sample jwt for /admin is %s\n\n", tokenString) + fmt.Printf("DEBUG: a sample jwt for /rotate is %s\n\n", sampleJWTRotate) } func main() { @@ -105,6 +118,23 @@ func router() http.Handler { }) }) + r.Group(func(r chi.Router) { + dynamicTokenAuth := dynamicTokenAuth{keySet: keySet} + // Seek, verify and validate JWT tokens based on keys returned by the callback function + r.Use(jwtauth.VerifierDynamic(dynamicTokenAuth.JWTAuth)) + + // Handle valid / invalid tokens. In this example, we use + // the provided authenticator middleware, but you can write your + // own very easily, look at the Authenticator method in jwtauth.go + // and tweak it, its not scary. + r.Use(jwtauth.Authenticator) + + r.Get("/rotate", func(w http.ResponseWriter, r *http.Request) { + _, claims, _ := jwtauth.FromContext(r.Context()) + w.Write([]byte(fmt.Sprintf("protected area. hi %v", claims["user_id"]))) + }) + }) + // Public routes r.Group(func(r chi.Router) { r.Get("/", func(w http.ResponseWriter, r *http.Request) { @@ -114,3 +144,20 @@ func router() http.Handler { return r } + +var ( + keySet = []byte(`{ + "keys": [ + { + "kty": "RSA", + "alg": "RS256", + "kid": "kid", + "use": "sig", + "n": "rgzO_v14UXJ33MvccKI8aIw3YpknVJbRB-m1z1X4j3gaTmmzmb7_naEd1TOKhF6Z1BGupvAKhCs8uHtp5e1PCrp52kzrjv7nqQfDpdppPZmKpwf-OD_lVgLLuCljB71mX9w7T5vI_WiVknuNhm48y0TJQNslpDZum4E2e0BLKUDRKKlo25foGoDuQN535_Xso861U8KsA80jX37BJplQ6IHewV_bbe04NYTVqaFcmLaZCAzh2f8L1h4xt76Y0xF_u8FXt2-rgcWlz17CtZzxC8ZXNI_92pX8CY5LY2eQf_B_n5Rhd5TQvEIdoI1GNBrcKUI9pMeEC4pErcOGgKGH7w", + "e": "AQAB" + } + ] +}`) + + sampleJWTRotate = `eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6ImtpZCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.APC4bUOmfbcXjBnZnmyiGBpXqlboTB4Qbh_sqJrgSU5AEQlwzjvDJ79eBlty8h6kfq3i5ffy87s-g82ZoRsHqMjwCIvTOVnoEyDgVu68s9lE32uaA0cc2-hbA13DIBsyIUGjehh9c3h93BrUoUr7n0CHgoKgx2OEw1Bq8vm4EqvmFGF-mr_0qi32uudPy3I15SyP1NJfU0ogQEFUdDHww3c8omDmrTPiGlWZAl9AiBMroDu0nq3UOtC4d5Se-361NEGiZ9J_kHcVWGdoMwsi5KEB0Uf3wAfXK3wcXeRu1pTXYKOV3X3g_2ss6mh65bNMsSx-MZUnQv5v6qZMOxMBUA` +) diff --git a/jwtauth.go b/jwtauth.go index 6199448..b7286a9 100644 --- a/jwtauth.go +++ b/jwtauth.go @@ -2,12 +2,14 @@ package jwtauth import ( "context" + "encoding/json" "errors" "net/http" "strings" "time" "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwt" ) @@ -17,6 +19,7 @@ type JWTAuth struct { verifyKey interface{} // public-key, only used by RSA and ECDSA algorithms verifier jwt.ParseOption validateOptions []jwt.ValidateOption + keySet jwk.Set } var ( @@ -50,6 +53,24 @@ func New(alg string, signKey interface{}, verifyKey interface{}, validateOptions return ja } +// NewKeySet initializes a new JWTAuth instance with the provided key set. +// It takes a keySet parameter, which is a byte slice containing the key set in JSON format. +// The function returns a pointer to JWTAuth and an error. +// If the key set cannot be unmarshaled from the byte slice, an error is returned. +// Otherwise, the JWTAuth instance is created with the unmarshaled key set and a verifier is set using the key set. +func NewKeySet(keySet []byte) (*JWTAuth, error) { + ks := jwk.NewSet() + err := json.Unmarshal(keySet, &ks) + if err != nil { + return nil, err + } + + ja := &JWTAuth{keySet: ks} + ja.verifier = jwt.WithKeySet(ks) + + return ja, nil +} + // Verifier http middleware handler will verify a JWT string from a http request. // // Verifier will search for a JWT token in a http request, in the order: @@ -119,6 +140,8 @@ func VerifyToken(ja *JWTAuth, tokenString string) (jwt.Token, error) { return token, nil } +// Encode generates a JWT token string with the provided claims. +// It returns the encoded token as a string, along with the token object and any error encountered. func (ja *JWTAuth) Encode(claims map[string]interface{}) (t jwt.Token, tokenString string, err error) { t = jwt.New() for k, v := range claims { @@ -126,6 +149,12 @@ func (ja *JWTAuth) Encode(claims map[string]interface{}) (t jwt.Token, tokenStri return nil, "", err } } + // ja.sign() isn't going to work if ja.signKey is nil + if ja.signKey == nil { + // This generally means that you've called Encode on a KeySet + // which can't be supported. + return nil, "", errors.New("no signing key provided") + } payload, err := ja.sign(t) if err != nil { return nil, "", err diff --git a/jwtauth_test.go b/jwtauth_test.go index d9e2c52..60e736d 100644 --- a/jwtauth_test.go +++ b/jwtauth_test.go @@ -13,6 +13,8 @@ import ( "testing" "time" + "github.com/lestrrat-go/jwx/v2/jws" + "github.com/go-chi/chi/v5" "github.com/go-chi/jwtauth/v5" "github.com/lestrrat-go/jwx/v2/jwa" @@ -41,6 +43,27 @@ MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBALxo3PCjFw4QjgOX06QCJIJBnXXNiEYw DLxxa5/7QyH6y77nCRQyJ3x3UwF9rUD0RCsp4sNdX5kOQ9PUyHyOtCUCAwEAAQ== -----END PUBLIC KEY----- ` + + KeySet = `{ + "keys": [ + { + "kty": "RSA", + "n": "vGjc8KMXDhCOA5fTpAIkgkGddc2IRjAMvHFrn_tDIfrLvucJFDInfHdTAX2tQPREKyniw11fmQ5D09TIfI60JQ", + "e": "AQAB", + "alg": "RS256", + "kid": "1", + "use": "sig" + }, + { + "kty": "RSA", + "n": "foo", + "e": "AQAB", + "alg": "RS256", + "kid": "2", + "use": "sig" + } + ] +}` ) func init() { @@ -51,6 +74,59 @@ func init() { // Tests // +func TestNewKeySet(t *testing.T) { + _, err := jwtauth.NewKeySet([]byte("not a valid key set")) + if err == nil { + t.Fatal("The error should not be nil") + } + + _, err = jwtauth.NewKeySet([]byte(KeySet)) + if err != nil { + t.Fatalf(err.Error()) + } +} + +func TestKeySetRSA(t *testing.T) { + privateKeyBlock, _ := pem.Decode([]byte(PrivateKeyRS256String)) + + privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes) + + if err != nil { + t.Fatalf(err.Error()) + } + + KeySetAuth, _ := jwtauth.NewKeySet([]byte(KeySet)) + claims := map[string]interface{}{ + "key": "val", + "key2": "val2", + "key3": "val3", + } + + signed := newJwtRSAToken(jwa.RS256, privateKey, "1", claims) + + token, err := KeySetAuth.Decode(signed) + + if err != nil { + t.Fatalf("Failed to decode token string %s\n", err.Error()) + } + + tokenClaims, err := token.AsMap(context.Background()) + if err != nil { + t.Fatal(err.Error()) + } + + if !reflect.DeepEqual(claims, tokenClaims) { + t.Fatalf("The decoded claims don't match the original ones\n") + } + + _, _, err = KeySetAuth.Encode(claims) + if err.Error() != "no signing key provided" { + t.Fatalf("Expect error to equal %s. Found: %s.", "no signing key provided", err.Error()) + } + fmt.Println(token.PrivateClaims()) + +} + func TestSimple(t *testing.T) { r := chi.NewRouter() @@ -279,20 +355,118 @@ func TestMore(t *testing.T) { } } -func TestEncodeClaims(t *testing.T) { +func TestKeySet(t *testing.T) { + privateKeyBlock, _ := pem.Decode([]byte(PrivateKeyRS256String)) + privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes) + if err != nil { + t.Fatalf(err.Error()) + } + + r := chi.NewRouter() + + keySet, err := jwtauth.NewKeySet([]byte(KeySet)) + if err != nil { + t.Fatalf(err.Error()) + } + + // Protected routes + r.Group(func(r chi.Router) { + r.Use(jwtauth.Verifier(keySet)) + + authenticator := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token, _, err := jwtauth.FromContext(r.Context()) + + if err != nil { + http.Error(w, jwtauth.ErrorReason(err).Error(), http.StatusUnauthorized) + return + } + + if err := jwt.Validate(token); err != nil { + http.Error(w, jwtauth.ErrorReason(err).Error(), http.StatusUnauthorized) + return + } + + // Token is authenticated, pass it through + next.ServeHTTP(w, r) + }) + } + r.Use(authenticator) + + r.Get("/admin", func(w http.ResponseWriter, r *http.Request) { + _, claims, err := jwtauth.FromContext(r.Context()) + + if err != nil { + w.Write([]byte(fmt.Sprintf("error! %v", err))) + return + } + + w.Write([]byte(fmt.Sprintf("protected, user:%v", claims["user_id"]))) + }) + }) + + // Public routes + r.Group(func(r chi.Router) { + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("welcome")) + }) + }) + + ts := httptest.NewServer(r) + defer ts.Close() + + h := http.Header{} + h.Set("Authorization", "BEARER "+newJwtRSAToken(jwa.RS256, privateKey, "1", map[string]interface{}{"user_id": 31337, "exp": jwtauth.ExpireIn(5 * time.Minute)})) + if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 200 || resp != "protected, user:31337" { + t.Fatalf(resp) + } +} + +func TestEncodeInvalidClaim(t *testing.T) { + ja := jwtauth.New("HS256", []byte("secretpass"), nil) claims := map[string]interface{}{ - "key1": "val1", - "key2": 2, - "key3": time.Now(), - "key4": []string{"1", "2"}, + "key1": "val1", + "key2": 2, + "key3": time.Now(), + "key4": []string{"1", "2"}, + jwt.JwtIDKey: 1, // This is invalid becasue it should be a string } - claims[jwt.JwtIDKey] = 1 - if _, _, err := TokenAuthHS256.Encode(claims); err == nil { + _, _, err := ja.Encode(claims) + if err == nil { + t.Fatal("encoding invalid claims succeeded") } - claims[jwt.JwtIDKey] = "123" - if _, _, err := TokenAuthHS256.Encode(claims); err != nil { - t.Fatalf("unexpected error encoding valid claims: %v", err) +} +func TestEncode(t *testing.T) { + ja := jwtauth.New("HS256", []byte("secretpass"), nil) + + claims := map[string]interface{}{ + "sub": "1234567890", + "name": "John Doe", + "iat": 1516239022, + } + + token, tokenString, err := ja.Encode(claims) + if err != nil { + t.Fatalf("Failed to encode claims: %s", err.Error()) + } + + if token == nil { + t.Fatal("Token should not be nil") + } + + if tokenString == "" { + t.Fatal("Token string should not be empty") + } + + // Verify the token string + verifiedToken, err := ja.Decode(tokenString) + if err != nil { + t.Fatalf("Failed to decode token string: %s", err.Error()) + } + + if !reflect.DeepEqual(token, verifiedToken) { + t.Fatal("Decoded token does not match the original token") } } @@ -357,6 +531,29 @@ func newJwt512Token(secret []byte, claims ...map[string]interface{}) string { return string(tokenPayload) } +func newJwtRSAToken(alg jwa.SignatureAlgorithm, secret interface{}, kid string, claims ...map[string]interface{}) string { + token := jwt.New() + if len(claims) > 0 { + for k, v := range claims[0] { + token.Set(k, v) + } + } + + headers := jws.NewHeaders() + if kid != "" { + err := headers.Set("kid", kid) + if err != nil { + log.Fatal(err) + } + } + + tokenPayload, err := jwt.Sign(token, jwt.WithKey(alg, secret, jws.WithProtectedHeaders(headers))) + if err != nil { + log.Fatal(err) + } + return string(tokenPayload) +} + func newAuthHeader(claims ...map[string]interface{}) http.Header { h := http.Header{} h.Set("Authorization", "BEARER "+newJwtToken(TokenSecret, claims...))