Skip to content

Commit

Permalink
Add taint upstream authority (#5340)
Browse files Browse the repository at this point in the history
* POC to add taint upstream authority

Signed-off-by: Marcos Yacob <[email protected]>

* Propagate taining and revoke into downstream server, updating upstream
client

Signed-off-by: Marcos Yacob <[email protected]>

* start working in unit tests for 'common/coretypes/jwtkey'

Signed-off-by: Marcos Yacob <[email protected]>

* refactor x509certificate package

Signed-off-by: Marcos Yacob <[email protected]>

* Add update test in coretypes bundle

Signed-off-by: Marcos Yacob <[email protected]>

* Add more tests for api bundle

Signed-off-by: Marcos Yacob <[email protected]>

* Add tests for local authority service

Signed-off-by: Marcos Yacob <[email protected]>

* more test

Signed-off-by: Marcos Yacob <[email protected]>

* more tests

Signed-off-by: Marcos Yacob <[email protected]>

* more

Signed-off-by: Marcos Yacob <[email protected]>

* more test

Signed-off-by: Marcos Yacob <[email protected]>

* resolve some lints

Signed-off-by: Marcos Yacob <[email protected]>

* more

Signed-off-by: Marcos Yacob <[email protected]>

* upgrade spire-api-sdk, and resolve lint

Signed-off-by: Marcos Yacob <[email protected]>

* Resolve lint...

Signed-off-by: Marcos Yacob <[email protected]>

* PR changes

Signed-off-by: Marcos Yacob <[email protected]>

---------

Signed-off-by: Marcos Yacob <[email protected]>
  • Loading branch information
MarcosDY authored Aug 16, 2024
1 parent 0738b82 commit e3dac17
Show file tree
Hide file tree
Showing 59 changed files with 1,904 additions and 1,057 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ require (
github.com/sigstore/sigstore v1.8.8
github.com/sirupsen/logrus v1.9.3
github.com/spiffe/go-spiffe/v2 v2.3.0
github.com/spiffe/spire-api-sdk v1.2.5-0.20240722174251-0116a7186c35
github.com/spiffe/spire-api-sdk v1.2.5-0.20240807182354-18e423ce2c1c
github.com/spiffe/spire-plugin-sdk v1.4.4-0.20230721151831-bf67dde4721d
github.com/stretchr/testify v1.9.0
github.com/uber-go/tally/v4 v4.1.16
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1426,8 +1426,8 @@ github.com/spf13/viper v1.19.0/go.mod h1:GQUN9bilAbhU/jgc1bKs99f/suXKeUMct8Adx5+
github.com/spiffe/go-spiffe/v2 v2.1.6/go.mod h1:eVDqm9xFvyqao6C+eQensb9ZPkyNEeaUbqbBpOhBnNk=
github.com/spiffe/go-spiffe/v2 v2.3.0 h1:g2jYNb/PDMB8I7mBGL2Zuq/Ur6hUhoroxGQFyD6tTj8=
github.com/spiffe/go-spiffe/v2 v2.3.0/go.mod h1:Oxsaio7DBgSNqhAO9i/9tLClaVlfRok7zvJnTV8ZyIY=
github.com/spiffe/spire-api-sdk v1.2.5-0.20240722174251-0116a7186c35 h1:Ah7jJvfjw2fYXtSJF69lWokspl5Vhge0yiSi/mFhzhM=
github.com/spiffe/spire-api-sdk v1.2.5-0.20240722174251-0116a7186c35/go.mod h1:4uuhFlN6KBWjACRP3xXwrOTNnvaLp1zJs8Lribtr4fI=
github.com/spiffe/spire-api-sdk v1.2.5-0.20240807182354-18e423ce2c1c h1:lK/B2paDUiqbngUGsLxDBmNX/BsG2yKxS8W/iGT+x2c=
github.com/spiffe/spire-api-sdk v1.2.5-0.20240807182354-18e423ce2c1c/go.mod h1:4uuhFlN6KBWjACRP3xXwrOTNnvaLp1zJs8Lribtr4fI=
github.com/spiffe/spire-plugin-sdk v1.4.4-0.20230721151831-bf67dde4721d h1:LCRQGU6vOqKLfRrG+GJQrwMwDILcAddAEIf4/1PaSVc=
github.com/spiffe/spire-plugin-sdk v1.4.4-0.20230721151831-bf67dde4721d/go.mod h1:GA6o2PVLwyJdevT6KKt5ZXCY/ziAPna13y/seGk49Ik=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
Expand Down
26 changes: 20 additions & 6 deletions pkg/common/coretypes/bundle/bundle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,15 @@ MWnIPs59/JF8AiBeKSM/rkL2igQchDTvlJJWsyk9YL8UZI/XfZO7907TWA==
pkixBytes, _ = x509.MarshalPKIXPublicKey(publicKey)
apiJWTAuthoritiesGood = []*apitypes.JWTKey{
{KeyId: "ID", PublicKey: pkixBytes, ExpiresAt: expiresAt.Unix()},
{KeyId: "IDTainted", PublicKey: pkixBytes, ExpiresAt: expiresAt.Unix(), Tainted: true},
}
apiJWTAuthoritiesBad = []*apitypes.JWTKey{
{PublicKey: pkixBytes, ExpiresAt: expiresAt.Unix()},
{PublicKey: pkixBytes, ExpiresAt: expiresAt.Unix()},
}
apiX509AuthoritiesGood = []*apitypes.X509Certificate{
{Asn1: root.Raw},
{Asn1: root.Raw, Tainted: true},
}
apiX509AuthoritiesBad = []*apitypes.X509Certificate{
{Asn1: []byte("malformed")},
Expand Down Expand Up @@ -71,9 +74,12 @@ MWnIPs59/JF8AiBeKSM/rkL2igQchDTvlJJWsyk9YL8UZI/XfZO7907TWA==
SequenceNumber: 2,
}
commonInvalidTD = &common.Bundle{
TrustDomainId: "not a trustdomain id",
RootCas: []*common.Certificate{{DerBytes: root.Raw}},
JwtSigningKeys: []*common.PublicKey{{Kid: "ID", PkixBytes: pkixBytes, NotAfter: expiresAt.Unix()}},
TrustDomainId: "not a trustdomain id",
RootCas: []*common.Certificate{{DerBytes: root.Raw}},
JwtSigningKeys: []*common.PublicKey{
{Kid: "ID", PkixBytes: pkixBytes, NotAfter: expiresAt.Unix()},
{Kid: "IDTainted", PkixBytes: pkixBytes, NotAfter: expiresAt.Unix(), TaintedKey: true},
},
RefreshHint: 1,
SequenceNumber: 2,
}
Expand All @@ -93,12 +99,14 @@ MWnIPs59/JF8AiBeKSM/rkL2igQchDTvlJJWsyk9YL8UZI/XfZO7907TWA==
}
pluginJWTAuthoritiesGood = []*plugintypes.JWTKey{
{KeyId: "ID", PublicKey: pkixBytes, ExpiresAt: expiresAt.Unix()},
{KeyId: "IDTainted", PublicKey: pkixBytes, ExpiresAt: expiresAt.Unix(), Tainted: true},
}
pluginJWTAuthoritiesBad = []*plugintypes.JWTKey{
{PublicKey: pkixBytes, ExpiresAt: expiresAt.Unix()},
}
pluginX509AuthoritiesGood = []*plugintypes.X509Certificate{
{Asn1: root.Raw},
{Asn1: root.Raw, Tainted: true},
}
pluginX509AuthoritiesBad = []*plugintypes.X509Certificate{
{Asn1: []byte("malformed")},
Expand Down Expand Up @@ -132,9 +140,15 @@ MWnIPs59/JF8AiBeKSM/rkL2igQchDTvlJJWsyk9YL8UZI/XfZO7907TWA==
SequenceNumber: 2,
}
commonGood = &common.Bundle{
TrustDomainId: "spiffe://example.org",
RootCas: []*common.Certificate{{DerBytes: root.Raw}},
JwtSigningKeys: []*common.PublicKey{{Kid: "ID", PkixBytes: pkixBytes, NotAfter: expiresAt.Unix()}},
TrustDomainId: "spiffe://example.org",
RootCas: []*common.Certificate{
{DerBytes: root.Raw},
{DerBytes: root.Raw, TaintedKey: true},
},
JwtSigningKeys: []*common.PublicKey{
{Kid: "ID", PkixBytes: pkixBytes, NotAfter: expiresAt.Unix()},
{Kid: "IDTainted", PkixBytes: pkixBytes, NotAfter: expiresAt.Unix(), TaintedKey: true},
},
RefreshHint: 1,
SequenceNumber: 2,
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/common/coretypes/jwtkey/apitypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
)

func ToAPIProto(jwtKey JWTKey) (*apitypes.JWTKey, error) {
id, publicKey, expiresAt, err := toProtoFields(jwtKey)
id, publicKey, expiresAt, tainted, err := toProtoFields(jwtKey)
if err != nil {
return nil, err
}
Expand All @@ -15,6 +15,7 @@ func ToAPIProto(jwtKey JWTKey) (*apitypes.JWTKey, error) {
KeyId: id,
PublicKey: publicKey,
ExpiresAt: expiresAt,
Tainted: tainted,
}, nil
}

Expand All @@ -23,7 +24,7 @@ func ToAPIFromPluginProto(pb *plugintypes.JWTKey) (*apitypes.JWTKey, error) {
return nil, nil
}

jwtKey, err := fromProtoFields(pb.KeyId, pb.PublicKey, pb.ExpiresAt)
jwtKey, err := fromProtoFields(pb.KeyId, pb.PublicKey, pb.ExpiresAt, pb.Tainted)
if err != nil {
return nil, err
}
Expand Down
11 changes: 6 additions & 5 deletions pkg/common/coretypes/jwtkey/commontypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
)

