Skip to content

Commit

Permalink
add case where veriy signature failed
Browse files Browse the repository at this point in the history
Signed-off-by: Marcos Yacob <[email protected]>
  • Loading branch information
MarcosDY committed Sep 23, 2024
1 parent 39d9e0d commit a838e4d
Show file tree
Hide file tree
Showing 10 changed files with 100 additions and 64 deletions.
16 changes: 11 additions & 5 deletions pkg/agent/manager/cache/lru_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -587,17 +587,17 @@ func (c *LRUCache) scheduleRotation(ctx context.Context, entryIDs []string, tain

entriesLeftCount := len(entryIDs)
if entriesLeftCount == 0 {
c.log.Debug("Finished to process all tainted entries")
c.log.Debug("Finished processing all tainted entries")
c.notifyTaintedBatchProcessed()
return
}
c.log.WithField(telemetry.Count, entriesLeftCount).Debug("Entries left to process")
c.log.WithField(telemetry.Count, entriesLeftCount).Debug("Tainted entries left to be processed")
c.notifyTaintedBatchProcessed()

select {
case <-ticker.C:
case <-ctx.Done():
c.log.Debug("Context cancelled, exiting rotation schedule")
c.log.WithError(ctx.Err()).Warn("Context cancelled, exiting rotation schedule")
return
}
}
Expand All @@ -614,7 +614,6 @@ func (c *LRUCache) processTaintedSVIDs(entryIDs []string, taintedX509Authorities
counter := telemetry.StartCall(c.metrics, telemetry.CacheManager, "", telemetry.ProcessTaintedSVIDs)
defer counter.Done(nil)

// TODO: add metric fr time
taintedSVIDs := 0

c.mu.Lock()
Expand All @@ -628,7 +627,14 @@ func (c *LRUCache) processTaintedSVIDs(entryIDs []string, taintedX509Authorities
}

// Check if the SVID is signed by any tainted authority
if x509util.IsSignedByRoot(svid.Chain, taintedX509Authorities) {
isTainted, err := x509util.IsSignedByRoot(svid.Chain, taintedX509Authorities)
if err != nil {
c.log.WithError(err).
WithField(telemetry.RegistrationID, entryID).
Error("Failed to check if SVID is signed by tainted authority")
continue
}
if isTainted {
taintedSVIDs++
delete(c.svids, entryID)
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/agent/manager/cache/lru_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1062,7 +1062,7 @@ func TestTaintX509SVIDs(t *testing.T) {
},
{
Level: logrus.DebugLevel,
Message: "Entries left to process",
Message: "Tainted entries left to be processed",
Data: logrus.Fields{telemetry.Count: "6"},
},
}
Expand All @@ -1083,7 +1083,7 @@ func TestTaintX509SVIDs(t *testing.T) {
},
{
Level: logrus.DebugLevel,
Message: "Entries left to process",
Message: "Tainted entries left to be processed",
Data: logrus.Fields{telemetry.Count: "2"},
},
}
Expand All @@ -1104,7 +1104,7 @@ func TestTaintX509SVIDs(t *testing.T) {
},
{
Level: logrus.DebugLevel,
Message: "Finished to process all tainted entries",
Message: "Finished processing all tainted entries",
},
}
expectMetrics = append([]fakemetrics.MetricItem{
Expand Down
3 changes: 2 additions & 1 deletion pkg/agent/manager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/spiffe/spire/pkg/common/rotationutil"
"github.com/spiffe/spire/pkg/common/telemetry"
"github.com/spiffe/spire/pkg/common/util"
"github.com/spiffe/spire/pkg/common/x509util"
"github.com/spiffe/spire/pkg/server/api/limits"
"github.com/spiffe/spire/proto/spire/common"
)
Expand Down Expand Up @@ -324,7 +325,7 @@ func (m *manager) runSynchronizer(ctx context.Context) error {

err := m.synchronize(ctx)
switch {
case nodeutil.IsUnknownAuthorityError(err):
case x509util.IsUnknownAuthorityError(err):
m.c.Log.WithError(err).Info("Synchronize failed, non-recoverable error")
return fmt.Errorf("failed to sync with SPIRE Server: %w", err)
case err != nil && nodeutil.ShouldAgentReattest(err):
Expand Down
2 changes: 1 addition & 1 deletion pkg/agent/manager/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ func TestForceRotation(t *testing.T) {
// Wait until tainted authorities are fully processed, then retry synchronization
assert.Eventually(t, func() bool {
for _, logEntry := range logHook.Entries {
if logEntry.Message == "Finished to process all tainted entries" {
if logEntry.Message == "Finished processing all tainted entries" {
return true
}
}
Expand Down
10 changes: 9 additions & 1 deletion pkg/agent/manager/storecache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,15 @@ func (c *Cache) TaintX509SVIDs(ctx context.Context, taintedX509Authorities []*x5
continue
}

if x509util.IsSignedByRoot(record.svid.Chain, taintedX509Authorities) {
isTainted, err := x509util.IsSignedByRoot(record.svid.Chain, taintedX509Authorities)
if err != nil {
c.c.Log.WithError(err).
WithField(telemetry.RegistrationID, record.entry.EntryId).
Error("Failed to check if SVID is signed by tainted authority")
continue
}

if isTainted {
taintedSVIDs++
record.svid = nil // Mark SVID as tainted by setting it to nil
}
Expand Down
6 changes: 5 additions & 1 deletion pkg/agent/svid/rotator.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,11 @@ func (r *rotator) NotifyTaintedAuthorities(taintedAuthorities []*x509.Certificat
return nil
}

tainted := x509util.IsSignedByRoot(state.SVID, taintedAuthorities)
tainted, err := x509util.IsSignedByRoot(state.SVID, taintedAuthorities)
if err != nil {
return fmt.Errorf("failed to check if SVID is tainted: %w", err)
}

if tainted {
r.c.Log.Debug("Agent SVID is tainted by a root authority, forcing rotation")
r.setTainted(tainted)
Expand Down
13 changes: 0 additions & 13 deletions pkg/common/nodeutil/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package nodeutil

import (
"errors"
"strings"

"github.com/spiffe/spire-api-sdk/proto/spire/api/types"
"github.com/spiffe/spire/proto/spire/common"
Expand All @@ -20,7 +19,6 @@ var (
shouldShutDown = map[types.PermissionDeniedDetails_Reason]struct{}{
types.PermissionDeniedDetails_AGENT_BANNED: {},
}
unknowAuthorityErr = "x509: certificate signed by unknown authority"
)

// IsAgentBanned determines if a given attested node is banned or not.
Expand All @@ -34,17 +32,6 @@ func ShouldAgentReattest(err error) bool {
return isExpectedPermissionDenied(err, shouldReattest)
}

// IsUnknownAuthorityError returns tru if the Server returned an unknow authority error when verifying
// presented SVID
func IsUnknownAuthorityError(err error) bool {
if err == nil {
return false
}

// Since it is an rpc error we are unable to use errors.As since it is not possible to unwrap
return strings.Contains(err.Error(), unknowAuthorityErr)
}

// ShouldAgentShutdown returns true if the Server returned an error worth shutting down the Agent
func ShouldAgentShutdown(err error) bool {
return isExpectedPermissionDenied(err, shouldShutDown)
Expand Down
27 changes: 0 additions & 27 deletions pkg/common/nodeutil/node_test.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
package nodeutil_test

import (
"errors"
"fmt"
"testing"

"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
"github.com/spiffe/spire-api-sdk/proto/spire/api/types"
"github.com/spiffe/spire/pkg/common/nodeutil"
"github.com/spiffe/spire/proto/spire/common"
"github.com/spiffe/spire/test/testca"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
Expand Down Expand Up @@ -51,29 +47,6 @@ func TestShouldAgentReattest(t *testing.T) {
require.False(t, nodeutil.ShouldAgentReattest(getError(t, codes.PermissionDenied, nil)))
}

func TestIsUnknownAuthority(t *testing.T) {
t.Run("no error provided", func(t *testing.T) {
require.False(t, nodeutil.IsUnknownAuthorityError(nil))
})

t.Run("unexpected error", func(t *testing.T) {
require.False(t, nodeutil.IsUnknownAuthorityError(errors.New("oh no")))
})

t.Run("unknown authority err", func(t *testing.T) {
// Create two bundles with same TD and an SVID that is signed by one of them
ca := testca.New(t, spiffeid.RequireTrustDomainFromString("test.td"))
ca2 := testca.New(t, spiffeid.RequireTrustDomainFromString("test.td"))
svid := ca2.CreateX509SVID(spiffeid.RequireFromString("spiffe://test.td/w1"))

// Verify must fail
_, _, err := x509svid.Verify(svid.Certificates, ca.X509Bundle())
require.Error(t, err)

require.True(t, nodeutil.IsUnknownAuthorityError(err))
})
}

func TestShouldAgentShutdown(t *testing.T) {
agentExpired := &types.PermissionDeniedDetails{
Reason: types.PermissionDeniedDetails_AGENT_EXPIRED,
Expand Down
31 changes: 27 additions & 4 deletions pkg/common/x509util/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,16 @@ import (
"crypto"
"crypto/rand"
"crypto/x509"
"fmt"
"strings"

"github.com/spiffe/spire/pkg/common/cryptoutil"
)

const (
unknowAuthorityErr = "x509: certificate signed by unknown authority"
)

func CreateCertificate(template, parent *x509.Certificate, pub, priv any) (*x509.Certificate, error) {
certDER, err := x509.CreateCertificate(rand.Reader, template, parent, pub, priv)
if err != nil {
Expand Down Expand Up @@ -73,10 +79,21 @@ func RawCertsFromCertificates(certs []*x509.Certificate) [][]byte {
return rawCerts
}

// IsUnknownAuthorityError returns tru if the Server returned an unknow authority error when verifying
// presented SVID
func IsUnknownAuthorityError(err error) bool {
if err == nil {
return false
}

// Since it is an rpc error we are unable to use errors.As since it is not possible to unwrap
return strings.Contains(err.Error(), unknowAuthorityErr)
}

// IsSignedByRoot checks if the provided certificate chain is signed by one of the specified root CAs.
func IsSignedByRoot(chain []*x509.Certificate, rootCAs []*x509.Certificate) bool {
func IsSignedByRoot(chain []*x509.Certificate, rootCAs []*x509.Certificate) (bool, error) {
if len(chain) == 0 {
return false
return false, nil
}
rootPool := x509.NewCertPool()
for _, x509Authority := range rootCAs {
Expand All @@ -93,7 +110,13 @@ func IsSignedByRoot(chain []*x509.Certificate, rootCAs []*x509.Certificate) bool
Intermediates: intermediatePool,
Roots: rootPool,
})
if err == nil {
return true, nil
}

if IsUnknownAuthorityError(err) {
return false, nil
}

// TODO: may we verify if error is different to Signed by unknown authority?
return err == nil
return false, fmt.Errorf("failed to verify certificate chain: %w", err)
}
50 changes: 42 additions & 8 deletions pkg/common/x509util/cert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,40 @@ package x509util_test

import (
"crypto/x509"
"errors"
"testing"

"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
"github.com/spiffe/spire/pkg/common/x509util"
"github.com/spiffe/spire/test/testca"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestIsUnknownAuthority(t *testing.T) {
t.Run("no error provided", func(t *testing.T) {
require.False(t, x509util.IsUnknownAuthorityError(nil))
})

t.Run("unexpected error", func(t *testing.T) {
require.False(t, x509util.IsUnknownAuthorityError(errors.New("oh no")))
})

t.Run("unknown authority err", func(t *testing.T) {
// Create two bundles with same TD and an SVID that is signed by one of them
ca := testca.New(t, spiffeid.RequireTrustDomainFromString("test.td"))
ca2 := testca.New(t, spiffeid.RequireTrustDomainFromString("test.td"))
svid := ca2.CreateX509SVID(spiffeid.RequireFromString("spiffe://test.td/w1"))

// Verify must fail
_, _, err := x509svid.Verify(svid.Certificates, ca.X509Bundle())
require.Error(t, err)

require.True(t, x509util.IsUnknownAuthorityError(err))
})
}

func TestIsSignedByRoot(t *testing.T) {
td := spiffeid.RequireTrustDomainFromString("example.org")
ca1 := testca.New(t, td)
Expand All @@ -19,19 +45,27 @@ func TestIsSignedByRoot(t *testing.T) {
ca2 := testca.New(t, td)
svid2 := ca2.CreateX509SVID(spiffeid.RequireFromPath(td, "/w2"))

testSignedByRoot := func(t *testing.T, chain []*x509.Certificate, rootCAs []*x509.Certificate, expect bool) {
isSigned := x509util.IsSignedByRoot(chain, rootCAs)
invalidCertificate := []*x509.Certificate{{Raw: []byte("invalid")}}

testSignedByRoot := func(t *testing.T, chain []*x509.Certificate, rootCAs []*x509.Certificate, expect bool, expectError string) {
isSigned, err := x509util.IsSignedByRoot(chain, rootCAs)
if expect {
assert.True(t, isSigned, "Expected chain to be signed by root")
} else {
assert.False(t, isSigned, "Expected chain NOT to be signed by root")
}
if expectError != "" {
assert.ErrorContains(t, err, expectError)
} else {
assert.NoError(t, err)
}
}

testSignedByRoot(t, svid1.Certificates, ca1.X509Authorities(), true)
testSignedByRoot(t, svid2.Certificates, ca2.X509Authorities(), true)
testSignedByRoot(t, svid2.Certificates, ca1.X509Authorities(), false)
testSignedByRoot(t, svid1.Certificates, ca2.X509Authorities(), false)
testSignedByRoot(t, nil, ca2.X509Authorities(), false)
testSignedByRoot(t, svid1.Certificates, nil, false)
testSignedByRoot(t, svid1.Certificates, ca1.X509Authorities(), true, "")
testSignedByRoot(t, svid2.Certificates, ca2.X509Authorities(), true, "")
testSignedByRoot(t, svid2.Certificates, ca1.X509Authorities(), false, "")
testSignedByRoot(t, svid1.Certificates, ca2.X509Authorities(), false, "")
testSignedByRoot(t, nil, ca2.X509Authorities(), false, "")
testSignedByRoot(t, svid1.Certificates, nil, false, "")
testSignedByRoot(t, invalidCertificate, ca1.X509Authorities(), false, "failed to verify certificate chain: x509: certificate has expired or is not yet valid")
}

0 comments on commit a838e4d

Please sign in to comment.