diff --git a/examples/go.mod b/examples/go.mod index 5e3c0f5aa..738e61809 100644 --- a/examples/go.mod +++ b/examples/go.mod @@ -4,9 +4,9 @@ go 1.16 require ( github.com/cloudflare/circl v1.3.3 - github.com/lestrrat-go/jwx/v2 v2.0.11 + github.com/lestrrat-go/jwx/v2 v2.0.12-0.20230824024517-a077c65f16eb ) replace github.com/cloudflare/circl v1.0.0 => github.com/cloudflare/circl v1.0.1-0.20210104183656-96a0695de3c3 -replace github.com/lestrrat-go/jwx/v2 v2.0.11 => ../ +replace github.com/lestrrat-go/jwx/v2 v2.0.11 => ../ \ No newline at end of file diff --git a/examples/go.sum b/examples/go.sum index 1bfab243b..8bdc85cf3 100644 --- a/examples/go.sum +++ b/examples/go.sum @@ -17,6 +17,8 @@ github.com/lestrrat-go/httprc v1.0.4 h1:bAZymwoZQb+Oq8MEbyipag7iSq6YIga8Wj6GOiJG github.com/lestrrat-go/httprc v1.0.4/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo= github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI= github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4= +github.com/lestrrat-go/jwx/v2 v2.0.12-0.20230824024517-a077c65f16eb h1:qPUmVTD6gWn0S8zfmAzjgzF5xdYtJrGhroN+i7u/TrE= +github.com/lestrrat-go/jwx/v2 v2.0.12-0.20230824024517-a077c65f16eb/go.mod h1:Mq4KN1mM7bp+5z/W5HS8aCNs5RKZ911G/0y2qUjAQuQ= github.com/lestrrat-go/option v1.0.0/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU= github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= diff --git a/examples/jwk_example_test.go b/examples/jwk_example_test.go index a41de033e..21804203a 100644 --- a/examples/jwk_example_test.go +++ b/examples/jwk_example_test.go @@ -32,7 +32,7 @@ func ExampleJWK_Usage() { key := pair.Value.(jwk.Key) var rawkey interface{} // This is the raw key, like *rsa.PrivateKey or *ecdsa.PrivateKey - if err := key.Raw(&rawkey); err != nil { + if err := jwk.Raw(key, &rawkey); err != nil { log.Printf("failed to create public key: %s", err) return } diff --git a/internal/jwxtest/jwxtest.go b/internal/jwxtest/jwxtest.go index cd9e8af2e..24220acee 100644 --- a/internal/jwxtest/jwxtest.go +++ b/internal/jwxtest/jwxtest.go @@ -270,7 +270,7 @@ func DecryptJweFile(ctx context.Context, file string, alg jwa.KeyEncryptionAlgor } var rawkey interface{} - if err := key.Raw(&rawkey); err != nil { + if err := jwk.Raw(key, &rawkey); err != nil { return nil, fmt.Errorf(`failed to obtain raw key from JWK: %w`, err) } @@ -288,19 +288,19 @@ func EncryptJweFile(ctx context.Context, payload []byte, keyalg jwa.KeyEncryptio switch keyalg { case jwa.RSA1_5, jwa.RSA_OAEP, jwa.RSA_OAEP_256: var rawkey rsa.PrivateKey - if err := key.Raw(&rawkey); err != nil { + if err := jwk.Raw(key, &rawkey); err != nil { return "", nil, fmt.Errorf(`failed to obtain raw key: %w`, err) } keyif = rawkey.PublicKey case jwa.ECDH_ES, jwa.ECDH_ES_A128KW, jwa.ECDH_ES_A192KW, jwa.ECDH_ES_A256KW: var rawkey ecdsa.PrivateKey - if err := key.Raw(&rawkey); err != nil { + if err := jwk.Raw(key, &rawkey); err != nil { return "", nil, fmt.Errorf(`failed to obtain raw key: %w`, err) } keyif = rawkey.PublicKey default: var rawkey []byte - if err := key.Raw(&rawkey); err != nil { + if err := jwk.Raw(key, &rawkey); err != nil { return "", nil, fmt.Errorf(`failed to obtain raw key: %w`, err) } keyif = rawkey @@ -326,7 +326,7 @@ func VerifyJwsFile(ctx context.Context, file string, alg jwa.SignatureAlgorithm, } var rawkey, pubkey interface{} - if err := key.Raw(&rawkey); err != nil { + if err := jwk.Raw(key, &rawkey); err != nil { return nil, fmt.Errorf(`failed to obtain raw key from JWK: %w`, err) } pubkey = rawkey diff --git a/internal/keyconv/keyconv.go b/internal/keyconv/keyconv.go index 807da1dee..0114325ed 100644 --- a/internal/keyconv/keyconv.go +++ b/internal/keyconv/keyconv.go @@ -17,7 +17,7 @@ import ( func RSAPrivateKey(dst, src interface{}) error { if jwkKey, ok := src.(jwk.Key); ok { var raw rsa.PrivateKey - if err := jwkKey.Raw(&raw); err != nil { + if err := jwk.Raw(jwkKey, &raw); err != nil { return fmt.Errorf(`failed to produce rsa.PrivateKey from %T: %w`, src, err) } src = &raw @@ -42,7 +42,7 @@ func RSAPrivateKey(dst, src interface{}) error { func RSAPublicKey(dst, src interface{}) error { if jwkKey, ok := src.(jwk.Key); ok { var raw rsa.PublicKey - if err := jwkKey.Raw(&raw); err != nil { + if err := jwk.Raw(jwkKey, &raw); err != nil { return fmt.Errorf(`failed to produce rsa.PublicKey from %T: %w`, src, err) } src = &raw @@ -66,7 +66,7 @@ func RSAPublicKey(dst, src interface{}) error { func ECDSAPrivateKey(dst, src interface{}) error { if jwkKey, ok := src.(jwk.Key); ok { var raw ecdsa.PrivateKey - if err := jwkKey.Raw(&raw); err != nil { + if err := jwk.Raw(jwkKey, &raw); err != nil { return fmt.Errorf(`failed to produce ecdsa.PrivateKey from %T: %w`, src, err) } src = &raw @@ -89,7 +89,7 @@ func ECDSAPrivateKey(dst, src interface{}) error { func ECDSAPublicKey(dst, src interface{}) error { if jwkKey, ok := src.(jwk.Key); ok { var raw ecdsa.PublicKey - if err := jwkKey.Raw(&raw); err != nil { + if err := jwk.Raw(jwkKey, &raw); err != nil { return fmt.Errorf(`failed to produce ecdsa.PublicKey from %T: %w`, src, err) } src = &raw @@ -110,7 +110,7 @@ func ECDSAPublicKey(dst, src interface{}) error { func ByteSliceKey(dst, src interface{}) error { if jwkKey, ok := src.(jwk.Key); ok { var raw []byte - if err := jwkKey.Raw(&raw); err != nil { + if err := jwk.Raw(jwkKey, &raw); err != nil { return fmt.Errorf(`failed to produce []byte from %T: %w`, src, err) } src = raw @@ -125,7 +125,7 @@ func ByteSliceKey(dst, src interface{}) error { func Ed25519PrivateKey(dst, src interface{}) error { if jwkKey, ok := src.(jwk.Key); ok { var raw ed25519.PrivateKey - if err := jwkKey.Raw(&raw); err != nil { + if err := jwk.Raw(jwkKey, &raw); err != nil { return fmt.Errorf(`failed to produce ed25519.PrivateKey from %T: %w`, src, err) } src = &raw @@ -146,7 +146,7 @@ func Ed25519PrivateKey(dst, src interface{}) error { func Ed25519PublicKey(dst, src interface{}) error { if jwkKey, ok := src.(jwk.Key); ok { var raw ed25519.PublicKey - if err := jwkKey.Raw(&raw); err != nil { + if err := jwk.Raw(jwkKey, &raw); err != nil { return fmt.Errorf(`failed to produce ed25519.PublicKey from %T: %w`, src, err) } src = &raw diff --git a/jwe/internal/keyenc/keyenc_test.go b/jwe/internal/keyenc/keyenc_test.go index f2ed9c994..169e6a24b 100644 --- a/jwe/internal/keyenc/keyenc_test.go +++ b/jwe/internal/keyenc/keyenc_test.go @@ -101,7 +101,7 @@ func TestDeriveECDHES(t *testing.T) { if !assert.NoError(t, err, `jwk.ParseKey should succeed`) { return } - if !assert.NoError(t, aliceWebKey.Raw(&aliceKey), `aliceWebKey.Raw should succeed`) { + if !assert.NoError(t, jwk.Raw(aliceWebKey, &aliceKey), `aliceWebKey.Raw should succeed`) { return } @@ -109,7 +109,7 @@ func TestDeriveECDHES(t *testing.T) { if !assert.NoError(t, err, `jwk.ParseKey should succeed`) { return } - if !assert.NoError(t, bobWebKey.Raw(&bobKey), `bobWebKey.Raw should succeed`) { + if !assert.NoError(t, jwk.Raw(bobWebKey, &bobKey), `bobWebKey.Raw should succeed`) { return } diff --git a/jwe/jwe.go b/jwe/jwe.go index 67b8e97b3..3d2c84a4d 100644 --- a/jwe/jwe.go +++ b/jwe/jwe.go @@ -76,7 +76,7 @@ func (b *recipientBuilder) Build(cek []byte, calg jwa.ContentEncryptionAlgorithm keyID = jwkKey.KeyID() var raw interface{} - if err := jwkKey.Raw(&raw); err != nil { + if err := jwk.Raw(jwkKey, &raw); err != nil { return nil, nil, fmt.Errorf(`failed to retrieve raw key out of %T: %w`, b.key, err) } @@ -573,7 +573,7 @@ func (dctx *decryptCtx) try(ctx context.Context, recipient Recipient, keyUsed in func (dctx *decryptCtx) decryptContent(ctx context.Context, alg jwa.KeyEncryptionAlgorithm, key interface{}, recipient Recipient) ([]byte, error) { if jwkKey, ok := key.(jwk.Key); ok { var raw interface{} - if err := jwkKey.Raw(&raw); err != nil { + if err := jwk.Raw(jwkKey, &raw); err != nil { return nil, fmt.Errorf(`failed to retrieve raw key from %T: %w`, key, err) } key = raw @@ -609,13 +609,13 @@ func (dctx *decryptCtx) decryptContent(ctx context.Context, alg jwa.KeyEncryptio switch epk := epkif.(type) { case jwk.ECDSAPublicKey: var pubkey ecdsa.PublicKey - if err := epk.Raw(&pubkey); err != nil { + if err := jwk.Raw(epk, &pubkey); err != nil { return nil, fmt.Errorf(`failed to get public key: %w`, err) } dec.PublicKey(&pubkey) case jwk.OKPPublicKey: var pubkey interface{} - if err := epk.Raw(&pubkey); err != nil { + if err := jwk.Raw(epk, &pubkey); err != nil { return nil, fmt.Errorf(`failed to get public key: %w`, err) } dec.PublicKey(pubkey) diff --git a/jwe/jwe_test.go b/jwe/jwe_test.go index 64800de9d..b9459b6d7 100644 --- a/jwe/jwe_test.go +++ b/jwe/jwe_test.go @@ -50,7 +50,7 @@ func init() { panic(err) } - if err := privkey.Raw(&rsaPrivKey); err != nil { + if err := jwk.Raw(privkey, &rsaPrivKey); err != nil { panic(err) } } @@ -168,7 +168,7 @@ func TestParse_RSAES_OAEP_AES_GCM(t *testing.T) { } var rawkey rsa.PrivateKey - if !assert.NoError(t, privkey.Raw(&rawkey), `obtaining raw key should succeed`) { + if !assert.NoError(t, jwk.Raw(privkey, &rawkey), `obtaining raw key should succeed`) { return } @@ -503,7 +503,7 @@ func Test_GHIssue207(t *testing.T) { } var key ecdsa.PrivateKey - if !assert.NoError(t, webKey.Raw(&key), `jwk.Raw should succeed`) { + if !assert.NoError(t, jwk.Raw(webKey, &key), `jwk.Raw should succeed`) { return } @@ -630,7 +630,7 @@ func TestDecodePredefined_Direct(t *testing.T) { } var key []byte - if !assert.NoError(t, webKey.Raw(&key), `jwk.Raw should succeed`) { + if !assert.NoError(t, jwk.Raw(webKey, &key), `jwk.Raw should succeed`) { return } diff --git a/jwk/BUILD.bazel b/jwk/BUILD.bazel index a61a919f5..0bc749805 100644 --- a/jwk/BUILD.bazel +++ b/jwk/BUILD.bazel @@ -4,6 +4,7 @@ go_library( name = "jwk", srcs = [ "cache.go", + "convert.go", "ecdsa.go", "ecdsa_gen.go", "fetch.go", @@ -16,6 +17,7 @@ go_library( "okp_gen.go", "options.go", "options_gen.go", + "pem.go", "rsa.go", "rsa_gen.go", "set.go", diff --git a/jwk/convert.go b/jwk/convert.go new file mode 100644 index 000000000..cb548c137 --- /dev/null +++ b/jwk/convert.go @@ -0,0 +1,266 @@ +package jwk + +import ( + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "fmt" + "sync" + + "github.com/lestrrat-go/jwx/v2/x25519" +) + +type RawFromKeyer interface { + RawFromKey(Key, interface{}) error +} + +type ChainedRawFromKeyer interface { + Next(RawFromKeyer, Key, interface{}) error +} + +type ChainedRawFromKeyFunc func(RawFromKeyer, Key, interface{}) error + +func (fn ChainedRawFromKeyFunc) Next(n RawFromKeyer, key Key, raw interface{}) error { + return fn(n, key, raw) +} + +type chainedRawFromKey struct { + mu sync.RWMutex + list []ChainedRawFromKeyer +} + +type chainedRawFromKeyCallState struct { + current int + parent *chainedRawFromKey +} + +func (c *chainedRawFromKey) Add(rfk ChainedRawFromKeyer) { + if rfk == nil { + return // no-op + } + c.mu.Lock() + defer c.mu.Unlock() + c.list = append(c.list, rfk) +} + +func (c *chainedRawFromKey) Next(key Key, raw interface{}) error { + c.mu.RLock() + lrfk := len(c.list) + c.mu.RUnlock() + st := &chainedRawFromKeyCallState{parent: c, current: lrfk} + return st.RawFromKey(key, raw) +} + +func (s *chainedRawFromKeyCallState) RawFromKey(key Key, raw interface{}) error { + idx := s.current - 1 + + s.parent.mu.RLock() + defer s.parent.mu.RUnlock() + + llist := len(s.parent.list) + if idx < 0 || idx >= llist { + return fmt.Errorf(`jwk.Raw: invalid raw key type %T`, raw) + } + s.current = idx + + rfk := s.parent.list[idx] + return rfk.Next(s, key, raw) +} + +type chainedKeyFromRaw struct { + mu sync.RWMutex + list []ChainedKeyFromRawer +} + +type chainedKeyFromRawCallState struct { + current int + parent *chainedKeyFromRaw +} + +func (c *chainedKeyFromRaw) Add(kfr ChainedKeyFromRawer) { + if kfr == nil { + return // no-op + } + c.mu.Lock() + defer c.mu.Unlock() + c.list = append(c.list, kfr) +} + +func (c *chainedKeyFromRaw) Next(raw interface{}) (Key, error) { + c.mu.RLock() + lkfr := len(c.list) + c.mu.RUnlock() + st := &chainedKeyFromRawCallState{parent: c, current: lkfr} + return st.KeyFromRaw(raw) +} + +func (s *chainedKeyFromRawCallState) KeyFromRaw(raw interface{}) (Key, error) { + idx := s.current - 1 + + s.parent.mu.RLock() + defer s.parent.mu.RUnlock() + + llist := len(s.parent.list) + if idx < 0 || idx >= llist { + return nil, fmt.Errorf(`jwk.FromRaw: invalid raw key type %T`, raw) + } + s.current = idx + + kfr := s.parent.list[idx] + return kfr.Next(s, raw) +} + +var chainedKFR = &chainedKeyFromRaw{ + list: []ChainedKeyFromRawer{ChainedKeyFromRawFunc(fromRaw)}, +} + +var chainedRFK = &chainedRawFromKey{ + list: []ChainedRawFromKeyer{ChainedRawFromKeyFunc(toRaw)}, +} + +type KeyFromRawer interface { + KeyFromRaw(interface{}) (Key, error) +} + +// ChainedKeyFromRawer describes a type that can build a Key from a raw key +// +// ChainedKeyFromRawer objects are expected to be called in sequence. When a new +// object is added to the list of KeyFromRawer objects, they are called +// from the most recently added all the way up to the default object, +// if you choose to do so by invokind the first argument. +type ChainedKeyFromRawer interface { + // Next calls the handler in the subsequent chain of handlers. + // + // The first argument invokes the _next_ KeyFromRawer that can be called in the + // chain of possible KeyFromRawers that are registered. For example, + // if your KeyFromRawer failed to match any key type that you can handle, + // you can defer to the next KeyFromRawer to see if it can handle it + Next(KeyFromRawer, interface{}) (Key, error) +} + +// ChainedKeyFromRawFunc is an instance of ChainedKeyFromRawer represented by a function +type ChainedKeyFromRawFunc func(KeyFromRawer, interface{}) (Key, error) + +func (fn ChainedKeyFromRawFunc) Next(n KeyFromRawer, raw interface{}) (Key, error) { + return fn(n, raw) +} + +// AddKeyFromRaw adds a new KeyFromRawer object that is used in the FromRaw() function, which +// in turn will handle converting a raw key to a Key. +func AddKeyFromRaw(kfr ChainedKeyFromRawer) { + chainedKFR.Add(kfr) +} + +func fromRaw(_ KeyFromRawer, key interface{}) (Key, error) { + var ptr interface{} + switch v := key.(type) { + case rsa.PrivateKey: + ptr = &v + case rsa.PublicKey: + ptr = &v + case ecdsa.PrivateKey: + ptr = &v + case ecdsa.PublicKey: + ptr = &v + default: + ptr = v + } + + switch rawKey := ptr.(type) { + case *rsa.PrivateKey: + k := newRSAPrivateKey() + if err := k.FromRaw(rawKey); err != nil { + return nil, fmt.Errorf(`failed to initialize %T from %T: %w`, k, rawKey, err) + } + return k, nil + case *rsa.PublicKey: + k := newRSAPublicKey() + if err := k.FromRaw(rawKey); err != nil { + return nil, fmt.Errorf(`failed to initialize %T from %T: %w`, k, rawKey, err) + } + return k, nil + case *ecdsa.PrivateKey: + k := newECDSAPrivateKey() + if err := k.FromRaw(rawKey); err != nil { + return nil, fmt.Errorf(`failed to initialize %T from %T: %w`, k, rawKey, err) + } + return k, nil + case *ecdsa.PublicKey: + k := newECDSAPublicKey() + if err := k.FromRaw(rawKey); err != nil { + return nil, fmt.Errorf(`failed to initialize %T from %T: %w`, k, rawKey, err) + } + return k, nil + case ed25519.PrivateKey: + k := newOKPPrivateKey() + if err := k.FromRaw(rawKey); err != nil { + return nil, fmt.Errorf(`failed to initialize %T from %T: %w`, k, rawKey, err) + } + return k, nil + case ed25519.PublicKey: + k := newOKPPublicKey() + if err := k.FromRaw(rawKey); err != nil { + return nil, fmt.Errorf(`failed to initialize %T from %T: %w`, k, rawKey, err) + } + return k, nil + case x25519.PrivateKey: + k := newOKPPrivateKey() + if err := k.FromRaw(rawKey); err != nil { + return nil, fmt.Errorf(`failed to initialize %T from %T: %w`, k, rawKey, err) + } + return k, nil + case x25519.PublicKey: + k := newOKPPublicKey() + if err := k.FromRaw(rawKey); err != nil { + return nil, fmt.Errorf(`failed to initialize %T from %T: %w`, k, rawKey, err) + } + return k, nil + case []byte: + k := newSymmetricKey() + if err := k.FromRaw(rawKey); err != nil { + return nil, fmt.Errorf(`failed to initialize %T from %T: %w`, k, rawKey, err) + } + return k, nil + default: + return nil, fmt.Errorf(`invalid key type '%T' for jwk.FromRaw`, key) + } +} + +// FromRaw creates a Key from the given key (RSA/ECDSA/symmetric keys). +// +// The constructor auto-detects the type of key to be instantiated +// based on the input type: +// +// - "crypto/rsa".PrivateKey and "crypto/rsa".PublicKey creates an RSA based key +// - "crypto/ecdsa".PrivateKey and "crypto/ecdsa".PublicKey creates an EC based key +// - "crypto/ed25519".PrivateKey and "crypto/ed25519".PublicKey creates an OKP based key +// - []byte creates a symmetric key +// +// This function also takes care of additional key types added by external +// libraries such as secp256k1 keys. +func FromRaw(raw interface{}) (Key, error) { + if raw == nil { + return nil, fmt.Errorf(`jwk.FromRaw requires a non-nil key`) + } + + return chainedKFR.Next(raw) +} + +// Raw converts a jwk.Key to its raw form and stores in the `raw` variable. +// `raw` must be a pointer to a compatible object, otherwise an error will +// be returned. +// +// As of v2.0.12, it is recommended to use `jwk.Raw()` instead of `keyObject.Raw()`. +// The latter will NOT take care of converting additional key types added by +// external libraries, such as secp256k1 keys. +func Raw(key Key, raw interface{}) error { + return chainedRFK.Next(key, raw) +} + +func toRaw(_ RawFromKeyer, key Key, raw interface{}) error { + return key.Raw(raw) +} + +func AddRawFromKey(rfk ChainedRawFromKeyer) { + chainedRFK.Add(rfk) +} diff --git a/jwk/ecdsa.go b/jwk/ecdsa.go index 67a14ba63..a8293223d 100644 --- a/jwk/ecdsa.go +++ b/jwk/ecdsa.go @@ -186,7 +186,7 @@ func (k ecdsaPublicKey) Thumbprint(hash crypto.Hash) ([]byte, error) { defer k.mu.RUnlock() var key ecdsa.PublicKey - if err := k.Raw(&key); err != nil { + if err := Raw(&k, &key); err != nil { return nil, fmt.Errorf(`failed to materialize ecdsa.PublicKey for thumbprint generation: %w`, err) } @@ -210,7 +210,7 @@ func (k ecdsaPrivateKey) Thumbprint(hash crypto.Hash) ([]byte, error) { defer k.mu.RUnlock() var key ecdsa.PrivateKey - if err := k.Raw(&key); err != nil { + if err := Raw(&k, &key); err != nil { return nil, fmt.Errorf(`failed to materialize ecdsa.PrivateKey for thumbprint generation: %w`, err) } diff --git a/jwk/es256k.go b/jwk/es256k.go index 1a9d2346a..b87354fb6 100644 --- a/jwk/es256k.go +++ b/jwk/es256k.go @@ -4,11 +4,212 @@ package jwk import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/x509/pkix" + "encoding/asn1" + "encoding/pem" + "fmt" + "sync" + "github.com/decred/dcrd/dcrec/secp256k1/v4" + "github.com/lestrrat-go/blackmagic" "github.com/lestrrat-go/jwx/v2/internal/ecutil" "github.com/lestrrat-go/jwx/v2/jwa" ) func init() { ecutil.RegisterCurve(secp256k1.S256(), jwa.Secp256k1) + + AddKeyFromRaw(ChainedKeyFromRawFunc(secp256k1FromRaw)) + AddRawFromKey(ChainedRawFromKeyFunc(secp256k1Raw)) + AddASN1Encoder(ChainedASN1EncodeFunc(secp256k1ASN1Encode)) + AddASN1Decoder(ChainedASN1DecodeFunc(secp256k1ASN1Decode)) +} + +var secp256k1OID = asn1.ObjectIdentifier{1, 3, 132, 0, 10} +var secp256k1PkPool = sync.Pool{ + New: func() interface{} { + return make([]byte, 32+1) // 32 bytes + 1 + }, +} + +func getPkBuf(size int) []byte { + buf := secp256k1PkPool.Get().([]byte) + if cap(buf) < size { + buf = make([]byte, size) + } else { + buf = buf[:size] + } + return buf +} + +func releasePkBuf(buf []byte) { + // XXX Replace this with clear() when we remove support for go < 1.21 + for i := 0; i < len(buf); i++ { + buf[i] = byte(0) + } + + secp256k1PkPool.Put(buf) +} + +type secp256k1ASN1PrivateKey struct { + Version int + PrivateKey []byte + NamedCurveOID asn1.ObjectIdentifier `asn1:"optional,explicit,tag:0"` + PublicKey asn1.BitString `asn1:"optional,explicit,tag:1"` +} + +type secp256k1ASN1PublicKey struct { + Algorithm pkix.AlgorithmIdentifier + BitString asn1.BitString +} + +func secp256k1ASN1Encode(n ASN1Encoder, key Key) (string, []byte, error) { + switch key := key.(type) { + case ECDSAPrivateKey: + if key.Crv() == jwa.Secp256k1 { + var raw secp256k1.PrivateKey + if err := Raw(key, &raw); err != nil { + return "", nil, fmt.Errorf(`failed to convert jwk.Key into raw key: %w`, err) + } + return secp256k1EncodePrivateKey(&raw) + } + case ECDSAPublicKey: + if key.Crv() == jwa.Secp256k1 { + var raw secp256k1.PublicKey + if err := Raw(key, &raw); err != nil { + return "", nil, fmt.Errorf(`failed to convert jwk.Key into raw key: %w`, err) + } + return secp256k1EncodePublicKey(&raw) + } + } + + return n.ASN1Encode(key) +} + +func secp256k1ASN1Decode(n ASN1Decoder, buf []byte) (interface{}, []byte, error) { + block, rest := pem.Decode(buf) + if block == nil { + return nil, buf, fmt.Errorf(`jwk: PEM block decoded to nil`) + } + + if block.Type == pmECPrivateKey { + var priv secp256k1ASN1PrivateKey + // for 1-3, we're going to believe that this may have been + // another EC key that can be decoded by the next decoder + if _, err := asn1.Unmarshal(block.Bytes, &priv); err != nil { // (1) + return n.ASN1Decode(buf) + } + + if !priv.NamedCurveOID.Equal(secp256k1OID) { // (2) + return n.ASN1Decode(buf) + } + + if priv.Version != 1 { // (3) + return n.ASN1Decode(buf) + } + + key := secp256k1.PrivKeyFromBytes(priv.PrivateKey) + return key, rest, nil + } + // All other cases including secp256k1 public key can be handled + // by the default handler + return n.ASN1Decode(buf) +} + +func secp256k1EncodePrivateKey(key *secp256k1.PrivateKey) (string, []byte, error) { + asECDSA := key.ToECDSA() + size := (asECDSA.Curve.Params().N.BitLen() + 7) / 8 + pkbuf := getPkBuf(size) + defer releasePkBuf(pkbuf) + + buf, err := asn1.Marshal(secp256k1ASN1PrivateKey{ + Version: 1, + PrivateKey: asECDSA.D.FillBytes(pkbuf), + NamedCurveOID: secp256k1OID, + PublicKey: asn1.BitString{ + Bytes: elliptic.Marshal(asECDSA.Curve, asECDSA.X, asECDSA.Y), + }, + }) + if err != nil { + return "", nil, fmt.Errorf(`failed to marshal secp256k1 private key: %w`, err) + } + + return pmECPrivateKey, buf, nil +} + +func secp256k1EncodePublicKey(key *secp256k1.PublicKey) (string, []byte, error) { + asECDSA := key.ToECDSA() + + pkbuf := elliptic.Marshal(asECDSA.Curve, asECDSA.X, asECDSA.Y) + + oidBuf, err := asn1.Marshal(secp256k1OID) + if err != nil { + return "", nil, fmt.Errorf(`failed to marshal oid in ASN.1 format`) + } + + buf, err := asn1.Marshal(secp256k1ASN1PublicKey{ + Algorithm: pkix.AlgorithmIdentifier{ + Algorithm: secp256k1OID, + Parameters: asn1.RawValue{ + FullBytes: oidBuf, + }, + }, + BitString: asn1.BitString{ + Bytes: pkbuf, + BitLength: 8 * len(pkbuf), + }, + }) + if err != nil { + return "", nil, fmt.Errorf(`failed to marshal secp256k1 public key: %w`, err) + } + + return pmPublicKey, buf, nil +} + +func secp256k1FromRaw(nextKFR KeyFromRawer, key interface{}) (Key, error) { + switch key := key.(type) { + case *secp256k1.PrivateKey: + return nextKFR.KeyFromRaw(key.ToECDSA()) + case *secp256k1.PublicKey: + return nextKFR.KeyFromRaw(key.ToECDSA()) + default: + return nextKFR.KeyFromRaw(key) + } +} + +func secp256k1Raw(nextRFK RawFromKeyer, key Key, raw interface{}) error { + // for secp256k1Raw keys, you can either create a ecdsa.* key or a + // secp256k1.* key. + switch raw := raw.(type) { + case *secp256k1.PrivateKey: + // we first get a ecdsa.PrivateKey, then convert it to secp256k1.PrivateKey + var ecdsaKey ecdsa.PrivateKey + if err := key.Raw(&ecdsaKey); err != nil { + return fmt.Errorf(`failed to convert JWK into raw ecdsa.PrivateKey: %w`, err) + } + // Make sure the curve is secp256k1 + if ecdsaKey.Curve.Params().Name != secp256k1.S256().Params().Name { + return fmt.Errorf(`invalid curve for secp256k1: %s`, ecdsaKey.Curve.Params().Name) + } + return blackmagic.AssignIfCompatible(raw, secp256k1.PrivKeyFromBytes(ecdsaKey.D.Bytes())) + case *secp256k1.PublicKey: + // we first get a ecdsa.PublicKey, then convert it to secp256k1.PublicKey + var ecdsaKey ecdsa.PublicKey + if err := key.Raw(&ecdsaKey); err != nil { + return fmt.Errorf(`failed to convert JWK into raw ecdsa.PublicKey: %w`, err) + } + // Make sure the curve is secp256k1 + if ecdsaKey.Curve.Params().Name != secp256k1.S256().Params().Name { + return fmt.Errorf(`invalid curve for secp256k1: %s`, ecdsaKey.Curve.Params().Name) + } + var x secp256k1.FieldVal + var y secp256k1.FieldVal + x.SetByteSlice(ecdsaKey.X.Bytes()) + y.SetByteSlice(ecdsaKey.Y.Bytes()) + return blackmagic.AssignIfCompatible(raw, secp256k1.NewPublicKey(&x, &y)) + default: + return nextRFK.RawFromKey(key, raw) + } } diff --git a/jwk/es256k_test.go b/jwk/es256k_test.go index 728aba812..b9c224631 100644 --- a/jwk/es256k_test.go +++ b/jwk/es256k_test.go @@ -19,6 +19,22 @@ import ( func TestES256K(t *testing.T) { require.True(t, ecutil.IsAvailable(jwa.Secp256k1), `jwa.Secp256k1 should be available`) + + t.Run("PEM", func(t *testing.T) { + privkey, err := secp256k1.GeneratePrivateKey() + require.NoError(t, err, `secp256k1.GeneratePrivateKey should succeed`) + + key, err := jwk.FromRaw(privkey) + require.NoError(t, err, `jwk.FromRaw should succeed`) + + buf, err := jwk.Pem(key) + require.NoError(t, err, `jwk.Pem should succeed`) + + parsed, err := jwk.ParseKey(buf, jwk.WithPEM(true)) + require.NoError(t, err, `jwk.ParseKey should succeed`) + + require.Equal(t, key, parsed, `jwk.ParseKey should return the same key`) + }) } func BenchmarkKeyInstantiation(b *testing.B) { diff --git a/jwk/jwk.go b/jwk/jwk.go index d46e3e046..5ad3b47de 100644 --- a/jwk/jwk.go +++ b/jwk/jwk.go @@ -11,7 +11,6 @@ import ( "crypto/elliptic" "crypto/rsa" "crypto/x509" - "encoding/pem" "fmt" "io" "math/big" @@ -32,94 +31,6 @@ func bigIntToBytes(n *big.Int) ([]byte, error) { return n.Bytes(), nil } -// FromRaw creates a jwk.Key from the given key (RSA/ECDSA/symmetric keys). -// -// The constructor auto-detects the type of key to be instantiated -// based on the input type: -// -// - "crypto/rsa".PrivateKey and "crypto/rsa".PublicKey creates an RSA based key -// - "crypto/ecdsa".PrivateKey and "crypto/ecdsa".PublicKey creates an EC based key -// - "crypto/ed25519".PrivateKey and "crypto/ed25519".PublicKey creates an OKP based key -// - []byte creates a symmetric key -func FromRaw(key interface{}) (Key, error) { - if key == nil { - return nil, fmt.Errorf(`jwk.FromRaw requires a non-nil key`) - } - - var ptr interface{} - switch v := key.(type) { - case rsa.PrivateKey: - ptr = &v - case rsa.PublicKey: - ptr = &v - case ecdsa.PrivateKey: - ptr = &v - case ecdsa.PublicKey: - ptr = &v - default: - ptr = v - } - - switch rawKey := ptr.(type) { - case *rsa.PrivateKey: - k := newRSAPrivateKey() - if err := k.FromRaw(rawKey); err != nil { - return nil, fmt.Errorf(`failed to initialize %T from %T: %w`, k, rawKey, err) - } - return k, nil - case *rsa.PublicKey: - k := newRSAPublicKey() - if err := k.FromRaw(rawKey); err != nil { - return nil, fmt.Errorf(`failed to initialize %T from %T: %w`, k, rawKey, err) - } - return k, nil - case *ecdsa.PrivateKey: - k := newECDSAPrivateKey() - if err := k.FromRaw(rawKey); err != nil { - return nil, fmt.Errorf(`failed to initialize %T from %T: %w`, k, rawKey, err) - } - return k, nil - case *ecdsa.PublicKey: - k := newECDSAPublicKey() - if err := k.FromRaw(rawKey); err != nil { - return nil, fmt.Errorf(`failed to initialize %T from %T: %w`, k, rawKey, err) - } - return k, nil - case ed25519.PrivateKey: - k := newOKPPrivateKey() - if err := k.FromRaw(rawKey); err != nil { - return nil, fmt.Errorf(`failed to initialize %T from %T: %w`, k, rawKey, err) - } - return k, nil - case ed25519.PublicKey: - k := newOKPPublicKey() - if err := k.FromRaw(rawKey); err != nil { - return nil, fmt.Errorf(`failed to initialize %T from %T: %w`, k, rawKey, err) - } - return k, nil - case x25519.PrivateKey: - k := newOKPPrivateKey() - if err := k.FromRaw(rawKey); err != nil { - return nil, fmt.Errorf(`failed to initialize %T from %T: %w`, k, rawKey, err) - } - return k, nil - case x25519.PublicKey: - k := newOKPPublicKey() - if err := k.FromRaw(rawKey); err != nil { - return nil, fmt.Errorf(`failed to initialize %T from %T: %w`, k, rawKey, err) - } - return k, nil - case []byte: - k := newSymmetricKey() - if err := k.FromRaw(rawKey); err != nil { - return nil, fmt.Errorf(`failed to initialize %T from %T: %w`, k, rawKey, err) - } - return k, nil - default: - return nil, fmt.Errorf(`invalid key type '%T' for jwk.New`, key) - } -} - // PublicSetOf returns a new jwk.Set consisting of // public keys of the keys contained in the set. // @@ -188,7 +99,7 @@ func PublicRawKeyOf(v interface{}) (interface{}, error) { } var raw interface{} - if err := pubk.Raw(&raw); err != nil { + if err := Raw(pubk, &raw); err != nil { return nil, fmt.Errorf(`failed to obtain raw key from %T: %w`, pubk, err) } return raw, nil @@ -233,14 +144,6 @@ func PublicRawKeyOf(v interface{}) (interface{}, error) { } } -const ( - pmPrivateKey = `PRIVATE KEY` - pmPublicKey = `PUBLIC KEY` - pmECPrivateKey = `EC PRIVATE KEY` - pmRSAPublicKey = `RSA PUBLIC KEY` - pmRSAPrivateKey = `RSA PRIVATE KEY` -) - // EncodeX509 encodes the key into a byte sequence in ASN.1 DER format // suitable for to be PEM encoded. The key can be a jwk.Key or a raw key // instance, but it must be one of the types supported by `x509` package. @@ -254,9 +157,9 @@ const ( // The second return value is the encoded byte sequence. func EncodeX509(v interface{}) (string, []byte, error) { // we can't import jwk, so just use the interface - if key, ok := v.(interface{ Raw(interface{}) error }); ok { + if key, ok := v.(Key); ok { var raw interface{} - if err := key.Raw(&raw); err != nil { + if err := Raw(key, &raw); err != nil { return "", nil, fmt.Errorf(`failed to get raw key out of %T: %w`, key, err) } @@ -290,77 +193,6 @@ func EncodeX509(v interface{}) (string, []byte, error) { } } -// EncodePEM encodes the key into a PEM encoded ASN.1 DER format. -// The key can be a jwk.Key or a raw key instance, but it must be one of -// the types supported by `x509` package. -// -// Internally, it uses the same routine as `jwk.EncodeX509()`, and therefore -// the same caveats apply -func EncodePEM(v interface{}) ([]byte, error) { - typ, marshaled, err := EncodeX509(v) - if err != nil { - return nil, fmt.Errorf(`failed to encode key in x509: %w`, err) - } - - block := &pem.Block{ - Type: typ, - Bytes: marshaled, - } - return pem.EncodeToMemory(block), nil -} - -// DecodePEM decodes a key in PEM encoded ASN.1 DER format. -// and returns a raw key -func DecodePEM(src []byte) (interface{}, []byte, error) { - block, rest := pem.Decode(src) - if block == nil { - return nil, nil, fmt.Errorf(`failed to decode PEM data`) - } - - switch block.Type { - // Handle the semi-obvious cases - case pmRSAPrivateKey: - key, err := x509.ParsePKCS1PrivateKey(block.Bytes) - if err != nil { - return nil, nil, fmt.Errorf(`failed to parse PKCS1 private key: %w`, err) - } - return key, rest, nil - case pmRSAPublicKey: - key, err := x509.ParsePKCS1PublicKey(block.Bytes) - if err != nil { - return nil, nil, fmt.Errorf(`failed to parse PKCS1 public key: %w`, err) - } - return key, rest, nil - case pmECPrivateKey: - key, err := x509.ParseECPrivateKey(block.Bytes) - if err != nil { - return nil, nil, fmt.Errorf(`failed to parse EC private key: %w`, err) - } - return key, rest, nil - case pmPublicKey: - // XXX *could* return dsa.PublicKey - key, err := x509.ParsePKIXPublicKey(block.Bytes) - if err != nil { - return nil, nil, fmt.Errorf(`failed to parse PKIX public key: %w`, err) - } - return key, rest, nil - case pmPrivateKey: - key, err := x509.ParsePKCS8PrivateKey(block.Bytes) - if err != nil { - return nil, nil, fmt.Errorf(`failed to parse PKCS8 private key: %w`, err) - } - return key, rest, nil - case "CERTIFICATE": - cert, err := x509.ParseCertificate(block.Bytes) - if err != nil { - return nil, nil, fmt.Errorf(`failed to parse certificate: %w`, err) - } - return cert.PublicKey, rest, nil - default: - return nil, nil, fmt.Errorf(`invalid PEM block type %s`, block.Type) - } -} - // ParseRawKey is a combination of ParseKey and Raw. It parses a single JWK key, // and assigns the "raw" key to the given parameter. The key must either be // a pointer to an empty interface, or a pointer to the actual raw key type @@ -371,7 +203,7 @@ func ParseRawKey(data []byte, rawkey interface{}) error { return fmt.Errorf(`failed to parse key: %w`, err) } - if err := key.Raw(rawkey); err != nil { + if err := Raw(key, rawkey); err != nil { return fmt.Errorf(`failed to assign to raw key variable: %w`, err) } @@ -634,71 +466,6 @@ func cloneKey(src Key) (Key, error) { return dst, nil } -// Pem serializes the given jwk.Key in PEM encoded ASN.1 DER format, -// using either PKCS8 for private keys and PKIX for public keys. -// If you need to encode using PKCS1 or SEC1, you must do it yourself. -// -// # Argument must be of type jwk.Key or jwk.Set -// -// Currently only EC (including Ed25519) and RSA keys (and jwk.Set -// comprised of these key types) are supported. -func Pem(v interface{}) ([]byte, error) { - var set Set - switch v := v.(type) { - case Key: - set = NewSet() - if err := set.AddKey(v); err != nil { - return nil, fmt.Errorf(`failed to add key to set: %w`, err) - } - case Set: - set = v - default: - return nil, fmt.Errorf(`argument to Pem must be either jwk.Key or jwk.Set: %T`, v) - } - - var ret []byte - for i := 0; i < set.Len(); i++ { - key, _ := set.Key(i) - typ, buf, err := asnEncode(key) - if err != nil { - return nil, fmt.Errorf(`failed to encode content for key #%d: %w`, i, err) - } - - var block pem.Block - block.Type = typ - block.Bytes = buf - ret = append(ret, pem.EncodeToMemory(&block)...) - } - return ret, nil -} - -func asnEncode(key Key) (string, []byte, error) { - switch key := key.(type) { - case RSAPrivateKey, ECDSAPrivateKey, OKPPrivateKey: - var rawkey interface{} - if err := key.Raw(&rawkey); err != nil { - return "", nil, fmt.Errorf(`failed to get raw key from jwk.Key: %w`, err) - } - buf, err := x509.MarshalPKCS8PrivateKey(rawkey) - if err != nil { - return "", nil, fmt.Errorf(`failed to marshal PKCS8: %w`, err) - } - return pmPrivateKey, buf, nil - case RSAPublicKey, ECDSAPublicKey, OKPPublicKey: - var rawkey interface{} - if err := key.Raw(&rawkey); err != nil { - return "", nil, fmt.Errorf(`failed to get raw key from jwk.Key: %w`, err) - } - buf, err := x509.MarshalPKIXPublicKey(rawkey) - if err != nil { - return "", nil, fmt.Errorf(`failed to marshal PKIX: %w`, err) - } - return pmPublicKey, buf, nil - default: - return "", nil, fmt.Errorf(`unsupported key type %T`, key) - } -} - // RegisterCustomField allows users to specify that a private field // be decoded as an instance of the specified type. This option has // a global effect. @@ -728,6 +495,20 @@ func CurveForAlgorithm(alg jwa.EllipticCurveAlgorithm) (elliptic.Curve, bool) { return ecutil.CurveForAlgorithm(alg) } +// KeySpec is a specification for additional key types +// to be added to the jwk system. +// +// This mechanism should be considered experimental and subject +// to change, even between micro versions. If you are adding a +// new key type, please be ready to update your code when +// a new version of this library is released. +type KeySpec struct { + Curve elliptic.Curve + Algorithm jwa.EllipticCurveAlgorithm + RawFromKey ChainedRawFromKeyer + KeyFromRaw ChainedKeyFromRawer +} + // Equal compares two keys and returns true if they are equal. The comparison // is solely done against the thumbprints of k1 and k2. It is possible for keys // that have, for example, different key IDs, key usage, etc, to be considered equal. diff --git a/jwk/jwk_test.go b/jwk/jwk_test.go index af7dfd4d4..7d9f8efb2 100644 --- a/jwk/jwk_test.go +++ b/jwk/jwk_test.go @@ -279,7 +279,7 @@ func VerifyKey(t *testing.T, def map[string]keyDef) { typ := expectedRawKeyType(key) var rawkey interface{} - if !assert.NoError(t, key.Raw(&rawkey), `Raw() should succeed`) { + if !assert.NoError(t, jwk.Raw(key, &rawkey), `Raw() should succeed`) { return } if !assert.IsType(t, rawkey, typ, `raw key should be of this type`) { @@ -391,7 +391,7 @@ func TestParse(t *testing.T) { t.Helper() var irawkey interface{} - if !assert.NoError(t, key.Raw(&irawkey), `key.Raw(&interface) should ucceed`) { + if !assert.NoError(t, jwk.Raw(key, &irawkey), `jwk.Raw(key,&interface) should ucceed`) { return } @@ -399,19 +399,19 @@ func TestParse(t *testing.T) { switch k := key.(type) { case jwk.RSAPrivateKey: var rawkey rsa.PrivateKey - if !assert.NoError(t, key.Raw(&rawkey), `key.Raw(&rsa.PrivateKey) should succeed`) { + if !assert.NoError(t, jwk.Raw(key, &rawkey), `jwk.Raw(key,&rsa.PrivateKey) should succeed`) { return } crawkey = &rawkey case jwk.RSAPublicKey: var rawkey rsa.PublicKey - if !assert.NoError(t, key.Raw(&rawkey), `key.Raw(&rsa.PublicKey) should succeed`) { + if !assert.NoError(t, jwk.Raw(key, &rawkey), `jwk.Raw(key,&rsa.PublicKey) should succeed`) { return } crawkey = &rawkey case jwk.ECDSAPrivateKey: var rawkey ecdsa.PrivateKey - if !assert.NoError(t, key.Raw(&rawkey), `key.Raw(&ecdsa.PrivateKey) should succeed`) { + if !assert.NoError(t, jwk.Raw(key, &rawkey), `jwk.Raw(key,&ecdsa.PrivateKey) should succeed`) { return } crawkey = &rawkey @@ -419,13 +419,13 @@ func TestParse(t *testing.T) { switch k.Crv() { case jwa.Ed25519: var rawkey ed25519.PrivateKey - if !assert.NoError(t, key.Raw(&rawkey), `key.Raw(&ed25519.PrivateKey) should succeed`) { + if !assert.NoError(t, jwk.Raw(key, &rawkey), `jwk.Raw(key,&ed25519.PrivateKey) should succeed`) { return } crawkey = rawkey case jwa.X25519: var rawkey x25519.PrivateKey - if !assert.NoError(t, key.Raw(&rawkey), `key.Raw(&x25519.PrivateKey) should succeed`) { + if !assert.NoError(t, jwk.Raw(key, &rawkey), `jwk.Raw(key,&x25519.PrivateKey) should succeed`) { return } crawkey = rawkey @@ -439,13 +439,13 @@ func TestParse(t *testing.T) { switch k.Crv() { case jwa.Ed25519: var rawkey ed25519.PublicKey - if !assert.NoError(t, key.Raw(&rawkey), `key.Raw(&ed25519.PublicKey) should succeed`) { + if !assert.NoError(t, jwk.Raw(key, &rawkey), `jwk.Raw(key,&ed25519.PublicKey) should succeed`) { return } crawkey = rawkey case jwa.X25519: var rawkey x25519.PublicKey - if !assert.NoError(t, key.Raw(&rawkey), `key.Raw(&x25519.PublicKey) should succeed`) { + if !assert.NoError(t, jwk.Raw(key, &rawkey), `jwk.Raw(key,&x25519.PublicKey) should succeed`) { return } crawkey = rawkey @@ -934,7 +934,7 @@ func TestPublicKeyOf(t *testing.T) { // Get the raw key to compare var rawKey interface{} - if !assert.NoError(t, pubJwkKey.Raw(&rawKey), `pubJwkKey.Raw should succeed`) { + if !assert.NoError(t, jwk.Raw(pubJwkKey, &rawKey), `pubJwkKey.Raw should succeed`) { return } @@ -987,7 +987,7 @@ func TestPublicKeyOf(t *testing.T) { // Get the raw key to compare var rawKey interface{} - if !assert.NoError(t, setKey.Raw(&rawKey), `pubJwkKey.Raw should succeed`) { + if !assert.NoError(t, jwk.Raw(setKey, &rawKey), `pubJwkKey.Raw should succeed`) { return } @@ -1453,7 +1453,7 @@ c4wOvhbalcX0FqTM3mXCgMFRbibquhwdxbU= } var pubkey rsa.PublicKey - if !assert.NoError(t, key.Raw(&pubkey), `key.Raw should succeed`) { + if !assert.NoError(t, jwk.Raw(key, &pubkey), `key.Raw should succeed`) { return } @@ -2183,5 +2183,5 @@ func TestGH947(t *testing.T) { k, err := jwk.ParseKey(raw) require.NoError(t, err, `jwk.ParseKey should succeed`) var exported []byte - require.Error(t, k.Raw(&exported), `(okpkey).Raw with 0-length OKP key should fail`) + require.Error(t, jwk.Raw(k, &exported), `(okpkey).Raw with 0-length OKP key should fail`) } diff --git a/jwk/pem.go b/jwk/pem.go new file mode 100644 index 000000000..9637217a5 --- /dev/null +++ b/jwk/pem.go @@ -0,0 +1,289 @@ +package jwk + +import ( + "crypto/x509" + "encoding/pem" + "fmt" + "sync" +) + +const ( + pmPrivateKey = `PRIVATE KEY` + pmPublicKey = `PUBLIC KEY` + pmECPrivateKey = `EC PRIVATE KEY` + pmRSAPublicKey = `RSA PUBLIC KEY` + pmRSAPrivateKey = `RSA PRIVATE KEY` +) + +// ASN1Decoder decodes a given byte sequence into a key. +type ASN1Decoder interface { + ASN1Decode([]byte) (interface{}, []byte, error) +} + +type ChainedASN1Decoder interface { + Next(ASN1Decoder, []byte) (interface{}, []byte, error) +} + +type ChainedASN1DecodeFunc func(ASN1Decoder, []byte) (interface{}, []byte, error) + +func (fn ChainedASN1DecodeFunc) Next(n ASN1Decoder, src []byte) (interface{}, []byte, error) { + return fn(n, src) +} + +type chainedASN1Decoder struct { + mu sync.RWMutex + list []ChainedASN1Decoder +} + +func (c *chainedASN1Decoder) Add(d ChainedASN1Decoder) { + c.mu.Lock() + defer c.mu.Unlock() + c.list = append(c.list, d) +} + +func (c *chainedASN1Decoder) Next(src []byte) (interface{}, []byte, error) { + c.mu.RLock() + llist := len(c.list) + c.mu.RUnlock() + st := &chainedASN1DecoderCallState{parent: c, current: llist} + return st.ASN1Decode(src) +} + +type chainedASN1DecoderCallState struct { + current int + parent *chainedASN1Decoder +} + +func (st *chainedASN1DecoderCallState) ASN1Decode(src []byte) (interface{}, []byte, error) { + idx := st.current - 1 + + st.parent.mu.RLock() + defer st.parent.mu.RUnlock() + + llist := len(st.parent.list) + if idx < 0 || idx >= llist { + return nil, nil, fmt.Errorf(`failed to decode PEM data`) + } + + st.current = idx + + d := st.parent.list[idx] + return d.Next(st, src) +} + +func AddASN1Decoder(dec ChainedASN1Decoder) { + chainedASN1D.Add(dec) +} + +var chainedASN1D = &chainedASN1Decoder{ + list: []ChainedASN1Decoder{ChainedASN1DecodeFunc(asn1Decode)}, +} + +type ASN1Encoder interface { + ASN1Encode(Key) (string, []byte, error) +} + +type ChainedASN1Encoder interface { + Next(ASN1Encoder, Key) (string, []byte, error) +} + +type ChainedASN1EncodeFunc func(ASN1Encoder, Key) (string, []byte, error) + +func (fn ChainedASN1EncodeFunc) Next(n ASN1Encoder, key Key) (string, []byte, error) { + return fn(n, key) +} + +type chainedASN1Encoder struct { + mu sync.RWMutex + list []ChainedASN1Encoder +} + +func (c *chainedASN1Encoder) Add(e ChainedASN1Encoder) { + c.mu.Lock() + defer c.mu.Unlock() + c.list = append(c.list, e) +} + +func (c *chainedASN1Encoder) Next(key Key) (string, []byte, error) { + c.mu.RLock() + llist := len(c.list) + c.mu.RUnlock() + st := &chainedASN1EncoderCallState{parent: c, current: llist} + return st.ASN1Encode(key) +} + +type chainedASN1EncoderCallState struct { + current int + parent *chainedASN1Encoder +} + +func (st *chainedASN1EncoderCallState) ASN1Encode(key Key) (string, []byte, error) { + idx := st.current - 1 + + st.parent.mu.RLock() + defer st.parent.mu.RUnlock() + + llist := len(st.parent.list) + if idx < 0 || idx >= llist { + return "", nil, fmt.Errorf(`failed to encode to jwk.Key %T to PEM`, key) + } + + st.current = idx + + e := st.parent.list[idx] + return e.Next(st, key) +} + +var chainedASN1E = &chainedASN1Encoder{ + list: []ChainedASN1Encoder{ChainedASN1EncodeFunc(asn1Encode)}, +} + +// AddASN1Encoder allows users +func AddASN1Encoder(enc ChainedASN1Encoder) { + chainedASN1E.Add(enc) +} + +// Encodes a Key in DER ASN.1 format. Can handle RSA, EC, OKP keys. +func EncodeASN1(key Key) (string, []byte, error) { + return chainedASN1E.Next(key) +} + +func asn1Encode(_ ASN1Encoder, key Key) (string, []byte, error) { + switch key := key.(type) { + case RSAPrivateKey, ECDSAPrivateKey, OKPPrivateKey: + var rawkey interface{} + if err := Raw(key, &rawkey); err != nil { + return "", nil, fmt.Errorf(`failed to get raw key from jwk.Key: %w`, err) + } + buf, err := x509.MarshalPKCS8PrivateKey(rawkey) + if err != nil { + return "", nil, fmt.Errorf(`failed to marshal PKCS8: %w`, err) + } + return pmPrivateKey, buf, nil + case RSAPublicKey, ECDSAPublicKey, OKPPublicKey: + var rawkey interface{} + if err := Raw(key, &rawkey); err != nil { + return "", nil, fmt.Errorf(`failed to get raw key from jwk.Key: %w`, err) + } + buf, err := x509.MarshalPKIXPublicKey(rawkey) + if err != nil { + return "", nil, fmt.Errorf(`failed to marshal PKIX: %w`, err) + } + return pmPublicKey, buf, nil + default: + return "", nil, fmt.Errorf(`encoding key to ASN.1 failed: unsupported key type %T`, key) + } +} + +// Pem serializes the given jwk.Key in PEM encoded ASN.1 DER format, +// using either PKCS8 for private keys and PKIX for public keys. +// If you need to encode using PKCS1 or SEC1, you must do it yourself. +// +// # Argument must be of type jwk.Key or jwk.Set +// +// Currently only EC (including Ed25519) and RSA keys (and jwk.Set +// comprised of these key types) are supported. +func Pem(v interface{}) ([]byte, error) { + var set Set + switch v := v.(type) { + case Key: + set = NewSet() + if err := set.AddKey(v); err != nil { + return nil, fmt.Errorf(`failed to add key to set: %w`, err) + } + case Set: + set = v + default: + return nil, fmt.Errorf(`argument to Pem must be either jwk.Key or jwk.Set: %T`, v) + } + + var ret []byte + for i := 0; i < set.Len(); i++ { + key, _ := set.Key(i) + typ, buf, err := chainedASN1E.Next(key) + if err != nil { + return nil, fmt.Errorf(`failed to encode content for key #%d: %w`, i, err) + } + + var block pem.Block + block.Type = typ + block.Bytes = buf + ret = append(ret, pem.EncodeToMemory(&block)...) + } + return ret, nil +} + +// DecodePEM decodes a key in PEM encoded ASN.1 DER format. +// and returns a raw key +func DecodePEM(src []byte) (interface{}, []byte, error) { + return chainedASN1D.Next(src) +} + +func asn1Decode(_ ASN1Decoder, src []byte) (interface{}, []byte, error) { + block, rest := pem.Decode(src) + if block == nil { + return nil, nil, fmt.Errorf(`failed to decode PEM data`) + } + + switch block.Type { + // Handle the semi-obvious cases + case pmRSAPrivateKey: + key, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, nil, fmt.Errorf(`failed to parse PKCS1 private key: %w`, err) + } + return key, rest, nil + case pmRSAPublicKey: + key, err := x509.ParsePKCS1PublicKey(block.Bytes) + if err != nil { + return nil, nil, fmt.Errorf(`failed to parse PKCS1 public key: %w`, err) + } + return key, rest, nil + case pmECPrivateKey: + key, err := x509.ParseECPrivateKey(block.Bytes) + if err != nil { + return nil, nil, fmt.Errorf(`failed to parse EC private key: %w`, err) + } + return key, rest, nil + case pmPublicKey: + // XXX *could* return dsa.PublicKey + key, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return nil, nil, fmt.Errorf(`failed to parse PKIX public key: %w`, err) + } + return key, rest, nil + case pmPrivateKey: + key, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, nil, fmt.Errorf(`failed to parse PKCS8 private key: %w`, err) + } + return key, rest, nil + case "CERTIFICATE": + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, nil, fmt.Errorf(`failed to parse certificate: %w`, err) + } + return cert.PublicKey, rest, nil + default: + return nil, nil, fmt.Errorf(`invalid PEM block type %s`, block.Type) + } +} + +// EncodePEM encodes the key into a PEM encoded ASN.1 DER format. +// The key can be a jwk.Key or a raw key instance, but it must be one of +// the types supported by `x509` package. +// +// Internally, it uses the same routine as `jwk.EncodeX509()`, and therefore +// the same caveats apply +func EncodePEM(v interface{}) ([]byte, error) { + typ, marshaled, err := EncodeX509(v) + if err != nil { + return nil, fmt.Errorf(`failed to encode key in x509: %w`, err) + } + + block := &pem.Block{ + Type: typ, + Bytes: marshaled, + } + return pem.EncodeToMemory(block), nil +} diff --git a/jwk/rsa.go b/jwk/rsa.go index 5de6b6358..d37b0fcd6 100644 --- a/jwk/rsa.go +++ b/jwk/rsa.go @@ -129,7 +129,7 @@ func (k *rsaPrivateKey) Raw(v interface{}) error { pubk := newRSAPublicKey() pubk.n = k.n pubk.e = k.e - if err := pubk.Raw(&key.PublicKey); err != nil { + if err := Raw(pubk, &key.PublicKey); err != nil { return fmt.Errorf(`failed to materialize RSA public key: %w`, err) } @@ -208,7 +208,7 @@ func (k rsaPrivateKey) Thumbprint(hash crypto.Hash) ([]byte, error) { defer k.mu.RUnlock() var key rsa.PrivateKey - if err := k.Raw(&key); err != nil { + if err := Raw(&k, &key); err != nil { return nil, fmt.Errorf(`failed to materialize RSA private key: %w`, err) } return rsaThumbprint(hash, &key.PublicKey) @@ -219,7 +219,7 @@ func (k rsaPublicKey) Thumbprint(hash crypto.Hash) ([]byte, error) { defer k.mu.RUnlock() var key rsa.PublicKey - if err := k.Raw(&key); err != nil { + if err := Raw(&k, &key); err != nil { return nil, fmt.Errorf(`failed to materialize RSA public key: %w`, err) } return rsaThumbprint(hash, &key) diff --git a/jwk/symmetric.go b/jwk/symmetric.go index d2498e334..9f86403c2 100644 --- a/jwk/symmetric.go +++ b/jwk/symmetric.go @@ -35,7 +35,7 @@ func (k *symmetricKey) Thumbprint(hash crypto.Hash) ([]byte, error) { k.mu.RLock() defer k.mu.RUnlock() var octets []byte - if err := k.Raw(&octets); err != nil { + if err := Raw(k, &octets); err != nil { return nil, fmt.Errorf(`failed to materialize symmetric key: %w`, err) } diff --git a/jws/jws_test.go b/jws/jws_test.go index 53d1803b1..763095ece 100644 --- a/jws/jws_test.go +++ b/jws/jws_test.go @@ -511,7 +511,7 @@ func TestEncode(t *testing.T) { t.Fatal("Failed to parse JWK") } var key interface{} - if !assert.NoError(t, jwkKey.Raw(&key), `jwk.Raw should succeed`) { + if !assert.NoError(t, jwk.Raw(jwkKey, &key), `jwk.Raw should succeed`) { return } var jwsCompact []byte @@ -583,7 +583,7 @@ func TestEncode(t *testing.T) { } var rawkey rsa.PrivateKey - if !assert.NoError(t, privkey.Raw(&rawkey), `obtaining raw key should succeed`) { + if !assert.NoError(t, jwk.Raw(privkey, &rawkey), `obtaining raw key should succeed`) { return } @@ -660,7 +660,7 @@ func TestEncode(t *testing.T) { } var rawkey ecdsa.PrivateKey - if !assert.NoError(t, privkey.Raw(&rawkey), `obtaining raw key should succeed`) { + if !assert.NoError(t, jwk.Raw(privkey, &rawkey), `obtaining raw key should succeed`) { return } @@ -745,7 +745,7 @@ func TestEncode(t *testing.T) { } var rawkey ed25519.PrivateKey - if !assert.NoError(t, privkey.Raw(&rawkey), `obtaining raw key should succeed`) { + if !assert.NoError(t, jwk.Raw(privkey, &rawkey), `obtaining raw key should succeed`) { return } @@ -1024,7 +1024,7 @@ func TestDecode_ES384Compact_NoSigTrim(t *testing.T) { } var rawkey ecdsa.PublicKey - if !assert.NoError(t, pubkey.Raw(&rawkey), `obtaining raw key should succeed`) { + if !assert.NoError(t, jwk.Raw(pubkey, &rawkey), `obtaining raw key should succeed`) { return } diff --git a/jwx_test.go b/jwx_test.go index b74243404..5111e49ab 100644 --- a/jwx_test.go +++ b/jwx_test.go @@ -167,7 +167,7 @@ func TestJoseCompatibility(t *testing.T) { } } - if !assert.NoError(t, webkey.Raw(&tc.Raw), `jwk.Raw should succeed`) { + if !assert.NoError(t, jwk.Raw(webkey, &tc.Raw), `jwk.Raw should succeed`) { return } }) @@ -298,17 +298,17 @@ func joseInteropTest(ctx context.Context, spec interopTest, t *testing.T) { switch spec.alg { case jwa.RSA1_5, jwa.RSA_OAEP, jwa.RSA_OAEP_256: var rawkey rsa.PrivateKey - if !assert.NoError(t, jwxJwk.Raw(&rawkey), `jwk.Raw should succeed`) { + if !assert.NoError(t, jwk.Raw(jwxJwk, &rawkey), `jwk.Raw should succeed`) { return } case jwa.ECDH_ES, jwa.ECDH_ES_A128KW, jwa.ECDH_ES_A192KW, jwa.ECDH_ES_A256KW: var rawkey ecdsa.PrivateKey - if !assert.NoError(t, jwxJwk.Raw(&rawkey), `jwk.Raw should succeed`) { + if !assert.NoError(t, jwk.Raw(jwxJwk, &rawkey), `jwk.Raw should succeed`) { return } default: var rawkey []byte - if !assert.NoError(t, jwxJwk.Raw(&rawkey), `jwk.Raw should succeed`) { + if !assert.NoError(t, jwk.Raw(jwxJwk, &rawkey), `jwk.Raw should succeed`) { return } }