func FromCommonProto(pb *common.PublicKey) (JWTKey, error) {
return fromProtoFields(pb.Kid, pb.PkixBytes, pb.NotAfter)
return fromProtoFields(pb.Kid, pb.PkixBytes, pb.NotAfter, pb.TaintedKey)
}

func FromCommonProtos(pbs []*common.PublicKey) ([]JWTKey, error) {
Expand All @@ -25,14 +25,15 @@ func FromCommonProtos(pbs []*common.PublicKey) ([]JWTKey, error) {
}

func ToCommonProto(jwtKey JWTKey) (*common.PublicKey, error) {
id, publicKey, expiresAt, err := toProtoFields(jwtKey)
id, publicKey, expiresAt, tainted, err := toProtoFields(jwtKey)
if err != nil {
return nil, err
}
return &common.PublicKey{
Kid: id,
PkixBytes: publicKey,
NotAfter: expiresAt,
Kid: id,
PkixBytes: publicKey,
NotAfter: expiresAt,
TaintedKey: tainted,
}, nil
}

Expand Down
17 changes: 9 additions & 8 deletions pkg/common/coretypes/jwtkey/jwtkey.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,30 @@ type JWTKey struct {
ID string
PublicKey crypto.PublicKey
ExpiresAt time.Time
Tainted bool
}

func toProtoFields(jwtKey JWTKey) (string, []byte, int64, error) {
func toProtoFields(jwtKey JWTKey) (id string, publicKey []byte, expiresAt int64, tainted bool, err error) {
if jwtKey.ID == "" {
return "", nil, 0, errors.New("missing key ID for JWT key")
return "", nil, 0, false, errors.New("missing key ID for JWT key")
}

if jwtKey.PublicKey == nil {
return "", nil, 0, fmt.Errorf("missing public key for JWT key %q", jwtKey.ID)
return "", nil, 0, false, fmt.Errorf("missing public key for JWT key %q", jwtKey.ID)
}
publicKey, err := x509.MarshalPKIXPublicKey(jwtKey.PublicKey)
publicKey, err = x509.MarshalPKIXPublicKey(jwtKey.PublicKey)
if err != nil {
return "", nil, 0, fmt.Errorf("failed to marshal public key for JWT key %q: %w", jwtKey.ID, err)
return "", nil, 0, false, fmt.Errorf("failed to marshal public key for JWT key %q: %w", jwtKey.ID, err)
}

var expiresAt int64
if !jwtKey.ExpiresAt.IsZero() {
expiresAt = jwtKey.ExpiresAt.Unix()
}

return jwtKey.ID, publicKey, expiresAt, nil
return jwtKey.ID, publicKey, expiresAt, jwtKey.Tainted, nil
}

func fromProtoFields(keyID string, publicKeyPKIX []byte, expiresAtUnix int64) (JWTKey, error) {
func fromProtoFields(keyID string, publicKeyPKIX []byte, expiresAtUnix int64, tainted bool) (JWTKey, error) {
if keyID == "" {
return JWTKey{}, errors.New("missing key ID for JWT key")
}
Expand All @@ -57,5 +57,6 @@ func fromProtoFields(keyID string, publicKeyPKIX []byte, expiresAtUnix int64) (J
ID: keyID,
PublicKey: publicKey,
ExpiresAt: expiresAt,
Tainted: tainted,
}, nil
}
34 changes: 27 additions & 7 deletions pkg/common/coretypes/jwtkey/jwtkey_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,25 @@ var (
pkixBytes, _ = x509.MarshalPKIXPublicKey(publicKey)
junk = []byte("JUNK")
jwtKeyGood = jwtkey.JWTKey{ID: "ID", PublicKey: publicKey, ExpiresAt: expiresAt}
jwtKeyTaintedGood = jwtkey.JWTKey{ID: "ID", PublicKey: publicKey, ExpiresAt: expiresAt, Tainted: true}
jwtKeyNoKeyID = jwtkey.JWTKey{PublicKey: publicKey, ExpiresAt: expiresAt}
jwtKeyNoPublicKey = jwtkey.JWTKey{ID: "ID", ExpiresAt: expiresAt}
jwtKeyBadPublicKey = jwtkey.JWTKey{ID: "ID", PublicKey: junk, ExpiresAt: expiresAt}
jwtKeyNoExpiresAt = jwtkey.JWTKey{ID: "ID", PublicKey: publicKey}
pluginGood = &plugintypes.JWTKey{KeyId: "ID", PublicKey: pkixBytes, ExpiresAt: expiresAt.Unix()}
pluginTaintedGood = &plugintypes.JWTKey{KeyId: "ID", PublicKey: pkixBytes, ExpiresAt: expiresAt.Unix(), Tainted: true}
pluginNoKeyID = &plugintypes.JWTKey{PublicKey: pkixBytes, ExpiresAt: expiresAt.Unix()}
pluginNoPublicKey = &plugintypes.JWTKey{KeyId: "ID", ExpiresAt: expiresAt.Unix()}
pluginBadPublicKey = &plugintypes.JWTKey{KeyId: "ID", PublicKey: junk, ExpiresAt: expiresAt.Unix()}
pluginNoExpiresAt = &plugintypes.JWTKey{KeyId: "ID", PublicKey: pkixBytes}
commonGood = &common.PublicKey{Kid: "ID", PkixBytes: pkixBytes, NotAfter: expiresAt.Unix()}
commonTaintedGood = &common.PublicKey{Kid: "ID", PkixBytes: pkixBytes, NotAfter: expiresAt.Unix(), TaintedKey: true}
commonNoKeyID = &common.PublicKey{PkixBytes: pkixBytes, NotAfter: expiresAt.Unix()}
commonNoPublicKey = &common.PublicKey{Kid: "ID", NotAfter: expiresAt.Unix()}
commonBadPublicKey = &common.PublicKey{Kid: "ID", PkixBytes: junk, NotAfter: expiresAt.Unix()}
commonNoExpiresAt = &common.PublicKey{Kid: "ID", PkixBytes: pkixBytes}
apiGood = &apitypes.JWTKey{KeyId: "ID", PublicKey: pkixBytes, ExpiresAt: expiresAt.Unix()}
apiTaintedGood = &apitypes.JWTKey{KeyId: "ID", PublicKey: pkixBytes, ExpiresAt: expiresAt.Unix(), Tainted: true}
apiNoKeyID = &apitypes.JWTKey{PublicKey: pkixBytes, ExpiresAt: expiresAt.Unix()}
apiNoPublicKey = &apitypes.JWTKey{KeyId: "ID", ExpiresAt: expiresAt.Unix()}
apiBadPublicKey = &apitypes.JWTKey{KeyId: "ID", PublicKey: junk, ExpiresAt: expiresAt.Unix()}
Expand All @@ -59,6 +63,7 @@ func TestFromCommonProto(t *testing.T) {
}

assertOK(t, commonGood, jwtKeyGood)
assertOK(t, commonTaintedGood, jwtKeyTaintedGood)
assertFail(t, commonNoKeyID, "missing key ID for JWT key")
assertFail(t, commonNoPublicKey, `missing public key for JWT key "ID"`)
assertFail(t, commonBadPublicKey, `failed to unmarshal public key for JWT key "ID": `)
Expand All @@ -80,7 +85,8 @@ func TestFromCommonProtos(t *testing.T) {
assert.Panics(t, func() { jwtkey.RequireFromCommonProtos(in) })
}

assertOK(t, []*common.PublicKey{commonGood}, []jwtkey.JWTKey{jwtKeyGood})
assertOK(t, []*common.PublicKey{commonGood, commonTaintedGood},
[]jwtkey.JWTKey{jwtKeyGood, jwtKeyTaintedGood})
assertFail(t, []*common.PublicKey{commonNoKeyID}, "missing key ID for JWT key")
assertOK(t, nil, nil)
}
Expand All @@ -101,6 +107,7 @@ func TestToCommonProto(t *testing.T) {
}

assertOK(t, jwtKeyGood, commonGood)
assertOK(t, jwtKeyTaintedGood, commonTaintedGood)
assertFail(t, jwtKeyNoKeyID, "missing key ID for JWT key")
assertFail(t, jwtKeyNoPublicKey, `missing public key for JWT key "ID"`)
assertFail(t, jwtKeyBadPublicKey, `failed to marshal public key for JWT key "ID": `)
Expand All @@ -122,7 +129,8 @@ func TestToCommonProtos(t *testing.T) {
assert.Panics(t, func() { jwtkey.RequireToCommonProtos(in) })
}

assertOK(t, []jwtkey.JWTKey{jwtKeyGood}, []*common.PublicKey{commonGood})
assertOK(t, []jwtkey.JWTKey{jwtKeyGood, jwtKeyTaintedGood},
[]*common.PublicKey{commonGood, commonTaintedGood})
assertFail(t, []jwtkey.JWTKey{jwtKeyNoKeyID}, "missing key ID for JWT key")
assertOK(t, nil, nil)
}
Expand All @@ -143,6 +151,7 @@ func TestFromPluginProto(t *testing.T) {
}

assertOK(t, pluginGood, jwtKeyGood)
assertOK(t, pluginTaintedGood, jwtKeyTaintedGood)
assertFail(t, pluginNoKeyID, "missing key ID for JWT key")
assertFail(t, pluginNoPublicKey, `missing public key for JWT key "ID"`)
assertFail(t, pluginBadPublicKey, `failed to unmarshal public key for JWT key "ID": `)
Expand All @@ -164,7 +173,8 @@ func TestFromPluginProtos(t *testing.T) {
assert.Panics(t, func() { jwtkey.RequireFromPluginProtos(in) })
}

assertOK(t, []*plugintypes.JWTKey{pluginGood}, []jwtkey.JWTKey{jwtKeyGood})
assertOK(t, []*plugintypes.JWTKey{pluginGood, pluginTaintedGood},
[]jwtkey.JWTKey{jwtKeyGood, jwtKeyTaintedGood})
assertFail(t, []*plugintypes.JWTKey{pluginNoKeyID}, "missing key ID for JWT key")
assertOK(t, nil, nil)
}
Expand All @@ -185,6 +195,7 @@ func TestToPluginProto(t *testing.T) {
}

assertOK(t, jwtKeyGood, pluginGood)
assertOK(t, jwtKeyTaintedGood, pluginTaintedGood)
assertFail(t, jwtKeyNoKeyID, "missing key ID for JWT key")
assertFail(t, jwtKeyNoPublicKey, `missing public key for JWT key "ID"`)
assertFail(t, jwtKeyBadPublicKey, `failed to marshal public key for JWT key "ID": `)
Expand All @@ -206,7 +217,8 @@ func TestToPluginProtos(t *testing.T) {
assert.Panics(t, func() { jwtkey.RequireToPluginProtos(in) })
}

assertOK(t, []jwtkey.JWTKey{jwtKeyGood}, []*plugintypes.JWTKey{pluginGood})
assertOK(t, []jwtkey.JWTKey{jwtKeyGood, jwtKeyTaintedGood},
[]*plugintypes.JWTKey{pluginGood, pluginTaintedGood})
assertFail(t, []jwtkey.JWTKey{jwtKeyNoKeyID}, "missing key ID for JWT key")
assertOK(t, nil, nil)
}
Expand All @@ -227,6 +239,7 @@ func TestToCommonFromPluginProto(t *testing.T) {
}

