From ece2ec519f868605836628ece74193a18453340d Mon Sep 17 00:00:00 2001 From: Andrew Wormald Date: Thu, 16 Nov 2023 10:41:11 +0000 Subject: [PATCH] rpatterns: Make memcursor concurrent safe --- rpatterns/cursor.go | 7 +++++++ rpatterns/cursor_test.go | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/rpatterns/cursor.go b/rpatterns/cursor.go index 9a8c592..20da7e3 100644 --- a/rpatterns/cursor.go +++ b/rpatterns/cursor.go @@ -3,6 +3,7 @@ package rpatterns import ( "context" "strconv" + "sync" "github.com/luno/reflex" ) @@ -75,14 +76,20 @@ func MemCursorStore(opts ...MemOpt) reflex.CursorStore { } type memCursorStore struct { + mu sync.Mutex cursors map[string]string } func (m *memCursorStore) GetCursor(_ context.Context, consumerName string) (string, error) { + m.mu.Lock() + defer m.mu.Unlock() return m.cursors[consumerName], nil } func (m *memCursorStore) SetCursor(_ context.Context, consumerName string, cursor string) error { + m.mu.Lock() + defer m.mu.Unlock() + if m.cursors == nil { m.cursors = make(map[string]string) } diff --git a/rpatterns/cursor_test.go b/rpatterns/cursor_test.go index e699998..525d9df 100644 --- a/rpatterns/cursor_test.go +++ b/rpatterns/cursor_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "math/rand" + "sync" "testing" "github.com/luno/jettison/jtest" @@ -160,3 +161,35 @@ func TestMemoryCStore(t *testing.T) { require.NoError(t, err) require.Equal(t, c2, actual) } + +func TestConcurrentWrites(t *testing.T) { + var ( + writerReadyGroup sync.WaitGroup + writerCompletedGroup sync.WaitGroup + ) + + store := rpatterns.MemCursorStore() + + writerCount := 100 + writerReadyGroup.Add(writerCount) + for i := 0; i < writerCount; i++ { + writerCompletedGroup.Add(1) + go func(writerReadyGroup *sync.WaitGroup, writerCompletedGroup *sync.WaitGroup) { + writerReadyGroup.Done() + for i := 0; i < 100; i++ { + // Write the thing + val := fmt.Sprintf("%v", i) + err := store.SetCursor(context.TODO(), "single-key", val) + jtest.RequireNil(t, err) + + // Flush is a noop but covering it will ensure that it stays concurrent safe + err = store.Flush(context.TODO()) + jtest.RequireNil(t, err) + } + writerCompletedGroup.Done() + }(&writerReadyGroup, &writerCompletedGroup) + } + + writerReadyGroup.Wait() + writerCompletedGroup.Wait() +}