diff --git a/.schema/config.schema.json b/.schema/config.schema.json index 1e9818cbf..6d6ef2c91 100644 --- a/.schema/config.schema.json +++ b/.schema/config.schema.json @@ -1140,6 +1140,25 @@ "pattern": "^[0-9]+(ns|us|ms|s|m|h)$", "default": "15m", "examples": ["1h", "1m", "30s"] + }, + "cache": { + "additionalProperties": false, + "type": "object", + "properties": { + "enabled": { + "title": "Enabled", + "type": "boolean", + "default": true, + "examples": [false, true], + "description": "En-/disables this component." + }, + "max_cost": { + "type": "integer", + "default": 33554432, + "title": "Maximum Cached Cost", + "description": "Max cost to cache." + } + } } }, "additionalProperties": false diff --git a/.schemas/mutators.id_token.schema.json b/.schemas/mutators.id_token.schema.json index c160f0bc0..ef07faa17 100644 --- a/.schemas/mutators.id_token.schema.json +++ b/.schemas/mutators.id_token.schema.json @@ -32,6 +32,25 @@ "pattern": "^[0-9]+(ns|us|ms|s|m|h)$", "default": "1m", "examples": ["1h", "1m", "30s"] + }, + "cache": { + "additionalProperties": false, + "type": "object", + "properties": { + "enabled": { + "title": "Enabled", + "type": "boolean", + "default": true, + "examples": [false, true], + "description": "En-/disables this component." + }, + "max_cost": { + "type": "integer", + "default": 33554432, + "title": "Maximum Cached Cost", + "description": "Max cost to cache." + } + } } }, "additionalProperties": false diff --git a/pipeline/mutate/mutator_id_token.go b/pipeline/mutate/mutator_id_token.go index 438f47e44..3ea6cf062 100644 --- a/pipeline/mutate/mutator_id_token.go +++ b/pipeline/mutate/mutator_id_token.go @@ -38,15 +38,20 @@ type MutatorIDToken struct { templates *template.Template templatesLock sync.Mutex - tokenCache *ristretto.Cache[string, *idTokenCacheContainer] - tokenCacheEnabled bool + tokenCache *ristretto.Cache[string, *idTokenCacheContainer] } type CredentialsIDTokenConfig struct { - Claims string `json:"claims"` - IssuerURL string `json:"issuer_url"` - JWKSURL string `json:"jwks_url"` - TTL string `json:"ttl"` + Claims string `json:"claims"` + IssuerURL string `json:"issuer_url"` + JWKSURL string `json:"jwks_url"` + TTL string `json:"ttl"` + Cache IdTokenCacheConfig `json:"cache"` +} + +type IdTokenCacheConfig struct { + Enabled bool `json:"enabled"` + MaxCost int `json:"max_cost"` } func (c *CredentialsIDTokenConfig) ClaimsTemplateID() string { @@ -54,12 +59,7 @@ func (c *CredentialsIDTokenConfig) ClaimsTemplateID() string { } func NewMutatorIDToken(c configuration.Provider, r MutatorIDTokenRegistry) *MutatorIDToken { - cache, _ := ristretto.NewCache(&ristretto.Config[string, *idTokenCacheContainer]{ - NumCounters: 10000, - MaxCost: 1 << 25, - BufferItems: 64, - }) - return &MutatorIDToken{r: r, c: c, templates: x.NewTemplate("id_token"), tokenCache: cache, tokenCacheEnabled: true} + return &MutatorIDToken{r: r, c: c, templates: x.NewTemplate("id_token")} } func (a *MutatorIDToken) GetID() string { @@ -70,10 +70,6 @@ func (a *MutatorIDToken) WithCache(t *template.Template) { a.templates = t } -func (a *MutatorIDToken) SetCaching(token bool) { - a.tokenCacheEnabled = token -} - type idTokenCacheContainer struct { ExpiresAt time.Time TTL time.Duration @@ -87,7 +83,7 @@ func (a *MutatorIDToken) cacheKey(config *CredentialsIDTokenConfig, ttl time.Dur } func (a *MutatorIDToken) tokenFromCache(config *CredentialsIDTokenConfig, session *authn.AuthenticationSession, claims []byte, ttl time.Duration) (string, bool) { - if !a.tokenCacheEnabled { + if !config.Cache.Enabled { return "", false } @@ -107,7 +103,7 @@ func (a *MutatorIDToken) tokenFromCache(config *CredentialsIDTokenConfig, sessio } func (a *MutatorIDToken) tokenToCache(config *CredentialsIDTokenConfig, session *authn.AuthenticationSession, claims []byte, ttl time.Duration, expiresAt time.Time, token string) { - if !a.tokenCacheEnabled { + if !config.Cache.Enabled { return } @@ -199,7 +195,11 @@ func (a *MutatorIDToken) Validate(config json.RawMessage) error { } func (a *MutatorIDToken) Config(config json.RawMessage) (*CredentialsIDTokenConfig, error) { - var c CredentialsIDTokenConfig + c := CredentialsIDTokenConfig{ + Cache: IdTokenCacheConfig{ + Enabled: true, // default to true + }, + } if err := a.c.MutatorConfig(a.GetID(), config, &c); err != nil { return nil, NewErrMutatorMisconfigured(a, err) } @@ -208,5 +208,28 @@ func (a *MutatorIDToken) Config(config json.RawMessage) (*CredentialsIDTokenConf c.TTL = "15m" } + cost := int64(c.Cache.MaxCost) + if cost == 0 { + cost = 1 << 25 + } + + if a.tokenCache == nil || a.tokenCache.MaxCost() != cost { + cache, err := ristretto.NewCache(&ristretto.Config[string, *idTokenCacheContainer]{ + NumCounters: cost * 10, + // Allocate a max + MaxCost: cost, + // This is a best-practice value. + BufferItems: 64, + Cost: func(container *idTokenCacheContainer) int64 { + return int64(len(container.Token)) + }, + }) + + if err != nil { + return nil, err + } + a.tokenCache = cache + } + return &c, nil } diff --git a/pipeline/mutate/mutator_id_token_test.go b/pipeline/mutate/mutator_id_token_test.go index 92ab11721..db261f70d 100644 --- a/pipeline/mutate/mutator_id_token_test.go +++ b/pipeline/mutate/mutator_id_token_test.go @@ -300,6 +300,18 @@ func TestMutatorIDToken(t *testing.T) { config, _ := sjson.SetBytes(config, "jwks_url", "file://../../test/stub/jwks-hs.json") assert.NotEqual(t, prev, mutate(t, *session, config)) }) + + t.Run("subcase=different tokens because cache disabled", func(t *testing.T) { + config, _ := sjson.SetBytes(config, "cache", map[string]bool{"enabled": false}) + prev := mutate(t, *session, config) + assert.NotEqual(t, prev, mutate(t, *session, config)) + }) + + t.Run("subcase=different tokens because exceeded cost", func(t *testing.T) { + config, _ := sjson.SetBytes(config, "cache", map[string]int{"max_cost": -1}) + prev := mutate(t, *session, config) + assert.NotEqual(t, prev, mutate(t, *session, config)) + }) }) t.Run("case=ensure template cache", func(t *testing.T) { @@ -386,8 +398,8 @@ func BenchmarkMutatorIDToken(b *testing.B) { } { b.Run("alg="+alg, func(b *testing.B) { for _, enableCache := range []bool{true, false} { - a.(*MutatorIDToken).SetCaching(enableCache) b.Run(fmt.Sprintf("cache=%v", enableCache), func(b *testing.B) { + conf.SetForTest(b, "mutators.id_token.config.cache.enabled", enableCache) var tc idTokenTestCase var config []byte diff --git a/spec/config.schema.json b/spec/config.schema.json index bb2bef34f..2660f1123 100644 --- a/spec/config.schema.json +++ b/spec/config.schema.json @@ -1140,6 +1140,25 @@ "pattern": "^[0-9]+(ns|us|ms|s|m|h)$", "default": "15m", "examples": ["1h", "1m", "30s"] + }, + "cache": { + "additionalProperties": false, + "type": "object", + "properties": { + "enabled": { + "title": "Enabled", + "type": "boolean", + "default": true, + "examples": [false, true], + "description": "En-/disables this component." + }, + "max_cost": { + "type": "integer", + "default": 33554432, + "title": "Maximum Cached Cost", + "description": "Max cost to cache." + } + } } }, "additionalProperties": false