diff --git a/bench/performance/go.sum b/bench/performance/go.sum index ac6e2155e..e3e3d5e72 100644 --- a/bench/performance/go.sum +++ b/bench/performance/go.sum @@ -6,8 +6,8 @@ github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 h1:8UrgZ3GkP4i/CLijOJx79Yu+etly github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= -github.com/lestrrat-go/blackmagic v1.0.1 h1:lS5Zts+5HIC/8og6cGHb0uCcNCa3OUt1ygh3Qz2Fe80= -github.com/lestrrat-go/blackmagic v1.0.1/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU= +github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k= +github.com/lestrrat-go/blackmagic v1.0.2/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU= github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE= github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E= github.com/lestrrat-go/httprc v1.0.4 h1:bAZymwoZQb+Oq8MEbyipag7iSq6YIga8Wj6GOiJGdI8= diff --git a/cmd/jwx/go.mod b/cmd/jwx/go.mod index 4113b5e82..dc4128bd5 100644 --- a/cmd/jwx/go.mod +++ b/cmd/jwx/go.mod @@ -12,7 +12,7 @@ require ( github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect github.com/goccy/go-json v0.10.2 // indirect - github.com/lestrrat-go/blackmagic v1.0.1 // indirect + github.com/lestrrat-go/blackmagic v1.0.2 // indirect github.com/lestrrat-go/httpcc v1.0.1 // indirect github.com/lestrrat-go/httprc v1.0.4 // indirect github.com/lestrrat-go/iter v1.0.2 // indirect diff --git a/cmd/jwx/go.sum b/cmd/jwx/go.sum index bf04325b7..0dd36f671 100644 --- a/cmd/jwx/go.sum +++ b/cmd/jwx/go.sum @@ -9,8 +9,8 @@ github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 h1:8UrgZ3GkP4i/CLijOJx79Yu+etly github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= -github.com/lestrrat-go/blackmagic v1.0.1 h1:lS5Zts+5HIC/8og6cGHb0uCcNCa3OUt1ygh3Qz2Fe80= -github.com/lestrrat-go/blackmagic v1.0.1/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU= +github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k= +github.com/lestrrat-go/blackmagic v1.0.2/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU= github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE= github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E= github.com/lestrrat-go/httprc v1.0.4 h1:bAZymwoZQb+Oq8MEbyipag7iSq6YIga8Wj6GOiJGdI8= diff --git a/docs/01-jwt.md b/docs/01-jwt.md index 4735bf1c0..01f7f4944 100644 --- a/docs/01-jwt.md +++ b/docs/01-jwt.md @@ -1234,7 +1234,8 @@ Please [look at the JWS documentation for it](./02-jws.md#parse-a-jws-message-an Any field in the token can be accessed in an uniform away using `(jwt.Token).Get()` ```go -v, ok := token.Get(name) +var v interface{} // can be concrete type, if you know the type beforehand +err := token.Get(name, &v) ``` If the field corresponding to `name` does not exist, the second return value will be `false`. diff --git a/docs/20-global-settings.md b/docs/20-global-settings.md index 3368a616c..23cb8cf38 100644 --- a/docs/20-global-settings.md +++ b/docs/20-global-settings.md @@ -79,8 +79,8 @@ This tells the decoder that when it encounters a JWT token with the field named access this value by using `Get()` ```go -v, _ := token.Get(`x-foo-bar`) -foobar := v.(mypkg.FooBar) +var v mypkg.FooBar +_ = token.Get(`x-foo-bar`, &v) ``` Do be aware that this has *global* effect. In the above example, all JWT tokens containing diff --git a/examples/jwe_example_test.go b/examples/jwe_example_test.go index 19f406af0..0a6e09cb1 100644 --- a/examples/jwe_example_test.go +++ b/examples/jwe_example_test.go @@ -82,24 +82,24 @@ func ExampleJWE_ComplexDecrypt() { // I would personally recommend creating a real type for your specific needs // instead of passing adhoc closures. YMMV. kp := func(ctx context.Context, sink jwe.KeySink, _ jwe.Recipient, msg *jwe.Message) error { - rawhint, _ := msg.ProtectedHeaders().Get(`jwx-hints`) - //nolint:forcetypeassert - hint, ok := rawhint.(string) - if ok && hint == `foobar` { - // This is where we are setting the key to be used. - // - // In real life you would look up the key or something. - // Here we just assign the key to use. - // - // You may opt to set both the algorithm and key here as well. - // BUT BE CAREFUL so that you don't accidentally create a - // vulnerability - sink.Key(jwa.RSA_OAEP, privkey) - return nil + var hint string + if err := msg.ProtectedHeaders().Get(`jwx-hints`, &hint); err == nil { + if hint == `foobar` { + // This is where we are setting the key to be used. + // + // In real life you would look up the key or something. + // Here we just assign the key to use. + // + // You may opt to set both the algorithm and key here as well. + // BUT BE CAREFUL so that you don't accidentally create a + // vulnerability + sink.Key(jwa.RSA_OAEP, privkey) + return nil + } } // If there were errors, just return it, and the whole jwe.Decrypt will fail. - return fmt.Errorf(`invalid value for jwx-hints: %s`, rawhint) + return fmt.Errorf(`invalid value for jwx-hints`) } // Calling jwe.Decrypt with the extra argument of jwe.WithPostParser(). diff --git a/examples/jwt_example_test.go b/examples/jwt_example_test.go index a2889c4c9..a0c35b382 100644 --- a/examples/jwt_example_test.go +++ b/examples/jwt_example_test.go @@ -126,15 +126,18 @@ func ExampleJWT_Sign_WithImportJWK() { fmt.Printf("%s\n", buf) - if v, ok := t.Get(`privateClaimKey`); ok { - fmt.Printf("privateClaimKey -> '%s'\n", v) + var pc string + if err := t.Get(`privateClaimKey`, &pc); err != nil { + fmt.Printf("failed to fetch private claim\n") + return } + fmt.Printf("privateClaimKey -> '%s'\n", pc) //convert jwk in bytes and return a new key jwkey, err := jwk.ParseKey([]byte(jwkStr)) - if err != nil { - log.Fatal("erro") + fmt.Printf("failed to parse key: %s\n", err) + return } // signed and return a jwt @@ -205,9 +208,12 @@ func ExampleJWT_Token() { fmt.Printf("%s\n", buf) fmt.Printf("aud -> '%s'\n", t.Audience()) fmt.Printf("iat -> '%s'\n", t.IssuedAt().Format(time.RFC3339)) - if v, ok := t.Get(`privateClaimKey`); ok { - fmt.Printf("privateClaimKey -> '%s'\n", v) + var pc string + if err := t.Get(`privateClaimKey`, &pc); err != nil { + fmt.Printf("failed to fetch private claim\n") + return } + fmt.Printf("privateClaimKey -> '%s'\n", pc) fmt.Printf("sub -> '%s'\n", t.Subject()) // OUTPUT: diff --git a/examples/jwt_get_claims_example_test.go b/examples/jwt_get_claims_example_test.go index acb4857cb..32cd10895 100644 --- a/examples/jwt_get_claims_example_test.go +++ b/examples/jwt_get_claims_example_test.go @@ -26,19 +26,23 @@ func ExampleJWT_GetClaims() { var _ string = tok.Issuer() var _ string = tok.Subject() - var v interface{} - var ok bool - // But you can also get them via the generic `.Get()` method. - // However, v is of type interface{}, so you might need to - // use a type switch to properly use its value. + // However, you would need to decide for yourself what the + // return type is. If you don't need the exact type, you could + // use interface{}, or you could use the specific time.Time + // type // // For the key name you could also use jwt.IssuedAtKey constant - v, ok = tok.Get(`iat`) + var iat time.Time + _ = tok.Get(`iat`, &iat) + + // var iat interface{} would also work, but you would need to + // convert the type if you need time.Time specific behavior // Private claims - v, ok = tok.Get(`claim1`) - v, ok = tok.Get(`claim2`) + var dummy interface{} + _ = tok.Get(`claim1`, &dummy) + _ = tok.Get(`claim2`, &dummy) // However, it is possible to globally specify that a private // claim should be parsed into a custom type. @@ -50,18 +54,13 @@ func ExampleJWT_GetClaims() { fmt.Printf(`failed to parse token: %s`, err) return } - v, ok = tok.Get(`claim2`) - if !ok { - fmt.Printf(`failed to get private claim "claim2"`) - return - } - if _, ok := v.(time.Time); !ok { - fmt.Printf(`claim2 expected to be time.Time, but got %T`, v) + + // now you can use the exact type + var claim2 time.Time + if err := tok.Get(`claim2`, &claim2); err != nil { + fmt.Printf("failed to get private claim \"claim2\": %s\n", err) return } - _ = v - _ = ok - // OUTPUT: } diff --git a/jwe/gh402_test.go b/jwe/gh402_test.go index 6d7011652..95dadf5d8 100644 --- a/jwe/gh402_test.go +++ b/jwe/gh402_test.go @@ -8,6 +8,7 @@ import ( "github.com/lestrrat-go/jwx/v3/jwa" "github.com/lestrrat-go/jwx/v3/jwe" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // Pin represents the structured clevis data which can be used to decrypt the jwe message @@ -86,14 +87,8 @@ func TestGH402(t *testing.T) { return } - v, ok := m.ProtectedHeaders().Get("clevis") - if !assert.True(t, ok, `m.Get("clevis") should be true`) { - return - } - - if !assert.IsType(t, Pin{}, v, `result of m.Get("clevis") should be an instance of Pin{}`) { - return - } + var v Pin + require.NoError(t, m.ProtectedHeaders().Get("clevis", &v), `m.Get("clevis") should be succeed`) } } decrypt(false) diff --git a/jwe/headers_gen.go b/jwe/headers_gen.go index 329b0170f..cfa81dfb0 100644 --- a/jwe/headers_gen.go +++ b/jwe/headers_gen.go @@ -9,6 +9,7 @@ import ( "sort" "sync" + "github.com/lestrrat-go/blackmagic" "github.com/lestrrat-go/jwx/v3/cert" "github.com/lestrrat-go/jwx/v3/internal/base64" "github.com/lestrrat-go/jwx/v3/internal/json" @@ -36,7 +37,14 @@ const ( X509URLKey = "x5u" ) -// Headers describe a standard Header set. +// Headers describe a standard JWE Header set. It is part of the JWE message +// and is used to represent both Protected and Unprotected headers, +// which in turn can be found in each Recipient object. +// If you are not sure how this works, it is strongly recommended that +// you read RFC7516, especially the section +// that describes the full JSON serialization format of JWE messages. +// +// In most cases, you likely want to use the protected headers, as this is the part of the encrypted content type Headers interface { json.Marshaler json.Unmarshaler @@ -59,9 +67,21 @@ type Headers interface { Iterate(ctx context.Context) Iterator Walk(ctx context.Context, v Visitor) error AsMap(ctx context.Context) (map[string]interface{}, error) - Get(string) (interface{}, bool) + + // Get is used to extract the value of any field, including non-standard fields, out of the header. + // + // The first argument is the name of the field. The second argument is a pointer + // to a variable that will receive the value of the field. The method returns + // an error if the field does not exist, or if the value cannot be assigned to + // the destination variable. Note that a field is considered to "exist" even if + // the value is empty-ish (e.g. 0, false, ""), as long as it is explicitly set. + Get(string, interface{}) error Set(string, interface{}) error Remove(string) error + // Has returns true if the specified header has a value, even if + // the value is empty-ish (e.g. 0, false, "") as long as it has been + // explicitly set. + Has(string) bool Encode() ([]byte, error) Decode([]byte) error // PrivateParams returns the map containing the non-standard ('private') parameters @@ -292,94 +312,174 @@ func (h *stdHeaders) PrivateParams() map[string]interface{} { return h.privateParams } -func (h *stdHeaders) Get(name string) (interface{}, bool) { +func (h *stdHeaders) Has(name string) bool { + h.mu.RLock() + defer h.mu.RUnlock() + switch name { + case AgreementPartyUInfoKey: + return h.agreementPartyUInfo != nil + case AgreementPartyVInfoKey: + return h.agreementPartyVInfo != nil + case AlgorithmKey: + return h.algorithm != nil + case CompressionKey: + return h.compression != nil + case ContentEncryptionKey: + return h.contentEncryption != nil + case ContentTypeKey: + return h.contentType != nil + case CriticalKey: + return h.critical != nil + case EphemeralPublicKeyKey: + return h.ephemeralPublicKey != nil + case JWKKey: + return h.jwk != nil + case JWKSetURLKey: + return h.jwkSetURL != nil + case KeyIDKey: + return h.keyID != nil + case TypeKey: + return h.typ != nil + case X509CertChainKey: + return h.x509CertChain != nil + case X509CertThumbprintKey: + return h.x509CertThumbprint != nil + case X509CertThumbprintS256Key: + return h.x509CertThumbprintS256 != nil + case X509URLKey: + return h.x509URL != nil + default: + _, ok := h.privateParams[name] + return ok + } +} + +func (h *stdHeaders) Get(name string, dst interface{}) error { h.mu.RLock() defer h.mu.RUnlock() switch name { case AgreementPartyUInfoKey: if h.agreementPartyUInfo == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, h.agreementPartyUInfo); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return h.agreementPartyUInfo, true case AgreementPartyVInfoKey: if h.agreementPartyVInfo == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, h.agreementPartyVInfo); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return h.agreementPartyVInfo, true case AlgorithmKey: if h.algorithm == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.algorithm)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.algorithm), true case CompressionKey: if h.compression == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.compression)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.compression), true case ContentEncryptionKey: if h.contentEncryption == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.contentEncryption)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.contentEncryption), true case ContentTypeKey: if h.contentType == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.contentType)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.contentType), true case CriticalKey: if h.critical == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, h.critical); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return h.critical, true case EphemeralPublicKeyKey: if h.ephemeralPublicKey == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, h.ephemeralPublicKey); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return h.ephemeralPublicKey, true case JWKKey: if h.jwk == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, h.jwk); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return h.jwk, true case JWKSetURLKey: if h.jwkSetURL == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.jwkSetURL)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.jwkSetURL), true case KeyIDKey: if h.keyID == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.keyID)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.keyID), true case TypeKey: if h.typ == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.typ)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.typ), true case X509CertChainKey: if h.x509CertChain == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, h.x509CertChain); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return h.x509CertChain, true case X509CertThumbprintKey: if h.x509CertThumbprint == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.x509CertThumbprint)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.x509CertThumbprint), true case X509CertThumbprintS256Key: if h.x509CertThumbprintS256 == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.x509CertThumbprintS256)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.x509CertThumbprintS256), true case X509URLKey: if h.x509URL == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.x509URL)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.x509URL), true default: v, ok := h.privateParams[name] - return v, ok + if !ok { + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, v); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } } + return nil } func (h *stdHeaders) Set(name string, value interface{}) error { diff --git a/jwe/headers_test.go b/jwe/headers_test.go index ce7469745..45cc8fdd8 100644 --- a/jwe/headers_test.go +++ b/jwe/headers_test.go @@ -11,6 +11,7 @@ import ( "github.com/lestrrat-go/jwx/v3/jwe" "github.com/lestrrat-go/jwx/v3/jwk" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var zeroval reflect.Value @@ -135,8 +136,8 @@ func TestHeaders(t *testing.T) { } for _, tc := range data { var values []interface{} - viaGet, ok := h.Get(tc.Key) - if !assert.True(t, ok, "value for %s should exist", tc.Key) { + var viaGet interface{} + if !assert.NoError(t, h.Get(tc.Key, &viaGet), `h.Get should be successful`) { return } values = append(values, viaGet) @@ -218,13 +219,9 @@ func TestHeaders(t *testing.T) { pair := iter.Pair() seen[pair.Key.(string)] = pair.Value - getV, ok := v.Get(pair.Key.(string)) - if !assert.True(t, ok, `v.Get should succeed for key %#v`, pair.Key) { - return - } - if !assert.Equal(t, pair.Value, getV, `pair.Value should match value from v.Get()`) { - return - } + var getV interface{} + require.NoError(t, v.Get(pair.Key.(string), &getV), `v.Get should succeed for key %#v`, pair.Key) + require.Equal(t, pair.Value, getV, `pair.Value should match value from v.Get()`) } if !assert.Equal(t, expected, seen, `values should match`) { return diff --git a/jwe/jwe.go b/jwe/jwe.go index 39bbe3ebc..f4af2de11 100644 --- a/jwe/jwe.go +++ b/jwe/jwe.go @@ -602,11 +602,11 @@ func (dctx *decryptCtx) decryptContent(ctx context.Context, alg jwa.KeyEncryptio switch alg { case jwa.ECDH_ES, jwa.ECDH_ES_A128KW, jwa.ECDH_ES_A192KW, jwa.ECDH_ES_A256KW: - epkif, ok := h2.Get(EphemeralPublicKeyKey) - if !ok { - return nil, fmt.Errorf(`failed to get 'epk' field`) + var epk interface{} + if err := h2.Get(EphemeralPublicKeyKey, &epk); err != nil { + return nil, fmt.Errorf(`failed to get 'epk' field: %w`, err) } - switch epk := epkif.(type) { + switch epk := epk.(type) { case jwk.ECDSAPublicKey: var pubkey ecdsa.PublicKey if err := epk.Raw(&pubkey); err != nil { @@ -620,7 +620,7 @@ func (dctx *decryptCtx) decryptContent(ctx context.Context, alg jwa.KeyEncryptio } dec.PublicKey(pubkey) default: - return nil, fmt.Errorf("unexpected 'epk' type %T for alg %s", epkif, alg) + return nil, fmt.Errorf("unexpected 'epk' type %T for alg %s", epk, alg) } if apu := h2.AgreementPartyUInfo(); len(apu) > 0 { @@ -630,54 +630,38 @@ func (dctx *decryptCtx) decryptContent(ctx context.Context, alg jwa.KeyEncryptio dec.AgreementPartyVInfo(apv) } case jwa.A128GCMKW, jwa.A192GCMKW, jwa.A256GCMKW: - ivB64, ok := h2.Get(InitializationVectorKey) - if ok { - ivB64Str, ok := ivB64.(string) - if !ok { - return nil, fmt.Errorf("unexpected type for 'iv': %T", ivB64) - } - iv, err := base64.DecodeString(ivB64Str) + var ivB64 string + if err := h2.Get(InitializationVectorKey, &ivB64); err == nil { + iv, err := base64.DecodeString(ivB64) if err != nil { return nil, fmt.Errorf(`failed to b64-decode 'iv': %w`, err) } dec.KeyInitializationVector(iv) } - tagB64, ok := h2.Get(TagKey) - if ok { - tagB64Str, ok := tagB64.(string) - if !ok { - return nil, fmt.Errorf("unexpected type for 'tag': %T", tagB64) - } - tag, err := base64.DecodeString(tagB64Str) + var tagB64 string + if err := h2.Get(TagKey, &tagB64); err == nil { + tag, err := base64.DecodeString(tagB64) if err != nil { return nil, fmt.Errorf(`failed to b64-decode 'tag': %w`, err) } dec.KeyTag(tag) } case jwa.PBES2_HS256_A128KW, jwa.PBES2_HS384_A192KW, jwa.PBES2_HS512_A256KW: - saltB64, ok := h2.Get(SaltKey) - if !ok { - return nil, fmt.Errorf(`failed to get 'p2s' field`) - } - saltB64Str, ok := saltB64.(string) - if !ok { - return nil, fmt.Errorf("unexpected type for 'p2s': %T", saltB64) + var saltB64 string + if err := h2.Get(SaltKey, &saltB64); err != nil { + return nil, fmt.Errorf(`failed to get %q field`, SaltKey) } - count, ok := h2.Get(CountKey) - if !ok { - return nil, fmt.Errorf(`failed to get 'p2c' field`) - } - countFlt, ok := count.(float64) - if !ok { - return nil, fmt.Errorf("unexpected type for 'p2c': %T", count) + var count float64 + if err := h2.Get(CountKey, &count); err != nil { + return nil, fmt.Errorf(`failed to get %q field`, CountKey) } - salt, err := base64.DecodeString(saltB64Str) + salt, err := base64.DecodeString(saltB64) if err != nil { return nil, fmt.Errorf(`failed to b64-decode 'salt': %w`, err) } dec.KeySalt(salt) - dec.KeyCount(int(countFlt)) + dec.KeyCount(int(count)) } plaintext, err := dec.Decrypt(recipient, dctx.msg.cipherText, dctx.msg) @@ -810,15 +794,16 @@ func parseCompact(buf []byte, storeProtectedHeaders bool) (*Message, error) { // you want to represent as a string formatted in RFC3339 in JSON, // but want it back as `time.Time`. // -// In that case you would register a custom field as follows +// In such case you would register a custom field as follows // -// jwe.RegisterCustomField(`x-birthday`, timeT) +// jws.RegisterCustomField(`x-birthday`, time.Time{}) // -// Then `hdr.Get("x-birthday")` will still return an `interface{}`, -// but you can convert its type to `time.Time` +// Then you can use a `time.Time` variable to extract the value +// of `x-birthday` field, instead of having to use `interface{}` +// and later convert it to `time.Time` // -// bdayif, _ := hdr.Get(`x-birthday`) -// bday := bdayif.(time.Time) +// var bday time.Time +// _ = hdr.Get(`x-birthday`, &bday) func RegisterCustomField(name string, object interface{}) { registry.Register(name, object) } diff --git a/jwe/jwe_test.go b/jwe/jwe_test.go index 66e02d3d0..865931f4c 100644 --- a/jwe/jwe_test.go +++ b/jwe/jwe_test.go @@ -721,11 +721,8 @@ func TestCustomField(t *testing.T) { return } - v, ok := msg.ProtectedHeaders().Get(`x-birthday`) - if !assert.True(t, ok, `msg.ProtectedHeaders().Get("x-birthday") should succeed`) { - return - } - + var v time.Time + require.NoError(t, msg.ProtectedHeaders().Get(`x-birthday`, &v), `msg.ProtectedHeaders().Get("x-birthday") should succeed`) if !assert.Equal(t, expected, v, `values should match`) { return } @@ -741,11 +738,8 @@ func TestCustomField(t *testing.T) { return } - v, ok = msg2.ProtectedHeaders().Get(`x-birthday`) - if !assert.True(t, ok, `msg2.ProtectedHeaders().Get("x-birthday") should succeed`) { - return - } - + v = time.Time{} // reset + require.NoError(t, msg2.ProtectedHeaders().Get(`x-birthday`, &v), `msg2.ProtectedHeaders().Get("x-birthday") should succeed`) if !assert.Equal(t, expected, v, `values should match`) { return } diff --git a/jwk/cache.go b/jwk/cache.go index 5d5b6b90b..1cdb63313 100644 --- a/jwk/cache.go +++ b/jwk/cache.go @@ -345,13 +345,13 @@ func (cs *CachedSet) Clone() (Set, error) { } // Get returns the value of non-Key field stored in the jwk.Set -func (cs *CachedSet) Get(name string) (interface{}, bool) { +func (cs *CachedSet) Get(name string, dst interface{}) error { set, err := cs.cached() if err != nil { - return nil, false + return err } - return set.Get(name) + return set.Get(name, dst) } // Key returns the Key at the specified index diff --git a/jwk/ecdsa_gen.go b/jwk/ecdsa_gen.go index 61a1d5340..fcea2f3c7 100644 --- a/jwk/ecdsa_gen.go +++ b/jwk/ecdsa_gen.go @@ -10,6 +10,7 @@ import ( "sort" "sync" + "github.com/lestrrat-go/blackmagic" "github.com/lestrrat-go/iter/mapiter" "github.com/lestrrat-go/jwx/v3/cert" "github.com/lestrrat-go/jwx/v3/internal/base64" @@ -182,71 +183,144 @@ func (h *ecdsaPublicKey) PrivateParams() map[string]interface{} { return h.privateParams } -func (h *ecdsaPublicKey) Get(name string) (interface{}, bool) { +func (h *ecdsaPublicKey) Has(name string) bool { + h.mu.RLock() + defer h.mu.RUnlock() + switch name { + case AlgorithmKey: + return h.algorithm != nil + case ECDSACrvKey: + return h.crv != nil + case KeyIDKey: + return h.keyID != nil + case KeyOpsKey: + return h.keyOps != nil + case KeyUsageKey: + return h.keyUsage != nil + case ECDSAXKey: + return h.x != nil + case X509CertChainKey: + return h.x509CertChain != nil + case X509CertThumbprintKey: + return h.x509CertThumbprint != nil + case X509CertThumbprintS256Key: + return h.x509CertThumbprintS256 != nil + case X509URLKey: + return h.x509URL != nil + case ECDSAYKey: + return h.y != nil + default: + _, ok := h.privateParams[name] + return ok + } +} + +func (h *ecdsaPublicKey) Get(name string, dst interface{}) error { h.mu.RLock() defer h.mu.RUnlock() switch name { case KeyTypeKey: - return h.KeyType(), true + if err := blackmagic.AssignIfCompatible(dst, h.KeyType()); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } case AlgorithmKey: if h.algorithm == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.algorithm)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.algorithm), true + return nil case ECDSACrvKey: if h.crv == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return *(h.crv), true + if err := blackmagic.AssignIfCompatible(dst, *(h.crv)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil case KeyIDKey: if h.keyID == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return *(h.keyID), true + if err := blackmagic.AssignIfCompatible(dst, *(h.keyID)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil case KeyOpsKey: if h.keyOps == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.keyOps)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.keyOps), true + return nil case KeyUsageKey: if h.keyUsage == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.keyUsage)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.keyUsage), true + return nil case ECDSAXKey: if h.x == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return h.x, true + if err := blackmagic.AssignIfCompatible(dst, h.x); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil case X509CertChainKey: if h.x509CertChain == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, h.x509CertChain); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return h.x509CertChain, true + return nil case X509CertThumbprintKey: if h.x509CertThumbprint == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.x509CertThumbprint)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.x509CertThumbprint), true + return nil case X509CertThumbprintS256Key: if h.x509CertThumbprintS256 == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return *(h.x509CertThumbprintS256), true + if err := blackmagic.AssignIfCompatible(dst, *(h.x509CertThumbprintS256)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil case X509URLKey: if h.x509URL == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.x509URL)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.x509URL), true + return nil case ECDSAYKey: if h.y == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, h.y); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return h.y, true + return nil default: v, ok := h.privateParams[name] - return v, ok + if !ok { + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, v); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } } + return nil } func (h *ecdsaPublicKey) Set(name string, value interface{}) error { @@ -753,76 +827,154 @@ func (h *ecdsaPrivateKey) PrivateParams() map[string]interface{} { return h.privateParams } -func (h *ecdsaPrivateKey) Get(name string) (interface{}, bool) { +func (h *ecdsaPrivateKey) Has(name string) bool { + h.mu.RLock() + defer h.mu.RUnlock() + switch name { + case AlgorithmKey: + return h.algorithm != nil + case ECDSACrvKey: + return h.crv != nil + case ECDSADKey: + return h.d != nil + case KeyIDKey: + return h.keyID != nil + case KeyOpsKey: + return h.keyOps != nil + case KeyUsageKey: + return h.keyUsage != nil + case ECDSAXKey: + return h.x != nil + case X509CertChainKey: + return h.x509CertChain != nil + case X509CertThumbprintKey: + return h.x509CertThumbprint != nil + case X509CertThumbprintS256Key: + return h.x509CertThumbprintS256 != nil + case X509URLKey: + return h.x509URL != nil + case ECDSAYKey: + return h.y != nil + default: + _, ok := h.privateParams[name] + return ok + } +} + +func (h *ecdsaPrivateKey) Get(name string, dst interface{}) error { h.mu.RLock() defer h.mu.RUnlock() switch name { case KeyTypeKey: - return h.KeyType(), true + if err := blackmagic.AssignIfCompatible(dst, h.KeyType()); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } case AlgorithmKey: if h.algorithm == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return *(h.algorithm), true + if err := blackmagic.AssignIfCompatible(dst, *(h.algorithm)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil case ECDSACrvKey: if h.crv == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.crv)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.crv), true + return nil case ECDSADKey: if h.d == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, h.d); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return h.d, true + return nil case KeyIDKey: if h.keyID == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.keyID)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.keyID), true + return nil case KeyOpsKey: if h.keyOps == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return *(h.keyOps), true + if err := blackmagic.AssignIfCompatible(dst, *(h.keyOps)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil case KeyUsageKey: if h.keyUsage == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.keyUsage)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.keyUsage), true + return nil case ECDSAXKey: if h.x == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, h.x); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return h.x, true + return nil case X509CertChainKey: if h.x509CertChain == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return h.x509CertChain, true + if err := blackmagic.AssignIfCompatible(dst, h.x509CertChain); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil case X509CertThumbprintKey: if h.x509CertThumbprint == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return *(h.x509CertThumbprint), true + if err := blackmagic.AssignIfCompatible(dst, *(h.x509CertThumbprint)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil case X509CertThumbprintS256Key: if h.x509CertThumbprintS256 == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.x509CertThumbprintS256)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.x509CertThumbprintS256), true + return nil case X509URLKey: if h.x509URL == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.x509URL)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.x509URL), true + return nil case ECDSAYKey: if h.y == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return h.y, true + if err := blackmagic.AssignIfCompatible(dst, h.y); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil default: v, ok := h.privateParams[name] - return v, ok + if !ok { + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, v); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } } + return nil } func (h *ecdsaPrivateKey) Set(name string, value interface{}) error { diff --git a/jwk/headers_test.go b/jwk/headers_test.go index 1ed1d965a..0b4f19711 100644 --- a/jwk/headers_test.go +++ b/jwk/headers_test.go @@ -7,6 +7,7 @@ import ( "github.com/lestrrat-go/jwx/v3/jwa" "github.com/lestrrat-go/jwx/v3/jwk" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestHeader(t *testing.T) { @@ -35,10 +36,8 @@ func TestHeader(t *testing.T) { return } - got, ok := h.Get(k) - if !assert.True(t, ok, "Get works for '%s'", k) { - return - } + var got interface{} + require.NoError(t, h.Get(k, &got), "Get works for '%s'", k) if !assert.Equal(t, v, got, "values match '%s'", k) { return @@ -122,10 +121,8 @@ func TestHeader(t *testing.T) { return } - got, ok := h.Get("alg") - if !assert.True(t, ok, "Get for alg should succeed") { - return - } + var got jwa.KeyAlgorithm + require.NoError(t, h.Get("alg", &got), "Get for alg should succeed") if !assert.Equal(t, value, got, "values match") { return diff --git a/jwk/interface.go b/jwk/interface.go index eaa7b0428..f03c30f39 100644 --- a/jwk/interface.go +++ b/jwk/interface.go @@ -68,7 +68,7 @@ type Set interface { // For the purposes of a key set, any field other than the "keys" field is // considered to be a private field. In other words, you cannot use this // method to directly access the list of keys in the set - Get(string) (interface{}, bool) + Get(string, interface{}) error // Set sets the value of a single field. // diff --git a/jwk/interface_gen.go b/jwk/interface_gen.go index cef61cd10..29b28d5ef 100644 --- a/jwk/interface_gen.go +++ b/jwk/interface_gen.go @@ -27,13 +27,20 @@ const ( // between each key types, so you should use type assertions // to perform more specific tasks with each key type Key interface { - // Get returns the value of a single field. The second boolean return value - // will be false if the field is not stored in the source + + // Has returns true if the specified field has a value, even if + // the value is empty-ish (e.g. 0, false, "") as long as it has been + // explicitly set. + Has(string) bool + + // Get is used to extract the value of any field, including non-standard fields, out of the key. // - // This method, which returns an `interface{}`, exists because - // these objects can contain extra _arbitrary_ fields that users can - // specify, and there is no way of knowing what type they could be - Get(string) (interface{}, bool) + // The first argument is the name of the field. The second argument is a pointer + // to a variable that will receive the value of the field. The method returns + // an error if the field does not exist, or if the value cannot be assigned to + // the destination variable. Note that a field is considered to "exist" even if + // the value is empty-ish (e.g. 0, false, ""), as long as it is explicitly set. + Get(string, interface{}) error // Set sets the value of a single field. Note that certain fields, // notably "kty", cannot be altered, but will not return an error diff --git a/jwk/jwk.go b/jwk/jwk.go index 302b7d799..533a8b341 100644 --- a/jwk/jwk.go +++ b/jwk/jwk.go @@ -578,7 +578,7 @@ func ParseString(s string, options ...ParseOption) (Set, error) { // section of the key, if it already doesn't have one. It uses Key.Thumbprint // method with crypto.SHA256 as the default hashing algorithm func AssignKeyID(key Key, options ...AssignKeyIDOption) error { - if _, ok := key.Get(KeyIDKey); ok { + if key.Has(KeyIDKey) { return nil } @@ -707,15 +707,16 @@ func asnEncode(key Key) (string, []byte, error) { // you want to represent as a string formatted in RFC3339 in JSON, // but want it back as `time.Time`. // -// In that case you would register a custom field as follows +// In such case you would register a custom field as follows // -// jwk.RegisterCustomField(`x-birthday`, timeT) +// jwk.RegisterCustomField(`x-birthday`, time.Time{}) // -// Then `key.Get("x-birthday")` will still return an `interface{}`, -// but you can convert its type to `time.Time` +// Then you can use a `time.Time` variable to extract the value +// of `x-birthday` field, instead of having to use `interface{}` +// and later convert it to `time.Time` // -// bdayif, _ := key.Get(`x-birthday`) -// bday := bdayif.(time.Time) +// var bday time.Time +// _ = key.Get(`x-birthday`, &bday) func RegisterCustomField(name string, object interface{}) { registry.Register(name, object) } diff --git a/jwk/jwk_internal_test.go b/jwk/jwk_internal_test.go index 3019abf00..92320d3b6 100644 --- a/jwk/jwk_internal_test.go +++ b/jwk/jwk_internal_test.go @@ -9,6 +9,7 @@ import ( "github.com/lestrrat-go/jwx/v3/internal/json" "github.com/lestrrat-go/jwx/v3/jwa" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestX509CertChain(t *testing.T) { @@ -34,11 +35,8 @@ func TestX509CertChain(t *testing.T) { return } - v, ok := key.Get(X509CertChainKey) - if !assert.True(t, ok, "Get for x5c should succeed") { - return - } - gotcerts := v.(*cert.Chain) + var gotcerts cert.Chain + require.NoError(t, key.Get(X509CertChainKey, &gotcerts), "Get for x5c should succeed") if !assert.Equal(t, gotcerts.Len(), 3, `should have 3 cert`) { return } @@ -63,10 +61,8 @@ func TestIterator(t *testing.T) { pair := iter.Pair() seen[pair.Key.(string)] = pair.Value - getV, ok := v.Get(pair.Key.(string)) - if !assert.True(t, ok, `v.Get should succeed for key %#v`, pair.Key) { - return - } + var getV interface{} + require.NoError(t, v.Get(pair.Key.(string), &getV), `v.Get should succeed for key %#v`, pair.Key) if !assert.Equal(t, pair.Value, getV, `pair.Value should match value from v.Get()`) { return } diff --git a/jwk/jwk_test.go b/jwk/jwk_test.go index 33623f91f..507f00d04 100644 --- a/jwk/jwk_test.go +++ b/jwk/jwk_test.go @@ -183,10 +183,9 @@ func VerifyKey(t *testing.T, def map[string]keyDef) { k := k kdef := kdef t.Run(k, func(t *testing.T) { - getval, ok := key.Get(k) - if !assert.True(t, ok, `key.Get(%s) should succeed`, k) { - return - } + var getval interface{} + require.NoError(t, key.Get(k, &getval), `key.Get(%s) should succeed`, k) + expected := kdef.Expected if expected == nil { expected = kdef.Value @@ -1404,10 +1403,8 @@ func TestCustomField(t *testing.T) { return } - v, ok := key.Get(`x-birthday`) - if !assert.True(t, ok, `key.Get("x-birthday") should succeed`) { - return - } + var v interface{} + require.NoError(t, key.Get(`x-birthday`, &v), `key.Get("x-birthday") should succeed`) if !assert.Equal(t, expected, v, `values should match`) { return @@ -1541,10 +1538,9 @@ func TestTypedFields(t *testing.T) { return } - v, ok := got.Get("typed-field") - if !assert.True(t, ok, `got.Get() should succeed`) { - return - } + var v interface{} + require.NoError(t, got.Get("typed-field", &v), `got.Get() should succeed`) + field, err := tc.PostProcess(t, v) if !assert.NoError(t, err, `tc.PostProcess should succeed`) { return @@ -1583,10 +1579,8 @@ func TestTypedFields(t *testing.T) { for iter := got.Keys(ctx); iter.Next(ctx); { pair := iter.Pair() key, _ := pair.Value.(jwk.Key) - v, ok := key.Get("typed-field") - if !assert.True(t, ok, `key.Get() should succeed`) { - return - } + var v interface{} + require.NoError(t, key.Get("typed-field", &v), `key.Get() should succeed`) field, err := tc.PostProcess(t, v) if !assert.NoError(t, err, `tc.PostProcess should succeed`) { return @@ -1715,12 +1709,10 @@ func TestSetWithPrivateParams(t *testing.T) { return } - v, ok := set.Get(`renewal_kid`) - if !assert.True(t, ok, `set.Get("renewal_kid") should return ok = true`) { - return - } + var kid string + require.NoError(t, set.Get(`renewal_kid`, &kid), `set.Get("renewal_kid") should succeed`) - if !assert.Equal(t, `foo`, v, `set.Get("renewal_kid") should return "foo"`) { + if !assert.Equal(t, `foo`, kid, `set.Get("renewal_kid") should return "foo"`) { return } @@ -1729,12 +1721,10 @@ func TestSetWithPrivateParams(t *testing.T) { return } - v, ok = key.Get(`renewal_kid`) - if !assert.True(t, ok, `key.Get("renewal_kid") should return ok = true`) { - return - } + kid = "" + require.NoError(t, key.Get(`renewal_kid`, &kid), `key.Get("renewal_kid") should return ok = true`) - if !assert.Equal(t, `foo`, v, `key.Get("renewal_kid") should return "foo"`) { + if !assert.Equal(t, `foo`, kid, `key.Get("renewal_kid") should return "foo"`) { return } } @@ -1775,10 +1765,8 @@ func TestSetWithPrivateParams(t *testing.T) { return } - v, ok := set.Get(`renewal_kid`) - if !assert.True(t, ok, `set.Get("renewal_kid") should return ok = true`) { - return - } + var v interface{} + require.NoError(t, set.Get(`renewal_kid`, &v), `set.Get("renewal_kid") should return ok = true`) if !assert.Equal(t, `foo`, v, `set.Get("renewal_kid") should return "foo"`) { return @@ -1806,10 +1794,8 @@ func TestSetWithPrivateParams(t *testing.T) { return } - v, ok := set.Get(`renewal_kid`) - if !assert.True(t, ok, `set.Get("renewal_kid") should succeed`) { - return - } + var v interface{} + require.NoError(t, set.Get(`renewal_kid`, &v), `set.Get("renewal_kid") should succeed`) if !assert.Equal(t, `foo`, v, `set.Get("renewal_kid") should return "foo"`) { return diff --git a/jwk/okp_gen.go b/jwk/okp_gen.go index b3d75e115..4af18bf30 100644 --- a/jwk/okp_gen.go +++ b/jwk/okp_gen.go @@ -9,6 +9,7 @@ import ( "sort" "sync" + "github.com/lestrrat-go/blackmagic" "github.com/lestrrat-go/iter/mapiter" "github.com/lestrrat-go/jwx/v3/cert" "github.com/lestrrat-go/jwx/v3/internal/base64" @@ -171,66 +172,134 @@ func (h *okpPublicKey) PrivateParams() map[string]interface{} { return h.privateParams } -func (h *okpPublicKey) Get(name string) (interface{}, bool) { +func (h *okpPublicKey) Has(name string) bool { + h.mu.RLock() + defer h.mu.RUnlock() + switch name { + case AlgorithmKey: + return h.algorithm != nil + case OKPCrvKey: + return h.crv != nil + case KeyIDKey: + return h.keyID != nil + case KeyOpsKey: + return h.keyOps != nil + case KeyUsageKey: + return h.keyUsage != nil + case OKPXKey: + return h.x != nil + case X509CertChainKey: + return h.x509CertChain != nil + case X509CertThumbprintKey: + return h.x509CertThumbprint != nil + case X509CertThumbprintS256Key: + return h.x509CertThumbprintS256 != nil + case X509URLKey: + return h.x509URL != nil + default: + _, ok := h.privateParams[name] + return ok + } +} + +func (h *okpPublicKey) Get(name string, dst interface{}) error { h.mu.RLock() defer h.mu.RUnlock() switch name { case KeyTypeKey: - return h.KeyType(), true + if err := blackmagic.AssignIfCompatible(dst, h.KeyType()); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } case AlgorithmKey: if h.algorithm == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.algorithm)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.algorithm), true + return nil case OKPCrvKey: if h.crv == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return *(h.crv), true + if err := blackmagic.AssignIfCompatible(dst, *(h.crv)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil case KeyIDKey: if h.keyID == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return *(h.keyID), true + if err := blackmagic.AssignIfCompatible(dst, *(h.keyID)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil case KeyOpsKey: if h.keyOps == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.keyOps)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.keyOps), true + return nil case KeyUsageKey: if h.keyUsage == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.keyUsage)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.keyUsage), true + return nil case OKPXKey: if h.x == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return h.x, true + if err := blackmagic.AssignIfCompatible(dst, h.x); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil case X509CertChainKey: if h.x509CertChain == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, h.x509CertChain); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return h.x509CertChain, true + return nil case X509CertThumbprintKey: if h.x509CertThumbprint == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.x509CertThumbprint)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.x509CertThumbprint), true + return nil case X509CertThumbprintS256Key: if h.x509CertThumbprintS256 == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.x509CertThumbprintS256)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.x509CertThumbprintS256), true + return nil case X509URLKey: if h.x509URL == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return *(h.x509URL), true + if err := blackmagic.AssignIfCompatible(dst, *(h.x509URL)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil default: v, ok := h.privateParams[name] - return v, ok + if !ok { + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, v); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } } + return nil } func (h *okpPublicKey) Set(name string, value interface{}) error { @@ -712,71 +781,144 @@ func (h *okpPrivateKey) PrivateParams() map[string]interface{} { return h.privateParams } -func (h *okpPrivateKey) Get(name string) (interface{}, bool) { +func (h *okpPrivateKey) Has(name string) bool { + h.mu.RLock() + defer h.mu.RUnlock() + switch name { + case AlgorithmKey: + return h.algorithm != nil + case OKPCrvKey: + return h.crv != nil + case OKPDKey: + return h.d != nil + case KeyIDKey: + return h.keyID != nil + case KeyOpsKey: + return h.keyOps != nil + case KeyUsageKey: + return h.keyUsage != nil + case OKPXKey: + return h.x != nil + case X509CertChainKey: + return h.x509CertChain != nil + case X509CertThumbprintKey: + return h.x509CertThumbprint != nil + case X509CertThumbprintS256Key: + return h.x509CertThumbprintS256 != nil + case X509URLKey: + return h.x509URL != nil + default: + _, ok := h.privateParams[name] + return ok + } +} + +func (h *okpPrivateKey) Get(name string, dst interface{}) error { h.mu.RLock() defer h.mu.RUnlock() switch name { case KeyTypeKey: - return h.KeyType(), true + if err := blackmagic.AssignIfCompatible(dst, h.KeyType()); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } case AlgorithmKey: if h.algorithm == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.algorithm)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.algorithm), true + return nil case OKPCrvKey: if h.crv == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return *(h.crv), true + if err := blackmagic.AssignIfCompatible(dst, *(h.crv)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil case OKPDKey: if h.d == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return h.d, true + if err := blackmagic.AssignIfCompatible(dst, h.d); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil case KeyIDKey: if h.keyID == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.keyID)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.keyID), true + return nil case KeyOpsKey: if h.keyOps == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.keyOps)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.keyOps), true + return nil case KeyUsageKey: if h.keyUsage == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return *(h.keyUsage), true + if err := blackmagic.AssignIfCompatible(dst, *(h.keyUsage)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil case OKPXKey: if h.x == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, h.x); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return h.x, true + return nil case X509CertChainKey: if h.x509CertChain == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, h.x509CertChain); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return h.x509CertChain, true + return nil case X509CertThumbprintKey: if h.x509CertThumbprint == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.x509CertThumbprint)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.x509CertThumbprint), true + return nil case X509CertThumbprintS256Key: if h.x509CertThumbprintS256 == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return *(h.x509CertThumbprintS256), true + if err := blackmagic.AssignIfCompatible(dst, *(h.x509CertThumbprintS256)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil case X509URLKey: if h.x509URL == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return *(h.x509URL), true + if err := blackmagic.AssignIfCompatible(dst, *(h.x509URL)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil default: v, ok := h.privateParams[name] - return v, ok + if !ok { + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, v); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } } + return nil } func (h *okpPrivateKey) Set(name string, value interface{}) error { diff --git a/jwk/refresh_test.go b/jwk/refresh_test.go index fd2e7b7c6..73f76ce04 100644 --- a/jwk/refresh_test.go +++ b/jwk/refresh_test.go @@ -14,6 +14,7 @@ import ( "github.com/lestrrat-go/jwx/v3/internal/jwxtest" "github.com/lestrrat-go/jwx/v3/jwk" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) //nolint:revive,golint @@ -24,10 +25,8 @@ func checkAccessCount(t *testing.T, ctx context.Context, src jwk.Set, expected . iter.Next(ctx) key := iter.Pair().Value.(jwk.Key) - v, ok := key.Get(`accessCount`) - if !assert.True(t, ok, `key.Get("accessCount") should succeed`) { - return false - } + var v float64 + require.NoError(t, key.Get(`accessCount`, &v), `key.Get("accessCount") should succeed`) for _, e := range expected { if v == float64(e) { diff --git a/jwk/rsa_gen.go b/jwk/rsa_gen.go index b06ee210a..e9b4ee67a 100644 --- a/jwk/rsa_gen.go +++ b/jwk/rsa_gen.go @@ -10,6 +10,7 @@ import ( "sort" "sync" + "github.com/lestrrat-go/blackmagic" "github.com/lestrrat-go/iter/mapiter" "github.com/lestrrat-go/jwx/v3/cert" "github.com/lestrrat-go/jwx/v3/internal/base64" @@ -174,66 +175,134 @@ func (h *rsaPublicKey) PrivateParams() map[string]interface{} { return h.privateParams } -func (h *rsaPublicKey) Get(name string) (interface{}, bool) { +func (h *rsaPublicKey) Has(name string) bool { + h.mu.RLock() + defer h.mu.RUnlock() + switch name { + case AlgorithmKey: + return h.algorithm != nil + case RSAEKey: + return h.e != nil + case KeyIDKey: + return h.keyID != nil + case KeyOpsKey: + return h.keyOps != nil + case KeyUsageKey: + return h.keyUsage != nil + case RSANKey: + return h.n != nil + case X509CertChainKey: + return h.x509CertChain != nil + case X509CertThumbprintKey: + return h.x509CertThumbprint != nil + case X509CertThumbprintS256Key: + return h.x509CertThumbprintS256 != nil + case X509URLKey: + return h.x509URL != nil + default: + _, ok := h.privateParams[name] + return ok + } +} + +func (h *rsaPublicKey) Get(name string, dst interface{}) error { h.mu.RLock() defer h.mu.RUnlock() switch name { case KeyTypeKey: - return h.KeyType(), true + if err := blackmagic.AssignIfCompatible(dst, h.KeyType()); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } case AlgorithmKey: if h.algorithm == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.algorithm)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.algorithm), true + return nil case RSAEKey: if h.e == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, h.e); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return h.e, true + return nil case KeyIDKey: if h.keyID == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return *(h.keyID), true + if err := blackmagic.AssignIfCompatible(dst, *(h.keyID)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil case KeyOpsKey: if h.keyOps == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.keyOps)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.keyOps), true + return nil case KeyUsageKey: if h.keyUsage == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.keyUsage)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.keyUsage), true + return nil case RSANKey: if h.n == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, h.n); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return h.n, true + return nil case X509CertChainKey: if h.x509CertChain == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, h.x509CertChain); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return h.x509CertChain, true + return nil case X509CertThumbprintKey: if h.x509CertThumbprint == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return *(h.x509CertThumbprint), true + if err := blackmagic.AssignIfCompatible(dst, *(h.x509CertThumbprint)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil case X509CertThumbprintS256Key: if h.x509CertThumbprintS256 == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.x509CertThumbprintS256)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.x509CertThumbprintS256), true + return nil case X509URLKey: if h.x509URL == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.x509URL)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.x509URL), true + return nil default: v, ok := h.privateParams[name] - return v, ok + if !ok { + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, v); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } } + return nil } func (h *rsaPublicKey) Set(name string, value interface{}) error { @@ -755,96 +824,194 @@ func (h *rsaPrivateKey) PrivateParams() map[string]interface{} { return h.privateParams } -func (h *rsaPrivateKey) Get(name string) (interface{}, bool) { +func (h *rsaPrivateKey) Has(name string) bool { + h.mu.RLock() + defer h.mu.RUnlock() + switch name { + case AlgorithmKey: + return h.algorithm != nil + case RSADKey: + return h.d != nil + case RSADPKey: + return h.dp != nil + case RSADQKey: + return h.dq != nil + case RSAEKey: + return h.e != nil + case KeyIDKey: + return h.keyID != nil + case KeyOpsKey: + return h.keyOps != nil + case KeyUsageKey: + return h.keyUsage != nil + case RSANKey: + return h.n != nil + case RSAPKey: + return h.p != nil + case RSAQKey: + return h.q != nil + case RSAQIKey: + return h.qi != nil + case X509CertChainKey: + return h.x509CertChain != nil + case X509CertThumbprintKey: + return h.x509CertThumbprint != nil + case X509CertThumbprintS256Key: + return h.x509CertThumbprintS256 != nil + case X509URLKey: + return h.x509URL != nil + default: + _, ok := h.privateParams[name] + return ok + } +} + +func (h *rsaPrivateKey) Get(name string, dst interface{}) error { h.mu.RLock() defer h.mu.RUnlock() switch name { case KeyTypeKey: - return h.KeyType(), true + if err := blackmagic.AssignIfCompatible(dst, h.KeyType()); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } case AlgorithmKey: if h.algorithm == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.algorithm)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.algorithm), true + return nil case RSADKey: if h.d == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, h.d); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return h.d, true + return nil case RSADPKey: if h.dp == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return h.dp, true + if err := blackmagic.AssignIfCompatible(dst, h.dp); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil case RSADQKey: if h.dq == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, h.dq); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return h.dq, true + return nil case RSAEKey: if h.e == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, h.e); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return h.e, true + return nil case KeyIDKey: if h.keyID == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return *(h.keyID), true + if err := blackmagic.AssignIfCompatible(dst, *(h.keyID)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil case KeyOpsKey: if h.keyOps == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return *(h.keyOps), true + if err := blackmagic.AssignIfCompatible(dst, *(h.keyOps)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil case KeyUsageKey: if h.keyUsage == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return *(h.keyUsage), true + if err := blackmagic.AssignIfCompatible(dst, *(h.keyUsage)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil case RSANKey: if h.n == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, h.n); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return h.n, true + return nil case RSAPKey: if h.p == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, h.p); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return h.p, true + return nil case RSAQKey: if h.q == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return h.q, true + if err := blackmagic.AssignIfCompatible(dst, h.q); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil case RSAQIKey: if h.qi == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, h.qi); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return h.qi, true + return nil case X509CertChainKey: if h.x509CertChain == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, h.x509CertChain); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return h.x509CertChain, true + return nil case X509CertThumbprintKey: if h.x509CertThumbprint == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return *(h.x509CertThumbprint), true + if err := blackmagic.AssignIfCompatible(dst, *(h.x509CertThumbprint)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil case X509CertThumbprintS256Key: if h.x509CertThumbprintS256 == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.x509CertThumbprintS256)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.x509CertThumbprintS256), true + return nil case X509URLKey: if h.x509URL == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.x509URL)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.x509URL), true + return nil default: v, ok := h.privateParams[name] - return v, ok + if !ok { + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, v); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } } + return nil } func (h *rsaPrivateKey) Set(name string, value interface{}) error { diff --git a/jwk/set.go b/jwk/set.go index 3755d1c0b..e1705882e 100644 --- a/jwk/set.go +++ b/jwk/set.go @@ -6,6 +6,7 @@ import ( "fmt" "sort" + "github.com/lestrrat-go/blackmagic" "github.com/lestrrat-go/iter/arrayiter" "github.com/lestrrat-go/iter/mapiter" "github.com/lestrrat-go/jwx/v3/internal/json" @@ -38,12 +39,18 @@ func (s *set) Set(n string, v interface{}) error { return nil } -func (s *set) Get(n string) (interface{}, bool) { +func (s *set) Get(name string, dst interface{}) error { s.mu.RLock() defer s.mu.RUnlock() - v, ok := s.privateParams[n] - return v, ok + v, ok := s.privateParams[name] + if !ok { + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, v); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil } func (s *set) Key(idx int) (Key, bool) { diff --git a/jwk/symmetric_gen.go b/jwk/symmetric_gen.go index 783339ada..9ccc216ef 100644 --- a/jwk/symmetric_gen.go +++ b/jwk/symmetric_gen.go @@ -9,6 +9,7 @@ import ( "sort" "sync" + "github.com/lestrrat-go/blackmagic" "github.com/lestrrat-go/iter/mapiter" "github.com/lestrrat-go/jwx/v3/cert" "github.com/lestrrat-go/jwx/v3/internal/base64" @@ -157,61 +158,124 @@ func (h *symmetricKey) PrivateParams() map[string]interface{} { return h.privateParams } -func (h *symmetricKey) Get(name string) (interface{}, bool) { +func (h *symmetricKey) Has(name string) bool { + h.mu.RLock() + defer h.mu.RUnlock() + switch name { + case AlgorithmKey: + return h.algorithm != nil + case KeyIDKey: + return h.keyID != nil + case KeyOpsKey: + return h.keyOps != nil + case KeyUsageKey: + return h.keyUsage != nil + case SymmetricOctetsKey: + return h.octets != nil + case X509CertChainKey: + return h.x509CertChain != nil + case X509CertThumbprintKey: + return h.x509CertThumbprint != nil + case X509CertThumbprintS256Key: + return h.x509CertThumbprintS256 != nil + case X509URLKey: + return h.x509URL != nil + default: + _, ok := h.privateParams[name] + return ok + } +} + +func (h *symmetricKey) Get(name string, dst interface{}) error { h.mu.RLock() defer h.mu.RUnlock() switch name { case KeyTypeKey: - return h.KeyType(), true + if err := blackmagic.AssignIfCompatible(dst, h.KeyType()); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } case AlgorithmKey: if h.algorithm == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return *(h.algorithm), true + if err := blackmagic.AssignIfCompatible(dst, *(h.algorithm)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil case KeyIDKey: if h.keyID == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return *(h.keyID), true + if err := blackmagic.AssignIfCompatible(dst, *(h.keyID)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil case KeyOpsKey: if h.keyOps == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.keyOps)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.keyOps), true + return nil case KeyUsageKey: if h.keyUsage == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.keyUsage)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.keyUsage), true + return nil case SymmetricOctetsKey: if h.octets == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, h.octets); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return h.octets, true + return nil case X509CertChainKey: if h.x509CertChain == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return h.x509CertChain, true + if err := blackmagic.AssignIfCompatible(dst, h.x509CertChain); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil case X509CertThumbprintKey: if h.x509CertThumbprint == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return *(h.x509CertThumbprint), true + if err := blackmagic.AssignIfCompatible(dst, *(h.x509CertThumbprint)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil case X509CertThumbprintS256Key: if h.x509CertThumbprintS256 == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.x509CertThumbprintS256)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.x509CertThumbprintS256), true + return nil case X509URLKey: if h.x509URL == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.x509URL)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.x509URL), true + return nil default: v, ok := h.privateParams[name] - return v, ok + if !ok { + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, v); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } } + return nil } func (h *symmetricKey) Set(name string, value interface{}) error { diff --git a/jws/headers_gen.go b/jws/headers_gen.go index 464800308..708bac327 100644 --- a/jws/headers_gen.go +++ b/jws/headers_gen.go @@ -9,6 +9,7 @@ import ( "sort" "sync" + "github.com/lestrrat-go/blackmagic" "github.com/lestrrat-go/jwx/v3/cert" "github.com/lestrrat-go/jwx/v3/internal/base64" "github.com/lestrrat-go/jwx/v3/internal/json" @@ -31,7 +32,13 @@ const ( X509URLKey = "x5u" ) -// Headers describe a standard Header set. +// Headers describe a standard JWS Header set. It is part of the JWS message +// and is used to represet both Public or Protected headers, which in turn +// can be found in each Signature object. If you are not sure how this works, +// it is strongly recommended that you read RFC7515, especially the section +// that describes the full JSON serialization format of JWS messages. +// +// In most cases, you likely want to use the protected headers, as this is part of the signed content. type Headers interface { json.Marshaler json.Unmarshaler @@ -51,9 +58,20 @@ type Headers interface { AsMap(context.Context) (map[string]interface{}, error) Copy(context.Context, Headers) error Merge(context.Context, Headers) (Headers, error) - Get(string) (interface{}, bool) + // Get is used to extract the value of any field, including non-standard fields, out of the header. + // + // The first argument is the name of the field. The second argument is a pointer + // to a variable that will receive the value of the field. The method returns + // an error if the field does not exist, or if the value cannot be assigned to + // the destination variable. Note that a field is considered to "exist" even if + // the value is empty-ish (e.g. 0, false, ""), as long as it is explicitly set. + Get(string, interface{}) error Set(string, interface{}) error Remove(string) error + // Has returns true if the specified header has a value, even if + // the value is empty-ish (e.g. 0, false, "") as long as it has been + // explicitly set. + Has(string) bool // PrivateParams returns the non-standard elements in the source structure // WARNING: DO NOT USE PrivateParams() IF YOU HAVE CONCURRENT CODE ACCESSING THEM. @@ -259,69 +277,143 @@ func (h *stdHeaders) PrivateParams() map[string]interface{} { return h.privateParams } -func (h *stdHeaders) Get(name string) (interface{}, bool) { +func (h *stdHeaders) Has(name string) bool { + h.mu.RLock() + defer h.mu.RUnlock() + switch name { + case AlgorithmKey: + return h.algorithm != nil + case ContentTypeKey: + return h.contentType != nil + case CriticalKey: + return h.critical != nil + case JWKKey: + return h.jwk != nil + case JWKSetURLKey: + return h.jwkSetURL != nil + case KeyIDKey: + return h.keyID != nil + case TypeKey: + return h.typ != nil + case X509CertChainKey: + return h.x509CertChain != nil + case X509CertThumbprintKey: + return h.x509CertThumbprint != nil + case X509CertThumbprintS256Key: + return h.x509CertThumbprintS256 != nil + case X509URLKey: + return h.x509URL != nil + default: + _, ok := h.privateParams[name] + return ok + } +} + +func (h *stdHeaders) Get(name string, dst interface{}) error { h.mu.RLock() defer h.mu.RUnlock() switch name { case AlgorithmKey: if h.algorithm == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.algorithm)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.algorithm), true + return nil case ContentTypeKey: if h.contentType == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return *(h.contentType), true + if err := blackmagic.AssignIfCompatible(dst, *(h.contentType)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil case CriticalKey: if h.critical == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return h.critical, true + if err := blackmagic.AssignIfCompatible(dst, + h.critical); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil case JWKKey: if h.jwk == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, + h.jwk); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return h.jwk, true + return nil case JWKSetURLKey: if h.jwkSetURL == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.jwkSetURL)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.jwkSetURL), true + return nil case KeyIDKey: if h.keyID == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return *(h.keyID), true + if err := blackmagic.AssignIfCompatible(dst, *(h.keyID)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil case TypeKey: if h.typ == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.typ)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.typ), true + return nil case X509CertChainKey: if h.x509CertChain == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, + h.x509CertChain); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return h.x509CertChain, true + return nil case X509CertThumbprintKey: if h.x509CertThumbprint == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, *(h.x509CertThumbprint)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) } - return *(h.x509CertThumbprint), true + return nil case X509CertThumbprintS256Key: if h.x509CertThumbprintS256 == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return *(h.x509CertThumbprintS256), true + if err := blackmagic.AssignIfCompatible(dst, *(h.x509CertThumbprintS256)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil case X509URLKey: if h.x509URL == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - return *(h.x509URL), true + if err := blackmagic.AssignIfCompatible(dst, *(h.x509URL)); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } + return nil default: v, ok := h.privateParams[name] - return v, ok + if !ok { + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, v); err != nil { + return fmt.Errorf(`failed to assign value for field %q: %w`, name, err) + } } + return nil } func (h *stdHeaders) Set(name string, value interface{}) error { diff --git a/jws/headers_test.go b/jws/headers_test.go index 6cb64162e..7ecfa2ff2 100644 --- a/jws/headers_test.go +++ b/jws/headers_test.go @@ -10,6 +10,7 @@ import ( "github.com/lestrrat-go/jwx/v3/jwk" "github.com/lestrrat-go/jwx/v3/jws" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var zeroval reflect.Value @@ -112,10 +113,8 @@ func TestHeader(t *testing.T) { } for _, tc := range data { var values []interface{} - viaGet, ok := h.Get(tc.Key) - if !assert.True(t, ok, "value for %s should exist", tc.Key) { - return - } + var viaGet interface{} + require.NoError(t, h.Get(tc.Key, &viaGet), `h.Get should succeed`) values = append(values, viaGet) if method := tc.Method; method != "" { @@ -176,13 +175,9 @@ func TestHeader(t *testing.T) { pair := iter.Pair() seen[pair.Key.(string)] = pair.Value - getV, ok := v.Get(pair.Key.(string)) - if !assert.True(t, ok, `v.Get should succeed for key %#v`, pair.Key) { - return - } - if !assert.Equal(t, pair.Value, getV, `pair.Value should match value from v.Get()`) { - return - } + var getV interface{} + require.NoError(t, v.Get(pair.Key.(string), &getV), `v.Get should succeed`) + require.Equal(t, pair.Value, getV, `pair.Value should match value from v.Get()`) } if !assert.Equal(t, expected, seen, `values should match`) { return diff --git a/jws/jws.go b/jws/jws.go index 4ee4a4aea..1617b9077 100644 --- a/jws/jws.go +++ b/jws/jws.go @@ -416,15 +416,11 @@ func Verify(buf []byte, options ...VerifyOption) ([]byte, error) { // If the field does not exist, returns true (default) // Otherwise return the value specified by the header field. func getB64Value(hdr Headers) bool { - b64raw, ok := hdr.Get("b64") - if !ok { + var b64 bool + if err := hdr.Get("b64", &b64); err != nil { return true // default } - b64, ok := b64raw.(bool) // default - if !ok { - return false - } return b64 } @@ -679,15 +675,16 @@ func parse(protected, payload, signature []byte) (*Message, error) { // you want to represent as a string formatted in RFC3339 in JSON, // but want it back as `time.Time`. // -// In that case you would register a custom field as follows +// In such case you would register a custom field as follows // -// jwe.RegisterCustomField(`x-birthday`, timeT) +// jws.RegisterCustomField(`x-birthday`, time.Time{}) // -// Then `hdr.Get("x-birthday")` will still return an `interface{}`, -// but you can convert its type to `time.Time` +// Then you can use a `time.Time` variable to extract the value +// of `x-birthday` field, instead of having to use `interface{}` +// and later convert it to `time.Time` // -// bdayif, _ := hdr.Get(`x-birthday`) -// bday := bdayif.(time.Time) +// var bday time.Time +// _ = hdr.Get(`x-birthday`, &bday) func RegisterCustomField(name string, object interface{}) { registry.Register(name, object) } diff --git a/jws/jws_test.go b/jws/jws_test.go index bd5a7ea60..0838d0da9 100644 --- a/jws/jws_test.go +++ b/jws/jws_test.go @@ -1284,10 +1284,8 @@ func TestCustomField(t *testing.T) { return } - v, ok := msg.Signatures()[0].ProtectedHeaders().Get(`x-birthday`) - if !assert.True(t, ok, `msg.Signatures()[0].ProtectedHeaders().Get("x-birthday") should succeed`) { - return - } + var v interface{} + require.NoError(t, msg.Signatures()[0].ProtectedHeaders().Get(`x-birthday`, &v), `msg.Signatures()[0].ProtectedHeaders().Get("x-birthday") should succeed`) if !assert.Equal(t, expected, v, `values should match`) { return @@ -1304,10 +1302,8 @@ func TestCustomField(t *testing.T) { return } - v, ok = msg2.Signatures()[0].ProtectedHeaders().Get(`x-birthday`) - if !assert.True(t, ok, `msg2.Signatures()[0].ProtectedHeaders().Get("x-birthday") should succeed`) { - return - } + v = nil + require.NoError(t, msg2.Signatures()[0].ProtectedHeaders().Get(`x-birthday`, &v), `msg2.Signatures()[0].ProtectedHeaders().Get("x-birthday") should succeed`) if !assert.Equal(t, expected, v, `values should match`) { return diff --git a/jwt/BUILD.bazel b/jwt/BUILD.bazel index 2bb84f4b7..612649416 100644 --- a/jwt/BUILD.bazel +++ b/jwt/BUILD.bazel @@ -30,6 +30,7 @@ go_library( "//jws", "//jwt/internal/types", "@com_github_lestrrat_go_iter//mapiter:go_default_library", + "@com_github_lestrrat_go_blackmagic//:go_default_library", "@com_github_lestrrat_go_option//:option", ], ) diff --git a/jwt/internal/types/BUILD.bazel b/jwt/internal/types/BUILD.bazel index 66710f119..4a99af17a 100644 --- a/jwt/internal/types/BUILD.bazel +++ b/jwt/internal/types/BUILD.bazel @@ -22,6 +22,7 @@ go_test( "//internal/json", "//jwt", "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//require", ], ) diff --git a/jwt/internal/types/date_test.go b/jwt/internal/types/date_test.go index 390f70399..ac11def77 100644 --- a/jwt/internal/types/date_test.go +++ b/jwt/internal/types/date_test.go @@ -10,6 +10,7 @@ import ( "github.com/lestrrat-go/jwx/v3/jwt" "github.com/lestrrat-go/jwx/v3/jwt/internal/types" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestDate(t *testing.T) { @@ -100,12 +101,10 @@ func TestDate(t *testing.T) { if !assert.NoError(t, err) { return } - v, ok := t1.Get(jwt.IssuedAtKey) - if !assert.True(t, ok) { - return - } - realized := v.(time.Time) - if !assert.Equal(t, tc.Expected, realized) { + var v time.Time + require.NoError(t, t1.Get(jwt.IssuedAtKey, &v), `t1.Get should succeed`) + + if !assert.Equal(t, tc.Expected, v) { return } }) diff --git a/jwt/jwt.go b/jwt/jwt.go index fee0a136b..7f3f95cab 100644 --- a/jwt/jwt.go +++ b/jwt/jwt.go @@ -456,15 +456,16 @@ func (t *stdToken) Clone() (Token, error) { // you want to represent as a string formatted in RFC3339 in JSON, // but want it back as `time.Time`. // -// In that case you would register a custom field as follows +// In such case you would register a custom field as follows // -// jwt.RegisterCustomField(`x-birthday`, timeT) +// jwt.RegisterCustomField(`x-birthday`, time.Time) // -// Then `token.Get("x-birthday")` will still return an `interface{}`, -// but you can convert its type to `time.Time` +// Then you can use a `time.Time` variable to extract the value +// of `x-birthday` field, instead of having to use `interface{}` +// and later convert it to `time.Time` // -// bdayif, _ := token.Get(`x-birthday`) -// bday := bdayif.(time.Time) +// var bday time.Time +// _ = token.Get(`x-birthday`, &bday) func RegisterCustomField(name string, object interface{}) { registry.Register(name, object) } diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go index bde6a7082..b256145c7 100644 --- a/jwt/jwt_test.go +++ b/jwt/jwt_test.go @@ -611,13 +611,11 @@ func TestUnmarshalJSON(t *testing.T) { if !assert.NoError(t, json.Unmarshal([]byte(`{"aud":["foo", "bar", "baz"]}`), &t1), `jwt.Parse should succeed`) { return } - aud, ok := t1.Get(jwt.AudienceKey) - if !assert.True(t, ok, `jwt.Get(jwt.AudienceKey) should succeed`) { - t.Logf("%#v", t1) - return - } - if !assert.Equal(t, aud.([]string), []string{"foo", "bar", "baz"}, "audience should match. got %v", aud) { + var aud []string + require.NoError(t, t1.Get(jwt.AudienceKey, &aud), `jwt.Get(jwt.AudienceKey) should succeed`) + + if !assert.Equal(t, aud, []string{"foo", "bar", "baz"}, "audience should match. got %v", aud) { return } }) @@ -766,10 +764,8 @@ func TestCustomField(t *testing.T) { return } - v, ok := token.Get(`x-birthday`) - if !assert.True(t, ok, `token.Get("x-birthday") should succeed`) { - return - } + var v time.Time + require.NoError(t, token.Get(`x-birthday`, &v), `token.Get("x-birthday") should succeed`) if !assert.Equal(t, expected, v, `values should match`) { return @@ -781,10 +777,8 @@ func TestCustomField(t *testing.T) { return } - v, ok := token.Get(`x-birthday`) - if !assert.True(t, ok, `token.Get("x-birthday") should succeed`) { - return - } + var v time.Time + require.NoError(t, token.Get(`x-birthday`, &v), `token.Get("x-birthday") should succeed`) if !assert.Equal(t, expected, v, `values should match`) { return @@ -1186,10 +1180,9 @@ func TestJWTParseWithTypedClaim(t *testing.T) { return } - v, ok := got.Get("typed-claim") - if !assert.True(t, ok, `got.Get() should succeed`) { - return - } + var v interface{} + require.NoError(t, got.Get("typed-claim", &v), `got.Get() should succeed`) + claim, err := tc.PostProcess(t, v) if !assert.NoError(t, err, `tc.PostProcess should succeed`) { return diff --git a/jwt/openid/BUILD.bazel b/jwt/openid/BUILD.bazel index 70e6b17fb..ddc328829 100644 --- a/jwt/openid/BUILD.bazel +++ b/jwt/openid/BUILD.bazel @@ -20,6 +20,7 @@ go_library( "//jwt", "//jwt/internal/types", "@com_github_lestrrat_go_iter//mapiter:go_default_library", + "@com_github_lestrrat_go_blackmagic//:go_default_library", ], ) diff --git a/jwt/openid/openid_test.go b/jwt/openid/openid_test.go index f194748a8..0b76eaead 100644 --- a/jwt/openid/openid_test.go +++ b/jwt/openid/openid_test.go @@ -122,8 +122,8 @@ func TestAdressClaim(t *testing.T) { func TestOpenIDClaims(t *testing.T) { getVerify := func(token openid.Token, key string, expected interface{}) bool { - v, ok := token.Get(key) - if !assert.True(t, ok, `token.Get %#v should succeed`, key) { + var v interface{} + if assert.NoError(t, token.Get(key, &v), `token.Get %#v should succeed`, key) { return false } return assert.Equal(t, v, expected) @@ -378,10 +378,8 @@ func TestOpenIDClaims(t *testing.T) { Value: `dummy`, Key: `dummy`, Check: func(token openid.Token) { - v, ok := token.Get(`dummy`) - if !assert.True(t, ok, `token.Get should return valid value`) { - return - } + var v interface{} + require.NoError(t, token.Get(`dummy`, &v), `token.Get should return valid value`) if !assert.Equal(t, `dummy`, v, `values should match`) { return } @@ -500,10 +498,8 @@ func TestOpenIDClaims(t *testing.T) { pair := iter.Pair() seen[pair.Key.(string)] = pair.Value - getV, ok := v.Get(pair.Key.(string)) - if !assert.True(t, ok, `v.Get should succeed for key %#v`, pair.Key) { - return - } + var getV interface{} + require.NoError(t, v.Get(pair.Key.(string), &getV), `v.Get should succeed for key %#v`, pair.Key) if !assert.Equal(t, pair.Value, getV, `pair.Value should match value from v.Get()`) { return } diff --git a/jwt/openid/token_gen.go b/jwt/openid/token_gen.go index d7172986a..7a74c5425 100644 --- a/jwt/openid/token_gen.go +++ b/jwt/openid/token_gen.go @@ -10,6 +10,7 @@ import ( "sync" "time" + "github.com/lestrrat-go/blackmagic" "github.com/lestrrat-go/iter/mapiter" "github.com/lestrrat-go/jwx/v3/internal/base64" "github.com/lestrrat-go/jwx/v3/internal/iter" @@ -132,21 +133,32 @@ type Token interface { // *other* than the pre-defined fields such as `iss`, `nbf`, `iat`, etc. PrivateClaims() map[string]interface{} - // Get returns the value of the corresponding field in the token, such as - // `nbf`, `exp`, `iat`, and other user-defined fields. If the field does not - // exist in the token, the second return value will be `false` + // Get is used to extract the value of any claim, including non-standard claims, out of the token. // - // If you need to access fields like `alg`, `kid`, `jku`, etc, you need - // to access the corresponding fields in the JWS/JWE message. For this, - // you will need to access them by directly parsing the payload using - // `jws.Parse` and `jwe.Parse` - Get(string) (interface{}, bool) + // The first argument is the name of the claim. The second argument is a pointer + // to a variable that will receive the value of the claim. The method returns + // an error if the claim does not exist, or if the value cannot be assigned to + // the destination variable. Note that a field is considered to "exist" even if + // the value is empty-ish (e.g. 0, false, ""), as long as it is explicitly set. + // + // For standard claims, you can use the corresponding getter method, such as + // `Issuer()`, `Subject()`, `Audience()`, `IssuedAt()`, `NotBefore()`, `ExpiresAt()` + // + // Note that fields of JWS/JWE are NOT accessible through this method. You need + // to use `jws.Parse` and `jwe.Parse` to obtain the JWS/JWE message (and NOT + // the payload, which presumably is the JWT), and then use their `Get` methods in their respective packages + Get(string, interface{}) error // Set assigns a value to the corresponding field in the token. Some // pre-defined fields such as `nbf`, `iat`, `iss` need their values to // be of a specific type. See the other getter methods in this interface // for the types of each of these fields Set(string, interface{}) error + + // Has returns true if the specified claim has a value, even if + // the value is empty-ish (e.g. 0, false, "") as long as it has been + // explicitly set. + Has(string) bool Remove(string) error // Options returns the per-token options associated with this token. @@ -207,169 +219,289 @@ func (t *stdToken) Options() *jwt.TokenOptionSet { return &t.options } -func (t *stdToken) Get(name string) (interface{}, bool) { +func (t *stdToken) Has(name string) bool { + t.mu.RLock() + defer t.mu.RUnlock() + switch name { + case AddressKey: + return t.address != nil + case AudienceKey: + return t.audience != nil + case BirthdateKey: + return t.birthdate != nil + case EmailKey: + return t.email != nil + case EmailVerifiedKey: + return t.emailVerified != nil + case ExpirationKey: + return t.expiration != nil + case FamilyNameKey: + return t.familyName != nil + case GenderKey: + return t.gender != nil + case GivenNameKey: + return t.givenName != nil + case IssuedAtKey: + return t.issuedAt != nil + case IssuerKey: + return t.issuer != nil + case JwtIDKey: + return t.jwtID != nil + case LocaleKey: + return t.locale != nil + case MiddleNameKey: + return t.middleName != nil + case NameKey: + return t.name != nil + case NicknameKey: + return t.nickname != nil + case NotBeforeKey: + return t.notBefore != nil + case PhoneNumberKey: + return t.phoneNumber != nil + case PhoneNumberVerifiedKey: + return t.phoneNumberVerified != nil + case PictureKey: + return t.picture != nil + case PreferredUsernameKey: + return t.preferredUsername != nil + case ProfileKey: + return t.profile != nil + case SubjectKey: + return t.subject != nil + case UpdatedAtKey: + return t.updatedAt != nil + case WebsiteKey: + return t.website != nil + case ZoneinfoKey: + return t.zoneinfo != nil + default: + _, ok := t.privateClaims[name] + return ok + } +} + +func (t *stdToken) Get(name string, dst interface{}) error { t.mu.RLock() defer t.mu.RUnlock() switch name { case AddressKey: if t.address == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := t.address - return v, true + if err := blackmagic.AssignIfCompatible(dst, t.address); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil case AudienceKey: if t.audience == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := t.audience.Get() - return v, true + if err := blackmagic.AssignIfCompatible(dst, t.audience.Get()); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil case BirthdateKey: if t.birthdate == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := t.birthdate - return v, true + if err := blackmagic.AssignIfCompatible(dst, t.birthdate); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil case EmailKey: if t.email == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := *(t.email) - return v, true + if err := blackmagic.AssignIfCompatible(dst, *(t.email)); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil case EmailVerifiedKey: if t.emailVerified == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := *(t.emailVerified) - return v, true + if err := blackmagic.AssignIfCompatible(dst, *(t.emailVerified)); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil case ExpirationKey: if t.expiration == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := t.expiration.Get() - return v, true + if err := blackmagic.AssignIfCompatible(dst, t.expiration.Get()); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil case FamilyNameKey: if t.familyName == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := *(t.familyName) - return v, true + if err := blackmagic.AssignIfCompatible(dst, *(t.familyName)); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil case GenderKey: if t.gender == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := *(t.gender) - return v, true + if err := blackmagic.AssignIfCompatible(dst, *(t.gender)); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil case GivenNameKey: if t.givenName == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := *(t.givenName) - return v, true + if err := blackmagic.AssignIfCompatible(dst, *(t.givenName)); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil case IssuedAtKey: if t.issuedAt == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := t.issuedAt.Get() - return v, true + if err := blackmagic.AssignIfCompatible(dst, t.issuedAt.Get()); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil case IssuerKey: if t.issuer == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := *(t.issuer) - return v, true + if err := blackmagic.AssignIfCompatible(dst, *(t.issuer)); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil case JwtIDKey: if t.jwtID == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := *(t.jwtID) - return v, true + if err := blackmagic.AssignIfCompatible(dst, *(t.jwtID)); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil case LocaleKey: if t.locale == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := *(t.locale) - return v, true + if err := blackmagic.AssignIfCompatible(dst, *(t.locale)); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil case MiddleNameKey: if t.middleName == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := *(t.middleName) - return v, true + if err := blackmagic.AssignIfCompatible(dst, *(t.middleName)); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil case NameKey: if t.name == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := *(t.name) - return v, true + if err := blackmagic.AssignIfCompatible(dst, *(t.name)); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil case NicknameKey: if t.nickname == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := *(t.nickname) - return v, true + if err := blackmagic.AssignIfCompatible(dst, *(t.nickname)); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil case NotBeforeKey: if t.notBefore == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := t.notBefore.Get() - return v, true + if err := blackmagic.AssignIfCompatible(dst, t.notBefore.Get()); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil case PhoneNumberKey: if t.phoneNumber == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := *(t.phoneNumber) - return v, true + if err := blackmagic.AssignIfCompatible(dst, *(t.phoneNumber)); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil case PhoneNumberVerifiedKey: if t.phoneNumberVerified == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := *(t.phoneNumberVerified) - return v, true + if err := blackmagic.AssignIfCompatible(dst, *(t.phoneNumberVerified)); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil case PictureKey: if t.picture == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := *(t.picture) - return v, true + if err := blackmagic.AssignIfCompatible(dst, *(t.picture)); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil case PreferredUsernameKey: if t.preferredUsername == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := *(t.preferredUsername) - return v, true + if err := blackmagic.AssignIfCompatible(dst, *(t.preferredUsername)); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil case ProfileKey: if t.profile == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := *(t.profile) - return v, true + if err := blackmagic.AssignIfCompatible(dst, *(t.profile)); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil case SubjectKey: if t.subject == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := *(t.subject) - return v, true + if err := blackmagic.AssignIfCompatible(dst, *(t.subject)); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil case UpdatedAtKey: if t.updatedAt == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := t.updatedAt.Get() - return v, true + if err := blackmagic.AssignIfCompatible(dst, t.updatedAt.Get()); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil case WebsiteKey: if t.website == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := *(t.website) - return v, true + if err := blackmagic.AssignIfCompatible(dst, *(t.website)); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil case ZoneinfoKey: if t.zoneinfo == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := *(t.zoneinfo) - return v, true + if err := blackmagic.AssignIfCompatible(dst, *(t.zoneinfo)); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil default: v, ok := t.privateClaims[name] - return v, ok + if !ok { + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, v); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil } } diff --git a/jwt/serialize.go b/jwt/serialize.go index 7f8464251..f661752fd 100644 --- a/jwt/serialize.go +++ b/jwt/serialize.go @@ -98,8 +98,9 @@ func (jsonSerializer) Serialize(_ SerializeCtx, v interface{}) (interface{}, err } type genericHeader interface { - Get(string) (interface{}, bool) + Get(string, interface{}) error Set(string, interface{}) error + Has(string) bool } func setTypeOrCty(ctx SerializeCtx, hdrs genericHeader) error { @@ -110,7 +111,7 @@ func setTypeOrCty(ctx SerializeCtx, hdrs genericHeader) error { if ctx.Step() == 1 { // We are executed immediately after json marshaling - if _, ok := hdrs.Get(typKey); !ok { + if !hdrs.Has(typKey) { if err := hdrs.Set(typKey, `JWT`); err != nil { return fmt.Errorf(`failed to set %s key to "JWT": %w`, typKey, err) } @@ -149,11 +150,10 @@ func (s *jwsSerializer) Serialize(ctx SerializeCtx, v interface{}) (interface{}, // JWTs MUST NOT use b64 = false // https://datatracker.ietf.org/doc/html/rfc7797#section-7 - if v, ok := hdrs.Get("b64"); ok { - if bval, bok := v.(bool); bok { - if !bval { // b64 = false - return nil, fmt.Errorf(`b64 cannot be false for JWTs`) - } + var b64 bool + if err := hdrs.Get("b64", &b64); err == nil { + if !b64 { // b64 = false + return nil, fmt.Errorf(`b64 cannot be false for JWTs`) } } } diff --git a/jwt/token_gen.go b/jwt/token_gen.go index 601a8628c..bcf16c261 100644 --- a/jwt/token_gen.go +++ b/jwt/token_gen.go @@ -10,6 +10,7 @@ import ( "sync" "time" + "github.com/lestrrat-go/blackmagic" "github.com/lestrrat-go/iter/mapiter" "github.com/lestrrat-go/jwx/v3/internal/base64" "github.com/lestrrat-go/jwx/v3/internal/iter" @@ -67,21 +68,32 @@ type Token interface { // *other* than the pre-defined fields such as `iss`, `nbf`, `iat`, etc. PrivateClaims() map[string]interface{} - // Get returns the value of the corresponding field in the token, such as - // `nbf`, `exp`, `iat`, and other user-defined fields. If the field does not - // exist in the token, the second return value will be `false` + // Get is used to extract the value of any claim, including non-standard claims, out of the token. // - // If you need to access fields like `alg`, `kid`, `jku`, etc, you need - // to access the corresponding fields in the JWS/JWE message. For this, - // you will need to access them by directly parsing the payload using - // `jws.Parse` and `jwe.Parse` - Get(string) (interface{}, bool) + // The first argument is the name of the claim. The second argument is a pointer + // to a variable that will receive the value of the claim. The method returns + // an error if the claim does not exist, or if the value cannot be assigned to + // the destination variable. Note that a field is considered to "exist" even if + // the value is empty-ish (e.g. 0, false, ""), as long as it is explicitly set. + // + // For standard claims, you can use the corresponding getter method, such as + // `Issuer()`, `Subject()`, `Audience()`, `IssuedAt()`, `NotBefore()`, `ExpiresAt()` + // + // Note that fields of JWS/JWE are NOT accessible through this method. You need + // to use `jws.Parse` and `jwe.Parse` to obtain the JWS/JWE message (and NOT + // the payload, which presumably is the JWT), and then use their `Get` methods in their respective packages + Get(string, interface{}) error // Set assigns a value to the corresponding field in the token. Some // pre-defined fields such as `nbf`, `iat`, `iss` need their values to // be of a specific type. See the other getter methods in this interface // for the types of each of these fields Set(string, interface{}) error + + // Has returns true if the specified claim has a value, even if + // the value is empty-ish (e.g. 0, false, "") as long as it has been + // explicitly set. + Has(string) bool Remove(string) error // Options returns the per-token options associated with this token. @@ -123,55 +135,99 @@ func (t *stdToken) Options() *TokenOptionSet { return &t.options } -func (t *stdToken) Get(name string) (interface{}, bool) { +func (t *stdToken) Has(name string) bool { + t.mu.RLock() + defer t.mu.RUnlock() + switch name { + case AudienceKey: + return t.audience != nil + case ExpirationKey: + return t.expiration != nil + case IssuedAtKey: + return t.issuedAt != nil + case IssuerKey: + return t.issuer != nil + case JwtIDKey: + return t.jwtID != nil + case NotBeforeKey: + return t.notBefore != nil + case SubjectKey: + return t.subject != nil + default: + _, ok := t.privateClaims[name] + return ok + } +} + +func (t *stdToken) Get(name string, dst interface{}) error { t.mu.RLock() defer t.mu.RUnlock() switch name { case AudienceKey: if t.audience == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := t.audience.Get() - return v, true + if err := blackmagic.AssignIfCompatible(dst, t.audience.Get()); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil case ExpirationKey: if t.expiration == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := t.expiration.Get() - return v, true + if err := blackmagic.AssignIfCompatible(dst, t.expiration.Get()); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil case IssuedAtKey: if t.issuedAt == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := t.issuedAt.Get() - return v, true + if err := blackmagic.AssignIfCompatible(dst, t.issuedAt.Get()); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil case IssuerKey: if t.issuer == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := *(t.issuer) - return v, true + if err := blackmagic.AssignIfCompatible(dst, *(t.issuer)); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil case JwtIDKey: if t.jwtID == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := *(t.jwtID) - return v, true + if err := blackmagic.AssignIfCompatible(dst, *(t.jwtID)); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil case NotBeforeKey: if t.notBefore == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := t.notBefore.Get() - return v, true + if err := blackmagic.AssignIfCompatible(dst, t.notBefore.Get()); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil case SubjectKey: if t.subject == nil { - return nil, false + return fmt.Errorf(`field %q not found`, name) } - v := *(t.subject) - return v, true + if err := blackmagic.AssignIfCompatible(dst, *(t.subject)); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil default: v, ok := t.privateClaims[name] - return v, ok + if !ok { + return fmt.Errorf(`field %q not found`, name) + } + if err := blackmagic.AssignIfCompatible(dst, v); err != nil { + return fmt.Errorf(`failed to assign value to dst: %w`, err) + } + return nil } } diff --git a/jwt/token_test.go b/jwt/token_test.go index b7b30cd8e..8762d4766 100644 --- a/jwt/token_test.go +++ b/jwt/token_test.go @@ -10,6 +10,7 @@ import ( "github.com/lestrrat-go/jwx/v3/jwt" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) const ( @@ -38,10 +39,8 @@ func TestHeader(t *testing.T) { if !assert.NoError(t, h.Set(k, v), `h.Set should succeed for key %#v`, k) { return } - got, ok := h.Get(k) - if !assert.True(t, ok, `h.Get should succeed for key %#v`, k) { - return - } + var got interface{} + require.NoError(t, h.Get(k, &got), `h.Get should succeed for key %#v`, k) if !reflect.DeepEqual(v, got) { t.Fatalf("Values do not match: (%v, %v)", v, got) } @@ -77,16 +76,11 @@ func TestHeader(t *testing.T) { if err != nil { t.Fatalf("Setting %s value failed", "default") } + var tmp interface{} for k := range values { - _, ok := h.Get(k) - if ok { - t.Fatalf("Getting %s value should have failed", k) - } - } - _, ok := h.Get("default") - if !ok { - t.Fatal("Failed to get default value") + require.Error(t, h.Get(k, &tmp), `Getting %s value should have failed`) } + require.NoError(t, h.Get("default", &tmp), `Getting %s value should have succeeded`) }) t.Run("GetError", func(t *testing.T) { @@ -209,10 +203,8 @@ func TestToken(t *testing.T) { t.Run("Get", func(t *testing.T) { rv := reflect.ValueOf(tok) for k, kdef := range def { - getval, ok := tok.Get(k) - if !assert.True(t, ok, `tok.Get(%s) should succeed`, k) { - return - } + var getval interface{} + require.NoError(t, tok.Get(k, &getval), `tok.Get(%s) should succeed`, k) if mname := kdef.Method; mname != "" { method := rv.MethodByName(mname) diff --git a/jwt/validate.go b/jwt/validate.go index db2a65959..54f32289c 100644 --- a/jwt/validate.go +++ b/jwt/validate.go @@ -464,14 +464,9 @@ func IsValidationError(err error) bool { } func (ccs claimContainsString) Validate(_ context.Context, t Token) ValidationError { - v, ok := t.Get(ccs.name) - if !ok { - return ccs.makeErr(fmt.Errorf(`claim %q not found`, ccs.name)) - } - - list, ok := v.([]string) - if !ok { - return ccs.makeErr(fmt.Errorf(`claim %q must be a []string (got %T)`, ccs.name, v)) + var list []string + if err := t.Get(ccs.name, &list); err != nil { + return ccs.makeErr(fmt.Errorf(`claim %q does not exist or is not a []string: %w`, ccs.name, err)) } for _, v := range list { @@ -515,9 +510,9 @@ func ClaimValueIs(name string, value interface{}) Validator { } func (cv *claimValueIs) Validate(_ context.Context, t Token) ValidationError { - v, ok := t.Get(cv.name) - if !ok { - return cv.makeErr(fmt.Errorf(`%q not satisfied: claim %q does not exist`, cv.name, cv.name)) + var v interface{} + if err := t.Get(cv.name, &v); err != nil { + return cv.makeErr(fmt.Errorf(`%[1]q not satisfied: claim %[1]q does not exist or is not a []string: %[2]w`, cv.name, err)) } if v != cv.value { return cv.makeErr(fmt.Errorf(`%q not satisfied: values do not match`, cv.name)) @@ -549,8 +544,7 @@ type isRequired string func (ir isRequired) Validate(_ context.Context, t Token) ValidationError { name := string(ir) - _, ok := t.Get(name) - if !ok { + if !t.Has(name) { return &missingRequiredClaimError{claim: name} } return nil diff --git a/jwt/validate_test.go b/jwt/validate_test.go index 5e285fcc5..a6ac005f4 100644 --- a/jwt/validate_test.go +++ b/jwt/validate_test.go @@ -519,8 +519,7 @@ func TestClaimValidator(t *testing.T) { const myClaim = "my-claim" err0 := errors.New(myClaim + " does not exist") v := jwt.ValidatorFunc(func(_ context.Context, tok jwt.Token) jwt.ValidationError { - _, ok := tok.Get(myClaim) - if !ok { + if !tok.Has(myClaim) { return jwt.NewValidationError(err0) } return nil diff --git a/tools/cmd/genjwe/main.go b/tools/cmd/genjwe/main.go index aa120a126..6e52fd3a6 100644 --- a/tools/cmd/genjwe/main.go +++ b/tools/cmd/genjwe/main.go @@ -87,7 +87,14 @@ func generateHeaders(obj *codegen.Object) error { } o.L(")") // end const - o.LL("// Headers describe a standard Header set.") + o.LL("// Headers describe a standard JWE Header set. It is part of the JWE message") + o.L("// and is used to represent both Protected and Unprotected headers,") + o.L("// which in turn can be found in each Recipient object.") + o.L("// If you are not sure how this works, it is strongly recommended that") + o.L("// you read RFC7516, especially the section") + o.L("// that describes the full JSON serialization format of JWE messages.") + o.L("//") + o.L("// In most cases, you likely want to use the protected headers, as this is the part of the encrypted content") o.L("type Headers interface {") o.L("json.Marshaler") o.L("json.Unmarshaler") @@ -102,9 +109,20 @@ func generateHeaders(obj *codegen.Object) error { o.L("AsMap(ctx context.Context) (map[string]interface{}, error)") // These are used to access a single element by key name - o.L("Get(string) (interface{}, bool)") + o.LL("// Get is used to extract the value of any field, including non-standard fields, out of the header.") + o.L("//") + o.L("// The first argument is the name of the field. The second argument is a pointer") + o.L("// to a variable that will receive the value of the field. The method returns") + o.L("// an error if the field does not exist, or if the value cannot be assigned to") + o.L("// the destination variable. Note that a field is considered to \"exist\" even if") + o.L("// the value is empty-ish (e.g. 0, false, \"\"), as long as it is explicitly set.") + o.L("Get(string, interface{}) error") o.L("Set(string, interface{}) error") o.L("Remove(string) error") + o.L("// Has returns true if the specified header has a value, even if") + o.L("// the value is empty-ish (e.g. 0, false, \"\") as long as it has been") + o.L("// explicitly set.") + o.L("Has(string) bool") // These are used to deal with encoded headers o.L("Encode() ([]byte, error)") @@ -185,25 +203,50 @@ func generateHeaders(obj *codegen.Object) error { o.L("return h.privateParams") o.L("}") - o.LL("func (h *stdHeaders) Get(name string) (interface{}, bool) {") + o.LL("func (h *stdHeaders) Has(name string) bool {") + o.L("h.mu.RLock()") + o.L("defer h.mu.RUnlock()") + o.L("switch name {") + for _, f := range obj.Fields() { + o.L("case %sKey:", f.Name(true)) + o.L("return h.%s != nil", f.Name(false)) + } + o.L("default:") + o.L("_, ok := h.privateParams[name]") + o.L("return ok") + o.L("}") + o.L("}") + + o.LL("func (h *stdHeaders) Get(name string, dst interface{}) error {") o.L("h.mu.RLock()") o.L("defer h.mu.RUnlock()") o.L("switch name {") for _, f := range obj.Fields() { o.L("case %sKey:", f.Name(true)) o.L("if h.%s == nil {", f.Name(false)) - o.L("return nil, false") + o.L("return fmt.Errorf(`field %%q not found`, name)") o.L("}") + + o.L("if err := blackmagic.AssignIfCompatible(dst, ") if fieldStorageTypeIsIndirect(f.Type()) { - o.L("return *(h.%s), true", f.Name(false)) + o.R("*(h.%s)", f.Name(false)) } else { - o.L("return h.%s, true", f.Name(false)) + o.R("h.%s", f.Name(false)) } + o.R("); err != nil {") + o.L("return fmt.Errorf(`failed to assign value for field %%q: %%w`, name, err)") + o.L("}") } o.L("default:") o.L("v, ok := h.privateParams[name]") - o.L("return v, ok") + o.L("if !ok {") + o.L("return fmt.Errorf(`field %%q not found`, name)") + o.L("}") + o.L("if err := blackmagic.AssignIfCompatible(dst, v); err != nil {") + o.L("return fmt.Errorf(`failed to assign value for field %%q: %%w`, name, err)") + o.L("}") o.L("}") // end switch name + o.L("return nil") o.L("}") // func (h *stdHeaders) Get(name string) (interface{}, bool) o.LL("func (h *stdHeaders) Set(name string, value interface{}) error {") diff --git a/tools/cmd/genjwk/main.go b/tools/cmd/genjwk/main.go index aa6bdbc5e..f783b0dca 100644 --- a/tools/cmd/genjwk/main.go +++ b/tools/cmd/genjwk/main.go @@ -273,12 +273,32 @@ func generateObject(o *codegen.Output, kt *KeyType, obj *codegen.Object) error { o.L("return h.privateParams") o.L("}") - o.LL("func (h *%s) Get(name string) (interface{}, bool) {", structName) + o.LL("func (h *%s) Has(name string) bool {", structName) + o.L("h.mu.RLock()") + o.L("defer h.mu.RUnlock()") + o.L("switch name {") + for _, f := range obj.Fields() { + if f.Bool(`is_std`) { + o.L("case %sKey:", f.Name(true)) + } else { + o.L("case %s%sKey:", kt.Prefix, f.Name(true)) + } + o.L("return h.%s != nil", f.Name(false)) + } + o.L("default:") + o.L("_, ok := h.privateParams[name]") + o.L("return ok") + o.L("}") + o.L("}") + + o.LL("func (h *%s) Get(name string, dst interface{}) error {", structName) o.L("h.mu.RLock()") o.L("defer h.mu.RUnlock()") o.L("switch name {") o.L("case KeyTypeKey:") - o.L("return h.KeyType(), true") + o.L("if err := blackmagic.AssignIfCompatible(dst, h.KeyType()); err != nil {") + o.L("return fmt.Errorf(`failed to assign value for field %%q: %%w`, name, err)") + o.L("}") for _, f := range obj.Fields() { if f.Bool(`is_std`) { o.L("case %sKey:", f.Name(true)) @@ -287,20 +307,31 @@ func generateObject(o *codegen.Output, kt *KeyType, obj *codegen.Object) error { } o.L("if h.%s == nil {", f.Name(false)) - o.L("return nil, false") + o.L("return fmt.Errorf(`field %%q not found`, name)") o.L("}") + o.L("if err := blackmagic.AssignIfCompatible(dst, ") if f.Bool(`hasGet`) { - o.L("return h.%s.Get(), true", f.Name(false)) + o.R("h.%s.Get()", f.Name(false)) } else if fieldStorageTypeIsIndirect(f.Type()) { - o.L("return *(h.%s), true", f.Name(false)) + o.R("*(h.%s)", f.Name(false)) } else { - o.L("return h.%s, true", f.Name(false)) + o.R("h.%s", f.Name(false)) } + o.R("); err != nil {") + o.L("return fmt.Errorf(`failed to assign value for field %%q: %%w`, name, err)") + o.L("}") + o.L("return nil") } o.L("default:") o.L("v, ok := h.privateParams[name]") - o.L("return v, ok") + o.L("if !ok {") + o.L("return fmt.Errorf(`field %%q not found`, name)") + o.L("}") + o.L("if err := blackmagic.AssignIfCompatible(dst, v); err != nil {") + o.L("return fmt.Errorf(`failed to assign value for field %%q: %%w`, name, err)") + o.L("}") o.L("}") // end switch name + o.L("return nil") o.L("}") // func (h *%s) Get(name string) (interface{}, bool) o.LL("func (h *%s) Set(name string, value interface{}) error {", structName) @@ -624,12 +655,18 @@ func generateGenericHeaders(fields codegen.FieldList) error { o.L("// between each key types, so you should use type assertions") o.L("// to perform more specific tasks with each key") o.L("type Key interface {") - o.L("// Get returns the value of a single field. The second boolean return value") - o.L("// will be false if the field is not stored in the source") - o.L("//\n// This method, which returns an `interface{}`, exists because") - o.L("// these objects can contain extra _arbitrary_ fields that users can") - o.L("// specify, and there is no way of knowing what type they could be") - o.L("Get(string) (interface{}, bool)") + o.LL("// Has returns true if the specified field has a value, even if") + o.L("// the value is empty-ish (e.g. 0, false, \"\") as long as it has been") + o.L("// explicitly set.") + o.L("Has(string) bool") + o.LL("// Get is used to extract the value of any field, including non-standard fields, out of the key.") + o.L("//") + o.L("// The first argument is the name of the field. The second argument is a pointer") + o.L("// to a variable that will receive the value of the field. The method returns") + o.L("// an error if the field does not exist, or if the value cannot be assigned to") + o.L("// the destination variable. Note that a field is considered to \"exist\" even if") + o.L("// the value is empty-ish (e.g. 0, false, \"\"), as long as it is explicitly set.") + o.L("Get(string, interface{}) error") o.LL("// Set sets the value of a single field. Note that certain fields,") o.L("// notably \"kty\", cannot be altered, but will not return an error") o.L("//\n// This method, which takes an `interface{}`, exists because") diff --git a/tools/cmd/genjws/main.go b/tools/cmd/genjws/main.go index 66df83133..253688713 100644 --- a/tools/cmd/genjws/main.go +++ b/tools/cmd/genjws/main.go @@ -86,7 +86,13 @@ func generateHeaders(obj *codegen.Object) error { } o.L(")") // end const - o.LL("// Headers describe a standard Header set.") + o.LL("// Headers describe a standard JWS Header set. It is part of the JWS message") + o.L("// and is used to represet both Public or Protected headers, which in turn") + o.L("// can be found in each Signature object. If you are not sure how this works,") + o.L("// it is strongly recommended that you read RFC7515, especially the section") + o.L("// that describes the full JSON serialization format of JWS messages.") + o.L("//") + o.L("// In most cases, you likely want to use the protected headers, as this is part of the signed content.") o.L("type Headers interface {") o.L("json.Marshaler") o.L("json.Unmarshaler") @@ -107,9 +113,20 @@ func generateHeaders(obj *codegen.Object) error { o.L("Merge(context.Context, Headers) (Headers, error)") // These are used to access a single element by key name - o.L("Get(string) (interface{}, bool)") + o.L("// Get is used to extract the value of any field, including non-standard fields, out of the header.") + o.L("//") + o.L("// The first argument is the name of the field. The second argument is a pointer") + o.L("// to a variable that will receive the value of the field. The method returns") + o.L("// an error if the field does not exist, or if the value cannot be assigned to") + o.L("// the destination variable. Note that a field is considered to \"exist\" even if") + o.L("// the value is empty-ish (e.g. 0, false, \"\"), as long as it is explicitly set.") + o.L("Get(string, interface{}) error") o.L("Set(string, interface{}) error") o.L("Remove(string) error") + o.L("// Has returns true if the specified header has a value, even if") + o.L("// the value is empty-ish (e.g. 0, false, \"\") as long as it has been") + o.L("// explicitly set.") + o.L("Has(string) bool") o.LL("// PrivateParams returns the non-standard elements in the source structure") o.L("// WARNING: DO NOT USE PrivateParams() IF YOU HAVE CONCURRENT CODE ACCESSING THEM.") @@ -208,25 +225,50 @@ func generateHeaders(obj *codegen.Object) error { o.L("return h.privateParams") o.L("}") - o.LL("func (h *stdHeaders) Get(name string) (interface{}, bool) {") + o.LL("func (h *stdHeaders) Has(name string) bool {") + o.L("h.mu.RLock()") + o.L("defer h.mu.RUnlock()") + o.L("switch name {") + for _, f := range obj.Fields() { + o.L("case %sKey:", f.Name(true)) + o.L("return h.%s != nil", f.Name(false)) + } + o.L("default:") + o.L("_, ok := h.privateParams[name]") + o.L("return ok") + o.L("}") + o.L("}") + + o.LL("func (h *stdHeaders) Get(name string, dst interface{}) error {") o.L("h.mu.RLock()") o.L("defer h.mu.RUnlock()") o.L("switch name {") for _, f := range obj.Fields() { o.L("case %sKey:", f.Name(true)) o.L("if h.%s == nil {", f.Name(false)) - o.L("return nil, false") + o.L("return fmt.Errorf(`field %%q not found`, name)") o.L("}") + o.L("if err := blackmagic.AssignIfCompatible(dst, ") if fieldStorageTypeIsIndirect(f.Type()) { - o.L("return *(h.%s), true", f.Name(false)) + o.R("*(h.%s)", f.Name(false)) } else { - o.L("return h.%s, true", f.Name(false)) + o.L("h.%s", f.Name(false)) } + o.R("); err != nil {") + o.L("return fmt.Errorf(`failed to assign value for field %%q: %%w`, name, err)") + o.L("}") + o.L("return nil") } o.L("default:") o.L("v, ok := h.privateParams[name]") - o.L("return v, ok") + o.L("if !ok {") + o.L("return fmt.Errorf(`field %%q not found`, name)") + o.L("}") + o.L("if err := blackmagic.AssignIfCompatible(dst, v); err != nil {") + o.L("return fmt.Errorf(`failed to assign value for field %%q: %%w`, name, err)") + o.L("}") o.L("}") // end switch name + o.L("return nil") o.L("}") // func (h *stdHeaders) Get(name string) (interface{}, bool) o.LL("func (h *stdHeaders) Set(name string, value interface{}) error {") diff --git a/tools/cmd/genjwt/main.go b/tools/cmd/genjwt/main.go index 2baa553e5..594ada6a7 100644 --- a/tools/cmd/genjwt/main.go +++ b/tools/cmd/genjwt/main.go @@ -139,21 +139,33 @@ func generateToken(obj *codegen.Object) error { o.LL("// PrivateClaims return the entire set of fields (claims) in the token") o.L("// *other* than the pre-defined fields such as `iss`, `nbf`, `iat`, etc.") o.L("PrivateClaims() map[string]interface{}") - o.LL("// Get returns the value of the corresponding field in the token, such as") - o.L("// `nbf`, `exp`, `iat`, and other user-defined fields. If the field does not") - o.L("// exist in the token, the second return value will be `false`") + o.LL("// Get is used to extract the value of any claim, including non-standard claims, out of the token.") o.L("//") - o.L("// If you need to access fields like `alg`, `kid`, `jku`, etc, you need") - o.L("// to access the corresponding fields in the JWS/JWE message. For this,") - o.L("// you will need to access them by directly parsing the payload using") - o.L("// `jws.Parse` and `jwe.Parse`") - o.L("Get(string) (interface{}, bool)") + o.L("// The first argument is the name of the claim. The second argument is a pointer") + o.L("// to a variable that will receive the value of the claim. The method returns") + o.L("// an error if the claim does not exist, or if the value cannot be assigned to") + o.L("// the destination variable. Note that a field is considered to \"exist\" even if") + o.L("// the value is empty-ish (e.g. 0, false, \"\"), as long as it is explicitly set.") + o.L("//") + o.L("// For standard claims, you can use the corresponding getter method, such as") + o.L("// `Issuer()`, `Subject()`, `Audience()`, `IssuedAt()`, `NotBefore()`, `ExpiresAt()`") + o.L("//") + o.L("// Note that fields of JWS/JWE are NOT accessible through this method. You need") + o.L("// to use `jws.Parse` and `jwe.Parse` to obtain the JWS/JWE message (and NOT") + o.L("// the payload, which presumably is the JWT), and then use their `Get` methods in their respective packages") + o.L("Get(string, interface{}) error") o.LL("// Set assigns a value to the corresponding field in the token. Some") o.L("// pre-defined fields such as `nbf`, `iat`, `iss` need their values to") o.L("// be of a specific type. See the other getter methods in this interface") o.L("// for the types of each of these fields") o.L("Set(string, interface{}) error") + + o.LL("// Has returns true if the specified claim has a value, even if") + o.L("// the value is empty-ish (e.g. 0, false, \"\") as long as it has been") + o.L("// explicitly set.") + o.L("Has(string) bool") + o.L("Remove(string) error") var pkgPrefix string @@ -211,29 +223,53 @@ func generateToken(obj *codegen.Object) error { o.L("return &t.options") o.L("}") - o.LL("func (t *%s) Get(name string) (interface{}, bool) {", obj.Name(false)) + o.LL("func (t *%s) Has(name string) bool {", obj.Name(false)) + o.L("t.mu.RLock()") + o.L("defer t.mu.RUnlock()") + o.L("switch name {") + for _, f := range obj.Fields() { + o.L("case %sKey:", f.Name(true)) + o.L("return t.%s != nil", f.Name(false)) + } + o.L("default:") + o.L("_, ok := t.privateClaims[name]") + o.L("return ok") + o.L("}") + o.L("}") + + o.LL("func (t *%s) Get(name string, dst interface{}) error {", obj.Name(false)) o.L("t.mu.RLock()") o.L("defer t.mu.RUnlock()") o.L("switch name {") for _, f := range fields { o.L("case %sKey:", f.Name(true)) o.L("if t.%s == nil {", f.Name(false)) - o.L("return nil, false") + o.L("return fmt.Errorf(`field %%q not found`, name)") o.L("}") + o.L("if err := blackmagic.AssignIfCompatible(dst, ") if f.Bool(`hasGet`) { - o.L("v := t.%s.Get()", f.Name(false)) + o.R("t.%s.Get()", f.Name(false)) } else { if fieldStorageTypeIsIndirect(f.Type()) { - o.L("v := *(t.%s)", f.Name(false)) + o.R("*(t.%s)", f.Name(false)) } else { - o.L("v := t.%s", f.Name(false)) + o.R("t.%s", f.Name(false)) } } - o.L("return v, true") + o.R("); err != nil {") + o.L("return fmt.Errorf(`failed to assign value to dst: %%w`, err)") + o.L("}") + o.L("return nil") } o.L("default:") o.L("v, ok := t.privateClaims[name]") - o.L("return v, ok") + o.L("if !ok {") + o.L("return fmt.Errorf(`field %%q not found`, name)") + o.L("}") + o.L("if err := blackmagic.AssignIfCompatible(dst, v); err != nil {") + o.L("return fmt.Errorf(`failed to assign value to dst: %%w`, err)") + o.L("}") + o.L("return nil") o.L("}") // end switch name o.L("}") // end of Get