Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NewKeySet method to JWTAuth and demonstrate JWKS rotation #93

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion _example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,18 @@ import (
"github.com/lestrrat-go/jwx/v2/jwt"
)

type dynamicTokenAuth struct {
keySet []byte
}

func (d *dynamicTokenAuth) JWTAuth() (*jwtauth.JWTAuth, error) {
keySet, err := jwtauth.NewKeySet(d.keySet)
if err != nil {
return nil, err
}
return keySet, nil
}

var tokenAuth *jwtauth.JWTAuth

func init() {
Expand All @@ -76,7 +88,8 @@ func init() {
// For debugging/example purposes, we generate and print
// a sample jwt token with claims `user_id:123` here:
_, tokenString, _ := tokenAuth.Encode(map[string]interface{}{"user_id": 123})
fmt.Printf("DEBUG: a sample jwt is %s\n\n", tokenString)
fmt.Printf("DEBUG: a sample jwt for /admin is %s\n\n", tokenString)
fmt.Printf("DEBUG: a sample jwt for /rotate is %s\n\n", sampleJWTRotate)
}

func main() {
Expand Down Expand Up @@ -105,6 +118,23 @@ func router() http.Handler {
})
})

r.Group(func(r chi.Router) {
dynamicTokenAuth := dynamicTokenAuth{keySet: keySet}
// Seek, verify and validate JWT tokens based on keys returned by the callback function
r.Use(jwtauth.VerifierDynamic(dynamicTokenAuth.JWTAuth))

// Handle valid / invalid tokens. In this example, we use
// the provided authenticator middleware, but you can write your
// own very easily, look at the Authenticator method in jwtauth.go
// and tweak it, its not scary.
r.Use(jwtauth.Authenticator)

r.Get("/rotate", func(w http.ResponseWriter, r *http.Request) {
_, claims, _ := jwtauth.FromContext(r.Context())
w.Write([]byte(fmt.Sprintf("protected area. hi %v", claims["user_id"])))
})
})

