From 5057224138db4ae8aa701d285aa1c6b7cef71e47 Mon Sep 17 00:00:00 2001 From: Sorin Dumitru Date: Wed, 16 Oct 2024 10:16:35 +0100 Subject: [PATCH 1/5] spire-agent: use a LRU cache for the JWT-SVID cache Signed-off-by: Sorin Dumitru --- pkg/agent/manager/cache/jwt_cache.go | 54 +++++++++++++++++++++++----- pkg/agent/manager/cache/lru_cache.go | 5 +-- 2 files changed, 47 insertions(+), 12 deletions(-) diff --git a/pkg/agent/manager/cache/jwt_cache.go b/pkg/agent/manager/cache/jwt_cache.go index 8f5fd0d9a6..b2612cf3a4 100644 --- a/pkg/agent/manager/cache/jwt_cache.go +++ b/pkg/agent/manager/cache/jwt_cache.go @@ -1,6 +1,7 @@ package cache import ( + "container/list" "context" "crypto/sha256" "encoding/base64" @@ -23,18 +24,28 @@ type JWTSVIDCache struct { log logrus.FieldLogger metrics telemetry.Metrics mu sync.RWMutex - svids map[string]*client.JWTSVID + + svids map[string]*list.Element + lruList *list.List + + // svidCacheMaxSize is a hard limit of max number of SVIDs that would be stored in cache + svidCacheMaxSize int } func (c *JWTSVIDCache) CountJWTSVIDs() int { + c.mu.Lock() + defer c.mu.Unlock() + return len(c.svids) } func NewJWTSVIDCache(log logrus.FieldLogger, metrics telemetry.Metrics) *JWTSVIDCache { return &JWTSVIDCache{ - metrics: metrics, - log: log, - svids: make(map[string]*client.JWTSVID), + metrics: metrics, + log: log, + svids: make(map[string]*list.Element), + lruList: list.New(), + svidCacheMaxSize: 1024, // TODO: make configurable } } @@ -43,8 +54,14 @@ func (c *JWTSVIDCache) GetJWTSVID(spiffeID spiffeid.ID, audience []string) (*cli c.mu.Lock() defer c.mu.Unlock() - svid, ok := c.svids[key] - return svid, ok + + svidElement, ok := c.svids[key] + if !ok { + return nil, ok + } + c.lruList.MoveToFront(svidElement) + + return svidElement.Value.(*client.JWTSVID), ok } func (c *JWTSVIDCache) SetJWTSVID(spiffeID spiffeid.ID, audience []string, svid *client.JWTSVID) { @@ -52,7 +69,26 @@ func (c *JWTSVIDCache) SetJWTSVID(spiffeID spiffeid.ID, audience []string, svid c.mu.Lock() defer c.mu.Unlock() - c.svids[key] = svid + + if len(c.svids) > c.svidCacheMaxSize { + element := c.lruList.Back() + keyID, err := getKeyIDFromSVIDToken(element.Value.(*client.JWTSVID).Token) + if err != nil { + c.log.WithError(err).Error("Could not get key ID from cached JWT-SVID") + return + } + delete(c.svids, keyID) + c.lruList.Remove(element) + } + + svidElement, ok := c.svids[key] + if ok { + svidElement.Value = svid + c.lruList.MoveToFront(svidElement) + } else { + svidElement = c.lruList.PushFront(svid) + c.svids[key] = svidElement + } } func (c *JWTSVIDCache) TaintJWTSVIDs(ctx context.Context, taintedJWTAuthorities map[string]struct{}) { @@ -64,7 +100,8 @@ func (c *JWTSVIDCache) TaintJWTSVIDs(ctx context.Context, taintedJWTAuthorities removedKeyIDs := make(map[string]int) totalCount := 0 - for key, jwtSVID := range c.svids { + for key, element := range c.svids { + jwtSVID := element.Value.(*client.JWTSVID) keyID, err := getKeyIDFromSVIDToken(jwtSVID.Token) if err != nil { c.log.WithError(err).Error("Could not get key ID from cached JWT-SVID") @@ -72,6 +109,7 @@ func (c *JWTSVIDCache) TaintJWTSVIDs(ctx context.Context, taintedJWTAuthorities } if _, tainted := taintedJWTAuthorities[keyID]; tainted { delete(c.svids, key) + c.lruList.Remove(element) removedKeyIDs[keyID]++ totalCount++ } diff --git a/pkg/agent/manager/cache/lru_cache.go b/pkg/agent/manager/cache/lru_cache.go index 4fe51f16af..43f99d5be4 100644 --- a/pkg/agent/manager/cache/lru_cache.go +++ b/pkg/agent/manager/cache/lru_cache.go @@ -218,10 +218,7 @@ func (c *LRUCache) CountX509SVIDs() int { } func (c *LRUCache) CountJWTSVIDs() int { - c.mu.RLock() - defer c.mu.RUnlock() - - return len(c.JWTSVIDCache.svids) + return c.JWTSVIDCache.CountJWTSVIDs() } func (c *LRUCache) CountRecords() int { From be010330d8fd3944eb80d2d0a20fea22af2fd00d Mon Sep 17 00:00:00 2001 From: Sorin Dumitru Date: Wed, 16 Oct 2024 11:14:49 +0100 Subject: [PATCH 2/5] Make JWT-SVID cache size configurable Signed-off-by: Sorin Dumitru --- .../api/delegatedidentity/v1/service_test.go | 2 +- pkg/agent/manager/cache/jwt_cache.go | 7 +++-- pkg/agent/manager/cache/jwt_cache_test.go | 27 ++++++++++--------- pkg/agent/manager/cache/lru_cache.go | 22 +++++++-------- pkg/agent/manager/cache/lru_cache_test.go | 4 +-- pkg/agent/manager/config.go | 2 +- 6 files changed, 34 insertions(+), 30 deletions(-) diff --git a/pkg/agent/api/delegatedidentity/v1/service_test.go b/pkg/agent/api/delegatedidentity/v1/service_test.go index 6375354d9c..ca3a9aef3a 100644 --- a/pkg/agent/api/delegatedidentity/v1/service_test.go +++ b/pkg/agent/api/delegatedidentity/v1/service_test.go @@ -996,7 +996,7 @@ func (m *FakeManager) SubscribeToBundleChanges() *cache.BundleStream { func newTestCache() *cache.LRUCache { log, _ := test.NewNullLogger() - return cache.NewLRUCache(log, trustDomain1, bundle1, telemetry.Blackhole{}, cache.DefaultSVIDCacheMaxSize, clock.New()) + return cache.NewLRUCache(log, trustDomain1, bundle1, telemetry.Blackhole{}, cache.DefaultSVIDCacheMaxSize, cache.DefaultSVIDCacheMaxSize, clock.New()) } func generateSubscribeToX509SVIDMetrics() []fakemetrics.MetricItem { diff --git a/pkg/agent/manager/cache/jwt_cache.go b/pkg/agent/manager/cache/jwt_cache.go index b2612cf3a4..12b55db4a4 100644 --- a/pkg/agent/manager/cache/jwt_cache.go +++ b/pkg/agent/manager/cache/jwt_cache.go @@ -39,13 +39,16 @@ func (c *JWTSVIDCache) CountJWTSVIDs() int { return len(c.svids) } -func NewJWTSVIDCache(log logrus.FieldLogger, metrics telemetry.Metrics) *JWTSVIDCache { +func NewJWTSVIDCache(log logrus.FieldLogger, metrics telemetry.Metrics, svidCacheMaxSize int) *JWTSVIDCache { + if svidCacheMaxSize <= 0 { + svidCacheMaxSize = DefaultSVIDCacheMaxSize + } return &JWTSVIDCache{ metrics: metrics, log: log, svids: make(map[string]*list.Element), lruList: list.New(), - svidCacheMaxSize: 1024, // TODO: make configurable + svidCacheMaxSize: svidCacheMaxSize, } } diff --git a/pkg/agent/manager/cache/jwt_cache_test.go b/pkg/agent/manager/cache/jwt_cache_test.go index 00c1442445..1a3487055a 100644 --- a/pkg/agent/manager/cache/jwt_cache_test.go +++ b/pkg/agent/manager/cache/jwt_cache_test.go @@ -21,13 +21,14 @@ func TestJWTSVIDCache(t *testing.T) { now := time.Now() tok1 := "eyJhbGciOiJFUzI1NiIsImtpZCI6ImRaRGZZaXcxdUd6TXdkTVlITDdGRVl5SzhIT0tLd0xYIiwidHlwIjoiSldUIn0.eyJhdWQiOlsidGVzdC1hdWRpZW5jZSJdLCJleHAiOjE3MjQzNjU3MzEsImlhdCI6MTcyNDI3OTQwNywic3ViIjoic3BpZmZlOi8vZXhhbXBsZS5vcmcvYWdlbnQvZGJ1c2VyIn0.dFr-oWhm5tK0bBuVXt-sGESM5l7hhoY-Gtt5DkuFoJL5Y9d4ZfmicCvUCjL4CqDB3BO_cPqmFfrO7H7pxQbGLg" tok2 := "eyJhbGciOiJFUzI1NiIsImtpZCI6ImNKMXI5TVY4OTZTWXBMY0RMUjN3Q29QRHprTXpkN25tIiwidHlwIjoiSldUIn0.eyJhdWQiOlsidGVzdC1hdWRpZW5jZSJdLCJleHAiOjE3Mjg1NzEwMjUsImlhdCI6MTcyODU3MDcyNSwic3ViIjoic3BpZmZlOi8vZXhhbXBsZS5vcmcvYWdlbnQvZGJ1c2VyIn0.1YnDj7nknwIHEuNKEN0cNypXKS4SUeILXlNOsOs2XElHzfKhhDcl0sYKYtQc1Itf6cygz9C16VOQ_Yjoos2Qfg" - jwtSVID := &client.JWTSVID{Token: tok1, IssuedAt: now, ExpiresAt: now.Add(time.Second)} + jwtSVID1 := &client.JWTSVID{Token: tok1, IssuedAt: now, ExpiresAt: now.Add(time.Second)} jwtSVID2 := &client.JWTSVID{Token: tok2, IssuedAt: now, ExpiresAt: now.Add(time.Second)} + //jwtSVID3 := &client.JWTSVID{Token: tok2, IssuedAt: now, ExpiresAt: now.Add(time.Second)} fakeMetrics := fakemetrics.New() log, logHook := test.NewNullLogger() log.Level = logrus.DebugLevel - cache := NewJWTSVIDCache(log, fakeMetrics) + cache := NewJWTSVIDCache(log, fakeMetrics, 8) spiffeID := spiffeid.RequireFromString("spiffe://example.org/blog") @@ -37,10 +38,10 @@ func TestJWTSVIDCache(t *testing.T) { assert.Nil(t, actual) // JWT is cached - cache.SetJWTSVID(spiffeID, []string{"bar"}, jwtSVID) + cache.SetJWTSVID(spiffeID, []string{"bar"}, jwtSVID1) actual, ok = cache.GetJWTSVID(spiffeID, []string{"bar"}) assert.True(t, ok) - assert.Equal(t, jwtSVID, actual) + assert.Equal(t, jwtSVID1, actual) // Test tainting of JWt-SVIDs ctx := context.Background() @@ -57,7 +58,7 @@ func TestJWTSVIDCache(t *testing.T) { name: "one authority tainted, one JWT-SVID", taintedKeyIDs: map[string]struct{}{keyID1: {}}, setJWTSVIDsCached: func(cache *JWTSVIDCache) { - cache.SetJWTSVID(spiffeID, []string{"audience-1"}, jwtSVID) + cache.SetJWTSVID(spiffeID, []string{"audience-1"}, jwtSVID1) }, expectLogs: []spiretest.LogEntry{ { @@ -93,8 +94,8 @@ func TestJWTSVIDCache(t *testing.T) { name: "one authority tainted, multiple JWT-SVIDs", taintedKeyIDs: map[string]struct{}{keyID1: {}}, setJWTSVIDsCached: func(cache *JWTSVIDCache) { - cache.SetJWTSVID(spiffeID, []string{"audience-1"}, jwtSVID) - cache.SetJWTSVID(spiffeID, []string{"audience-2"}, jwtSVID) + cache.SetJWTSVID(spiffeID, []string{"audience-1"}, jwtSVID1) + cache.SetJWTSVID(spiffeID, []string{"audience-2"}, jwtSVID1) }, expectLogs: []spiretest.LogEntry{ { @@ -130,8 +131,8 @@ func TestJWTSVIDCache(t *testing.T) { name: "multiple authorities tainted, multiple JWT-SVIDs", taintedKeyIDs: map[string]struct{}{keyID1: {}, keyID2: {}}, setJWTSVIDsCached: func(cache *JWTSVIDCache) { - cache.SetJWTSVID(spiffeID, []string{"audience-1"}, jwtSVID) - cache.SetJWTSVID(spiffeID, []string{"audience-2"}, jwtSVID) + cache.SetJWTSVID(spiffeID, []string{"audience-1"}, jwtSVID1) + cache.SetJWTSVID(spiffeID, []string{"audience-2"}, jwtSVID1) cache.SetJWTSVID(spiffeID, []string{"audience-3"}, jwtSVID2) }, expectLogs: []spiretest.LogEntry{ @@ -176,8 +177,8 @@ func TestJWTSVIDCache(t *testing.T) { name: "none of the authorities tainted is in cache", taintedKeyIDs: map[string]struct{}{"not-cached-1": {}, "not-cached-2": {}}, setJWTSVIDsCached: func(cache *JWTSVIDCache) { - cache.SetJWTSVID(spiffeID, []string{"audience-1"}, jwtSVID) - cache.SetJWTSVID(spiffeID, []string{"audience-2"}, jwtSVID) + cache.SetJWTSVID(spiffeID, []string{"audience-1"}, jwtSVID1) + cache.SetJWTSVID(spiffeID, []string{"audience-2"}, jwtSVID1) cache.SetJWTSVID(spiffeID, []string{"audience-3"}, jwtSVID2) }, expectMetrics: []fakemetrics.MetricItem{ @@ -203,7 +204,7 @@ func TestJWTSVIDCache(t *testing.T) { } { tt := tt t.Run(tt.name, func(t *testing.T) { - cache := NewJWTSVIDCache(log, fakeMetrics) + cache := NewJWTSVIDCache(log, fakeMetrics, 8) if tt.setJWTSVIDsCached != nil { tt.setJWTSVIDsCached(cache) } @@ -229,7 +230,7 @@ func TestJWTSVIDCacheKeyHashing(t *testing.T) { fakeMetrics := fakemetrics.New() log, _ := test.NewNullLogger() log.Level = logrus.DebugLevel - cache := NewJWTSVIDCache(log, fakeMetrics) + cache := NewJWTSVIDCache(log, fakeMetrics, 8) cache.SetJWTSVID(spiffeID, []string{"ab", "cd"}, expected) // JWT is cached diff --git a/pkg/agent/manager/cache/lru_cache.go b/pkg/agent/manager/cache/lru_cache.go index 43f99d5be4..ef42bebc5d 100644 --- a/pkg/agent/manager/cache/lru_cache.go +++ b/pkg/agent/manager/cache/lru_cache.go @@ -21,7 +21,7 @@ import ( ) const ( - // DefaultSVIDCacheMaxSize is set when svidCacheMaxSize is not provided + // DefaultSVIDCacheMaxSize is set when x509SvidCacheMaxSize is not provided DefaultSVIDCacheMaxSize = 1000 // SVIDSyncInterval is the interval at which SVIDs are synced with subscribers SVIDSyncInterval = 500 * time.Millisecond @@ -141,7 +141,7 @@ type LRUCache struct { svids map[string]*X509SVID // svidCacheMaxSize is a soft limit of max number of SVIDs that would be stored in cache - svidCacheMaxSize int + x509SvidCacheMaxSize int subscribeBackoffFn func() backoff.BackOff @@ -150,14 +150,14 @@ type LRUCache struct { taintedBatchProcessedCh chan struct{} } -func NewLRUCache(log logrus.FieldLogger, trustDomain spiffeid.TrustDomain, bundle *Bundle, metrics telemetry.Metrics, svidCacheMaxSize int, clk clock.Clock) *LRUCache { - if svidCacheMaxSize <= 0 { - svidCacheMaxSize = DefaultSVIDCacheMaxSize +func NewLRUCache(log logrus.FieldLogger, trustDomain spiffeid.TrustDomain, bundle *Bundle, metrics telemetry.Metrics, x509SvidCacheMaxSize int, jwtSvidCacheMaxSize int, clk clock.Clock) *LRUCache { + if x509SvidCacheMaxSize <= 0 { + x509SvidCacheMaxSize = DefaultSVIDCacheMaxSize } return &LRUCache{ BundleCache: NewBundleCache(trustDomain, bundle), - JWTSVIDCache: NewJWTSVIDCache(log, metrics), + JWTSVIDCache: NewJWTSVIDCache(log, metrics, jwtSvidCacheMaxSize), log: log, metrics: metrics, @@ -168,9 +168,9 @@ func NewLRUCache(log logrus.FieldLogger, trustDomain spiffeid.TrustDomain, bundl bundles: map[spiffeid.TrustDomain]*spiffebundle.Bundle{ trustDomain: bundle, }, - svids: make(map[string]*X509SVID), - svidCacheMaxSize: svidCacheMaxSize, - clk: clk, + svids: make(map[string]*X509SVID), + x509SvidCacheMaxSize: x509SvidCacheMaxSize, + clk: clk, subscribeBackoffFn: func() backoff.BackOff { return backoff.NewBackoff(clk, SVIDSyncInterval) }, @@ -439,7 +439,7 @@ func (c *LRUCache) UpdateEntries(update *UpdateEntries, checkSVID func(*common.R // entries with active subscribers which are not cached will be put in staleEntries map; // irrespective of what svid cache size as we cannot deny identity to a subscriber activeSubsByEntryID, recordsWithLastAccessTime := c.syncSVIDsWithSubscribers() - extraSize := len(c.svids) - c.svidCacheMaxSize + extraSize := len(c.svids) - c.x509SvidCacheMaxSize // delete svids without subscribers and which have not been accessed since svidCacheExpiryTime if extraSize > 0 { @@ -778,7 +778,7 @@ func (c *LRUCache) syncSVIDsWithSubscribers() (map[string]struct{}, []recordAcce lastAccessTimestamps = append(lastAccessTimestamps, newRecordAccessEvent(record.lastAccessTimestamp, id)) } - remainderSize := c.svidCacheMaxSize - len(c.svids) + remainderSize := c.x509SvidCacheMaxSize - len(c.svids) // add records which are not cached for remainder of cache size for id := range c.records { if len(c.staleEntries) >= remainderSize { diff --git a/pkg/agent/manager/cache/lru_cache_test.go b/pkg/agent/manager/cache/lru_cache_test.go index d42e6ff516..70b953446f 100644 --- a/pkg/agent/manager/cache/lru_cache_test.go +++ b/pkg/agent/manager/cache/lru_cache_test.go @@ -1284,12 +1284,12 @@ func BenchmarkLRUCacheGlobalNotification(b *testing.B) { func newTestLRUCache(t testing.TB) *LRUCache { log, _ := test.NewNullLogger() return NewLRUCache(log, spiffeid.RequireTrustDomainFromString("domain.test"), bundleV1, - telemetry.Blackhole{}, 0, clock.NewMock(t)) + telemetry.Blackhole{}, 0, 0, clock.NewMock(t)) } func newTestLRUCacheWithConfig(svidCacheMaxSize int, clk clock.Clock) *LRUCache { log, _ := test.NewNullLogger() - return NewLRUCache(log, trustDomain1, bundleV1, telemetry.Blackhole{}, svidCacheMaxSize, clk) + return NewLRUCache(log, trustDomain1, bundleV1, telemetry.Blackhole{}, svidCacheMaxSize, svidCacheMaxSize, clk) } // numEntries should not be more than 12 digits diff --git a/pkg/agent/manager/config.go b/pkg/agent/manager/config.go index 16a33e3bde..e2dab9a481 100644 --- a/pkg/agent/manager/config.go +++ b/pkg/agent/manager/config.go @@ -66,7 +66,7 @@ func newManager(c *Config) *manager { } cache := managerCache.NewLRUCache(c.Log.WithField(telemetry.SubsystemName, telemetry.CacheManager), c.TrustDomain, c.Bundle, - c.Metrics, c.SVIDCacheMaxSize, c.Clk) + c.Metrics, c.SVIDCacheMaxSize, c.SVIDCacheMaxSize, c.Clk) rotCfg := &svid.RotatorConfig{ SVIDKeyManager: keymanager.ForSVID(c.Catalog.GetKeyManager()), From 8b05cc20dfb790d8ba9f648aba122fa8fdb1393e Mon Sep 17 00:00:00 2001 From: Sorin Dumitru Date: Wed, 16 Oct 2024 13:05:37 +0100 Subject: [PATCH 3/5] Add a test for the LRU cache Signed-off-by: Sorin Dumitru --- pkg/agent/manager/cache/jwt_cache.go | 34 ++++++++++++------- pkg/agent/manager/cache/jwt_cache_test.go | 41 +++++++++++++++++++++++ 2 files changed, 63 insertions(+), 12 deletions(-) diff --git a/pkg/agent/manager/cache/jwt_cache.go b/pkg/agent/manager/cache/jwt_cache.go index 12b55db4a4..de24435c28 100644 --- a/pkg/agent/manager/cache/jwt_cache.go +++ b/pkg/agent/manager/cache/jwt_cache.go @@ -32,6 +32,11 @@ type JWTSVIDCache struct { svidCacheMaxSize int } +type jwtSvidElement struct { + key string + svid *client.JWTSVID +} + func (c *JWTSVIDCache) CountJWTSVIDs() int { c.mu.Lock() defer c.mu.Unlock() @@ -64,7 +69,7 @@ func (c *JWTSVIDCache) GetJWTSVID(spiffeID spiffeid.ID, audience []string) (*cli } c.lruList.MoveToFront(svidElement) - return svidElement.Value.(*client.JWTSVID), ok + return svidElement.Value.(jwtSvidElement).svid, ok } func (c *JWTSVIDCache) SetJWTSVID(spiffeID spiffeid.ID, audience []string, svid *client.JWTSVID) { @@ -73,23 +78,26 @@ func (c *JWTSVIDCache) SetJWTSVID(spiffeID spiffeid.ID, audience []string, svid c.mu.Lock() defer c.mu.Unlock() - if len(c.svids) > c.svidCacheMaxSize { + if len(c.svids) >= c.svidCacheMaxSize { element := c.lruList.Back() - keyID, err := getKeyIDFromSVIDToken(element.Value.(*client.JWTSVID).Token) - if err != nil { - c.log.WithError(err).Error("Could not get key ID from cached JWT-SVID") - return - } - delete(c.svids, keyID) + jwtSvidWithHash := element.Value.(jwtSvidElement) + delete(c.svids, jwtSvidWithHash.key) + c.log.Info("removing a svid") c.lruList.Remove(element) } svidElement, ok := c.svids[key] if ok { - svidElement.Value = svid + svidElement.Value = jwtSvidElement{ + key: key, + svid: svid, + } c.lruList.MoveToFront(svidElement) } else { - svidElement = c.lruList.PushFront(svid) + svidElement = c.lruList.PushFront(jwtSvidElement{ + key: key, + svid: svid, + }) c.svids[key] = svidElement } } @@ -104,15 +112,17 @@ func (c *JWTSVIDCache) TaintJWTSVIDs(ctx context.Context, taintedJWTAuthorities removedKeyIDs := make(map[string]int) totalCount := 0 for key, element := range c.svids { - jwtSVID := element.Value.(*client.JWTSVID) - keyID, err := getKeyIDFromSVIDToken(jwtSVID.Token) + jwtSvidElement := element.Value.(jwtSvidElement) + keyID, err := getKeyIDFromSVIDToken(jwtSvidElement.svid.Token) if err != nil { c.log.WithError(err).Error("Could not get key ID from cached JWT-SVID") continue } + if _, tainted := taintedJWTAuthorities[keyID]; tainted { delete(c.svids, key) c.lruList.Remove(element) + removedKeyIDs[keyID]++ totalCount++ } diff --git a/pkg/agent/manager/cache/jwt_cache_test.go b/pkg/agent/manager/cache/jwt_cache_test.go index 1a3487055a..19479cc4c3 100644 --- a/pkg/agent/manager/cache/jwt_cache_test.go +++ b/pkg/agent/manager/cache/jwt_cache_test.go @@ -222,6 +222,47 @@ func TestJWTSVIDCache(t *testing.T) { } } +func TestJWTSVIDCacheSize(t *testing.T) { + fakeMetrics := fakemetrics.New() + log, _ := test.NewNullLogger() + log.Level = logrus.DebugLevel + cache := NewJWTSVIDCache(log, fakeMetrics, 2) + + now := time.Now() + jwtSvid1 := &client.JWTSVID{Token: "1", IssuedAt: now, ExpiresAt: now.Add(time.Second)} + jwtSvid2 := &client.JWTSVID{Token: "3", IssuedAt: now, ExpiresAt: now.Add(time.Second)} + jwtSvid3 := &client.JWTSVID{Token: "3", IssuedAt: now, ExpiresAt: now.Add(time.Second)} + + spiffeID := spiffeid.RequireFromString("spiffe://example.org/blog") + cache.SetJWTSVID(spiffeID, []string{"audience-1"}, jwtSvid1) + cache.SetJWTSVID(spiffeID, []string{"audience-2"}, jwtSvid2) + cache.SetJWTSVID(spiffeID, []string{"audience-3"}, jwtSvid3) + + _, ok := cache.GetJWTSVID(spiffeID, []string{"audience-1"}) + assert.False(t, ok) + + actual, ok := cache.GetJWTSVID(spiffeID, []string{"audience-2"}) + assert.True(t, ok) + assert.Equal(t, jwtSvid2, actual) + + actual, ok = cache.GetJWTSVID(spiffeID, []string{"audience-3"}) + assert.True(t, ok) + assert.Equal(t, jwtSvid3, actual) + + // Make the second token the most recently used token + _, _ = cache.GetJWTSVID(spiffeID, []string{"audience-2"}) + + // Insert a token + cache.SetJWTSVID(spiffeID, []string{"audience-1"}, jwtSvid1) + + actual, ok = cache.GetJWTSVID(spiffeID, []string{"audience-2"}) + assert.True(t, ok) + assert.Equal(t, jwtSvid2, actual) + + _, ok = cache.GetJWTSVID(spiffeID, []string{"audience-3"}) + assert.False(t, ok) +} + func TestJWTSVIDCacheKeyHashing(t *testing.T) { spiffeID := spiffeid.RequireFromString("spiffe://example.org/blog") now := time.Now() From 20ec3c11d4b08e6c5e962170d4f30b8252780115 Mon Sep 17 00:00:00 2001 From: Sorin Dumitru Date: Thu, 17 Oct 2024 07:42:51 +0100 Subject: [PATCH 4/5] Make JWT-SVID cache size configurable Signed-off-by: Sorin Dumitru --- cmd/spire-agent/cli/run/run.go | 6 +++ doc/spire_agent.md | 3 +- pkg/agent/agent.go | 3 +- pkg/agent/config.go | 5 ++- pkg/agent/manager/config.go | 5 ++- pkg/agent/manager/manager_test.go | 62 ++++++++++++++++--------------- 6 files changed, 49 insertions(+), 35 deletions(-) diff --git a/cmd/spire-agent/cli/run/run.go b/cmd/spire-agent/cli/run/run.go index 609788b1be..e2522bca27 100644 --- a/cmd/spire-agent/cli/run/run.go +++ b/cmd/spire-agent/cli/run/run.go @@ -92,6 +92,7 @@ type agentConfig struct { AllowedForeignJWTClaims []string `hcl:"allowed_foreign_jwt_claims"` AvailabilityTarget string `hcl:"availability_target"` X509SVIDCacheMaxSize int `hcl:"x509_svid_cache_max_size"` + JWTSVIDCacheMaxSize int `hcl:"jwt_svid_cache_max_size"` AuthorizedDelegates []string `hcl:"authorized_delegates"` @@ -501,6 +502,11 @@ func NewAgentConfig(c *Config, logOptions []log.Option, allowUnknownConfig bool) } ac.X509SVIDCacheMaxSize = c.Agent.X509SVIDCacheMaxSize + if c.Agent.JWTSVIDCacheMaxSize < 0 { + return nil, errors.New("jwt_svid_cache_max_size should not be negative") + } + ac.JWTSVIDCacheMaxSize = c.Agent.JWTSVIDCacheMaxSize + td, err := common_cli.ParseTrustDomain(c.Agent.TrustDomain, logger) if err != nil { return nil, err diff --git a/doc/spire_agent.md b/doc/spire_agent.md index afbf7132a8..b747d99b83 100644 --- a/doc/spire_agent.md +++ b/doc/spire_agent.md @@ -70,7 +70,8 @@ This may be useful for templating configuration files, for example across differ | `trust_domain` | The trust domain that this agent belongs to (should be no more than 255 characters) | | | `workload_x509_svid_key_type` | The workload X509 SVID key type <rsa-2048|ec-p256> | ec-p256 | | `availability_target` | The minimum amount of time desired to gracefully handle SPIRE Server or Agent downtime. This configurable influences how aggressively X509 SVIDs should be rotated. If set, must be at least 24h. See [Availability Target](#availability-target) | | -| `x509_svid_cache_max_size` | Soft limit of max number of SVIDs that would be stored in LRU cache | 1000 | +| `x509_svid_cache_max_size` | Soft limit of max number of X509-SVIDs that would be stored in LRU cache | 1000 | +| `jwt_svid_cache_max_size` | Hard limit of max number of JWT-SVIDs that would be stored in LRU cache | 1000 | | experimental | Description | Default | |:------------------------------|--------------------------------------------------------------------------------------|-------------------------| diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index e32a500707..ebe7aa7855 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -279,7 +279,8 @@ func (a *Agent) newManager(ctx context.Context, sto storage.Storage, cat catalog Storage: sto, SyncInterval: a.c.SyncInterval, UseSyncAuthorizedEntries: a.c.UseSyncAuthorizedEntries, - SVIDCacheMaxSize: a.c.X509SVIDCacheMaxSize, + X509SVIDCacheMaxSize: a.c.X509SVIDCacheMaxSize, + JWTSVIDCacheMaxSize: a.c.JWTSVIDCacheMaxSize, SVIDStoreCache: cache, NodeAttestor: na, RotationStrategy: rotationutil.NewRotationStrategy(a.c.AvailabilityTarget), diff --git a/pkg/agent/config.go b/pkg/agent/config.go index b7b21a7d8e..f1a28b6249 100644 --- a/pkg/agent/config.go +++ b/pkg/agent/config.go @@ -66,9 +66,12 @@ type Config struct { // is used to sync entries from the server. UseSyncAuthorizedEntries bool - // X509SVIDCacheMaxSize is a soft limit of max number of SVIDs that would be stored in cache + // X509SVIDCacheMaxSize is a soft limit of max number of X509-SVIDs that would be stored in cache X509SVIDCacheMaxSize int + // JWTSVIDCacheMaxSize is a soft limit of max number of JWT-SVIDs that would be stored in cache + JWTSVIDCacheMaxSize int + // Trust domain and associated CA bundle TrustDomain spiffeid.TrustDomain TrustBundle []*x509.Certificate diff --git a/pkg/agent/manager/config.go b/pkg/agent/manager/config.go index e2dab9a481..b21b43be2b 100644 --- a/pkg/agent/manager/config.go +++ b/pkg/agent/manager/config.go @@ -38,7 +38,8 @@ type Config struct { UseSyncAuthorizedEntries bool RotationInterval time.Duration SVIDStoreCache *storecache.Cache - SVIDCacheMaxSize int + X509SVIDCacheMaxSize int + JWTSVIDCacheMaxSize int DisableLRUCache bool NodeAttestor nodeattestor.NodeAttestor RotationStrategy *rotationutil.RotationStrategy @@ -66,7 +67,7 @@ func newManager(c *Config) *manager { } cache := managerCache.NewLRUCache(c.Log.WithField(telemetry.SubsystemName, telemetry.CacheManager), c.TrustDomain, c.Bundle, - c.Metrics, c.SVIDCacheMaxSize, c.SVIDCacheMaxSize, c.Clk) + c.Metrics, c.X509SVIDCacheMaxSize, c.JWTSVIDCacheMaxSize, c.Clk) rotCfg := &svid.RotatorConfig{ SVIDKeyManager: keymanager.ForSVID(c.Catalog.GetKeyManager()), diff --git a/pkg/agent/manager/manager_test.go b/pkg/agent/manager/manager_test.go index ffc35407ff..f566155534 100644 --- a/pkg/agent/manager/manager_test.go +++ b/pkg/agent/manager/manager_test.go @@ -1027,22 +1027,23 @@ func TestSynchronizationWithLRUCache(t *testing.T) { cat.SetKeyManager(km) c := &Config{ - ServerAddr: api.addr, - SVID: baseSVID, - SVIDKey: baseSVIDKey, - Log: testLogger, - TrustDomain: trustDomain, - Storage: openStorage(t, dir), - Bundle: api.bundle, - Metrics: &telemetry.Blackhole{}, - RotationInterval: time.Hour, - SyncInterval: time.Hour, - Clk: clk, - Catalog: cat, - WorkloadKeyType: workloadkey.ECP256, - SVIDCacheMaxSize: 10, - SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), - RotationStrategy: rotationutil.NewRotationStrategy(0), + ServerAddr: api.addr, + SVID: baseSVID, + SVIDKey: baseSVIDKey, + Log: testLogger, + TrustDomain: trustDomain, + Storage: openStorage(t, dir), + Bundle: api.bundle, + Metrics: &telemetry.Blackhole{}, + RotationInterval: time.Hour, + SyncInterval: time.Hour, + Clk: clk, + Catalog: cat, + WorkloadKeyType: workloadkey.ECP256, + X509SVIDCacheMaxSize: 10, + JWTSVIDCacheMaxSize: 10, + SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), + RotationStrategy: rotationutil.NewRotationStrategy(0), } m := newManager(c) @@ -1347,20 +1348,21 @@ func TestSyncSVIDsWithLRUCache(t *testing.T) { cat.SetKeyManager(km) c := &Config{ - ServerAddr: api.addr, - SVID: baseSVID, - SVIDKey: baseSVIDKey, - Log: testLogger, - TrustDomain: trustDomain, - Storage: openStorage(t, dir), - Bundle: api.bundle, - Metrics: &telemetry.Blackhole{}, - Clk: clk, - Catalog: cat, - WorkloadKeyType: workloadkey.ECP256, - SVIDCacheMaxSize: 1, - SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), - RotationStrategy: rotationutil.NewRotationStrategy(0), + ServerAddr: api.addr, + SVID: baseSVID, + SVIDKey: baseSVIDKey, + Log: testLogger, + TrustDomain: trustDomain, + Storage: openStorage(t, dir), + Bundle: api.bundle, + Metrics: &telemetry.Blackhole{}, + Clk: clk, + Catalog: cat, + WorkloadKeyType: workloadkey.ECP256, + X509SVIDCacheMaxSize: 1, + JWTSVIDCacheMaxSize: 1, + SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), + RotationStrategy: rotationutil.NewRotationStrategy(0), } m := newManager(c) From 039e25e6f82d938d503a3e604930ed572e281190 Mon Sep 17 00:00:00 2001 From: Sorin Dumitru Date: Thu, 21 Nov 2024 15:03:53 +0000 Subject: [PATCH 5/5] Address review comments Signed-off-by: Sorin Dumitru --- pkg/agent/manager/cache/jwt_cache.go | 1 - pkg/agent/manager/cache/jwt_cache_test.go | 14 +++++++------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/pkg/agent/manager/cache/jwt_cache.go b/pkg/agent/manager/cache/jwt_cache.go index de24435c28..31bb08314f 100644 --- a/pkg/agent/manager/cache/jwt_cache.go +++ b/pkg/agent/manager/cache/jwt_cache.go @@ -82,7 +82,6 @@ func (c *JWTSVIDCache) SetJWTSVID(spiffeID spiffeid.ID, audience []string, svid element := c.lruList.Back() jwtSvidWithHash := element.Value.(jwtSvidElement) delete(c.svids, jwtSvidWithHash.key) - c.log.Info("removing a svid") c.lruList.Remove(element) } diff --git a/pkg/agent/manager/cache/jwt_cache_test.go b/pkg/agent/manager/cache/jwt_cache_test.go index 19479cc4c3..754ece5266 100644 --- a/pkg/agent/manager/cache/jwt_cache_test.go +++ b/pkg/agent/manager/cache/jwt_cache_test.go @@ -21,9 +21,8 @@ func TestJWTSVIDCache(t *testing.T) { now := time.Now() tok1 := "eyJhbGciOiJFUzI1NiIsImtpZCI6ImRaRGZZaXcxdUd6TXdkTVlITDdGRVl5SzhIT0tLd0xYIiwidHlwIjoiSldUIn0.eyJhdWQiOlsidGVzdC1hdWRpZW5jZSJdLCJleHAiOjE3MjQzNjU3MzEsImlhdCI6MTcyNDI3OTQwNywic3ViIjoic3BpZmZlOi8vZXhhbXBsZS5vcmcvYWdlbnQvZGJ1c2VyIn0.dFr-oWhm5tK0bBuVXt-sGESM5l7hhoY-Gtt5DkuFoJL5Y9d4ZfmicCvUCjL4CqDB3BO_cPqmFfrO7H7pxQbGLg" tok2 := "eyJhbGciOiJFUzI1NiIsImtpZCI6ImNKMXI5TVY4OTZTWXBMY0RMUjN3Q29QRHprTXpkN25tIiwidHlwIjoiSldUIn0.eyJhdWQiOlsidGVzdC1hdWRpZW5jZSJdLCJleHAiOjE3Mjg1NzEwMjUsImlhdCI6MTcyODU3MDcyNSwic3ViIjoic3BpZmZlOi8vZXhhbXBsZS5vcmcvYWdlbnQvZGJ1c2VyIn0.1YnDj7nknwIHEuNKEN0cNypXKS4SUeILXlNOsOs2XElHzfKhhDcl0sYKYtQc1Itf6cygz9C16VOQ_Yjoos2Qfg" - jwtSVID1 := &client.JWTSVID{Token: tok1, IssuedAt: now, ExpiresAt: now.Add(time.Second)} - jwtSVID2 := &client.JWTSVID{Token: tok2, IssuedAt: now, ExpiresAt: now.Add(time.Second)} - //jwtSVID3 := &client.JWTSVID{Token: tok2, IssuedAt: now, ExpiresAt: now.Add(time.Second)} + jwtSVID1 := &client.JWTSVID{Token: tok1, IssuedAt: now, ExpiresAt: now.Add(time.Minute)} + jwtSVID2 := &client.JWTSVID{Token: tok2, IssuedAt: now, ExpiresAt: now.Add(time.Minute)} fakeMetrics := fakemetrics.New() log, logHook := test.NewNullLogger() @@ -229,15 +228,16 @@ func TestJWTSVIDCacheSize(t *testing.T) { cache := NewJWTSVIDCache(log, fakeMetrics, 2) now := time.Now() - jwtSvid1 := &client.JWTSVID{Token: "1", IssuedAt: now, ExpiresAt: now.Add(time.Second)} - jwtSvid2 := &client.JWTSVID{Token: "3", IssuedAt: now, ExpiresAt: now.Add(time.Second)} - jwtSvid3 := &client.JWTSVID{Token: "3", IssuedAt: now, ExpiresAt: now.Add(time.Second)} + jwtSvid1 := &client.JWTSVID{Token: "1", IssuedAt: now, ExpiresAt: now.Add(time.Minute)} + jwtSvid2 := &client.JWTSVID{Token: "2", IssuedAt: now, ExpiresAt: now.Add(time.Minute)} + jwtSvid3 := &client.JWTSVID{Token: "3", IssuedAt: now, ExpiresAt: now.Add(time.Minute)} spiffeID := spiffeid.RequireFromString("spiffe://example.org/blog") cache.SetJWTSVID(spiffeID, []string{"audience-1"}, jwtSvid1) cache.SetJWTSVID(spiffeID, []string{"audience-2"}, jwtSvid2) cache.SetJWTSVID(spiffeID, []string{"audience-3"}, jwtSvid3) + // The first SVID that was inserted into the cache should have been evicted. _, ok := cache.GetJWTSVID(spiffeID, []string{"audience-1"}) assert.False(t, ok) @@ -266,7 +266,7 @@ func TestJWTSVIDCacheSize(t *testing.T) { func TestJWTSVIDCacheKeyHashing(t *testing.T) { spiffeID := spiffeid.RequireFromString("spiffe://example.org/blog") now := time.Now() - expected := &client.JWTSVID{Token: "X", IssuedAt: now, ExpiresAt: now.Add(time.Second)} + expected := &client.JWTSVID{Token: "X", IssuedAt: now, ExpiresAt: now.Add(time.Minute)} fakeMetrics := fakemetrics.New() log, _ := test.NewNullLogger()