assertOK(t, pluginGood, commonGood)
assertOK(t, pluginTaintedGood, commonTaintedGood)
assertFail(t, pluginNoKeyID, "missing key ID for JWT key")
assertFail(t, pluginNoPublicKey, `missing public key for JWT key "ID"`)
assertFail(t, pluginBadPublicKey, `failed to unmarshal public key for JWT key "ID": `)
Expand All @@ -248,7 +261,8 @@ func TestToCommonFromPluginProtos(t *testing.T) {
assert.Panics(t, func() { jwtkey.RequireToCommonFromPluginProtos(in) })
}

assertOK(t, []*plugintypes.JWTKey{pluginGood}, []*common.PublicKey{commonGood})
assertOK(t, []*plugintypes.JWTKey{pluginGood, pluginTaintedGood},
[]*common.PublicKey{commonGood, commonTaintedGood})
assertFail(t, []*plugintypes.JWTKey{pluginNoKeyID}, "missing key ID for JWT key")
assertOK(t, nil, nil)
}
Expand All @@ -269,6 +283,7 @@ func TestToPluginFromCommonProto(t *testing.T) {
}

assertOK(t, commonGood, pluginGood)
assertOK(t, commonTaintedGood, pluginTaintedGood)
assertFail(t, commonNoKeyID, "missing key ID for JWT key")
assertFail(t, commonNoPublicKey, `missing public key for JWT key "ID"`)
assertFail(t, commonBadPublicKey, `failed to unmarshal public key for JWT key "ID": `)
Expand All @@ -290,7 +305,8 @@ func TestToPluginFromCommonProtos(t *testing.T) {
assert.Panics(t, func() { jwtkey.RequireToPluginFromCommonProtos(in) })
}

