From c8e3dea679844dda4154b6aba674a6a6f6aa5b90 Mon Sep 17 00:00:00 2001 From: colindickson Date: Fri, 15 Sep 2023 22:20:32 -0400 Subject: [PATCH] cache: refactor, and fix of race condition between expiration goroutine and external calls. --- pkg/cache/ttl.go | 70 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 47 insertions(+), 23 deletions(-) diff --git a/pkg/cache/ttl.go b/pkg/cache/ttl.go index a64b692..b9813cb 100644 --- a/pkg/cache/ttl.go +++ b/pkg/cache/ttl.go @@ -20,7 +20,7 @@ type sortableItem struct { type TTLMap struct { m map[string]*item - l sync.Mutex + l sync.RWMutex maxItems int metrics Metrics @@ -39,15 +39,17 @@ func NewTTLMap(maxItems int, name, namespace string) (m *TTLMap) { go func() { for now := range time.Tick(time.Second * 1) { + m.l.Lock() for k, v := range m.m { if v.invincible { continue } if v.expiresAt.Before(now) { - m.Delete(k) + m.delete(k, v.value, v.expiresAt) } } + m.l.Unlock() } }() @@ -75,13 +77,18 @@ func (m *TTLMap) OnItemAdded(f func(string, interface{}, time.Time)) { } func (m *TTLMap) Delete(k string) { - val, expiresAt, err := m.Get(k) + m.l.Lock() + defer m.l.Unlock() + + val, expiresAt, err := m.get(k) if err != nil { return } - m.l.Lock() + m.delete(k, val, expiresAt) +} +func (m *TTLMap) delete(k string, val interface{}, expiresAt time.Time) { delete(m.m, k) m.metrics.ObserveOperations(OperationDEL, 1) @@ -89,13 +96,11 @@ func (m *TTLMap) Delete(k string) { for _, f := range m.deletedCallbacks { go f(k, val, expiresAt) } - - m.l.Unlock() } func (m *TTLMap) evictItemToClosestToExpiry() { // This is a very naive implementation. - items := []sortableItem{} + evictableItems := make([]sortableItem, 0, len(m.m)) // Get all non-invincible items. for k, v := range m.m { @@ -103,35 +108,47 @@ func (m *TTLMap) evictItemToClosestToExpiry() { continue } - items = append(items, sortableItem{ + evictableItems = append(evictableItems, sortableItem{ key: k, expiresAt: v.expiresAt, }) } - sort.Slice(items, func(i, j int) bool { - return items[i].expiresAt.Before(items[j].expiresAt) + if len(evictableItems) == 0 { + return + } + + sort.Slice(evictableItems, func(i, j int) bool { + return evictableItems[i].expiresAt.Before(evictableItems[j].expiresAt) }) - if len(items) > 0 { - m.Delete(items[0].key) - m.metrics.ObserveOperations(OperationEVICT, 1) - } + m.delete(evictableItems[0].key, evictableItems[0].expiresAt, evictableItems[0].expiresAt) + m.metrics.ObserveOperations(OperationEVICT, 1) } func (m *TTLMap) Len() int { + m.l.RLock() + defer m.l.RUnlock() + + return m.len() +} + +func (m *TTLMap) len() int { return len(m.m) } func (m *TTLMap) Add(k string, v interface{}, expiresAt time.Time, invincible bool) { - if m.Len() >= m.maxItems { - m.evictItemToClosestToExpiry() - } - m.l.Lock() - defer m.l.Unlock() + m.add(k, v, expiresAt, invincible) +} + +func (m *TTLMap) add(k string, v interface{}, expiresAt time.Time, invincible bool) { + if m.len() >= m.maxItems { + m.evictItemToClosestToExpiry() + } + it, ok := m.m[k] if !ok { it = &item{ @@ -150,16 +167,23 @@ func (m *TTLMap) Add(k string, v interface{}, expiresAt time.Time, invincible bo } func (m *TTLMap) Get(k string) (interface{}, time.Time, error) { - m.metrics.ObserveOperations(OperationGET, 1) + m.l.RLock() + itv, expires, err := m.get(k) + m.l.RUnlock() - m.l.Lock() + if err != nil { + return nil, time.Now(), err + } - defer m.l.Unlock() + return itv, expires, err +} + +func (m *TTLMap) get(k string) (interface{}, time.Time, error) { + m.metrics.ObserveOperations(OperationGET, 1) it, ok := m.m[k] if !ok { m.metrics.ObserveMiss() - return nil, time.Now(), errors.New("not found") }