// Public routes
r.Group(func(r chi.Router) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -114,3 +144,20 @@ func router() http.Handler {

return r
}

var (
keySet = []byte(`{
"keys": [
{
"kty": "RSA",
"alg": "RS256",
"kid": "kid",
"use": "sig",
"n": "rgzO_v14UXJ33MvccKI8aIw3YpknVJbRB-m1z1X4j3gaTmmzmb7_naEd1TOKhF6Z1BGupvAKhCs8uHtp5e1PCrp52kzrjv7nqQfDpdppPZmKpwf-OD_lVgLLuCljB71mX9w7T5vI_WiVknuNhm48y0TJQNslpDZum4E2e0BLKUDRKKlo25foGoDuQN535_Xso861U8KsA80jX37BJplQ6IHewV_bbe04NYTVqaFcmLaZCAzh2f8L1h4xt76Y0xF_u8FXt2-rgcWlz17CtZzxC8ZXNI_92pX8CY5LY2eQf_B_n5Rhd5TQvEIdoI1GNBrcKUI9pMeEC4pErcOGgKGH7w",
"e": "AQAB"
}
]
}`)

sampleJWTRotate = `eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6ImtpZCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.APC4bUOmfbcXjBnZnmyiGBpXqlboTB4Qbh_sqJrgSU5AEQlwzjvDJ79eBlty8h6kfq3i5ffy87s-g82ZoRsHqMjwCIvTOVnoEyDgVu68s9lE32uaA0cc2-hbA13DIBsyIUGjehh9c3h93BrUoUr7n0CHgoKgx2OEw1Bq8vm4EqvmFGF-mr_0qi32uudPy3I15SyP1NJfU0ogQEFUdDHww3c8omDmrTPiGlWZAl9AiBMroDu0nq3UOtC4d5Se-361NEGiZ9J_kHcVWGdoMwsi5KEB0Uf3wAfXK3wcXeRu1pTXYKOV3X3g_2ss6mh65bNMsSx-MZUnQv5v6qZMOxMBUA`
)
29 changes: 29 additions & 0 deletions jwtauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package jwtauth

import (
"context"
"encoding/json"
"errors"
"net/http"
"strings"
"time"

"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jwt"
)

Expand All @@ -17,6 +19,7 @@ type JWTAuth struct {
verifyKey interface{} // public-key, only used by RSA and ECDSA algorithms
verifier jwt.ParseOption
validateOptions []jwt.ValidateOption
keySet jwk.Set
}

var (
Expand Down Expand Up @@ -50,6 +53,24 @@ func New(alg string, signKey interface{}, verifyKey interface{}, validateOptions
return ja
}

// NewKeySet initializes a new JWTAuth instance with the provided key set.
// It takes a keySet parameter, which is a byte slice containing the key set in JSON format.
// The function returns a pointer to JWTAuth and an error.
// If the key set cannot be unmarshaled from the byte slice, an error is returned.
// Otherwise, the JWTAuth instance is created with the unmarshaled key set and a verifier is set using the key set.
func NewKeySet(keySet []byte) (*JWTAuth, error) {
ks := jwk.NewSet()
err := json.Unmarshal(keySet, &ks)
if err != nil {
return nil, err
}

ja := &JWTAuth{keySet: ks}
ja.verifier = jwt.WithKeySet(ks)

return ja, nil
}

// Verifier http middleware handler will verify a JWT string from a http request.
//
// Verifier will search for a JWT token in a http request, in the order:
Expand Down Expand Up @@ -119,13 +140,21 @@ func VerifyToken(ja *JWTAuth, tokenString string) (jwt.Token, error) {
return token, nil
}

// Encode generates a JWT token string with the provided claims.
// It returns the encoded token as a string, along with the token object and any error encountered.
func (ja *JWTAuth) Encode(claims map[string]interface{}) (t jwt.Token, tokenString string, err error) {
t = jwt.New()
for k, v := range claims {
if err := t.Set(k, v); err != nil {
return nil, "", err
}
}
// ja.sign() isn't going to work if ja.signKey is nil
if ja.signKey == nil {
// This generally means that you've called Encode on a KeySet
// which can't be supported.
return nil, "", errors.New("no signing key provided")
}
payload, err := ja.sign(t)
if err != nil {
return nil, "", err
Expand Down
217 changes: 207 additions & 10 deletions jwtauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (
"testing"
"time"

"github.com/lestrrat-go/jwx/v2/jws"

"github.com/go-chi/chi/v5"
"github.com/go-chi/jwtauth/v5"
"github.com/lestrrat-go/jwx/v2/jwa"
Expand Down Expand Up @@ -41,6 +43,27 @@ MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBALxo3PCjFw4QjgOX06QCJIJBnXXNiEYw
DLxxa5/7QyH6y77nCRQyJ3x3UwF9rUD0RCsp4sNdX5kOQ9PUyHyOtCUCAwEAAQ==
-----END PUBLIC KEY-----
`

KeySet = `{
"keys": [
{
"kty": "RSA",
"n": "vGjc8KMXDhCOA5fTpAIkgkGddc2IRjAMvHFrn_tDIfrLvucJFDInfHdTAX2tQPREKyniw11fmQ5D09TIfI60JQ",
"e": "AQAB",
"alg": "RS256",
"kid": "1",
"use": "sig"
},
{
"kty": "RSA",
"n": "foo",
"e": "AQAB",
"alg": "RS256",
"kid": "2",
"use": "sig"
}
]
}`
)

func init() {
Expand All @@ -51,6 +74,59 @@ func init() {
// Tests
//

func TestNewKeySet(t *testing.T) {
_, err := jwtauth.NewKeySet([]byte("not a valid key set"))
if err == nil {
t.Fatal("The error should not be nil")
}

_, err = jwtauth.NewKeySet([]byte(KeySet))
if err != nil {
t.Fatalf(err.Error())
}
}

func TestKeySetRSA(t *testing.T) {
privateKeyBlock, _ := pem.Decode([]byte(PrivateKeyRS256String))

privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes)

if err != nil {
t.Fatalf(err.Error())
}

KeySetAuth, _ := jwtauth.NewKeySet([]byte(KeySet))
claims := map[string]interface{}{
"key": "val",
"key2": "val2",
"key3": "val3",
}

signed := newJwtRSAToken(jwa.RS256, privateKey, "1", claims)

token, err := KeySetAuth.Decode(signed)

if err != nil {
t.Fatalf("Failed to decode token string %s\n", err.Error())
}

tokenClaims, err := token.AsMap(context.Background())
if err != nil {
t.Fatal(err.Error())
}

if !reflect.DeepEqual(claims, tokenClaims) {
t.Fatalf("The decoded claims don't match the original ones\n")
}

_, _, err = KeySetAuth.Encode(claims)
if err.Error() != "no signing key provided" {
t.Fatalf("Expect error to equal %s. Found: %s.", "no signing key provided", err.Error())
}
fmt.Println(token.PrivateClaims())

}

func TestSimple(t *testing.T) {
r := chi.NewRouter()

Expand Down Expand Up @@ -279,20 +355,118 @@ func TestMore(t *testing.T) {
}
}

func TestEncodeClaims(t *testing.T) {
func TestKeySet(t *testing.T) {
privateKeyBlock, _ := pem.Decode([]byte(PrivateKeyRS256String))
privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes)
if err != nil {
t.Fatalf(err.Error())
}

r := chi.NewRouter()

keySet, err := jwtauth.NewKeySet([]byte(KeySet))
if err != nil {
t.Fatalf(err.Error())
}

// Protected routes
r.Group(func(r chi.Router) {
r.Use(jwtauth.Verifier(keySet))

authenticator := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token, _, err := jwtauth.FromContext(r.Context())

if err != nil {
http.Error(w, jwtauth.ErrorReason(err).Error(), http.StatusUnauthorized)
return
}

if err := jwt.Validate(token); err != nil {
http.Error(w, jwtauth.ErrorReason(err).Error(), http.StatusUnauthorized)
return
}

// Token is authenticated, pass it through
next.ServeHTTP(w, r)
})
}
r.Use(authenticator)

r.Get("/admin", func(w http.ResponseWriter, r *http.Request) {
_, claims, err := jwtauth.FromContext(r.Context())

if err != nil {
w.Write([]byte(fmt.Sprintf("error! %v", err)))
return
}

w.Write([]byte(fmt.Sprintf("protected, user:%v", claims["user_id"])))
})
})

// Public routes
r.Group(func(r chi.Router) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("welcome"))
})
})

ts := httptest.NewServer(r)
defer ts.Close()

h := http.Header{}
h.Set("Authorization", "BEARER "+newJwtRSAToken(jwa.RS256, privateKey, "1", map[string]interface{}{"user_id": 31337, "exp": jwtauth.ExpireIn(5 * time.Minute)}))
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 200 || resp != "protected, user:31337" {
t.Fatalf(resp)
}
}

func TestEncodeInvalidClaim(t *testing.T) {
ja := jwtauth.New("HS256", []byte("secretpass"), nil)
claims := map[string]interface{}{
"key1": "val1",
"key2": 2,
"key3": time.Now(),
"key4": []string{"1", "2"},
"key1": "val1",
"key2": 2,
"key3": time.Now(),
"key4": []string{"1", "2"},
jwt.JwtIDKey: 1, // This is invalid becasue it should be a string
}
claims[jwt.JwtIDKey] = 1
if _, _, err := TokenAuthHS256.Encode(claims); err == nil {
_, _, err := ja.Encode(claims)
if err == nil {

t.Fatal("encoding invalid claims succeeded")
}
claims[jwt.JwtIDKey] = "123"
if _, _, err := TokenAuthHS256.Encode(claims); err != nil {
t.Fatalf("unexpected error encoding valid claims: %v", err)
}
func TestEncode(t *testing.T) {
ja := jwtauth.New("HS256", []byte("secretpass"), nil)

claims := map[string]interface{}{
"sub": "1234567890",
"name": "John Doe",
"iat": 1516239022,
}

token, tokenString, err := ja.Encode(claims)
if err != nil {
t.Fatalf("Failed to encode claims: %s", err.Error())
}

if token == nil {
t.Fatal("Token should not be nil")
}

if tokenString == "" {
t.Fatal("Token string should not be empty")
}

// Verify the token string
verifiedToken, err := ja.Decode(tokenString)
if err != nil {
t.Fatalf("Failed to decode token string: %s", err.Error())
}

if !reflect.DeepEqual(token, verifiedToken) {
t.Fatal("Decoded token does not match the original token")
}
}

Expand Down Expand Up @@ -357,6 +531,29 @@ func newJwt512Token(secret []byte, claims ...map[string]interface{}) string {
return string(tokenPayload)
}

func newJwtRSAToken(alg jwa.SignatureAlgorithm, secret interface{}, kid string, claims ...map[string]interface{}) string {
token := jwt.New()
if len(claims) > 0 {
for k, v := range claims[0] {
token.Set(k, v)
}
}

headers := jws.NewHeaders()
if kid != "" {
err := headers.Set("kid", kid)
if err != nil {
log.Fatal(err)
}
}

tokenPayload, err := jwt.Sign(token, jwt.WithKey(alg, secret, jws.WithProtectedHeaders(headers)))
if err != nil {
log.Fatal(err)
}
return string(tokenPayload)
}

func newAuthHeader(claims ...map[string]interface{}) http.Header {
h := http.Header{}
h.Set("Authorization", "BEARER "+newJwtToken(TokenSecret, claims...))
Expand Down
Loading