assertOK(t, []*common.PublicKey{commonGood}, []*plugintypes.JWTKey{pluginGood})
assertOK(t, []*common.PublicKey{commonGood, commonTaintedGood},
[]*plugintypes.JWTKey{pluginGood, pluginTaintedGood})
assertFail(t, []*common.PublicKey{commonNoKeyID}, "missing key ID for JWT key")
assertOK(t, nil, nil)
}
Expand All @@ -308,6 +324,7 @@ func TestToPluginFromAPIProto(t *testing.T) {
}

assertOK(t, apiGood, pluginGood)
assertOK(t, apiTaintedGood, pluginTaintedGood)
assertFail(t, apiNoKeyID, "missing key ID for JWT key")
assertFail(t, apiNoPublicKey, `missing public key for JWT key "ID"`)
assertFail(t, apiBadPublicKey, `failed to unmarshal public key for JWT key "ID": `)
Expand All @@ -328,7 +345,8 @@ func TestToPluginFromAPIProtos(t *testing.T) {
assert.Empty(t, actualOut)
}

assertOK(t, []*apitypes.JWTKey{apiGood}, []*plugintypes.JWTKey{pluginGood})
assertOK(t, []*apitypes.JWTKey{apiGood, apiTaintedGood},
[]*plugintypes.JWTKey{pluginGood, pluginTaintedGood})
assertFail(t, []*apitypes.JWTKey{apiNoKeyID}, "missing key ID for JWT key")
assertOK(t, nil, nil)
}
Expand All @@ -347,6 +365,7 @@ func TestToAPIProto(t *testing.T) {
}

