From 72066951fd1fc139f270033ea6538563b837147c Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Tue, 21 Nov 2023 21:26:05 +0900 Subject: [PATCH] Streamline exporting from jwk.Key to raw key Remove Raw() from keys, and implement jwk.Export --- Changes-v3.md | 10 +- examples/jwk_example_test.go | 2 +- .../jwx_register_ec_and_key_example_test.go | 11 +- internal/jwxtest/jwxtest.go | 10 +- internal/keyconv/keyconv.go | 14 +- jwe/internal/keyenc/keyenc_test.go | 4 +- jwe/jwe.go | 8 +- jwe/jwe_test.go | 8 +- jwk/convert.go | 41 ++++-- jwk/ecdsa.go | 28 +--- jwk/interface_gen.go | 16 -- jwk/jwk.go | 16 +- jwk/jwk_test.go | 26 ++-- jwk/okp.go | 35 +++-- jwk/rsa.go | 138 ++++++++++-------- jwk/symmetric.go | 25 +++- jws/jws_test.go | 10 +- jwx_test.go | 8 +- tools/cmd/genjwk/main.go | 13 -- 19 files changed, 213 insertions(+), 210 deletions(-) diff --git a/Changes-v3.md b/Changes-v3.md index 466a7859f..87ff46b7f 100644 --- a/Changes-v3.md +++ b/Changes-v3.md @@ -8,7 +8,7 @@ These are changes that are incompatible with the v2.x.x version. ## Module -* This module now requires Go 1.20.x +* This module now requires Go 1.21 * All `xxx.Get()` methods have been changed from `Get(string) (interface{}, error)` to `Get(string, interface{}) error`, where the second argument should be a pointer @@ -42,7 +42,9 @@ These are changes that are incompatible with the v2.x.x version. type to instantiate, and aids implementing your own `jwk.KeyParser`. Also see `jwk.RegisterKeyProbe()` -* Conversion between raw keys and `jwk.Key` can be customized using `jwk.KeyConverter`. - Also see `jwk.RegisterKeyConverter()` +* Conversion between raw keys and `jwk.Key` can be customized using `jwk.KeyImporter` and `jwk.KeyExporter`. + Also see `jwk.RegisterKeyImporter()` and `jwk.RegisterKeyExporter()` -* Added `jwk/ecdsa` to keep track of which curves are available for ECDSA keys. \ No newline at end of file +* Added `jwk/ecdsa` to keep track of which curves are available for ECDSA keys. + +* `(jwk.Key).Raw()` has been deprecated. Use `jwk.Export()` instead. diff --git a/examples/jwk_example_test.go b/examples/jwk_example_test.go index 96332056e..ec38bfac4 100644 --- a/examples/jwk_example_test.go +++ b/examples/jwk_example_test.go @@ -37,7 +37,7 @@ func ExampleJWK_Usage() { // jws and jwe operations can be performed using jwk.Key, but you could also // covert it to their "raw" forms, such as *rsa.PrivateKey or *ecdsa.PrivateKey - if err := key.Raw(&rawkey); err != nil { + if err := jwk.Export(key, &rawkey); err != nil { log.Printf("failed to create public key: %s", err) return } diff --git a/examples/jwx_register_ec_and_key_example_test.go b/examples/jwx_register_ec_and_key_example_test.go index 9086365c6..e03779566 100644 --- a/examples/jwx_register_ec_and_key_example_test.go +++ b/examples/jwx_register_ec_and_key_example_test.go @@ -53,7 +53,10 @@ func convertShangMiSm2(key interface{}) (jwk.Key, error) { } func convertJWKToShangMiSm2(key jwk.Key, hint interface{}) (interface{}, error) { - ecdsaKey := key.(jwk.ECDSAPrivateKey) + ecdsaKey, ok := key.(jwk.ECDSAPrivateKey) + if !ok { + return nil, fmt.Errorf(`invalid key type %T: %w`, key, jwk.ContinueError()) + } if ecdsaKey.Crv() != SM2 { return nil, fmt.Errorf(`cannot convert curve of type %s to ShangMi key: %w`, ecdsaKey.Crv(), jwk.ContinueError()) } @@ -87,7 +90,7 @@ func ExampleShangMiSm2() { { // Create a ShangMi SM2 private key back from the jwk.Key var clone sm2.PrivateKey - if err := shangmi2JWK.Raw(&clone); err != nil { + if err := jwk.Export(shangmi2JWK, &clone); err != nil { fmt.Printf("failed to create ShangMi private key from jwk.Key: %s\n", err) return } @@ -116,7 +119,7 @@ func ExampleShangMiSm2() { { // Can do the same thing for interface{} var clone interface{} - if err := shangmi2JWK.Raw(&clone); err != nil { + if err := jwk.Export(shangmi2JWK, &clone); err != nil { fmt.Printf("failed to create ShangMi private key from jwk.Key (via interface{}): %s\n", err) return } @@ -135,7 +138,7 @@ func ExampleShangMiSm2() { return } var clone ecdsa.PrivateKey - if err := eckjwk.Raw(&clone); err != nil { + if err := jwk.Export(eckjwk, &clone); err != nil { fmt.Printf("failed to create ShangMi public key from jwk.Key: %s\n", err) return } diff --git a/internal/jwxtest/jwxtest.go b/internal/jwxtest/jwxtest.go index 11ea59399..594b00246 100644 --- a/internal/jwxtest/jwxtest.go +++ b/internal/jwxtest/jwxtest.go @@ -267,7 +267,7 @@ func DecryptJweFile(ctx context.Context, file string, alg jwa.KeyEncryptionAlgor } var rawkey interface{} - if err := key.Raw(&rawkey); err != nil { + if err := jwk.Export(key, &rawkey); err != nil { return nil, fmt.Errorf(`failed to obtain raw key from JWK: %w`, err) } @@ -285,19 +285,19 @@ func EncryptJweFile(ctx context.Context, payload []byte, keyalg jwa.KeyEncryptio switch keyalg { case jwa.RSA1_5, jwa.RSA_OAEP, jwa.RSA_OAEP_256: var rawkey rsa.PrivateKey - if err := key.Raw(&rawkey); err != nil { + if err := jwk.Export(key, &rawkey); err != nil { return "", nil, fmt.Errorf(`failed to obtain raw key: %w`, err) } keyif = rawkey.PublicKey case jwa.ECDH_ES, jwa.ECDH_ES_A128KW, jwa.ECDH_ES_A192KW, jwa.ECDH_ES_A256KW: var rawkey ecdsa.PrivateKey - if err := key.Raw(&rawkey); err != nil { + if err := jwk.Export(key, &rawkey); err != nil { return "", nil, fmt.Errorf(`failed to obtain raw key: %w`, err) } keyif = rawkey.PublicKey default: var rawkey []byte - if err := key.Raw(&rawkey); err != nil { + if err := jwk.Export(key, &rawkey); err != nil { return "", nil, fmt.Errorf(`failed to obtain raw key: %w`, err) } keyif = rawkey @@ -323,7 +323,7 @@ func VerifyJwsFile(ctx context.Context, file string, alg jwa.SignatureAlgorithm, } var rawkey, pubkey interface{} - if err := key.Raw(&rawkey); err != nil { + if err := jwk.Export(key, &rawkey); err != nil { return nil, fmt.Errorf(`failed to obtain raw key from JWK: %w`, err) } pubkey = rawkey diff --git a/internal/keyconv/keyconv.go b/internal/keyconv/keyconv.go index a8b291a2b..044ca49bc 100644 --- a/internal/keyconv/keyconv.go +++ b/internal/keyconv/keyconv.go @@ -17,7 +17,7 @@ import ( func RSAPrivateKey(dst, src interface{}) error { if jwkKey, ok := src.(jwk.Key); ok { var raw rsa.PrivateKey - if err := jwkKey.Raw(&raw); err != nil { + if err := jwk.Export(jwkKey, &raw); err != nil { return fmt.Errorf(`failed to produce rsa.PrivateKey from %T: %w`, src, err) } src = &raw @@ -42,7 +42,7 @@ func RSAPrivateKey(dst, src interface{}) error { func RSAPublicKey(dst, src interface{}) error { if jwkKey, ok := src.(jwk.Key); ok { var raw rsa.PublicKey - if err := jwkKey.Raw(&raw); err != nil { + if err := jwk.Export(jwkKey, &raw); err != nil { return fmt.Errorf(`failed to produce rsa.PublicKey from %T: %w`, src, err) } src = &raw @@ -66,7 +66,7 @@ func RSAPublicKey(dst, src interface{}) error { func ECDSAPrivateKey(dst, src interface{}) error { if jwkKey, ok := src.(jwk.Key); ok { var raw ecdsa.PrivateKey - if err := jwkKey.Raw(&raw); err != nil { + if err := jwk.Export(jwkKey, &raw); err != nil { return fmt.Errorf(`failed to produce ecdsa.PrivateKey from %T: %w`, src, err) } src = &raw @@ -89,7 +89,7 @@ func ECDSAPrivateKey(dst, src interface{}) error { func ECDSAPublicKey(dst, src interface{}) error { if jwkKey, ok := src.(jwk.Key); ok { var raw ecdsa.PublicKey - if err := jwkKey.Raw(&raw); err != nil { + if err := jwk.Export(jwkKey, &raw); err != nil { return fmt.Errorf(`failed to produce ecdsa.PublicKey from %T: %w`, src, err) } src = &raw @@ -110,7 +110,7 @@ func ECDSAPublicKey(dst, src interface{}) error { func ByteSliceKey(dst, src interface{}) error { if jwkKey, ok := src.(jwk.Key); ok { var raw []byte - if err := jwkKey.Raw(&raw); err != nil { + if err := jwk.Export(jwkKey, &raw); err != nil { return fmt.Errorf(`failed to produce []byte from %T: %w`, src, err) } src = raw @@ -125,7 +125,7 @@ func ByteSliceKey(dst, src interface{}) error { func Ed25519PrivateKey(dst, src interface{}) error { if jwkKey, ok := src.(jwk.Key); ok { var raw ed25519.PrivateKey - if err := jwkKey.Raw(&raw); err != nil { + if err := jwk.Export(jwkKey, &raw); err != nil { return fmt.Errorf(`failed to produce ed25519.PrivateKey from %T: %w`, src, err) } src = &raw @@ -146,7 +146,7 @@ func Ed25519PrivateKey(dst, src interface{}) error { func Ed25519PublicKey(dst, src interface{}) error { if jwkKey, ok := src.(jwk.Key); ok { var raw ed25519.PublicKey - if err := jwkKey.Raw(&raw); err != nil { + if err := jwk.Export(jwkKey, &raw); err != nil { return fmt.Errorf(`failed to produce ed25519.PublicKey from %T: %w`, src, err) } src = &raw diff --git a/jwe/internal/keyenc/keyenc_test.go b/jwe/internal/keyenc/keyenc_test.go index 808d27331..396a257c1 100644 --- a/jwe/internal/keyenc/keyenc_test.go +++ b/jwe/internal/keyenc/keyenc_test.go @@ -101,7 +101,7 @@ func TestDeriveECDHES(t *testing.T) { if !assert.NoError(t, err, `jwk.ParseKey should succeed`) { return } - if !assert.NoError(t, aliceWebKey.Raw(&aliceKey), `aliceWebKey.Raw should succeed`) { + if !assert.NoError(t, jwk.Export(aliceWebKey, &aliceKey), `jwk.Export(aliceWebKey) should succeed`) { return } @@ -109,7 +109,7 @@ func TestDeriveECDHES(t *testing.T) { if !assert.NoError(t, err, `jwk.ParseKey should succeed`) { return } - if !assert.NoError(t, bobWebKey.Raw(&bobKey), `bobWebKey.Raw should succeed`) { + if !assert.NoError(t, jwk.Export(bobWebKey, &bobKey), `jwk.Export(bobWebKey) should succeed`) { return } diff --git a/jwe/jwe.go b/jwe/jwe.go index c63abf724..f7a161e1c 100644 --- a/jwe/jwe.go +++ b/jwe/jwe.go @@ -76,7 +76,7 @@ func (b *recipientBuilder) Build(cek []byte, calg jwa.ContentEncryptionAlgorithm keyID = jwkKey.KeyID() var raw interface{} - if err := jwkKey.Raw(&raw); err != nil { + if err := jwk.Export(jwkKey, &raw); err != nil { return nil, nil, fmt.Errorf(`failed to retrieve raw key out of %T: %w`, b.key, err) } @@ -572,7 +572,7 @@ func (dctx *decryptCtx) try(ctx context.Context, recipient Recipient, keyUsed in func (dctx *decryptCtx) decryptContent(alg jwa.KeyEncryptionAlgorithm, key interface{}, recipient Recipient) ([]byte, error) { if jwkKey, ok := key.(jwk.Key); ok { var raw interface{} - if err := jwkKey.Raw(&raw); err != nil { + if err := jwk.Export(jwkKey, &raw); err != nil { return nil, fmt.Errorf(`failed to retrieve raw key from %T: %w`, key, err) } key = raw @@ -608,13 +608,13 @@ func (dctx *decryptCtx) decryptContent(alg jwa.KeyEncryptionAlgorithm, key inter switch epk := epk.(type) { case jwk.ECDSAPublicKey: var pubkey ecdsa.PublicKey - if err := epk.Raw(&pubkey); err != nil { + if err := jwk.Export(epk, &pubkey); err != nil { return nil, fmt.Errorf(`failed to get public key: %w`, err) } dec.PublicKey(&pubkey) case jwk.OKPPublicKey: var pubkey interface{} - if err := epk.Raw(&pubkey); err != nil { + if err := jwk.Export(epk, &pubkey); err != nil { return nil, fmt.Errorf(`failed to get public key: %w`, err) } dec.PublicKey(pubkey) diff --git a/jwe/jwe_test.go b/jwe/jwe_test.go index 53ba2065a..e3b2f7982 100644 --- a/jwe/jwe_test.go +++ b/jwe/jwe_test.go @@ -50,7 +50,7 @@ func init() { panic(err) } - if err := privkey.Raw(&rsaPrivKey); err != nil { + if err := jwk.Export(privkey, &rsaPrivKey); err != nil { panic(err) } } @@ -168,7 +168,7 @@ func TestParse_RSAES_OAEP_AES_GCM(t *testing.T) { } var rawkey rsa.PrivateKey - if !assert.NoError(t, privkey.Raw(&rawkey), `obtaining raw key should succeed`) { + if !assert.NoError(t, jwk.Export(privkey, &rawkey), `obtaining raw key should succeed`) { return } @@ -501,7 +501,7 @@ func Test_GHIssue207(t *testing.T) { } var key ecdsa.PrivateKey - if !assert.NoError(t, webKey.Raw(&key), `jwk.Raw should succeed`) { + if !assert.NoError(t, jwk.Export(webKey, &key), `jwk.Export should succeed`) { return } @@ -628,7 +628,7 @@ func TestDecodePredefined_Direct(t *testing.T) { } var key []byte - if !assert.NoError(t, webKey.Raw(&key), `jwk.Raw should succeed`) { + if !assert.NoError(t, jwk.Export(webKey, &key), `jwk.Export should succeed`) { return } diff --git a/jwk/convert.go b/jwk/convert.go index 3c9a39950..517b10cf0 100644 --- a/jwk/convert.go +++ b/jwk/convert.go @@ -21,15 +21,15 @@ import ( var keyImporters = make(map[reflect.Type]KeyImporter) var keyExporters = make(map[jwa.KeyType][]KeyExporter) -var myKeyImporters sync.RWMutex +var muKeyImporters sync.RWMutex var muKeyExporters sync.RWMutex // RegisterKeyImporter registers a KeyImporter for the given raw key. When `jwk.FromRaw()` is called, // the library will look up the appropriate KeyImporter for the given raw key type (via `reflect`) // and execute the KeyImporters in succession until either one of them succeeds, or all of them fail. func RegisterKeyImporter(from interface{}, conv KeyImporter) { - myKeyImporters.Lock() - defer myKeyImporters.Unlock() + muKeyImporters.Lock() + defer muKeyImporters.Unlock() keyImporters[reflect.TypeOf(from)] = conv } @@ -254,32 +254,43 @@ func bytesToKey(src interface{}) (Key, error) { return k, nil } -// All objects call this method to convert themselves to a raw key. -// It's done this way to centralize the logic (mapping) of which keys are converted -// to what raw key. -func raw(key Key, dst interface{}) error { - myKeyImporters.RLock() - defer myKeyImporters.RUnlock() +// Export converts a `jwk.Key` to a Export key. The dst argument must be a pointer to the +// object that the user wants the result to be assigned to. +// +// Normally you would pass a pointer to the zero value of the raw key type +// such as &(*rsa.PrivateKey) or &(*ecdsa.PublicKey), which gets assigned +// the converted key. +// +// If you do not know the exact type of a jwk.Key before attempting +// to obtain the raw key, you can simply pass a pointer to an +// empty interface as the second argument +// +// If you already know the exact type, it is recommended that you +// pass a pointer to the zero value of the actual key type for efficiency. +func Export(key Key, dst interface{}) error { // dst better be a pointer rv := reflect.ValueOf(dst) if rv.Kind() != reflect.Ptr { - return fmt.Errorf(`destination object must be a pointer`) + return fmt.Errorf(`jwk.Export: destination object must be a pointer`) } - if convs, ok := keyExporters[key.KeyType()]; ok { - for _, conv := range convs { + muKeyExporters.RLock() + exporters, ok := keyExporters[key.KeyType()] + muKeyExporters.RUnlock() + if ok { + for _, conv := range exporters { v, err := conv.Export(key, dst) if err != nil { if IsContinueError(err) { continue } - return fmt.Errorf(`failed to convert jwk.Key to raw format: %w`, err) + return fmt.Errorf(`jwk.Export: failed to export jwk.Key to raw format: %w`, err) } if err := blackmagic.AssignIfCompatible(dst, v); err != nil { - return fmt.Errorf(`failed to assign key: %w`, err) + return fmt.Errorf(`jwk.Export: failed to assign key: %w`, err) } return nil } } - return fmt.Errorf(`failed to find converter for key type '%T'`, key) + return fmt.Errorf(`jwk.Export: failed to find exporter for key type '%T'`, key) } diff --git a/jwk/ecdsa.go b/jwk/ecdsa.go index 6abf8b5bc..6ecec962d 100644 --- a/jwk/ecdsa.go +++ b/jwk/ecdsa.go @@ -7,7 +7,6 @@ import ( "fmt" "math/big" - "github.com/lestrrat-go/blackmagic" "github.com/lestrrat-go/jwx/v3/internal/base64" "github.com/lestrrat-go/jwx/v3/internal/ecutil" "github.com/lestrrat-go/jwx/v3/jwa" @@ -19,7 +18,7 @@ func init() { ourecdsa.RegisterCurve(jwa.P384, elliptic.P384()) ourecdsa.RegisterCurve(jwa.P521, elliptic.P521()) - RegisterKeyExporter(jwa.EC, KeyExportFunc(ecdsaPrivateJWKToRaw)) + RegisterKeyExporter(jwa.EC, KeyExportFunc(ecdsaJWKToRaw)) } func (k *ecdsaPublicKey) FromRaw(rawKey *ecdsa.PublicKey) error { @@ -103,24 +102,7 @@ func buildECDSAPublicKey(alg jwa.EllipticCurveAlgorithm, xbuf, ybuf []byte) (*ec return &ecdsa.PublicKey{Curve: crv, X: &x, Y: &y}, nil } -// Raw returns the EC-DSA public key represented by this JWK -func (k *ecdsaPublicKey) Raw(v interface{}) error { - k.mu.RLock() - defer k.mu.RUnlock() - - pubk, err := buildECDSAPublicKey(k.Crv(), k.x, k.y) - if err != nil { - return fmt.Errorf(`failed to build public key: %w`, err) - } - - return blackmagic.AssignIfCompatible(v, pubk) -} - -func (k *ecdsaPrivateKey) Raw(v interface{}) error { - return raw(k, v) -} - -func ecdsaPrivateJWKToRaw(keyif Key, hint interface{}) (interface{}, error) { +func ecdsaJWKToRaw(keyif Key, hint interface{}) (interface{}, error) { switch k := keyif.(type) { case *ecdsaPublicKey: switch hint.(type) { @@ -131,7 +113,6 @@ func ecdsaPrivateJWKToRaw(keyif Key, hint interface{}) (interface{}, error) { k.mu.RLock() defer k.mu.RUnlock() - return buildECDSAPublicKey(k.Crv(), k.x, k.y) case *ecdsaPrivateKey: switch hint.(type) { @@ -142,7 +123,6 @@ func ecdsaPrivateJWKToRaw(keyif Key, hint interface{}) (interface{}, error) { k.mu.RLock() defer k.mu.RUnlock() - pubk, err := buildECDSAPublicKey(k.Crv(), k.x, k.y) if err != nil { return nil, fmt.Errorf(`failed to build public key: %w`, err) @@ -210,7 +190,7 @@ func (k ecdsaPublicKey) Thumbprint(hash crypto.Hash) ([]byte, error) { defer k.mu.RUnlock() var key ecdsa.PublicKey - if err := k.Raw(&key); err != nil { + if err := Export(&k, &key); err != nil { return nil, fmt.Errorf(`failed to materialize ecdsa.PublicKey for thumbprint generation: %w`, err) } @@ -234,7 +214,7 @@ func (k ecdsaPrivateKey) Thumbprint(hash crypto.Hash) ([]byte, error) { defer k.mu.RUnlock() var key ecdsa.PrivateKey - if err := k.Raw(&key); err != nil { + if err := Export(&k, &key); err != nil { return nil, fmt.Errorf(`failed to materialize ecdsa.PrivateKey for thumbprint generation: %w`, err) } diff --git a/jwk/interface_gen.go b/jwk/interface_gen.go index 2ec82798d..1d08dd984 100644 --- a/jwk/interface_gen.go +++ b/jwk/interface_gen.go @@ -67,22 +67,6 @@ type Key interface { // called by the user Validate() error - // Raw creates the corresponding raw key from the jwk.Key. For example, - // EC keys would create *ecdsa.PublicKey or *ecdsa.PrivateKey, - // and OctetSeq types create a []byte key. - // - // If you do not know the exact type of a jwk.Key before attempting - // to obtain the raw key, you can simply pass a pointer to an - // empty interface as the first argument (important caveat: this can only - // be done for keys that are defined in this package. If you are using keys - // imported from third party modules, you will need to know the exact type - // before calling this method). - // - // If you already know the exact type, it is recommended that you - // pass a pointer to the zero value of the actual key type (e.g. &rsa.PrivateKey) - // for efficiency. - Raw(interface{}) error - // Thumbprint returns the JWK thumbprint using the indicated // hashing algorithm, according to RFC 7638 Thumbprint(crypto.Hash) ([]byte, error) diff --git a/jwk/jwk.go b/jwk/jwk.go index 6a43db20a..6948d10db 100644 --- a/jwk/jwk.go +++ b/jwk/jwk.go @@ -81,9 +81,9 @@ func FromRaw(raw interface{}) (Key, error) { return nil, fmt.Errorf(`jwk.FromRaw requires a non-nil key`) } - myKeyImporters.RLock() + muKeyImporters.RLock() conv, ok := keyImporters[reflect.TypeOf(raw)] - myKeyImporters.RUnlock() + muKeyImporters.RUnlock() if !ok { return nil, fmt.Errorf(`jwk.FromRaw: failed to convert %T to jwk.Key: no converters were able to convert`, raw) } @@ -174,7 +174,7 @@ func PublicRawKeyOf(v interface{}) (interface{}, error) { } var raw interface{} - if err := pubk.Raw(&raw); err != nil { + if err := Export(pubk, &raw); err != nil { return nil, fmt.Errorf(`jwk.PublicRawKeyOf: failed to obtain raw key from %T: %w`, pubk, err) } return raw, nil @@ -201,9 +201,9 @@ const ( // The second return value is the encoded byte sequence. func EncodeX509(v interface{}) (string, []byte, error) { // we can't import jwk, so just use the interface - if key, ok := v.(interface{ Raw(interface{}) error }); ok { + if key, ok := v.(Key); ok { var raw interface{} - if err := key.Raw(&raw); err != nil { + if err := Export(key, &raw); err != nil { return "", nil, fmt.Errorf(`failed to get raw key out of %T: %w`, key, err) } @@ -318,7 +318,7 @@ func ParseRawKey(data []byte, rawkey interface{}) error { return fmt.Errorf(`failed to parse key: %w`, err) } - if err := key.Raw(rawkey); err != nil { + if err := Export(key, rawkey); err != nil { return fmt.Errorf(`failed to assign to raw key variable: %w`, err) } @@ -603,7 +603,7 @@ func asnEncode(key Key) (string, []byte, error) { switch key := key.(type) { case RSAPrivateKey, ECDSAPrivateKey, OKPPrivateKey: var rawkey interface{} - if err := key.Raw(&rawkey); err != nil { + if err := Export(key, &rawkey); err != nil { return "", nil, fmt.Errorf(`failed to get raw key from jwk.Key: %w`, err) } buf, err := x509.MarshalPKCS8PrivateKey(rawkey) @@ -613,7 +613,7 @@ func asnEncode(key Key) (string, []byte, error) { return pmPrivateKey, buf, nil case RSAPublicKey, ECDSAPublicKey, OKPPublicKey: var rawkey interface{} - if err := key.Raw(&rawkey); err != nil { + if err := Export(key, &rawkey); err != nil { return "", nil, fmt.Errorf(`failed to get raw key from jwk.Key: %w`, err) } buf, err := x509.MarshalPKIXPublicKey(rawkey) diff --git a/jwk/jwk_test.go b/jwk/jwk_test.go index 1c35b9a17..e5ec212de 100644 --- a/jwk/jwk_test.go +++ b/jwk/jwk_test.go @@ -267,7 +267,7 @@ func VerifyKey(t *testing.T, def map[string]keyDef) { typ := expectedRawKeyType(key) var rawkey interface{} - if !assert.NoError(t, key.Raw(&rawkey), `Raw() should succeed`) { + if !assert.NoError(t, jwk.Export(key, &rawkey), `Raw() should succeed`) { return } if !assert.IsType(t, rawkey, typ, `raw key should be of this type`) { @@ -377,7 +377,7 @@ func TestParse(t *testing.T) { t.Helper() var irawkey interface{} - if !assert.NoError(t, key.Raw(&irawkey), `key.Raw(&interface) should ucceed`) { + if !assert.NoError(t, jwk.Export(key, &irawkey), `key.Raw(&interface) should ucceed`) { return } @@ -393,7 +393,7 @@ func TestParse(t *testing.T) { return } var rawkey rsa.PrivateKey - if !assert.NoError(t, key.Raw(&rawkey), `key.Raw(&rsa.PrivateKey) should succeed`) { + if !assert.NoError(t, jwk.Export(key, &rawkey), `key.Raw(&rsa.PrivateKey) should succeed`) { return } crawkey = &rawkey @@ -402,7 +402,7 @@ func TestParse(t *testing.T) { return } var rawkey rsa.PublicKey - if !assert.NoError(t, key.Raw(&rawkey), `key.Raw(&rsa.PublicKey) should succeed`) { + if !assert.NoError(t, jwk.Export(key, &rawkey), `key.Raw(&rsa.PublicKey) should succeed`) { return } crawkey = &rawkey @@ -411,7 +411,7 @@ func TestParse(t *testing.T) { return } var rawkey ecdsa.PrivateKey - if !assert.NoError(t, key.Raw(&rawkey), `key.Raw(&ecdsa.PrivateKey) should succeed`) { + if !assert.NoError(t, jwk.Export(key, &rawkey), `key.Raw(&ecdsa.PrivateKey) should succeed`) { return } crawkey = &rawkey @@ -422,13 +422,13 @@ func TestParse(t *testing.T) { switch k.Crv() { case jwa.Ed25519: var rawkey ed25519.PrivateKey - if !assert.NoError(t, key.Raw(&rawkey), `key.Raw(&ed25519.PrivateKey) should succeed`) { + if !assert.NoError(t, jwk.Export(key, &rawkey), `key.Raw(&ed25519.PrivateKey) should succeed`) { return } crawkey = rawkey case jwa.X25519: var rawkey ecdh.PrivateKey - if !assert.NoError(t, key.Raw(&rawkey), `key.Raw(&ecdh.PrivateKey) should succeed`) { + if !assert.NoError(t, jwk.Export(key, &rawkey), `key.Raw(&ecdh.PrivateKey) should succeed`) { return } crawkey = &rawkey @@ -445,13 +445,13 @@ func TestParse(t *testing.T) { switch k.Crv() { case jwa.Ed25519: var rawkey ed25519.PublicKey - if !assert.NoError(t, key.Raw(&rawkey), `key.Raw(&ed25519.PublicKey) should succeed`) { + if !assert.NoError(t, jwk.Export(key, &rawkey), `key.Raw(&ed25519.PublicKey) should succeed`) { return } crawkey = rawkey case jwa.X25519: var rawkey ecdh.PublicKey - if !assert.NoError(t, key.Raw(&rawkey), `key.Raw(&ecdh.PublicKey) should succeed`) { + if !assert.NoError(t, jwk.Export(key, &rawkey), `key.Raw(&ecdh.PublicKey) should succeed`) { return } crawkey = &rawkey @@ -940,7 +940,7 @@ func TestPublicKeyOf(t *testing.T) { // Get the raw key to compare var rawKey interface{} - if !assert.NoError(t, pubJwkKey.Raw(&rawKey), `pubJwkKey.Raw should succeed`) { + if !assert.NoError(t, jwk.Export(pubJwkKey, &rawKey), `pubJwkKey.Raw should succeed`) { return } @@ -993,7 +993,7 @@ func TestPublicKeyOf(t *testing.T) { // Get the raw key to compare var rawKey interface{} - if !assert.NoError(t, setKey.Raw(&rawKey), `pubJwkKey.Raw should succeed`) { + if !assert.NoError(t, jwk.Export(setKey, &rawKey), `pubJwkKey.Raw should succeed`) { return } @@ -1476,7 +1476,7 @@ c4wOvhbalcX0FqTM3mXCgMFRbibquhwdxbU= } var pubkey rsa.PublicKey - if !assert.NoError(t, key.Raw(&pubkey), `key.Raw should succeed`) { + if !assert.NoError(t, jwk.Export(key, &pubkey), `key.Raw should succeed`) { return } @@ -2182,7 +2182,7 @@ func TestGH947(t *testing.T) { k, err := jwk.ParseKey(raw) require.NoError(t, err, `jwk.ParseKey should succeed`) var exported []byte - require.Error(t, k.Raw(&exported), `(okpkey).Raw with 0-length OKP key should fail`) + require.Error(t, jwk.Export(k, &exported), `(okpkey).Raw with 0-length OKP key should fail`) } func TestValidation(t *testing.T) { diff --git a/jwk/okp.go b/jwk/okp.go index f78200d09..77a3576eb 100644 --- a/jwk/okp.go +++ b/jwk/okp.go @@ -12,6 +12,10 @@ import ( "github.com/lestrrat-go/jwx/v3/jwa" ) +func init() { + RegisterKeyExporter(jwa.OKP, KeyExportFunc(okpJWKToRaw)) +} + // Mental note: // // Curve25519 refers to a particular curve, and is represented in its Montgomery form. @@ -134,19 +138,30 @@ func buildOKPPrivateKey(alg jwa.EllipticCurveAlgorithm, xbuf []byte, dbuf []byte } } -func (k *okpPrivateKey) Raw(v interface{}) error { - k.mu.RLock() - defer k.mu.RUnlock() +// This is half baked. I think it will blow up if we used ecdh.* keys and/or x25519 keys +func okpJWKToRaw(key Key, hint interface{}) (interface{}, error) { + switch key := key.(type) { + case *okpPrivateKey: + key.mu.RLock() + defer key.mu.RUnlock() - privk, err := buildOKPPrivateKey(k.Crv(), k.x, k.d) - if err != nil { - return fmt.Errorf(`jwk.OKPPrivateKey: failed to build public key: %w`, err) - } + privk, err := buildOKPPrivateKey(key.Crv(), key.x, key.d) + if err != nil { + return nil, fmt.Errorf(`jwk.OKPPrivateKey: failed to build public key: %w`, err) + } + return privk, nil + case *okpPublicKey: + key.mu.RLock() + defer key.mu.RUnlock() - if err := blackmagic.AssignIfCompatible(v, privk); err != nil { - return fmt.Errorf(`jwk.OKPPrivateKey: failed to assign to destination variable: %w`, err) + pubk, err := buildOKPPublicKey(key.Crv(), key.x) + if err != nil { + return nil, fmt.Errorf(`jwk.OKPPublicKey: failed to build public key: %w`, err) + } + return pubk, nil + default: + return nil, ContinueError() } - return nil } func makeOKPPublicKey(src Key) (Key, error) { diff --git a/jwk/rsa.go b/jwk/rsa.go index f4104a837..6382316f6 100644 --- a/jwk/rsa.go +++ b/jwk/rsa.go @@ -7,11 +7,15 @@ import ( "fmt" "math/big" - "github.com/lestrrat-go/blackmagic" "github.com/lestrrat-go/jwx/v3/internal/base64" "github.com/lestrrat-go/jwx/v3/internal/pool" + "github.com/lestrrat-go/jwx/v3/jwa" ) +func init() { + RegisterKeyExporter(jwa.RSA, KeyExportFunc(rsaJWKToRaw)) +} + func (k *rsaPrivateKey) FromRaw(rawKey *rsa.PrivateKey) error { k.mu.Lock() defer k.mu.Unlock() @@ -97,78 +101,86 @@ func (k *rsaPublicKey) FromRaw(rawKey *rsa.PublicKey) error { return nil } -func (k *rsaPrivateKey) Raw(v interface{}) error { - k.mu.RLock() - defer k.mu.RUnlock() - - var d, q, p big.Int // note: do not use from sync.Pool - - d.SetBytes(k.d) - q.SetBytes(k.q) - p.SetBytes(k.p) - - // optional fields - var dp, dq, qi *big.Int - if len(k.dp) > 0 { - dp = &big.Int{} // note: do not use from sync.Pool - dp.SetBytes(k.dp) - } - - if len(k.dq) > 0 { - dq = &big.Int{} // note: do not use from sync.Pool - dq.SetBytes(k.dq) - } +func buildRSAPublicKey(key *rsa.PublicKey, n, e []byte) { + bin := pool.GetBigInt() + bie := pool.GetBigInt() + defer pool.ReleaseBigInt(bie) - if len(k.qi) > 0 { - qi = &big.Int{} // note: do not use from sync.Pool - qi.SetBytes(k.qi) - } + bin.SetBytes(n) + bie.SetBytes(e) - var key rsa.PrivateKey + key.N = bin + key.E = int(bie.Int64()) +} - pubk := newRSAPublicKey() - pubk.n = k.n - pubk.e = k.e - if err := pubk.Raw(&key.PublicKey); err != nil { - return fmt.Errorf(`failed to materialize RSA public key: %w`, err) - } +func rsaJWKToRaw(key Key, hint interface{}) (interface{}, error) { + switch key := key.(type) { + case *rsaPublicKey: + switch hint.(type) { + case *rsa.PublicKey, *interface{}: + default: + return nil, fmt.Errorf(`invalid destination object type %T for public RSA JWK: %w`, hint, ContinueError()) + } - key.D = &d - key.Primes = []*big.Int{&p, &q} + key.mu.RLock() + defer key.mu.RUnlock() + var pubkey rsa.PublicKey + buildRSAPublicKey(&pubkey, key.n, key.e) - if dp != nil { - key.Precomputed.Dp = dp - } - if dq != nil { - key.Precomputed.Dq = dq - } - if qi != nil { - key.Precomputed.Qinv = qi - } - key.Precomputed.CRTValues = []rsa.CRTValue{} + return &pubkey, nil + case *rsaPrivateKey: + switch hint.(type) { + case *rsa.PrivateKey, *interface{}: + default: + return nil, fmt.Errorf(`invalid destination object type %T for private RSA JWK: %w`, hint, ContinueError()) + } + key.mu.RLock() + defer key.mu.RUnlock() - return blackmagic.AssignIfCompatible(v, &key) -} + var d, q, p big.Int // note: do not use from sync.Pool -// Raw takes the values stored in the Key object, and creates the -// corresponding *rsa.PublicKey object. -func (k *rsaPublicKey) Raw(v interface{}) error { - k.mu.RLock() - defer k.mu.RUnlock() + d.SetBytes(key.d) + q.SetBytes(key.q) + p.SetBytes(key.p) - var key rsa.PublicKey + // optional fields + var dp, dq, qi *big.Int + if len(key.dp) > 0 { + dp = &big.Int{} // note: do not use from sync.Pool + dp.SetBytes(key.dp) + } - n := pool.GetBigInt() - e := pool.GetBigInt() - defer pool.ReleaseBigInt(e) + if len(key.dq) > 0 { + dq = &big.Int{} // note: do not use from sync.Pool + dq.SetBytes(key.dq) + } - n.SetBytes(k.n) - e.SetBytes(k.e) + if len(key.qi) > 0 { + qi = &big.Int{} // note: do not use from sync.Pool + qi.SetBytes(key.qi) + } - key.N = n - key.E = int(e.Int64()) + var privkey rsa.PrivateKey + buildRSAPublicKey(&privkey.PublicKey, key.n, key.e) + privkey.D = &d + privkey.Primes = []*big.Int{&p, &q} - return blackmagic.AssignIfCompatible(v, &key) + if dp != nil { + privkey.Precomputed.Dp = dp + } + if dq != nil { + privkey.Precomputed.Dq = dq + } + if qi != nil { + privkey.Precomputed.Qinv = qi + } + // This may look like a no-op, but it's required if we want to + // compare it against a key generated by rsa.GenerateKey + privkey.Precomputed.CRTValues = []rsa.CRTValue{} + return &privkey, nil + default: + return nil, ContinueError() + } } func makeRSAPublicKey(src Key) (Key, error) { @@ -208,7 +220,7 @@ func (k rsaPrivateKey) Thumbprint(hash crypto.Hash) ([]byte, error) { defer k.mu.RUnlock() var key rsa.PrivateKey - if err := k.Raw(&key); err != nil { + if err := Export(&k, &key); err != nil { return nil, fmt.Errorf(`failed to materialize RSA private key: %w`, err) } return rsaThumbprint(hash, &key.PublicKey) @@ -219,7 +231,7 @@ func (k rsaPublicKey) Thumbprint(hash crypto.Hash) ([]byte, error) { defer k.mu.RUnlock() var key rsa.PublicKey - if err := k.Raw(&key); err != nil { + if err := Export(&k, &key); err != nil { return nil, fmt.Errorf(`failed to materialize RSA public key: %w`, err) } return rsaThumbprint(hash, &key) diff --git a/jwk/symmetric.go b/jwk/symmetric.go index af582f429..e58f17920 100644 --- a/jwk/symmetric.go +++ b/jwk/symmetric.go @@ -4,10 +4,14 @@ import ( "crypto" "fmt" - "github.com/lestrrat-go/blackmagic" "github.com/lestrrat-go/jwx/v3/internal/base64" + "github.com/lestrrat-go/jwx/v3/jwa" ) +func init() { + RegisterKeyExporter(jwa.OctetSeq, KeyExportFunc(octetSeqToRaw)) +} + func (k *symmetricKey) FromRaw(rawKey []byte) error { k.mu.Lock() defer k.mu.Unlock() @@ -21,12 +25,17 @@ func (k *symmetricKey) FromRaw(rawKey []byte) error { return nil } -// Raw returns the octets for this symmetric key. -// Since this is a symmetric key, this just calls Octets -func (k *symmetricKey) Raw(v interface{}) error { - k.mu.RLock() - defer k.mu.RUnlock() - return blackmagic.AssignIfCompatible(v, k.octets) +func octetSeqToRaw(key Key, hint interface{}) (interface{}, error) { + switch key := key.(type) { + case *symmetricKey: + key.mu.RLock() + defer key.mu.RUnlock() + octets := make([]byte, len(key.octets)) + copy(octets, key.octets) + return octets, nil + default: + return nil, ContinueError() + } } // Thumbprint returns the JWK thumbprint using the indicated @@ -35,7 +44,7 @@ func (k *symmetricKey) Thumbprint(hash crypto.Hash) ([]byte, error) { k.mu.RLock() defer k.mu.RUnlock() var octets []byte - if err := k.Raw(&octets); err != nil { + if err := Export(k, &octets); err != nil { return nil, fmt.Errorf(`failed to materialize symmetric key: %w`, err) } diff --git a/jws/jws_test.go b/jws/jws_test.go index 08a4a8eb4..8e73ff985 100644 --- a/jws/jws_test.go +++ b/jws/jws_test.go @@ -511,7 +511,7 @@ func TestEncode(t *testing.T) { t.Fatal("Failed to parse JWK") } var key interface{} - if !assert.NoError(t, jwkKey.Raw(&key), `jwk.Raw should succeed`) { + if !assert.NoError(t, jwk.Export(jwkKey, &key), `jwk.Export should succeed`) { return } var jwsCompact []byte @@ -583,7 +583,7 @@ func TestEncode(t *testing.T) { } var rawkey rsa.PrivateKey - if !assert.NoError(t, privkey.Raw(&rawkey), `obtaining raw key should succeed`) { + if !assert.NoError(t, jwk.Export(privkey, &rawkey), `obtaining raw key should succeed`) { return } @@ -660,7 +660,7 @@ func TestEncode(t *testing.T) { } var rawkey ecdsa.PrivateKey - if !assert.NoError(t, privkey.Raw(&rawkey), `obtaining raw key should succeed`) { + if !assert.NoError(t, jwk.Export(privkey, &rawkey), `obtaining raw key should succeed`) { return } @@ -745,7 +745,7 @@ func TestEncode(t *testing.T) { } var rawkey ed25519.PrivateKey - if !assert.NoError(t, privkey.Raw(&rawkey), `obtaining raw key should succeed`) { + if !assert.NoError(t, jwk.Export(privkey, &rawkey), `obtaining raw key should succeed`) { return } @@ -1024,7 +1024,7 @@ func TestDecode_ES384Compact_NoSigTrim(t *testing.T) { } var rawkey ecdsa.PublicKey - if !assert.NoError(t, pubkey.Raw(&rawkey), `obtaining raw key should succeed`) { + if !assert.NoError(t, jwk.Export(pubkey, &rawkey), `obtaining raw key should succeed`) { return } diff --git a/jwx_test.go b/jwx_test.go index b3dbca8d7..d9e466a0b 100644 --- a/jwx_test.go +++ b/jwx_test.go @@ -162,7 +162,7 @@ func TestJoseCompatibility(t *testing.T) { } } - if !assert.NoError(t, webkey.Raw(&tc.Raw), `jwk.Raw should succeed`) { + if !assert.NoError(t, jwk.Export(webkey, &tc.Raw), `jwk.Export should succeed`) { return } }) @@ -293,17 +293,17 @@ func joseInteropTest(ctx context.Context, spec interopTest, t *testing.T) { switch spec.alg { case jwa.RSA1_5, jwa.RSA_OAEP, jwa.RSA_OAEP_256: var rawkey rsa.PrivateKey - if !assert.NoError(t, jwxJwk.Raw(&rawkey), `jwk.Raw should succeed`) { + if !assert.NoError(t, jwk.Export(jwxJwk, &rawkey), `jwk.Export should succeed`) { return } case jwa.ECDH_ES, jwa.ECDH_ES_A128KW, jwa.ECDH_ES_A192KW, jwa.ECDH_ES_A256KW: var rawkey ecdsa.PrivateKey - if !assert.NoError(t, jwxJwk.Raw(&rawkey), `jwk.Raw should succeed`) { + if !assert.NoError(t, jwk.Export(jwxJwk, &rawkey), `jwk.Export should succeed`) { return } default: var rawkey []byte - if !assert.NoError(t, jwxJwk.Raw(&rawkey), `jwk.Raw should succeed`) { + if !assert.NoError(t, jwk.Export(jwxJwk, &rawkey), `jwk.Export should succeed`) { return } } diff --git a/tools/cmd/genjwk/main.go b/tools/cmd/genjwk/main.go index cd975e946..40992d776 100644 --- a/tools/cmd/genjwk/main.go +++ b/tools/cmd/genjwk/main.go @@ -687,19 +687,6 @@ func generateGenericHeaders(fields codegen.FieldList) error { o.L("// Validate is never called by `UnmarshalJSON()` or `Set`. It must explicitly be") o.L("// called by the user") o.L("Validate() error") - o.LL("// Raw creates the corresponding raw key from the jwk.Key. For example,") - o.L("// EC keys would create *ecdsa.PublicKey or *ecdsa.PrivateKey,") - o.L("// and OctetSeq types create a []byte key.") - o.L("//\n// If you do not know the exact type of a jwk.Key before attempting") - o.L("// to obtain the raw key, you can simply pass a pointer to an") - o.L("// empty interface as the first argument (important caveat: this can only") - o.L("// be done for keys that are defined in this package. If you are using keys") - o.L("// imported from third party modules, you will need to know the exact type") - o.L("// before calling this method).") - o.L("//\n// If you already know the exact type, it is recommended that you") - o.L("// pass a pointer to the zero value of the actual key type (e.g. &rsa.PrivateKey)") - o.L("// for efficiency.") - o.L("Raw(interface{}) error") o.LL("// Thumbprint returns the JWK thumbprint using the indicated") o.L("// hashing algorithm, according to RFC 7638") o.L("Thumbprint(crypto.Hash) ([]byte, error)")