diff --git a/cmd/spire-agent/cli/run/run.go b/cmd/spire-agent/cli/run/run.go index 7c66301f65..627182b525 100644 --- a/cmd/spire-agent/cli/run/run.go +++ b/cmd/spire-agent/cli/run/run.go @@ -120,10 +120,6 @@ type experimentalConfig struct { UseSyncAuthorizedEntries bool `hcl:"use_sync_authorized_entries"` Flags fflag.RawConfig `hcl:"feature_flags"` - - UnusedKeyPositions map[string][]token.Pos `hcl:",unusedKeyPositions"` - X509SVIDCacheMaxSize int `hcl:"x509_svid_cache_max_size"` - DisableLRUCache bool `hcl:"disable_lru_cache"` } type Command struct { @@ -498,19 +494,6 @@ func NewAgentConfig(c *Config, logOptions []log.Option, allowUnknownConfig bool) ac.LogReopener = log.ReopenOnSignal(logger, reopenableFile) } - if c.Agent.Experimental.X509SVIDCacheMaxSize < 0 { - return nil, errors.New("x509_svid_cache_max_size should not be negative") - } - if c.Agent.Experimental.X509SVIDCacheMaxSize > 0 || c.Agent.Experimental.DisableLRUCache { - logger.Warn("The `x509_svid_cache_max_size` and `disable_lru_cache` configurations are deprecated. They will be removed in a future release.") - } - ac.X509SVIDCacheMaxSize = c.Agent.Experimental.X509SVIDCacheMaxSize - - if c.Agent.Experimental.DisableLRUCache && ac.X509SVIDCacheMaxSize != 0 { - return nil, errors.New("x509_svid_cache_max_size should not be set when disable_lru_cache is set") - } - ac.DisableLRUCache = c.Agent.Experimental.DisableLRUCache - td, err := common_cli.ParseTrustDomain(c.Agent.TrustDomain, logger) if err != nil { return nil, err diff --git a/cmd/spire-agent/cli/run/run_test.go b/cmd/spire-agent/cli/run/run_test.go index 0bea4aebc4..164b4e0e4d 100644 --- a/cmd/spire-agent/cli/run/run_test.go +++ b/cmd/spire-agent/cli/run/run_test.go @@ -897,106 +897,6 @@ func TestNewAgentConfig(t *testing.T) { require.Nil(t, c) }, }, - { - msg: "x509_svid_cache_max_size is set", - input: func(c *Config) { - c.Agent.Experimental.X509SVIDCacheMaxSize = 100 - }, - logOptions: func(t *testing.T) []log.Option { - return []log.Option{ - func(logger *log.Logger) error { - logger.SetOutput(io.Discard) - hook := test.NewLocal(logger.Logger) - t.Cleanup(func() { - spiretest.AssertLogsContainEntries(t, hook.AllEntries(), []spiretest.LogEntry{ - { - Level: logrus.WarnLevel, - Message: "The `x509_svid_cache_max_size` and `disable_lru_cache` " + - "configurations are deprecated. They will be removed in a future release.", - }, - }) - }) - return nil - }, - } - }, - test: func(t *testing.T, c *agent.Config) { - require.EqualValues(t, 100, c.X509SVIDCacheMaxSize) - }, - }, - { - msg: "x509_svid_cache_max_size is not set", - input: func(c *Config) { - }, - test: func(t *testing.T, c *agent.Config) { - require.EqualValues(t, 0, c.X509SVIDCacheMaxSize) - }, - }, - { - msg: "x509_svid_cache_max_size is zero", - input: func(c *Config) { - c.Agent.Experimental.X509SVIDCacheMaxSize = 0 - }, - test: func(t *testing.T, c *agent.Config) { - require.EqualValues(t, 0, c.X509SVIDCacheMaxSize) - }, - }, - { - msg: "x509_svid_cache_max_size is negative", - expectError: true, - input: func(c *Config) { - c.Agent.Experimental.X509SVIDCacheMaxSize = -10 - }, - test: func(t *testing.T, c *agent.Config) { - require.Nil(t, c) - }, - }, - { - msg: "disable_lru_cache is set", - input: func(c *Config) { - c.Agent.Experimental.DisableLRUCache = true - }, - logOptions: func(t *testing.T) []log.Option { - return []log.Option{ - func(logger *log.Logger) error { - logger.SetOutput(io.Discard) - hook := test.NewLocal(logger.Logger) - t.Cleanup(func() { - spiretest.AssertLogsContainEntries(t, hook.AllEntries(), []spiretest.LogEntry{ - { - Level: logrus.WarnLevel, - Message: "The `x509_svid_cache_max_size` and `disable_lru_cache` " + - "configurations are deprecated. They will be removed in a future release.", - }, - }) - }) - return nil - }, - } - }, - test: func(t *testing.T, c *agent.Config) { - require.True(t, c.DisableLRUCache) - }, - }, - { - msg: "both disable_lru_cache and x509_svid_cache_max_size are set", - expectError: true, - input: func(c *Config) { - c.Agent.Experimental.DisableLRUCache = true - c.Agent.Experimental.X509SVIDCacheMaxSize = 100 - }, - test: func(t *testing.T, c *agent.Config) { - require.Nil(t, c) - }, - }, - { - msg: "disable_lru_cache is not set", - input: func(c *Config) { - }, - test: func(t *testing.T, c *agent.Config) { - require.False(t, c.DisableLRUCache) - }, - }, { msg: "allowed_foreign_jwt_claims provided", input: func(c *Config) { diff --git a/doc/spire_agent.md b/doc/spire_agent.md index 4954d6b176..c9e82ea940 100644 --- a/doc/spire_agent.md +++ b/doc/spire_agent.md @@ -75,8 +75,6 @@ This may be useful for templating configuration files, for example across differ |:------------------------------|--------------------------------------------------------------------------------------|-------------------------| | `named_pipe_name` | Pipe name to bind the SPIRE Agent API named pipe (Windows only) | \spire-agent\public\api | | `sync_interval` | Sync interval with SPIRE server with exponential backoff | 5 sec | -| `x509_svid_cache_max_size` | Soft limit of max number of SVIDs that would be stored in LRU cache (deprecated) | 1000 | -| `disable_lru_cache` | Reverts back to use the SPIRE Agent non-LRU cache for storing SVIDs (deprecated) | false | | `use_sync_authorized_entries` | Use SyncAuthorizedEntries API for periodically synchronization of authorized entries | false | ### Initial trust bundle configuration @@ -387,7 +385,7 @@ There are two ways the trusted delegate workload can request SVIDs for other wor In this approach, the trusted delegate workload is entirely responsible for attesting the other workload and building the attested selectors. When those selectors are presented to the SPIRE Agent, the SPIRE Agent will simply return SVIDs for any workload registration entries that match the provided selectors. No other checks or attestations will be performed by the SPIRE Agent. - + 1. By obtaining a PID for the other workload, and providing that PID to the SPIRE Agent over the Delegated Identity API. In this approach, the SPIRE Agent will do attestation for the provided PID, build the attested selectors, and return SVIDs for any workload registration entries that match the selectors the SPIRE Agent attested from that PID. This differs from the previous approach in that the SPIRE Agent itself (not the trusted delegate) handles the attestation of the other workload. diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index 1c7dc9050c..64ed3b5211 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -275,8 +275,6 @@ func (a *Agent) newManager(ctx context.Context, sto storage.Storage, cat catalog Storage: sto, SyncInterval: a.c.SyncInterval, UseSyncAuthorizedEntries: a.c.UseSyncAuthorizedEntries, - SVIDCacheMaxSize: a.c.X509SVIDCacheMaxSize, - DisableLRUCache: a.c.DisableLRUCache, SVIDStoreCache: cache, NodeAttestor: na, RotationStrategy: rotationutil.NewRotationStrategy(a.c.AvailabilityTarget), diff --git a/pkg/agent/api/delegatedidentity/v1/service_test.go b/pkg/agent/api/delegatedidentity/v1/service_test.go index ea1d9f68a9..6aaec07928 100644 --- a/pkg/agent/api/delegatedidentity/v1/service_test.go +++ b/pkg/agent/api/delegatedidentity/v1/service_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/andres-erbsen/clock" "github.com/sirupsen/logrus" "github.com/sirupsen/logrus/hooks/test" "github.com/spiffe/go-spiffe/v2/bundle/spiffebundle" @@ -993,9 +994,9 @@ func (m *FakeManager) SubscribeToBundleChanges() *cache.BundleStream { return myCache.BundleCache.SubscribeToBundleChanges() } -func newTestCache() *cache.Cache { +func newTestCache() *cache.LRUCache { log, _ := test.NewNullLogger() - return cache.New(log, trustDomain1, bundle1, telemetry.Blackhole{}) + return cache.NewLRUCache(log, trustDomain1, bundle1, telemetry.Blackhole{}, clock.New()) } func generateSubscribeToX509SVIDMetrics() []fakemetrics.MetricItem { diff --git a/pkg/agent/config.go b/pkg/agent/config.go index 1d964be1d9..1239b820c6 100644 --- a/pkg/agent/config.go +++ b/pkg/agent/config.go @@ -66,12 +66,6 @@ type Config struct { // is used to sync entries from the server. UseSyncAuthorizedEntries bool - // X509SVIDCacheMaxSize is a soft limit of max number of SVIDs that would be stored in cache - X509SVIDCacheMaxSize int - - // DisableLRUCache disables the SPIRE Agent LRU cache used for storing SVIDs and fallback to original cache - DisableLRUCache bool - // Trust domain and associated CA bundle TrustDomain spiffeid.TrustDomain TrustBundle []*x509.Certificate diff --git a/pkg/agent/manager/cache/bundle_cache.go b/pkg/agent/manager/cache/bundle_cache.go index e2d79f413f..f7f6c5a6d8 100644 --- a/pkg/agent/manager/cache/bundle_cache.go +++ b/pkg/agent/manager/cache/bundle_cache.go @@ -2,9 +2,12 @@ package cache import ( "github.com/imkira/go-observer" + "github.com/spiffe/go-spiffe/v2/bundle/spiffebundle" "github.com/spiffe/go-spiffe/v2/spiffeid" ) +type Bundle = spiffebundle.Bundle + type BundleCache struct { trustDomain spiffeid.TrustDomain bundles observer.Property diff --git a/pkg/agent/manager/cache/cache.go b/pkg/agent/manager/cache/cache.go deleted file mode 100644 index 7ad5293090..0000000000 --- a/pkg/agent/manager/cache/cache.go +++ /dev/null @@ -1,798 +0,0 @@ -package cache - -import ( - "context" - "crypto" - "crypto/x509" - "sort" - "sync" - "time" - - "github.com/sirupsen/logrus" - "github.com/spiffe/go-spiffe/v2/bundle/spiffebundle" - "github.com/spiffe/go-spiffe/v2/spiffeid" - "github.com/spiffe/spire/pkg/common/telemetry" - "github.com/spiffe/spire/proto/spire/common" -) - -type Selectors []*common.Selector -type Bundle = spiffebundle.Bundle - -// Identity holds the data for a single workload identity -type Identity struct { - Entry *common.RegistrationEntry - SVID []*x509.Certificate - PrivateKey crypto.Signer -} - -// WorkloadUpdate is used to convey workload information to cache subscribers -type WorkloadUpdate struct { - Identities []Identity - Bundle *spiffebundle.Bundle - FederatedBundles map[spiffeid.TrustDomain]*spiffebundle.Bundle -} - -func (u *WorkloadUpdate) HasIdentity() bool { - return len(u.Identities) > 0 -} - -// Update holds information for an entries update to the cache. -type UpdateEntries struct { - // Bundles is a set of ALL trust bundles available to the agent, keyed by trust domain - Bundles map[spiffeid.TrustDomain]*spiffebundle.Bundle - - // RegistrationEntries is a set of ALL registration entries available to the - // agent, keyed by registration entry id. - RegistrationEntries map[string]*common.RegistrationEntry -} - -// Update holds information for an SVIDs update to the cache. -type UpdateSVIDs struct { - // X509SVIDs is a set of updated X509-SVIDs that should be merged into - // the cache, keyed by registration entry id. - X509SVIDs map[string]*X509SVID -} - -// X509SVID holds onto the SVID certificate chain and private key. -type X509SVID struct { - Chain []*x509.Certificate - PrivateKey crypto.Signer -} - -// Cache caches each registration entry, signed X509-SVIDs for those entries, -// bundles, and JWT SVIDs for the agent. It allows subscriptions by (workload) -// selector sets and notifies subscribers when: -// -// 1) a registration entry related to the selectors: -// - is modified -// - has a new X509-SVID signed for it -// - federates with a federated bundle that is updated -// -// 2) the trust bundle for the agent trust domain is updated -// -// When notified, the subscriber is given a WorkloadUpdate containing -// related identities and trust bundles. -// -// The cache does this efficiently by building an index for each unique -// selector it encounters. Each selector index tracks the subscribers (i.e -// workloads) and registration entries that have that selector. -// -// When registration entries are added/updated/removed, the set of relevant -// selectors are gathered and the indexes for those selectors are combed for -// all relevant subscribers. -// -// For each relevant subscriber, the selector index for each selector of the -// subscriber is combed for registration whose selectors are a subset of the -// subscriber selector set. Identities for those entries are added to the -// workload update returned to the subscriber. -// -// NOTE: The cache is intended to be able to handle thousands of workload -// subscriptions, which can involve thousands of certificates, keys, bundles, -// and registration entries, etc. The selector index itself is intended to be -// scalable, but the objects themselves can take a considerable amount of -// memory. For maximal safety, the objects should be cloned both coming in and -// leaving the cache. However, during global updates (e.g. trust bundle is -// updated for the agent trust domain) in particular, cloning all of the -// relevant objects for each subscriber causes HUGE amounts of memory pressure -// which adds non-trivial amounts of latency and causes a giant memory spike -// that could OOM the agent on smaller VMs. For this reason, the cache is -// presumed to own ALL data passing in and out of the cache. Producers and -// consumers MUST NOT mutate the data. -type Cache struct { - *BundleCache - *JWTSVIDCache - - log logrus.FieldLogger - trustDomain spiffeid.TrustDomain - - metrics telemetry.Metrics - - mu sync.RWMutex - - // records holds the records for registration entries, keyed by registration entry ID - records map[string]*cacheRecord - - // selectors holds the selector indices, keyed by a selector key - selectors map[selector]*selectorIndex - - // staleEntries holds stale registration entries - staleEntries map[string]bool - - // bundles holds the trust bundles, keyed by trust domain id (i.e. "spiffe://domain.test") - bundles map[spiffeid.TrustDomain]*spiffebundle.Bundle -} - -// StaleEntry holds stale entries with SVIDs expiration time -type StaleEntry struct { - // Entry stale registration entry - Entry *common.RegistrationEntry - // SVIDs expiration time - SVIDExpiresAt time.Time -} - -func New(log logrus.FieldLogger, trustDomain spiffeid.TrustDomain, bundle *Bundle, metrics telemetry.Metrics) *Cache { - return &Cache{ - BundleCache: NewBundleCache(trustDomain, bundle), - JWTSVIDCache: NewJWTSVIDCache(), - - log: log, - metrics: metrics, - trustDomain: trustDomain, - records: make(map[string]*cacheRecord), - selectors: make(map[selector]*selectorIndex), - staleEntries: make(map[string]bool), - bundles: map[spiffeid.TrustDomain]*spiffebundle.Bundle{ - trustDomain: bundle, - }, - } -} - -// Identities is only used by manager tests -// TODO: We should remove this and find a better way -func (c *Cache) Identities() []Identity { - c.mu.RLock() - defer c.mu.RUnlock() - - out := make([]Identity, 0, len(c.records)) - for _, record := range c.records { - if record.svid == nil { - // The record does not have an SVID yet and should not be returned - // from the cache. - continue - } - out = append(out, makeIdentity(record)) - } - sortIdentities(out) - return out -} - -func (c *Cache) CountX509SVIDs() int { - c.mu.RLock() - defer c.mu.RUnlock() - - var records int - for _, record := range c.records { - if record.svid == nil { - // The record does not have an SVID yet and should not be returned - // from the cache. - continue - } - records++ - } - - return records -} - -func (c *Cache) CountJWTSVIDs() int { - return c.JWTSVIDCache.CountJWTSVIDs() -} - -func (c *Cache) MatchingIdentities(selectors []*common.Selector) []Identity { - set, setDone := allocSelectorSet(selectors...) - defer setDone() - - c.mu.RLock() - defer c.mu.RUnlock() - return c.matchingIdentities(set) -} - -func (c *Cache) FetchWorkloadUpdate(selectors []*common.Selector) *WorkloadUpdate { - set, setDone := allocSelectorSet(selectors...) - defer setDone() - - c.mu.RLock() - defer c.mu.RUnlock() - return c.buildWorkloadUpdate(set) -} - -func (c *Cache) SubscribeToWorkloadUpdates(_ context.Context, selectors Selectors) (Subscriber, error) { - return c.subscribeToWorkloadUpdates(selectors), nil -} - -// UpdateEntries updates the cache with the provided registration entries and bundles and -// notifies impacted subscribers. The checkSVID callback, if provided, is used to determine -// if the SVID for the entry is stale, or otherwise in need of rotation. Entries marked stale -// through the checkSVID callback are returned from GetStaleEntries() until the SVID is -// updated through a call to UpdateSVIDs. -func (c *Cache) UpdateEntries(update *UpdateEntries, checkSVID func(*common.RegistrationEntry, *common.RegistrationEntry, *X509SVID) bool) { - c.mu.Lock() - defer c.mu.Unlock() - - // Remove bundles that no longer exist. The bundle for the agent trust - // domain should NOT be removed even if not present (which should only be - // the case if there is a bug on the server) since it is necessary to - // authenticate the server. - bundleRemoved := false - for id := range c.bundles { - if _, ok := update.Bundles[id]; !ok && id != c.trustDomain { - bundleRemoved = true - // bundle no longer exists. - c.log.WithField(telemetry.TrustDomainID, id).Debug("Bundle removed") - delete(c.bundles, id) - } - } - - // Update bundles with changes, populating a "changed" set that we can - // check when processing registration entries to know if they need to spawn - // a notification. - bundleChanged := make(map[spiffeid.TrustDomain]bool) - for id, bundle := range update.Bundles { - existing, ok := c.bundles[id] - if !(ok && existing.Equal(bundle)) { - if !ok { - c.log.WithField(telemetry.TrustDomainID, id).Debug("Bundle added") - } else { - c.log.WithField(telemetry.TrustDomainID, id).Debug("Bundle updated") - } - bundleChanged[id] = true - c.bundles[id] = bundle - } - } - trustDomainBundleChanged := bundleChanged[c.trustDomain] - - // Allocate sets from the pool to track changes to selectors and - // federatesWith declarations. These sets must be cleared after EACH use - // and returned to their respective pools when done processing the - // updates. - notifySets := make([]selectorSet, 0) - selAdd, selAddDone := allocSelectorSet() - defer selAddDone() - selRem, selRemDone := allocSelectorSet() - defer selRemDone() - fedAdd, fedAddDone := allocStringSet() - defer fedAddDone() - fedRem, fedRemDone := allocStringSet() - defer fedRemDone() - - // Remove records for registration entries that no longer exist - for id, record := range c.records { - if _, ok := update.RegistrationEntries[id]; !ok { - c.log.WithFields(logrus.Fields{ - telemetry.Entry: id, - telemetry.SPIFFEID: record.entry.SpiffeId, - }).Debug("Entry removed") - - // built a set of selectors for the record being removed, drop the - // record for each selector index, and add the entry selectors to - // the notify set. - notifySet, notifySetDone := allocSelectorSet(record.entry.Selectors...) - defer notifySetDone() - c.delSelectorIndicesRecord(notifySet, record) - notifySets = append(notifySets, notifySet) - delete(c.records, id) - // Remove stale entry since, registration entry is no longer on cache. - delete(c.staleEntries, id) - } - } - - // Add/update records for registration entries in the update - for _, newEntry := range update.RegistrationEntries { - clearSelectorSet(selAdd) - clearSelectorSet(selRem) - clearStringSet(fedAdd) - clearStringSet(fedRem) - - record, existingEntry := c.updateOrCreateRecord(newEntry) - - // Calculate the difference in selectors, add/remove the record - // from impacted selector indices, and add the selector diff to the - // notify set. - c.diffSelectors(existingEntry, newEntry, selAdd, selRem) - selectorsChanged := len(selAdd) > 0 || len(selRem) > 0 - c.addSelectorIndicesRecord(selAdd, record) - c.delSelectorIndicesRecord(selRem, record) - - // Determine if there were changes to FederatesWith declarations or - // if any federated bundles related to the entry were updated. - c.diffFederatesWith(existingEntry, newEntry, fedAdd, fedRem) - federatedBundlesChanged := len(fedAdd) > 0 || len(fedRem) > 0 - if !federatedBundlesChanged { - for _, id := range newEntry.FederatesWith { - td, err := spiffeid.TrustDomainFromString(id) - if err != nil { - c.log.WithFields(logrus.Fields{ - telemetry.TrustDomainID: id, - logrus.ErrorKey: err, - }).Warn("Invalid federated trust domain") - continue - } - if bundleChanged[td] { - federatedBundlesChanged = true - break - } - } - } - - // If any selectors or federated bundles were changed, then make - // sure subscribers for the new and existing entry selector sets - // are notified. - if selectorsChanged { - if existingEntry != nil { - notifySet, notifySetDone := allocSelectorSet(existingEntry.Selectors...) - defer notifySetDone() - notifySets = append(notifySets, notifySet) - } - } - - if federatedBundlesChanged || selectorsChanged { - notifySet, notifySetDone := allocSelectorSet(newEntry.Selectors...) - defer notifySetDone() - notifySets = append(notifySets, notifySet) - } - - // Invoke the svid checker callback for this record - if checkSVID != nil && checkSVID(existingEntry, newEntry, record.svid) { - c.staleEntries[newEntry.EntryId] = true - } - - // Log all the details of the update to the DEBUG log - if federatedBundlesChanged || selectorsChanged { - log := c.log.WithFields(logrus.Fields{ - telemetry.Entry: newEntry.EntryId, - telemetry.SPIFFEID: newEntry.SpiffeId, - }) - if len(selAdd) > 0 { - log = log.WithField(telemetry.SelectorsAdded, len(selAdd)) - } - if len(selRem) > 0 { - log = log.WithField(telemetry.SelectorsRemoved, len(selRem)) - } - if len(fedAdd) > 0 { - log = log.WithField(telemetry.FederatedAdded, len(fedAdd)) - } - if len(fedRem) > 0 { - log = log.WithField(telemetry.FederatedRemoved, len(fedRem)) - } - if existingEntry != nil { - log.Debug("Entry updated") - } else { - log.Debug("Entry created") - } - } - } - - if bundleRemoved || len(bundleChanged) > 0 { - c.BundleCache.Update(c.bundles) - } - - if trustDomainBundleChanged { - c.notifyAll() - } else { - c.notifyBySelectorSet(notifySets...) - } -} - -func (c *Cache) UpdateSVIDs(update *UpdateSVIDs) { - c.mu.Lock() - defer c.mu.Unlock() - - // Allocate a set of selectors that - notifySet, notifySetDone := allocSelectorSet() - defer notifySetDone() - - // Add/update records for registration entries in the update - for entryID, svid := range update.X509SVIDs { - record, existingEntry := c.records[entryID] - if !existingEntry { - c.log.WithField(telemetry.RegistrationID, entryID).Error("Entry not found") - continue - } - - record.svid = svid - notifySet.Merge(record.entry.Selectors...) - log := c.log.WithFields(logrus.Fields{ - telemetry.Entry: record.entry.EntryId, - telemetry.SPIFFEID: record.entry.SpiffeId, - }) - log.Debug("SVID updated") - - // Registration entry is updated, remove it from stale map - delete(c.staleEntries, entryID) - c.notifyBySelectorSet(notifySet) - clearSelectorSet(notifySet) - } -} - -// GetStaleEntries obtains a list of stale entries -func (c *Cache) GetStaleEntries() []*StaleEntry { - c.mu.Lock() - defer c.mu.Unlock() - - var staleEntries []*StaleEntry - for entryID := range c.staleEntries { - cachedEntry, ok := c.records[entryID] - if !ok { - c.log.WithField(telemetry.RegistrationID, entryID).Debug("Stale marker found for unknown entry. Please fill a bug") - delete(c.staleEntries, entryID) - continue - } - - var expiresAt time.Time - if cachedEntry.svid != nil { - expiresAt = cachedEntry.svid.Chain[0].NotAfter - } - - staleEntries = append(staleEntries, &StaleEntry{ - Entry: cachedEntry.entry, - SVIDExpiresAt: expiresAt, - }) - } - - return staleEntries -} - -func (c *Cache) MatchingRegistrationEntries(selectors []*common.Selector) []*common.RegistrationEntry { - c.mu.RLock() - defer c.mu.RUnlock() - - set, setDone := allocSelectorSet(selectors...) - defer setDone() - - records, recordsDone := c.getRecordsForSelectors(set) - defer recordsDone() - - // Return identities in ascending "entry id" order to maintain a consistent - // ordering. - // TODO: figure out how to determine the "default" identity - out := make([]*common.RegistrationEntry, 0, len(records)) - for record := range records { - out = append(out, record.entry) - } - sortEntriesByID(out) - return out -} - -func (c *Cache) Entries() []*common.RegistrationEntry { - c.mu.RLock() - defer c.mu.RUnlock() - - out := make([]*common.RegistrationEntry, 0, len(c.records)) - for _, record := range c.records { - out = append(out, record.entry) - } - sortEntriesByID(out) - return out -} - -func (c *Cache) SyncSVIDsWithSubscribers() { - c.log.Error("SyncSVIDsWithSubscribers method is not implemented") -} - -func (c *Cache) subscribeToWorkloadUpdates(selectors []*common.Selector) Subscriber { - c.mu.Lock() - defer c.mu.Unlock() - - sub := newSubscriber(c, selectors) - for s := range sub.set { - c.addSelectorIndexSub(s, sub) - } - c.notify(sub) - return sub -} - -func (c *Cache) updateOrCreateRecord(newEntry *common.RegistrationEntry) (*cacheRecord, *common.RegistrationEntry) { - var existingEntry *common.RegistrationEntry - record, recordExists := c.records[newEntry.EntryId] - if !recordExists { - record = newCacheRecord() - c.records[newEntry.EntryId] = record - } else { - existingEntry = record.entry - } - record.entry = newEntry - return record, existingEntry -} - -func (c *Cache) diffSelectors(existingEntry, newEntry *common.RegistrationEntry, added, removed selectorSet) { - // Make a set of all the selectors being added - if newEntry != nil { - added.Merge(newEntry.Selectors...) - } - - // Make a set of all the selectors that are being removed - if existingEntry != nil { - for _, selector := range existingEntry.Selectors { - s := makeSelector(selector) - if _, ok := added[s]; ok { - // selector already exists in entry - delete(added, s) - } else { - // selector has been removed from entry - removed[s] = struct{}{} - } - } - } -} - -func (c *Cache) diffFederatesWith(existingEntry, newEntry *common.RegistrationEntry, added, removed stringSet) { - // Make a set of all the selectors being added - if newEntry != nil { - added.Merge(newEntry.FederatesWith...) - } - - // Make a set of all the selectors that are being removed - if existingEntry != nil { - for _, id := range existingEntry.FederatesWith { - if _, ok := added[id]; ok { - // Bundle already exists in entry - delete(added, id) - } else { - // Bundle has been removed from entry - removed[id] = struct{}{} - } - } - } -} - -func (c *Cache) addSelectorIndicesRecord(selectors selectorSet, record *cacheRecord) { - for selector := range selectors { - c.addSelectorIndexRecord(selector, record) - } -} - -func (c *Cache) addSelectorIndexRecord(s selector, record *cacheRecord) { - index := c.getSelectorIndexForWrite(s) - index.records[record] = struct{}{} -} - -func (c *Cache) delSelectorIndicesRecord(selectors selectorSet, record *cacheRecord) { - for selector := range selectors { - c.delSelectorIndexRecord(selector, record) - } -} - -// delSelectorIndexRecord removes the record from the selector index. If -// the selector index is empty afterwards, it is also removed. -func (c *Cache) delSelectorIndexRecord(s selector, record *cacheRecord) { - index, ok := c.selectors[s] - if ok { - delete(index.records, record) - if index.isEmpty() { - delete(c.selectors, s) - } - } -} - -func (c *Cache) addSelectorIndexSub(s selector, sub *subscriber) { - index := c.getSelectorIndexForWrite(s) - index.subs[sub] = struct{}{} -} - -// delSelectorIndexSub removes the subscription from the selector index. If -// the selector index is empty afterwards, it is also removed. -func (c *Cache) delSelectorIndexSub(s selector, sub *subscriber) { - index, ok := c.selectors[s] - if ok { - delete(index.subs, sub) - if index.isEmpty() { - delete(c.selectors, s) - } - } -} - -func (c *Cache) unsubscribe(sub *subscriber) { - c.mu.Lock() - defer c.mu.Unlock() - for selector := range sub.set { - c.delSelectorIndexSub(selector, sub) - } -} - -func (c *Cache) notifyAll() { - subs, subsDone := c.allSubscribers() - defer subsDone() - for sub := range subs { - c.notify(sub) - } -} - -func (c *Cache) notifyBySelectorSet(sets ...selectorSet) { - notifiedSubs, notifiedSubsDone := allocSubscriberSet() - defer notifiedSubsDone() - for _, set := range sets { - subs, subsDone := c.getSubscribers(set) - defer subsDone() - for sub := range subs { - if _, notified := notifiedSubs[sub]; !notified && sub.set.SuperSetOf(set) { - c.notify(sub) - notifiedSubs[sub] = struct{}{} - } - } - } -} - -func (c *Cache) notify(sub *subscriber) { - update := c.buildWorkloadUpdate(sub.set) - sub.notify(update) -} - -func (c *Cache) allSubscribers() (subscriberSet, func()) { - subs, subsDone := allocSubscriberSet() - for _, index := range c.selectors { - for sub := range index.subs { - subs[sub] = struct{}{} - } - } - return subs, subsDone -} - -func (c *Cache) getSubscribers(set selectorSet) (subscriberSet, func()) { - subs, subsDone := allocSubscriberSet() - for s := range set { - if index := c.getSelectorIndexForRead(s); index != nil { - for sub := range index.subs { - subs[sub] = struct{}{} - } - } - } - return subs, subsDone -} - -func (c *Cache) matchingIdentities(set selectorSet) []Identity { - records, recordsDone := c.getRecordsForSelectors(set) - defer recordsDone() - - if len(records) == 0 { - return nil - } - - // Return identities in ascending "entry id" order to maintain a consistent - // ordering. - // TODO: figure out how to determine the "default" identity - out := make([]Identity, 0, len(records)) - for record := range records { - out = append(out, makeIdentity(record)) - } - sortIdentities(out) - return out -} - -func (c *Cache) buildWorkloadUpdate(set selectorSet) *WorkloadUpdate { - w := &WorkloadUpdate{ - Bundle: c.bundles[c.trustDomain], - FederatedBundles: make(map[spiffeid.TrustDomain]*spiffebundle.Bundle), - Identities: c.matchingIdentities(set), - } - - // Add in the bundles the workload is federated with. - for _, identity := range w.Identities { - for _, federatesWith := range identity.Entry.FederatesWith { - td, err := spiffeid.TrustDomainFromString(federatesWith) - if err != nil { - c.log.WithFields(logrus.Fields{ - telemetry.TrustDomainID: federatesWith, - logrus.ErrorKey: err, - }).Warn("Invalid federated trust domain") - continue - } - if federatedBundle := c.bundles[td]; federatedBundle != nil { - w.FederatedBundles[td] = federatedBundle - } else { - c.log.WithFields(logrus.Fields{ - telemetry.RegistrationID: identity.Entry.EntryId, - telemetry.SPIFFEID: identity.Entry.SpiffeId, - telemetry.FederatedBundle: federatesWith, - }).Warn("Federated bundle contents missing") - } - } - } - - return w -} - -func (c *Cache) getRecordsForSelectors(set selectorSet) (recordSet, func()) { - // Build and dedup a list of candidate entries. Ignore those without an - // SVID but otherwise don't check for selector set inclusion yet, since - // that is a more expensive operation and we could easily have duplicate - // entries to check. - records, recordsDone := allocRecordSet() - for selector := range set { - if index := c.getSelectorIndexForRead(selector); index != nil { - for record := range index.records { - if record.svid == nil { - continue - } - records[record] = struct{}{} - } - } - } - - // Filter out records whose registration entry selectors are not within - // inside the selector set. - for record := range records { - for _, s := range record.entry.Selectors { - if !set.In(s) { - delete(records, record) - } - } - } - return records, recordsDone -} - -// getSelectorIndexForWrite gets the selector index for the selector. If one -// doesn't exist, it is created. Callers must hold the write lock. If the index -// is only being read, then getSelectorIndexForRead should be used instead. -func (c *Cache) getSelectorIndexForWrite(s selector) *selectorIndex { - index, ok := c.selectors[s] - if !ok { - index = newSelectorIndex() - c.selectors[s] = index - } - return index -} - -// getSelectorIndexForRead gets the selector index for the selector. If one -// doesn't exist, nil is returned. Callers should hold the read or write lock. -// If the index is being modified, callers should use getSelectorIndexForWrite -// instead. -func (c *Cache) getSelectorIndexForRead(s selector) *selectorIndex { - if index, ok := c.selectors[s]; ok { - return index - } - return nil -} - -type cacheRecord struct { - entry *common.RegistrationEntry - svid *X509SVID - subs map[*subscriber]struct{} -} - -func newCacheRecord() *cacheRecord { - return &cacheRecord{ - subs: make(map[*subscriber]struct{}), - } -} - -type selectorIndex struct { - // subs holds the subscriptions related to this selector - subs map[*subscriber]struct{} - - // records holds the cache records related to this selector - records map[*cacheRecord]struct{} -} - -func (x *selectorIndex) isEmpty() bool { - return len(x.subs) == 0 && len(x.records) == 0 -} - -func newSelectorIndex() *selectorIndex { - return &selectorIndex{ - subs: make(map[*subscriber]struct{}), - records: make(map[*cacheRecord]struct{}), - } -} - -func sortIdentities(identities []Identity) { - sort.Slice(identities, func(a, b int) bool { - return identities[a].Entry.EntryId < identities[b].Entry.EntryId - }) -} - -func makeIdentity(record *cacheRecord) Identity { - return Identity{ - Entry: record.entry, - SVID: record.svid.Chain, - PrivateKey: record.svid.PrivateKey, - } -} diff --git a/pkg/agent/manager/cache/cache_test.go b/pkg/agent/manager/cache/cache_test.go deleted file mode 100644 index 9dd0d03e60..0000000000 --- a/pkg/agent/manager/cache/cache_test.go +++ /dev/null @@ -1,851 +0,0 @@ -package cache - -import ( - "crypto/x509" - "fmt" - "runtime" - "testing" - "time" - - "github.com/sirupsen/logrus/hooks/test" - "github.com/spiffe/go-spiffe/v2/bundle/spiffebundle" - "github.com/spiffe/go-spiffe/v2/spiffeid" - "github.com/spiffe/spire/pkg/common/telemetry" - "github.com/spiffe/spire/proto/spire/common" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -var ( - trustDomain1 = spiffeid.RequireTrustDomainFromString("domain.test") - trustDomain2 = spiffeid.RequireTrustDomainFromString("otherdomain.test") - bundleV1 = spiffebundle.FromX509Authorities(trustDomain1, []*x509.Certificate{{Raw: []byte{1}}}) - bundleV2 = spiffebundle.FromX509Authorities(trustDomain1, []*x509.Certificate{{Raw: []byte{2}}}) - bundleV3 = spiffebundle.FromX509Authorities(trustDomain1, []*x509.Certificate{{Raw: []byte{3}}}) - otherBundleV1 = spiffebundle.FromX509Authorities(trustDomain2, []*x509.Certificate{{Raw: []byte{4}}}) - otherBundleV2 = spiffebundle.FromX509Authorities(trustDomain2, []*x509.Certificate{{Raw: []byte{5}}}) - defaultX509SVIDTTL = int32(700) - defaultJwtSVIDTTL = int32(800) -) - -func TestFetchWorkloadUpdate(t *testing.T) { - cache := newTestCache() - // populate the cache with FOO and BAR without SVIDS - foo := makeRegistrationEntry("FOO", "A") - bar := makeRegistrationEntry("BAR", "B") - bar.FederatesWith = makeFederatesWith(otherBundleV1) - updateEntries := &UpdateEntries{ - Bundles: makeBundles(bundleV1, otherBundleV1), - RegistrationEntries: makeRegistrationEntries(foo, bar), - } - cache.UpdateEntries(updateEntries, nil) - - workloadUpdate := cache.FetchWorkloadUpdate(makeSelectors("A", "B")) - assert.Len(t, workloadUpdate.Identities, 0, "identities should not be returned that don't have SVIDs") - - updateSVIDs := &UpdateSVIDs{ - X509SVIDs: makeX509SVIDs(foo, bar), - } - cache.UpdateSVIDs(updateSVIDs) - - workloadUpdate = cache.FetchWorkloadUpdate(makeSelectors("A", "B")) - assert.Equal(t, &WorkloadUpdate{ - Bundle: bundleV1, - FederatedBundles: makeBundles(otherBundleV1), - Identities: []Identity{ - {Entry: bar}, - {Entry: foo}, - }, - }, workloadUpdate) -} - -func TestMatchingIdentities(t *testing.T) { - cache := newTestCache() - - // populate the cache with FOO and BAR without SVIDS - foo := makeRegistrationEntry("FOO", "A") - bar := makeRegistrationEntry("BAR", "B") - updateEntries := &UpdateEntries{ - Bundles: makeBundles(bundleV1), - RegistrationEntries: makeRegistrationEntries(foo, bar), - } - cache.UpdateEntries(updateEntries, nil) - - identities := cache.MatchingIdentities(makeSelectors("A", "B")) - assert.Len(t, identities, 0, "identities should not be returned that don't have SVIDs") - - updateSVIDs := &UpdateSVIDs{ - X509SVIDs: makeX509SVIDs(foo, bar), - } - cache.UpdateSVIDs(updateSVIDs) - - identities = cache.MatchingIdentities(makeSelectors("A", "B")) - assert.Equal(t, []Identity{ - {Entry: bar}, - {Entry: foo}, - }, identities) -} - -func TestCountSVIDs(t *testing.T) { - cache := newTestCache() - - // populate the cache with FOO and BAR without SVIDS - foo := makeRegistrationEntry("FOO", "A") - bar := makeRegistrationEntry("BAR", "B") - updateEntries := &UpdateEntries{ - Bundles: makeBundles(bundleV1), - RegistrationEntries: makeRegistrationEntries(foo, bar), - } - cache.UpdateEntries(updateEntries, nil) - - // No SVIDs expected - require.Equal(t, 0, cache.CountX509SVIDs()) - - updateSVIDs := &UpdateSVIDs{ - X509SVIDs: makeX509SVIDs(foo), - } - cache.UpdateSVIDs(updateSVIDs) - - // Only one SVID expected - require.Equal(t, 1, cache.CountX509SVIDs()) -} - -func TestBundleChanges(t *testing.T) { - cache := newTestCache() - - bundleStream := cache.SubscribeToBundleChanges() - assert.Equal(t, makeBundles(bundleV1), bundleStream.Value()) - - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV1, otherBundleV1), - }, nil) - if assert.True(t, bundleStream.HasNext(), "has new bundle value after adding bundle") { - bundleStream.Next() - assert.Equal(t, makeBundles(bundleV1, otherBundleV1), bundleStream.Value()) - } - - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV1), - }, nil) - - if assert.True(t, bundleStream.HasNext(), "has new bundle value after removing bundle") { - bundleStream.Next() - assert.Equal(t, makeBundles(bundleV1), bundleStream.Value()) - } -} - -func TestAllSubscribersNotifiedOnBundleChange(t *testing.T) { - cache := newTestCache() - - // create some subscribers and assert they get the initial bundle - subA := cache.subscribeToWorkloadUpdates(makeSelectors("A")) - defer subA.Finish() - assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{Bundle: bundleV1}) - - subB := cache.subscribeToWorkloadUpdates(makeSelectors("B")) - defer subB.Finish() - assertWorkloadUpdateEqual(t, subB, &WorkloadUpdate{Bundle: bundleV1}) - - // update the bundle and assert all subscribers gets the updated bundle - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV2), - }, nil) - assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{Bundle: bundleV2}) - assertWorkloadUpdateEqual(t, subB, &WorkloadUpdate{Bundle: bundleV2}) -} - -func TestSomeSubscribersNotifiedOnFederatedBundleChange(t *testing.T) { - cache := newTestCache() - - // initialize the cache with an entry FOO that has a valid SVID and - // selector "A" - foo := makeRegistrationEntry("FOO", "A") - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV1), - RegistrationEntries: makeRegistrationEntries(foo), - }, nil) - cache.UpdateSVIDs(&UpdateSVIDs{ - X509SVIDs: makeX509SVIDs(foo), - }) - - // subscribe to A and B and assert initial updates are received. - subA := cache.subscribeToWorkloadUpdates(makeSelectors("A")) - defer subA.Finish() - assertAnyWorkloadUpdate(t, subA) - - subB := cache.subscribeToWorkloadUpdates(makeSelectors("B")) - defer subB.Finish() - assertAnyWorkloadUpdate(t, subB) - - // add the federated bundle with no registration entries federating with - // it and make sure nobody is notified. - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV1, otherBundleV1), - RegistrationEntries: makeRegistrationEntries(foo), - }, nil) - assertNoWorkloadUpdate(t, subA) - assertNoWorkloadUpdate(t, subB) - - // update FOO to federate with otherdomain.test and make sure subA is - // notified but not subB. - foo = makeRegistrationEntry("FOO", "A") - foo.FederatesWith = makeFederatesWith(otherBundleV1) - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV1, otherBundleV1), - RegistrationEntries: makeRegistrationEntries(foo), - }, nil) - assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ - Bundle: bundleV1, - FederatedBundles: makeBundles(otherBundleV1), - Identities: []Identity{{Entry: foo}}, - }) - assertNoWorkloadUpdate(t, subB) - - // now change the federated bundle and make sure subA gets notified, but - // again, not subB. - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV1, otherBundleV2), - RegistrationEntries: makeRegistrationEntries(foo), - }, nil) - assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ - Bundle: bundleV1, - FederatedBundles: makeBundles(otherBundleV2), - Identities: []Identity{{Entry: foo}}, - }) - assertNoWorkloadUpdate(t, subB) - - // now drop the federation and make sure subA is again notified and no - // longer has the federated bundle. - foo = makeRegistrationEntry("FOO", "A") - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV1, otherBundleV2), - RegistrationEntries: makeRegistrationEntries(foo), - }, nil) - assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ - Bundle: bundleV1, - Identities: []Identity{{Entry: foo}}, - }) - assertNoWorkloadUpdate(t, subB) -} - -func TestSubscribersGetEntriesWithSelectorSubsets(t *testing.T) { - cache := newTestCache() - - // create subscribers for each combination of selectors - subA := cache.subscribeToWorkloadUpdates(makeSelectors("A")) - defer subA.Finish() - subB := cache.subscribeToWorkloadUpdates(makeSelectors("B")) - defer subB.Finish() - subAB := cache.subscribeToWorkloadUpdates(makeSelectors("A", "B")) - defer subAB.Finish() - - // assert all subscribers get the initial update - initialUpdate := &WorkloadUpdate{Bundle: bundleV1} - assertWorkloadUpdateEqual(t, subA, initialUpdate) - assertWorkloadUpdateEqual(t, subB, initialUpdate) - assertWorkloadUpdateEqual(t, subAB, initialUpdate) - - // create entry FOO that will target any subscriber with containing (A) - foo := makeRegistrationEntry("FOO", "A") - - // create entry BAR that will target any subscriber with containing (A,C) - bar := makeRegistrationEntry("BAR", "A", "C") - - // update the cache with foo and bar - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV1), - RegistrationEntries: makeRegistrationEntries(foo, bar), - }, nil) - cache.UpdateSVIDs(&UpdateSVIDs{ - X509SVIDs: makeX509SVIDs(foo, bar), - }) - - // subA selector set contains (A), but not (A, C), so it should only get FOO - assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ - Bundle: bundleV1, - Identities: []Identity{{Entry: foo}}, - }) - - // subB selector set does not contain either (A) or (A,C) so it isn't even - // notified. - assertNoWorkloadUpdate(t, subB) - - // subAB selector set contains (A) but not (A, C), so it should get FOO - assertWorkloadUpdateEqual(t, subAB, &WorkloadUpdate{ - Bundle: bundleV1, - Identities: []Identity{{Entry: foo}}, - }) -} - -func TestSubscriberIsNotNotifiedIfNothingChanges(t *testing.T) { - cache := newTestCache() - - foo := makeRegistrationEntry("FOO", "A") - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV1), - RegistrationEntries: makeRegistrationEntries(foo), - }, nil) - cache.UpdateSVIDs(&UpdateSVIDs{ - X509SVIDs: makeX509SVIDs(foo), - }) - - sub := cache.subscribeToWorkloadUpdates(makeSelectors("A")) - defer sub.Finish() - assertAnyWorkloadUpdate(t, sub) - - // Second update is the same (other than X509SVIDs, which, when set, - // always constitute a "change" for the impacted registration entries. - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV1), - RegistrationEntries: makeRegistrationEntries(foo), - }, nil) - - assertNoWorkloadUpdate(t, sub) -} - -func TestSubscriberNotifiedOnSVIDChanges(t *testing.T) { - cache := newTestCache() - - foo := makeRegistrationEntry("FOO", "A") - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV1), - RegistrationEntries: makeRegistrationEntries(foo), - }, nil) - cache.UpdateSVIDs(&UpdateSVIDs{ - X509SVIDs: makeX509SVIDs(foo), - }) - - sub := cache.subscribeToWorkloadUpdates(makeSelectors("A")) - defer sub.Finish() - assertAnyWorkloadUpdate(t, sub) - - // Update SVID - cache.UpdateSVIDs(&UpdateSVIDs{ - X509SVIDs: makeX509SVIDs(foo), - }) - - assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ - Bundle: bundleV1, - Identities: []Identity{{Entry: foo}}, - }) -} - -func TestSubscriberNotificationsOnSelectorChanges(t *testing.T) { - cache := newTestCache() - - // initialize the cache with a FOO entry with selector A and an SVID - foo := makeRegistrationEntry("FOO", "A") - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV1), - RegistrationEntries: makeRegistrationEntries(foo), - }, nil) - cache.UpdateSVIDs(&UpdateSVIDs{ - X509SVIDs: makeX509SVIDs(foo), - }) - - // create subscribers for A and make sure the initial update has FOO - sub := cache.subscribeToWorkloadUpdates(makeSelectors("A")) - defer sub.Finish() - assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ - Bundle: bundleV1, - Identities: []Identity{{Entry: foo}}, - }) - - // update FOO to have selectors (A,B) and make sure the subscriber loses - // FOO, since (A,B) is not a subset of the subscriber set (A). - foo = makeRegistrationEntry("FOO", "A", "B") - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV1), - RegistrationEntries: makeRegistrationEntries(foo), - }, nil) - cache.UpdateSVIDs(&UpdateSVIDs{ - X509SVIDs: makeX509SVIDs(foo), - }) - assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ - Bundle: bundleV1, - }) - - // update FOO to drop B and make sure the subscriber regains FOO - foo = makeRegistrationEntry("FOO", "A") - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV1), - RegistrationEntries: makeRegistrationEntries(foo), - }, nil) - cache.UpdateSVIDs(&UpdateSVIDs{ - X509SVIDs: makeX509SVIDs(foo), - }) - - assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ - Bundle: bundleV1, - Identities: []Identity{{Entry: foo}}, - }) -} - -func newTestCache() *Cache { - log, _ := test.NewNullLogger() - return New(log, spiffeid.RequireTrustDomainFromString("domain.test"), bundleV1, telemetry.Blackhole{}) -} - -func TestSubscriberNotifiedWhenEntryDropped(t *testing.T) { - cache := newTestCache() - - subA := cache.subscribeToWorkloadUpdates(makeSelectors("A")) - defer subA.Finish() - assertAnyWorkloadUpdate(t, subA) - - // subB's job here is to just make sure we don't notify unrelated - // subscribers when dropping registration entries - subB := cache.subscribeToWorkloadUpdates(makeSelectors("B")) - defer subB.Finish() - assertAnyWorkloadUpdate(t, subB) - - foo := makeRegistrationEntry("FOO", "A") - bar := makeRegistrationEntry("BAR", "B") - - updateEntries := &UpdateEntries{ - Bundles: makeBundles(bundleV1), - RegistrationEntries: makeRegistrationEntries(foo), - } - cache.UpdateEntries(updateEntries, nil) - cache.UpdateSVIDs(&UpdateSVIDs{ - X509SVIDs: makeX509SVIDs(foo), - }) - - // make sure subA gets notified with FOO but not subB - assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ - Bundle: bundleV1, - Identities: []Identity{{Entry: foo}}, - }) - assertNoWorkloadUpdate(t, subB) - - // Swap out FOO for BAR - updateEntries.RegistrationEntries = makeRegistrationEntries(bar) - cache.UpdateEntries(updateEntries, nil) - cache.UpdateSVIDs(&UpdateSVIDs{ - X509SVIDs: makeX509SVIDs(bar), - }) - assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ - Bundle: bundleV1, - }) - assertWorkloadUpdateEqual(t, subB, &WorkloadUpdate{ - Bundle: bundleV1, - Identities: []Identity{{Entry: bar}}, - }) - - // Drop both - updateEntries.RegistrationEntries = nil - cache.UpdateEntries(updateEntries, nil) - assertNoWorkloadUpdate(t, subA) - assertWorkloadUpdateEqual(t, subB, &WorkloadUpdate{ - Bundle: bundleV1, - }) - - // Make sure trying to update SVIDs of removed entry does not notify - cache.UpdateSVIDs(&UpdateSVIDs{ - X509SVIDs: makeX509SVIDs(foo), - }) - assertNoWorkloadUpdate(t, subB) -} - -func TestSubscriberOnlyGetsEntriesWithSVID(t *testing.T) { - cache := newTestCache() - - foo := makeRegistrationEntry("FOO", "A") - updateEntries := &UpdateEntries{ - Bundles: makeBundles(bundleV1), - RegistrationEntries: makeRegistrationEntries(foo), - } - cache.UpdateEntries(updateEntries, nil) - - sub := cache.subscribeToWorkloadUpdates(makeSelectors("A")) - defer sub.Finish() - - // workload update does not include the identity because it has no SVID. - assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ - Bundle: bundleV1, - }) - - // update to include the SVID and now we should get the update - cache.UpdateSVIDs(&UpdateSVIDs{ - X509SVIDs: makeX509SVIDs(foo), - }) - assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ - Bundle: bundleV1, - Identities: []Identity{{Entry: foo}}, - }) -} - -func TestSubscribersDoNotBlockNotifications(t *testing.T) { - cache := newTestCache() - - sub := cache.subscribeToWorkloadUpdates(makeSelectors("A")) - defer sub.Finish() - - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV2), - }, nil) - - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV3), - }, nil) - - assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ - Bundle: bundleV3, - }) -} - -func TestCheckSVIDCallback(t *testing.T) { - cache := newTestCache() - - // no calls because there are no registration entries - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV2), - }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { - assert.Fail(t, "should not be called if there are no registration entries") - - return false - }) - - foo := makeRegistrationEntryWithTTL("FOO", 70, 80) - - // called once for FOO with no SVID - callCount := 0 - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV2), - RegistrationEntries: makeRegistrationEntries(foo), - }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { - callCount++ - assert.Equal(t, "FOO", newEntry.EntryId) - - // there is no already existing entry, only the new entry - assert.Nil(t, existingEntry) - assert.Equal(t, foo, newEntry) - assert.Nil(t, svid) - - return false - }) - assert.Equal(t, 1, callCount) - assert.Empty(t, cache.staleEntries) - - // called once for FOO with new SVID - callCount = 0 - svids := makeX509SVIDs(foo) - cache.UpdateSVIDs(&UpdateSVIDs{ - X509SVIDs: svids, - }) - - // called once for FOO with existing SVID - callCount = 0 - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV2), - RegistrationEntries: makeRegistrationEntries(foo), - }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { - callCount++ - assert.Equal(t, "FOO", newEntry.EntryId) - if assert.NotNil(t, svid) { - assert.Exactly(t, svids["FOO"], svid) - } - - return true - }) - assert.Equal(t, 1, callCount) - assert.Equal(t, map[string]bool{foo.EntryId: true}, cache.staleEntries) -} - -func TestGetStaleEntries(t *testing.T) { - cache := newTestCache() - - foo := makeRegistrationEntryWithTTL("FOO", 70, 80) - - // Create entry but don't mark it stale - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV2), - RegistrationEntries: makeRegistrationEntries(foo), - }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { - return false - }) - assert.Empty(t, cache.GetStaleEntries()) - - // Update entry and mark it as stale - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV2), - RegistrationEntries: makeRegistrationEntries(foo), - }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { - return true - }) - // Assert that the entry is returned as stale. The `ExpiresAt` field should be unset since there is no SVID. - expectedEntries := []*StaleEntry{{Entry: cache.records[foo.EntryId].entry}} - assert.Equal(t, expectedEntries, cache.GetStaleEntries()) - - // Update the SVID for the stale entry - svids := make(map[string]*X509SVID) - expiredAt := time.Now() - svids[foo.EntryId] = &X509SVID{ - Chain: []*x509.Certificate{{NotAfter: expiredAt}}, - } - cache.UpdateSVIDs(&UpdateSVIDs{ - X509SVIDs: svids, - }) - // Assert that updating the SVID removes stale marker from entry - assert.Empty(t, cache.GetStaleEntries()) - - // Update entry again and mark it as stale - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV2), - RegistrationEntries: makeRegistrationEntries(foo), - }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { - return true - }) - - // Assert that the entry again returns as stale. This time the `ExpiresAt` field should be populated with the expiration of the SVID. - expectedEntries = []*StaleEntry{{ - Entry: cache.records[foo.EntryId].entry, - SVIDExpiresAt: expiredAt, - }} - assert.Equal(t, expectedEntries, cache.GetStaleEntries()) - - // Remove registration entry and assert that it is no longer returned as stale - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV2), - }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { - return true - }) - assert.Empty(t, cache.GetStaleEntries()) -} - -func TestSubscriberNotNotifiedOnDifferentSVIDChanges(t *testing.T) { - cache := newTestCache() - - foo := makeRegistrationEntry("FOO", "A") - bar := makeRegistrationEntry("BAR", "B") - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV1), - RegistrationEntries: makeRegistrationEntries(foo, bar), - }, nil) - cache.UpdateSVIDs(&UpdateSVIDs{ - X509SVIDs: makeX509SVIDs(foo, bar), - }) - - sub := cache.subscribeToWorkloadUpdates(makeSelectors("A")) - defer sub.Finish() - assertAnyWorkloadUpdate(t, sub) - - // Update SVID - cache.UpdateSVIDs(&UpdateSVIDs{ - X509SVIDs: makeX509SVIDs(bar), - }) - - assertNoWorkloadUpdate(t, sub) -} - -func TestSubscriberNotNotifiedOnOverlappingSVIDChanges(t *testing.T) { - cache := newTestCache() - - foo := makeRegistrationEntry("FOO", "A", "C") - bar := makeRegistrationEntry("FOO", "A", "B") - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV1), - RegistrationEntries: makeRegistrationEntries(foo), - }, nil) - cache.UpdateSVIDs(&UpdateSVIDs{ - X509SVIDs: makeX509SVIDs(foo, bar), - }) - - sub := cache.subscribeToWorkloadUpdates(makeSelectors("A", "B")) - defer sub.Finish() - assertAnyWorkloadUpdate(t, sub) - - // Update SVID - cache.UpdateSVIDs(&UpdateSVIDs{ - X509SVIDs: makeX509SVIDs(foo), - }) - - assertNoWorkloadUpdate(t, sub) -} - -func BenchmarkCacheGlobalNotification(b *testing.B) { - cache := newTestCache() - - const numEntries = 1000 - const numWorkloads = 1000 - const selectorsPerEntry = 3 - const selectorsPerWorkload = 10 - - // build a set of 1000 registration entries with distinct selectors - bundlesV1 := makeBundles(bundleV1) - bundlesV2 := makeBundles(bundleV2) - updateEntries := &UpdateEntries{ - Bundles: bundlesV1, - RegistrationEntries: make(map[string]*common.RegistrationEntry, numEntries), - } - for i := 0; i < numEntries; i++ { - entryID := fmt.Sprintf("00000000-0000-0000-0000-%012d", i) - updateEntries.RegistrationEntries[entryID] = &common.RegistrationEntry{ - EntryId: entryID, - ParentId: "spiffe://domain.test/node", - SpiffeId: fmt.Sprintf("spiffe://domain.test/workload-%d", i), - Selectors: distinctSelectors(i, selectorsPerEntry), - } - } - - cache.UpdateEntries(updateEntries, nil) - for i := 0; i < numWorkloads; i++ { - selectors := distinctSelectors(i, selectorsPerWorkload) - cache.subscribeToWorkloadUpdates(selectors) - } - - runtime.GC() - - b.ResetTimer() - b.ReportAllocs() - for i := 0; i < b.N; i++ { - if i%2 == 0 { - updateEntries.Bundles = bundlesV2 - } else { - updateEntries.Bundles = bundlesV1 - } - cache.UpdateEntries(updateEntries, nil) - } -} - -func TestMatchingRegistrationEntries(t *testing.T) { - cache := newTestCache() - - // populate the cache with FOO and BAR without SVIDS - foo := makeRegistrationEntry("FOO", "A") - bar := makeRegistrationEntry("BAR", "B") - - // check empty result - assert.Equal(t, []*common.RegistrationEntry{}, - cache.MatchingRegistrationEntries(makeSelectors("A", "B"))) - - updateEntries := &UpdateEntries{ - Bundles: makeBundles(bundleV1), - RegistrationEntries: makeRegistrationEntries(foo, bar), - } - cache.UpdateEntries(updateEntries, nil) - - // Update SVIDs and MatchingRegistrationEntries should return both entries - updateSVIDs := &UpdateSVIDs{ - X509SVIDs: makeX509SVIDs(foo, bar), - } - cache.UpdateSVIDs(updateSVIDs) - assert.Equal(t, []*common.RegistrationEntry{bar, foo}, - cache.MatchingRegistrationEntries(makeSelectors("A", "B"))) -} - -func TestEntries(t *testing.T) { - cache := newTestCache() - - // populate the cache with FOO and BAR without SVIDS - foo := makeRegistrationEntry("FOO", "A") - bar := makeRegistrationEntry("BAR", "B") - updateEntries := &UpdateEntries{ - Bundles: makeBundles(bundleV1), - RegistrationEntries: makeRegistrationEntries(foo, bar), - } - cache.UpdateEntries(updateEntries, nil) - - assert.Equal(t, []*common.RegistrationEntry{bar, foo}, cache.Entries()) -} - -func distinctSelectors(id, n int) []*common.Selector { - out := make([]*common.Selector, 0, n) - for i := 0; i < n; i++ { - out = append(out, &common.Selector{ - Type: "test", - Value: fmt.Sprintf("id:%d:n:%d", id, i), - }) - } - return out -} - -func assertNoWorkloadUpdate(t *testing.T, sub Subscriber) { - select { - case update := <-sub.Updates(): - assert.FailNow(t, "unexpected workload update", update) - default: - } -} - -func assertAnyWorkloadUpdate(t *testing.T, sub Subscriber) { - select { - case <-sub.Updates(): - case <-time.After(time.Minute): - assert.FailNow(t, "timed out waiting for any workload update") - } -} - -func assertWorkloadUpdateEqual(t *testing.T, sub Subscriber, expected *WorkloadUpdate) { - select { - case actual := <-sub.Updates(): - assert.NotNil(t, actual.Bundle, "bundle is not set") - assert.True(t, actual.Bundle.Equal(expected.Bundle), "bundles don't match") - assert.Equal(t, expected.Identities, actual.Identities, "identities don't match") - case <-time.After(time.Minute): - assert.FailNow(t, "timed out waiting for workload update") - } -} - -func makeBundles(bundles ...*Bundle) map[spiffeid.TrustDomain]*Bundle { - out := make(map[spiffeid.TrustDomain]*Bundle) - for _, bundle := range bundles { - td := spiffeid.RequireTrustDomainFromString(bundle.TrustDomain().IDString()) - out[td] = bundle - } - return out -} - -func makeX509SVIDs(entries ...*common.RegistrationEntry) map[string]*X509SVID { - out := make(map[string]*X509SVID) - for _, entry := range entries { - out[entry.EntryId] = &X509SVID{} - } - return out -} - -func makeRegistrationEntry(id string, selectors ...string) *common.RegistrationEntry { - return &common.RegistrationEntry{ - EntryId: id, - SpiffeId: "spiffe://domain.test/" + id, - Selectors: makeSelectors(selectors...), - DnsNames: []string{fmt.Sprintf("name-%s", id)}, - X509SvidTtl: defaultX509SVIDTTL, - JwtSvidTtl: defaultJwtSVIDTTL, - } -} - -func makeRegistrationEntryWithTTL(id string, x509SVIDTTL int32, jwtSVIDTTL int32, selectors ...string) *common.RegistrationEntry { - return &common.RegistrationEntry{ - EntryId: id, - SpiffeId: "spiffe://domain.test/" + id, - Selectors: makeSelectors(selectors...), - DnsNames: []string{fmt.Sprintf("name-%s", id)}, - X509SvidTtl: x509SVIDTTL, - JwtSvidTtl: jwtSVIDTTL, - } -} - -func makeRegistrationEntries(entries ...*common.RegistrationEntry) map[string]*common.RegistrationEntry { - out := make(map[string]*common.RegistrationEntry) - for _, entry := range entries { - out[entry.EntryId] = entry - } - return out -} - -func makeSelectors(values ...string) []*common.Selector { - var out []*common.Selector - for _, value := range values { - out = append(out, &common.Selector{Type: "test", Value: value}) - } - return out -} - -func makeFederatesWith(bundles ...*Bundle) []string { - var out []string - for _, bundle := range bundles { - out = append(out, bundle.TrustDomain().IDString()) - } - return out -} diff --git a/pkg/agent/manager/cache/lru_cache.go b/pkg/agent/manager/cache/lru_cache.go index eb5b4e5140..100bda2e87 100644 --- a/pkg/agent/manager/cache/lru_cache.go +++ b/pkg/agent/manager/cache/lru_cache.go @@ -18,12 +18,30 @@ import ( ) const ( - // DefaultSVIDCacheMaxSize is set when svidCacheMaxSize is not provided - DefaultSVIDCacheMaxSize = 1000 + // SVIDCacheMaxSize is the size for the cache + SVIDCacheMaxSize = 1000 // SVIDSyncInterval is the interval at which SVIDs are synced with subscribers SVIDSyncInterval = 500 * time.Millisecond ) +// UpdateEntries holds information for an entries update to the cache. +type UpdateEntries struct { + // Bundles is a set of ALL trust bundles available to the agent, keyed by trust domain + Bundles map[spiffeid.TrustDomain]*spiffebundle.Bundle + + // RegistrationEntries is a set of all registration entries available to the + // agent, keyed by registration entry id. + RegistrationEntries map[string]*common.RegistrationEntry +} + +// StaleEntry holds stale entries with SVIDs expiration time +type StaleEntry struct { + // Entry stale registration entry + Entry *common.RegistrationEntry + // SVIDs expiration time + SVIDExpiresAt time.Time +} + // Cache caches each registration entry, bundles, and JWT SVIDs for the agent. // The signed X509-SVIDs for those entries are stored in LRU-like cache. // It allows subscriptions by (workload) selector sets and notifies subscribers when: @@ -42,7 +60,7 @@ const ( // selector it encounters. Each selector index tracks the subscribers (i.e // workloads) and registration entries that have that selector. // -// The LRU-like SVID cache has configurable size limit and expiry period. +// The LRU-like SVID cache has a size limit and expiry period. // 1. Size limit of SVID cache is a soft limit. If SVID has a subscriber present then // that SVID is never removed from cache. // 2. Least recently used SVIDs are removed from cache only after the cache expiry period has passed. @@ -106,17 +124,10 @@ type LRUCache struct { // svids are stored by entry IDs svids map[string]*X509SVID - // svidCacheMaxSize is a soft limit of max number of SVIDs that would be stored in cache - svidCacheMaxSize int subscribeBackoffFn func() backoff.BackOff } -func NewLRUCache(log logrus.FieldLogger, trustDomain spiffeid.TrustDomain, bundle *Bundle, metrics telemetry.Metrics, - svidCacheMaxSize int, clk clock.Clock) *LRUCache { - if svidCacheMaxSize <= 0 { - svidCacheMaxSize = DefaultSVIDCacheMaxSize - } - +func NewLRUCache(log logrus.FieldLogger, trustDomain spiffeid.TrustDomain, bundle *Bundle, metrics telemetry.Metrics, clk clock.Clock) *LRUCache { return &LRUCache{ BundleCache: NewBundleCache(trustDomain, bundle), JWTSVIDCache: NewJWTSVIDCache(), @@ -130,9 +141,8 @@ func NewLRUCache(log logrus.FieldLogger, trustDomain spiffeid.TrustDomain, bundl bundles: map[spiffeid.TrustDomain]*spiffebundle.Bundle{ trustDomain: bundle, }, - svids: make(map[string]*X509SVID), - svidCacheMaxSize: svidCacheMaxSize, - clk: clk, + svids: make(map[string]*X509SVID), + clk: clk, subscribeBackoffFn: func() backoff.BackOff { return backoff.NewBackoff(clk, SVIDSyncInterval) }, @@ -403,7 +413,7 @@ func (c *LRUCache) UpdateEntries(update *UpdateEntries, checkSVID func(*common.R // entries with active subscribers which are not cached will be put in staleEntries map; // irrespective of what svid cache size as we cannot deny identity to a subscriber activeSubsByEntryID, recordsWithLastAccessTime := c.syncSVIDsWithSubscribers() - extraSize := len(c.svids) - c.svidCacheMaxSize + extraSize := len(c.svids) - SVIDCacheMaxSize // delete svids without subscribers and which have not been accessed since svidCacheExpiryTime if extraSize > 0 { @@ -412,7 +422,7 @@ func (c *LRUCache) UpdateEntries(update *UpdateEntries, checkSVID func(*common.R for _, record := range recordsWithLastAccessTime { if extraSize <= 0 { - // no need to delete SVIDs any further as cache size <= svidCacheMaxSize + // no need to delete SVIDs any further as cache size <= SVIDCacheMaxSize break } if _, ok := c.svids[record.id]; ok { @@ -633,7 +643,7 @@ func (c *LRUCache) syncSVIDsWithSubscribers() (map[string]struct{}, []recordAcce lastAccessTimestamps = append(lastAccessTimestamps, newRecordAccessEvent(record.lastAccessTimestamp, id)) } - remainderSize := c.svidCacheMaxSize - len(c.svids) + remainderSize := SVIDCacheMaxSize - len(c.svids) // add records which are not cached for remainder of cache size for id := range c.records { if len(c.staleEntries) >= remainderSize { diff --git a/pkg/agent/manager/cache/lru_cache_subscriber.go b/pkg/agent/manager/cache/lru_cache_subscriber.go index 00556f89a9..7b23c81b3a 100644 --- a/pkg/agent/manager/cache/lru_cache_subscriber.go +++ b/pkg/agent/manager/cache/lru_cache_subscriber.go @@ -6,6 +6,11 @@ import ( "github.com/spiffe/spire/proto/spire/common" ) +type Subscriber interface { + Updates() <-chan *WorkloadUpdate + Finish() +} + type lruCacheSubscriber struct { cache *LRUCache set selectorSet diff --git a/pkg/agent/manager/cache/lru_cache_test.go b/pkg/agent/manager/cache/lru_cache_test.go index 4691826842..63da336a42 100644 --- a/pkg/agent/manager/cache/lru_cache_test.go +++ b/pkg/agent/manager/cache/lru_cache_test.go @@ -19,6 +19,18 @@ import ( "github.com/stretchr/testify/require" ) +var ( + trustDomain1 = spiffeid.RequireTrustDomainFromString("domain.test") + trustDomain2 = spiffeid.RequireTrustDomainFromString("otherdomain.test") + bundleV1 = spiffebundle.FromX509Authorities(trustDomain1, []*x509.Certificate{{Raw: []byte{1}}}) + bundleV2 = spiffebundle.FromX509Authorities(trustDomain1, []*x509.Certificate{{Raw: []byte{2}}}) + bundleV3 = spiffebundle.FromX509Authorities(trustDomain1, []*x509.Certificate{{Raw: []byte{3}}}) + otherBundleV1 = spiffebundle.FromX509Authorities(trustDomain2, []*x509.Certificate{{Raw: []byte{4}}}) + otherBundleV2 = spiffebundle.FromX509Authorities(trustDomain2, []*x509.Certificate{{Raw: []byte{5}}}) + defaultX509SVIDTTL = int32(700) + defaultJwtSVIDTTL = int32(800) +) + func TestLRUCacheFetchWorkloadUpdate(t *testing.T) { cache := newTestLRUCache(t) // populate the cache with FOO and BAR without SVIDS @@ -644,7 +656,7 @@ func TestLRUCacheSubscriberNotNotifiedOnOverlappingSVIDChanges(t *testing.T) { func TestLRUCacheSVIDCacheExpiry(t *testing.T) { clk := clock.NewMock(t) - cache := newTestLRUCacheWithConfig(10, clk) + cache := newTestLRUCacheWithConfig(clk) clk.Add(1 * time.Second) foo := makeRegistrationEntry("FOO", "A") @@ -687,8 +699,8 @@ func TestLRUCacheSVIDCacheExpiry(t *testing.T) { // Move clk by 2 seconds clk.Add(2 * time.Second) - // update total of 12 entries - updateEntries := createUpdateEntries(10, makeBundles(bundleV1)) + // update total of size+2 entries + updateEntries := createUpdateEntries(SVIDCacheMaxSize, makeBundles(bundleV1)) updateEntries.RegistrationEntries[foo.EntryId] = foo updateEntries.RegistrationEntries[bar.EntryId] = bar @@ -705,10 +717,10 @@ func TestLRUCacheSVIDCacheExpiry(t *testing.T) { sub.Finish() } } - assert.Equal(t, 12, cache.CountX509SVIDs()) + assert.Equal(t, SVIDCacheMaxSize+2, cache.CountX509SVIDs()) cache.UpdateEntries(updateEntries, nil) - assert.Equal(t, 10, cache.CountX509SVIDs()) + assert.Equal(t, SVIDCacheMaxSize, cache.CountX509SVIDs()) // foo SVID should be removed from cache as it does not have active subscriber assert.False(t, cache.notifySubscriberIfSVIDAvailable(makeSelectors("A"), subA.(*lruCacheSubscriber))) @@ -724,24 +736,24 @@ func TestLRUCacheSVIDCacheExpiry(t *testing.T) { require.Len(t, cache.GetStaleEntries(), 1) assert.Equal(t, foo, cache.GetStaleEntries()[0].Entry) - assert.Equal(t, 10, cache.CountX509SVIDs()) + assert.Equal(t, SVIDCacheMaxSize, cache.CountX509SVIDs()) } func TestLRUCacheMaxSVIDCacheSize(t *testing.T) { clk := clock.NewMock(t) - cache := newTestLRUCacheWithConfig(10, clk) + cache := newTestLRUCacheWithConfig(clk) // create entries more than maxSvidCacheSize - updateEntries := createUpdateEntries(12, makeBundles(bundleV1)) + updateEntries := createUpdateEntries(SVIDCacheMaxSize+2, makeBundles(bundleV1)) cache.UpdateEntries(updateEntries, nil) - require.Len(t, cache.GetStaleEntries(), 10) + require.Len(t, cache.GetStaleEntries(), SVIDCacheMaxSize) cache.UpdateSVIDs(&UpdateSVIDs{ X509SVIDs: makeX509SVIDsFromStaleEntries(cache.GetStaleEntries()), }) require.Len(t, cache.GetStaleEntries(), 0) - assert.Equal(t, 10, cache.CountX509SVIDs()) + assert.Equal(t, SVIDCacheMaxSize, cache.CountX509SVIDs()) // Validate that active subscriber will still get SVID even if SVID count is at maxSvidCacheSize foo := makeRegistrationEntry("FOO", "A") @@ -752,25 +764,25 @@ func TestLRUCacheMaxSVIDCacheSize(t *testing.T) { cache.UpdateEntries(updateEntries, nil) require.Len(t, cache.GetStaleEntries(), 1) - assert.Equal(t, 10, cache.CountX509SVIDs()) + assert.Equal(t, SVIDCacheMaxSize, cache.CountX509SVIDs()) cache.UpdateSVIDs(&UpdateSVIDs{ X509SVIDs: makeX509SVIDs(foo), }) - assert.Equal(t, 11, cache.CountX509SVIDs()) + assert.Equal(t, SVIDCacheMaxSize+1, cache.CountX509SVIDs()) require.Len(t, cache.GetStaleEntries(), 0) } func TestSyncSVIDsWithSubscribers(t *testing.T) { clk := clock.NewMock(t) - cache := newTestLRUCacheWithConfig(5, clk) + cache := newTestLRUCacheWithConfig(clk) - updateEntries := createUpdateEntries(5, makeBundles(bundleV1)) + updateEntries := createUpdateEntries(SVIDCacheMaxSize, makeBundles(bundleV1)) cache.UpdateEntries(updateEntries, nil) cache.UpdateSVIDs(&UpdateSVIDs{ X509SVIDs: makeX509SVIDsFromStaleEntries(cache.GetStaleEntries()), }) - assert.Equal(t, 5, cache.CountX509SVIDs()) + assert.Equal(t, SVIDCacheMaxSize, cache.CountX509SVIDs()) // Update foo but its SVID is not yet cached foo := makeRegistrationEntry("FOO", "A") @@ -788,7 +800,7 @@ func TestSyncSVIDsWithSubscribers(t *testing.T) { require.Len(t, cache.GetStaleEntries(), 1) assert.Equal(t, []*StaleEntry{{Entry: cache.records[foo.EntryId].entry}}, cache.GetStaleEntries()) - assert.Equal(t, 5, cache.CountX509SVIDs()) + assert.Equal(t, SVIDCacheMaxSize, cache.CountX509SVIDs()) } func TestNotifySubscriberWhenSVIDIsAvailable(t *testing.T) { @@ -813,15 +825,15 @@ func TestNotifySubscriberWhenSVIDIsAvailable(t *testing.T) { func TestSubscribeToWorkloadUpdatesLRUNoSelectors(t *testing.T) { clk := clock.NewMock(t) - cache := newTestLRUCacheWithConfig(1, clk) + cache := newTestLRUCacheWithConfig(clk) // Creating test entries, but this will not affect current test... foo := makeRegistrationEntry("FOO", "A") bar := makeRegistrationEntry("BAR", "B") - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV1), - RegistrationEntries: makeRegistrationEntries(foo, bar), - }, nil) + updateEntries := createUpdateEntries(SVIDCacheMaxSize, makeBundles(bundleV1)) + updateEntries.RegistrationEntries[foo.EntryId] = foo + updateEntries.RegistrationEntries[bar.EntryId] = bar + cache.UpdateEntries(updateEntries, nil) subWaitCh := make(chan struct{}, 1) subErrCh := make(chan error, 1) @@ -859,7 +871,7 @@ func TestSubscribeToWorkloadUpdatesLRUNoSelectors(t *testing.T) { <-subWaitCh cache.SyncSVIDsWithSubscribers() - assert.Len(t, cache.GetStaleEntries(), 1) + assert.Len(t, cache.GetStaleEntries(), SVIDCacheMaxSize) cache.UpdateSVIDs(&UpdateSVIDs{ X509SVIDs: makeX509SVIDs(foo, bar), }) @@ -875,7 +887,7 @@ func TestSubscribeToWorkloadUpdatesLRUNoSelectors(t *testing.T) { func TestSubscribeToLRUCacheChanges(t *testing.T) { clk := clock.NewMock(t) - cache := newTestLRUCacheWithConfig(1, clk) + cache := newTestLRUCacheWithConfig(clk) foo := makeRegistrationEntry("FOO", "A") bar := makeRegistrationEntry("BAR", "B") @@ -1011,13 +1023,9 @@ func TestMetrics(t *testing.T) { } func TestNewLRUCache(t *testing.T) { - // negative value - cache := newTestLRUCacheWithConfig(-5, clock.NewMock(t)) - require.Equal(t, DefaultSVIDCacheMaxSize, cache.svidCacheMaxSize) - - // zero value - cache = newTestLRUCacheWithConfig(0, clock.NewMock(t)) - require.Equal(t, DefaultSVIDCacheMaxSize, cache.svidCacheMaxSize) + // expected cache size + cache := newTestLRUCacheWithConfig(clock.NewMock(t)) + require.NotNil(t, cache) } func BenchmarkLRUCacheGlobalNotification(b *testing.B) { @@ -1068,13 +1076,12 @@ func BenchmarkLRUCacheGlobalNotification(b *testing.B) { func newTestLRUCache(t testing.TB) *LRUCache { log, _ := test.NewNullLogger() return NewLRUCache(log, spiffeid.RequireTrustDomainFromString("domain.test"), bundleV1, - telemetry.Blackhole{}, 0, clock.NewMock(t)) + telemetry.Blackhole{}, clock.NewMock(t)) } -func newTestLRUCacheWithConfig(svidCacheMaxSize int, clk clock.Clock) *LRUCache { +func newTestLRUCacheWithConfig(clk clock.Clock) *LRUCache { log, _ := test.NewNullLogger() - return NewLRUCache(log, spiffeid.RequireTrustDomainFromString("domain.test"), bundleV1, telemetry.Blackhole{}, - svidCacheMaxSize, clk) + return NewLRUCache(log, spiffeid.RequireTrustDomainFromString("domain.test"), bundleV1, telemetry.Blackhole{}, clk) } // numEntries should not be more than 12 digits @@ -1117,3 +1124,104 @@ func subscribeToWorkloadUpdates(t *testing.T, cache *LRUCache, selectors []*comm assert.NoError(t, err) return subscriber } + +func distinctSelectors(id, n int) []*common.Selector { + out := make([]*common.Selector, 0, n) + for i := 0; i < n; i++ { + out = append(out, &common.Selector{ + Type: "test", + Value: fmt.Sprintf("id:%d:n:%d", id, i), + }) + } + return out +} + +func assertNoWorkloadUpdate(t *testing.T, sub Subscriber) { + select { + case update := <-sub.Updates(): + assert.FailNow(t, "unexpected workload update", update) + default: + } +} + +func assertAnyWorkloadUpdate(t *testing.T, sub Subscriber) { + select { + case <-sub.Updates(): + case <-time.After(time.Minute): + assert.FailNow(t, "timed out waiting for any workload update") + } +} + +func assertWorkloadUpdateEqual(t *testing.T, sub Subscriber, expected *WorkloadUpdate) { + select { + case actual := <-sub.Updates(): + assert.NotNil(t, actual.Bundle, "bundle is not set") + assert.True(t, actual.Bundle.Equal(expected.Bundle), "bundles don't match") + assert.Equal(t, expected.Identities, actual.Identities, "identities don't match") + case <-time.After(time.Minute): + assert.FailNow(t, "timed out waiting for workload update") + } +} + +func makeBundles(bundles ...*Bundle) map[spiffeid.TrustDomain]*Bundle { + out := make(map[spiffeid.TrustDomain]*Bundle) + for _, bundle := range bundles { + td := spiffeid.RequireTrustDomainFromString(bundle.TrustDomain().IDString()) + out[td] = bundle + } + return out +} + +func makeX509SVIDs(entries ...*common.RegistrationEntry) map[string]*X509SVID { + out := make(map[string]*X509SVID) + for _, entry := range entries { + out[entry.EntryId] = &X509SVID{} + } + return out +} + +func makeRegistrationEntry(id string, selectors ...string) *common.RegistrationEntry { + return &common.RegistrationEntry{ + EntryId: id, + SpiffeId: "spiffe://domain.test/" + id, + Selectors: makeSelectors(selectors...), + DnsNames: []string{fmt.Sprintf("name-%s", id)}, + X509SvidTtl: defaultX509SVIDTTL, + JwtSvidTtl: defaultJwtSVIDTTL, + } +} + +func makeRegistrationEntryWithTTL(id string, x509SVIDTTL int32, jwtSVIDTTL int32, selectors ...string) *common.RegistrationEntry { + return &common.RegistrationEntry{ + EntryId: id, + SpiffeId: "spiffe://domain.test/" + id, + Selectors: makeSelectors(selectors...), + DnsNames: []string{fmt.Sprintf("name-%s", id)}, + X509SvidTtl: x509SVIDTTL, + JwtSvidTtl: jwtSVIDTTL, + } +} + +func makeRegistrationEntries(entries ...*common.RegistrationEntry) map[string]*common.RegistrationEntry { + out := make(map[string]*common.RegistrationEntry) + for _, entry := range entries { + out[entry.EntryId] = entry + } + return out +} + +func makeSelectors(values ...string) []*common.Selector { + var out []*common.Selector + for _, value := range values { + out = append(out, &common.Selector{Type: "test", Value: value}) + } + return out +} + +func makeFederatesWith(bundles ...*Bundle) []string { + var out []string + for _, bundle := range bundles { + out = append(out, bundle.TrustDomain().IDString()) + } + return out +} diff --git a/pkg/agent/manager/cache/sets.go b/pkg/agent/manager/cache/sets.go index f471d56d45..c081224d8b 100644 --- a/pkg/agent/manager/cache/sets.go +++ b/pkg/agent/manager/cache/sets.go @@ -13,24 +13,12 @@ var ( }, } - subscriberSetPool = sync.Pool{ - New: func() any { - return make(subscriberSet) - }, - } - selectorSetPool = sync.Pool{ New: func() any { return make(selectorSet) }, } - recordSetPool = sync.Pool{ - New: func() any { - return make(recordSet) - }, - } - lruCacheRecordSetPool = sync.Pool{ New: func() any { return make(lruCacheRecordSet) @@ -67,23 +55,6 @@ func (set stringSet) Merge(ss ...string) { } } -// unique set of subscribers, allocated from a pool -type subscriberSet map[*subscriber]struct{} - -func allocSubscriberSet() (subscriberSet, func()) { - set := subscriberSetPool.Get().(subscriberSet) - return set, func() { - clearSubscriberSet(set) - subscriberSetPool.Put(set) - } -} - -func clearSubscriberSet(set subscriberSet) { - for k := range set { - delete(set, k) - } -} - // unique set of selectors, allocated from a pool type selector struct { Type string @@ -144,23 +115,6 @@ func (set selectorSet) SuperSetOf(other selectorSet) bool { return true } -// unique set of cache records, allocated from a pool -type recordSet map[*cacheRecord]struct{} - -func allocRecordSet() (recordSet, func()) { - set := recordSetPool.Get().(recordSet) - return set, func() { - clearRecordSet(set) - recordSetPool.Put(set) - } -} - -func clearRecordSet(set recordSet) { - for k := range set { - delete(set, k) - } -} - // unique set of LRU cache records, allocated from a pool type lruCacheRecordSet map[*lruCacheRecord]struct{} diff --git a/pkg/agent/manager/cache/subscriber.go b/pkg/agent/manager/cache/subscriber.go deleted file mode 100644 index c0d48ced43..0000000000 --- a/pkg/agent/manager/cache/subscriber.go +++ /dev/null @@ -1,65 +0,0 @@ -package cache - -import ( - "sync" - - "github.com/spiffe/spire/proto/spire/common" -) - -type Subscriber interface { - Updates() <-chan *WorkloadUpdate - Finish() -} - -type subscriber struct { - cache *Cache - set selectorSet - setFree func() - - mu sync.Mutex - c chan *WorkloadUpdate - done bool -} - -func newSubscriber(cache *Cache, selectors []*common.Selector) *subscriber { - set, setFree := allocSelectorSet(selectors...) - return &subscriber{ - cache: cache, - set: set, - setFree: setFree, - c: make(chan *WorkloadUpdate, 1), - } -} - -func (s *subscriber) Updates() <-chan *WorkloadUpdate { - return s.c -} - -func (s *subscriber) Finish() { - s.mu.Lock() - done := s.done - if !done { - s.done = true - close(s.c) - } - s.mu.Unlock() - if !done { - s.cache.unsubscribe(s) - s.setFree() - s.set = nil - } -} - -func (s *subscriber) notify(update *WorkloadUpdate) { - s.mu.Lock() - defer s.mu.Unlock() - if s.done { - return - } - - select { - case <-s.c: - default: - } - s.c <- update -} diff --git a/pkg/agent/manager/cache/util.go b/pkg/agent/manager/cache/util.go index ab365514fd..adcbbabd3b 100644 --- a/pkg/agent/manager/cache/util.go +++ b/pkg/agent/manager/cache/util.go @@ -11,3 +11,9 @@ func sortEntriesByID(entries []*common.RegistrationEntry) { return entries[a].EntryId < entries[b].EntryId }) } + +func sortIdentities(identities []Identity) { + sort.Slice(identities, func(a, b int) bool { + return identities[a].Entry.EntryId < identities[b].Entry.EntryId + }) +} diff --git a/pkg/agent/manager/cache/workload.go b/pkg/agent/manager/cache/workload.go new file mode 100644 index 0000000000..8360e97fdb --- /dev/null +++ b/pkg/agent/manager/cache/workload.go @@ -0,0 +1,43 @@ +package cache + +import ( + "crypto" + "crypto/x509" + + "github.com/spiffe/go-spiffe/v2/bundle/spiffebundle" + "github.com/spiffe/go-spiffe/v2/spiffeid" + "github.com/spiffe/spire/proto/spire/common" +) + +type Selectors []*common.Selector + +// Identity holds the data for a single workload identity +type Identity struct { + Entry *common.RegistrationEntry + SVID []*x509.Certificate + PrivateKey crypto.Signer +} + +// UpdateSVIDs holds information for an SVIDs update to the cache. +type UpdateSVIDs struct { + // X509SVIDs is a set of updated X509-SVIDs that should be merged into + // the cache, keyed by registration entry id. + X509SVIDs map[string]*X509SVID +} + +// WorkloadUpdate is used to convey workload information to cache subscribers +type WorkloadUpdate struct { + Identities []Identity + Bundle *spiffebundle.Bundle + FederatedBundles map[spiffeid.TrustDomain]*spiffebundle.Bundle +} + +func (u *WorkloadUpdate) HasIdentity() bool { + return len(u.Identities) > 0 +} + +// X509SVID holds onto the SVID certificate chain and private key. +type X509SVID struct { + Chain []*x509.Certificate + PrivateKey crypto.Signer +} diff --git a/pkg/agent/manager/config.go b/pkg/agent/manager/config.go index f5d71bbe12..cd90472525 100644 --- a/pkg/agent/manager/config.go +++ b/pkg/agent/manager/config.go @@ -65,15 +65,8 @@ func newManager(c *Config) *manager { c.Clk = clock.New() } - var cache Cache - if c.DisableLRUCache { - cache = managerCache.New(c.Log.WithField(telemetry.SubsystemName, telemetry.CacheManager), c.TrustDomain, c.Bundle, - c.Metrics) - } else { - // use LRU cache implementation - cache = managerCache.NewLRUCache(c.Log.WithField(telemetry.SubsystemName, telemetry.CacheManager), c.TrustDomain, c.Bundle, - c.Metrics, c.SVIDCacheMaxSize, c.Clk) - } + cache := managerCache.NewLRUCache(c.Log.WithField(telemetry.SubsystemName, telemetry.CacheManager), c.TrustDomain, c.Bundle, + c.Metrics, c.Clk) rotCfg := &svid.RotatorConfig{ SVIDKeyManager: keymanager.ForSVID(c.Catalog.GetKeyManager()), diff --git a/pkg/agent/manager/sync.go b/pkg/agent/manager/sync.go index 141d615fe1..6e826716cc 100644 --- a/pkg/agent/manager/sync.go +++ b/pkg/agent/manager/sync.go @@ -36,12 +36,8 @@ type SVIDCache interface { } func (m *manager) syncSVIDs(ctx context.Context) (err error) { - // perform syncSVIDs only if using LRU cache - if m.c.SVIDCacheMaxSize > 0 { - m.cache.SyncSVIDsWithSubscribers() - return m.updateSVIDs(ctx, m.c.Log.WithField(telemetry.CacheType, "workload"), m.cache) - } - return nil + m.cache.SyncSVIDsWithSubscribers() + return m.updateSVIDs(ctx, m.c.Log.WithField(telemetry.CacheType, "workload"), m.cache) } // synchronize fetches the authorized entries from the server, updates the diff --git a/test/integration/suites/agent-cli/conf/agent/agent.conf b/test/integration/suites/agent-cli/conf/agent/agent.conf index 20020d1df6..44085a50e4 100644 --- a/test/integration/suites/agent-cli/conf/agent/agent.conf +++ b/test/integration/suites/agent-cli/conf/agent/agent.conf @@ -7,9 +7,6 @@ agent { trust_bundle_path = "/opt/spire/conf/agent/bootstrap.crt" trust_domain = "domain.test" admin_socket_path = "/opt/debug.sock" - experimental { - x509_svid_cache_max_size = 8 - } } plugins { diff --git a/test/integration/suites/fetch-x509-svids/04-create-registration-entries b/test/integration/suites/fetch-x509-svids/04-create-registration-entries index 318b53162d..c70ae606b8 100755 --- a/test/integration/suites/fetch-x509-svids/04-create-registration-entries +++ b/test/integration/suites/fetch-x509-svids/04-create-registration-entries @@ -1,6 +1,8 @@ #!/bin/bash -SIZE=10 +# LRU Cache size is 1000; we expect uid:1001 to receive all 1002 identities, +# and later on disconnect for the cache to be pruned back to 1000 +SIZE=1002 # Create entries for uid 1001 for ((m=1;m<=$SIZE;m++)); do diff --git a/test/integration/suites/fetch-x509-svids/05-fetch-x509-svids b/test/integration/suites/fetch-x509-svids/05-fetch-x509-svids index 2518884a74..e80fea74b0 100755 --- a/test/integration/suites/fetch-x509-svids/05-fetch-x509-svids +++ b/test/integration/suites/fetch-x509-svids/05-fetch-x509-svids @@ -1,7 +1,7 @@ #!/bin/bash -ENTRYCOUNT=10 -CACHESIZE=8 +ENTRYCOUNT=1002 +CACHESIZE=1000 X509SVIDCOUNT=$(docker compose exec -u 1001 -T spire-agent \ /opt/spire/bin/spire-agent api fetch x509 \ @@ -14,4 +14,4 @@ else fi # Call agent debug endpoints and check if extra X.509-SVIDs from cache are cleaned up -check-x509-svid-count "spire-agent" $CACHESIZE +check-x509-svid-count "spire-agent" 1000 diff --git a/test/integration/suites/fetch-x509-svids/06-create-registration-entries b/test/integration/suites/fetch-x509-svids/06-create-registration-entries index cb0f9333d6..5e50ceb10e 100755 --- a/test/integration/suites/fetch-x509-svids/06-create-registration-entries +++ b/test/integration/suites/fetch-x509-svids/06-create-registration-entries @@ -1,6 +1,8 @@ #!/bin/bash -SIZE=10 +# LRU Cache size is 1000; we expect uid:1002 to receive all 1002 identities, +# and later on disconnect for the cache to be pruned back to 1000 +SIZE=1002 # Create entries for uid 1002 for ((m=1;m<=$SIZE;m++)); do diff --git a/test/integration/suites/fetch-x509-svids/07-fetch-x509-svids b/test/integration/suites/fetch-x509-svids/07-fetch-x509-svids index 7ff7f43b14..103f2e3544 100755 --- a/test/integration/suites/fetch-x509-svids/07-fetch-x509-svids +++ b/test/integration/suites/fetch-x509-svids/07-fetch-x509-svids @@ -1,7 +1,7 @@ #!/bin/bash -CACHESIZE=8 -ENTRYCOUNT=10 +ENTRYCOUNT=1002 +CACHESIZE=1000 X509SVIDCOUNT=$(docker compose exec -u 1002 -T spire-agent \ /opt/spire/bin/spire-agent api fetch x509 \ diff --git a/test/integration/suites/fetch-x509-svids/conf/agent/agent.conf b/test/integration/suites/fetch-x509-svids/conf/agent/agent.conf index bdbc803a95..11012a904b 100644 --- a/test/integration/suites/fetch-x509-svids/conf/agent/agent.conf +++ b/test/integration/suites/fetch-x509-svids/conf/agent/agent.conf @@ -7,9 +7,6 @@ agent { trust_bundle_path = "/opt/spire/conf/agent/bootstrap.crt" trust_domain = "domain.test" admin_socket_path = "/opt/debug.sock" - experimental { - x509_svid_cache_max_size = 8 - } } plugins {