assertOK(t, jwtKeyGood, apiGood)
assertOK(t, jwtKeyTaintedGood, apiTaintedGood)
assertFail(t, jwtKeyNoKeyID, "missing key ID for JWT key")
assertFail(t, jwtKeyNoPublicKey, `missing public key for JWT key "ID"`)
assertFail(t, jwtKeyBadPublicKey, `failed to marshal public key for JWT key "ID": `)
Expand All @@ -367,6 +386,7 @@ func TestToAPIFromPluginProto(t *testing.T) {
}

assertOK(t, pluginGood, apiGood)
assertOK(t, pluginTaintedGood, apiTaintedGood)
assertFail(t, pluginNoKeyID, "missing key ID for JWT key")
assertFail(t, pluginNoPublicKey, `missing public key for JWT key "ID"`)
assertFail(t, pluginBadPublicKey, `failed to unmarshal public key for JWT key "ID": `)
Expand Down
7 changes: 4 additions & 3 deletions pkg/common/coretypes/jwtkey/plugintypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
)

func FromPluginProto(pb *plugintypes.JWTKey) (JWTKey, error) {
return fromProtoFields(pb.KeyId, pb.PublicKey, pb.ExpiresAt)
return fromProtoFields(pb.KeyId, pb.PublicKey, pb.ExpiresAt, pb.Tainted)
}

func FromPluginProtos(pbs []*plugintypes.JWTKey) ([]JWTKey, error) {
Expand All @@ -26,14 +26,15 @@ func FromPluginProtos(pbs []*plugintypes.JWTKey) ([]JWTKey, error) {
}

func ToPluginProto(jwtKey JWTKey) (*plugintypes.JWTKey, error) {
id, publicKey, expiresAt, err := toProtoFields(jwtKey)
id, publicKey, expiresAt, tainted, err := toProtoFields(jwtKey)
if err != nil {
return nil, err
}
return &plugintypes.JWTKey{
KeyId: id,
PublicKey: publicKey,
ExpiresAt: expiresAt,
Tainted: tainted,
}, nil
}

Expand Down Expand Up @@ -80,7 +81,7 @@ func ToPluginFromAPIProto(pb *apitypes.JWTKey) (*plugintypes.JWTKey, error) {
return nil, nil
}

jwtKey, err := fromProtoFields(pb.KeyId, pb.PublicKey, pb.ExpiresAt)
jwtKey, err := fromProtoFields(pb.KeyId, pb.PublicKey, pb.ExpiresAt, pb.Tainted)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit e3dac17

Please sign in to comment.