From d229720b77d3bf48a91b1d2373e8efe7ab968842 Mon Sep 17 00:00:00 2001 From: Roman Zavodskikh Date: Thu, 30 Nov 2023 16:57:17 +0100 Subject: [PATCH] Use sync.Map in the endpointregistry Signed-off-by: Roman Zavodskikh --- loadbalancer/algorithm_test.go | 4 +- loadbalancer/fadein_test.go | 4 +- routing/endpointregistry.go | 133 ++++++++++++------------ routing/endpointregistry_test.go | 168 +++++++++++++++++++++++++++---- 4 files changed, 213 insertions(+), 96 deletions(-) diff --git a/loadbalancer/algorithm_test.go b/loadbalancer/algorithm_test.go index 3db8e1fc19..1eb9bb310a 100644 --- a/loadbalancer/algorithm_test.go +++ b/loadbalancer/algorithm_test.go @@ -419,7 +419,7 @@ func TestConsistentHashBoundedLoadDistribution(t *testing.T) { t.Errorf("Expected in-flight requests for each endpoint to be less than %d. In-flight request counts: %d, %d, %d", limit, ifr0, ifr1, ifr2) } ep.Metrics.IncInflightRequest() - ctx.Registry.IncInflightRequest(ep.Host) + ctx.Registry.GetMetrics(ep.Host).IncInflightRequest() } } @@ -441,7 +441,7 @@ func TestConsistentHashKeyDistribution(t *testing.T) { func addInflightRequests(registry *routing.EndpointRegistry, endpoint routing.LBEndpoint, count int) { for i := 0; i < count; i++ { endpoint.Metrics.IncInflightRequest() - registry.IncInflightRequest(endpoint.Host) + registry.GetMetrics(endpoint.Host).IncInflightRequest() } } diff --git a/loadbalancer/fadein_test.go b/loadbalancer/fadein_test.go index c05311fe23..4768ce3dff 100644 --- a/loadbalancer/fadein_test.go +++ b/loadbalancer/fadein_test.go @@ -70,7 +70,7 @@ func initializeEndpoints(endpointAges []time.Duration, fadeInDuration time.Durat Host: eps[i], Detected: detectionTimes[i], }) - ctx.Registry.SetDetectedTime(eps[i], detectionTimes[i]) + ctx.Registry.GetMetrics(eps[i]).SetDetected(detectionTimes[i]) } ctx.LBEndpoints = ctx.Route.LBEndpoints @@ -332,7 +332,7 @@ func benchmarkFadeIn( Host: eps[i], Detected: detectionTimes[i], }) - registry.SetDetectedTime(eps[i], detectionTimes[i]) + registry.GetMetrics(eps[i]).SetDetected(detectionTimes[i]) } var wg sync.WaitGroup diff --git a/routing/endpointregistry.go b/routing/endpointregistry.go index 2d1679ffa9..e4c94e15ec 100644 --- a/routing/endpointregistry.go +++ b/routing/endpointregistry.go @@ -2,6 +2,7 @@ package routing import ( "sync" + "sync/atomic" "time" "github.com/zalando/skipper/eskip" @@ -13,32 +14,69 @@ const defaultLastSeenTimeout = 1 * time.Minute // used to perform better load balancing, fadeIn, etc. type Metrics interface { DetectedTime() time.Time + SetDetected(detected time.Time) + + LastSeen() time.Time + SetLastSeen(lastSeen time.Time) + InflightRequests() int64 + IncInflightRequest() + DecInflightRequest() } type entry struct { - detected time.Time - inflightRequests int64 + detected atomic.Value // time.Time + lastSeen atomic.Value // time.Time + inflightRequests atomic.Int64 } var _ Metrics = &entry{} func (e *entry) DetectedTime() time.Time { - return e.detected + return e.detected.Load().(time.Time) +} + +func (e *entry) LastSeen() time.Time { + return e.lastSeen.Load().(time.Time) } func (e *entry) InflightRequests() int64 { - return e.inflightRequests + return e.inflightRequests.Load() +} + +func (e *entry) IncInflightRequest() { + e.inflightRequests.Add(1) +} + +func (e *entry) DecInflightRequest() { + e.inflightRequests.Add(-1) +} + +func (e *entry) setDetectedForce(detected time.Time) { + e.detected.Store(detected) +} + +func (e *entry) SetDetected(detected time.Time) { + e.detected.CompareAndSwap(time.Time{}, detected) +} + +func (e *entry) SetLastSeen(ts time.Time) { + e.lastSeen.Store(ts) +} + +func newEntry() (result *entry) { + result = &entry{} + result.setDetectedForce(time.Time{}) + result.SetLastSeen(time.Time{}) + return } type EndpointRegistry struct { - lastSeen map[string]time.Time lastSeenTimeout time.Duration now func() time.Time - mu sync.Mutex - - data map[string]*entry + // map[string]*entry + data sync.Map } var _ PostProcessor = &EndpointRegistry{} @@ -53,26 +91,23 @@ func (r *EndpointRegistry) Do(routes []*Route) []*Route { for _, route := range routes { if route.BackendType == eskip.LBBackend { for _, epi := range route.LBEndpoints { - metrics := r.GetMetrics(epi.Host) - if metrics.DetectedTime().IsZero() { - r.SetDetectedTime(epi.Host, now) - } + e, _ := r.data.LoadOrStore(epi.Host, newEntry()) - r.lastSeen[epi.Host] = now + e.(*entry).SetDetected(now) + e.(*entry).SetLastSeen(now) } } } - for host, ts := range r.lastSeen { - if ts.Add(r.lastSeenTimeout).Before(now) { - r.mu.Lock() - if r.data[host].inflightRequests == 0 { - delete(r.lastSeen, host) - delete(r.data, host) - } - r.mu.Unlock() + removeOlder := now.Add(-r.lastSeenTimeout) + r.data.Range(func(key, value any) bool { + e := value.(*entry) + if e.lastSeen.Load().(time.Time).Before(removeOlder) { + r.data.Delete(key) } - } + + return true + }) return routes } @@ -82,58 +117,16 @@ func NewEndpointRegistry(o RegistryOptions) *EndpointRegistry { o.LastSeenTimeout = defaultLastSeenTimeout } - return &EndpointRegistry{ - data: map[string]*entry{}, - lastSeen: map[string]time.Time{}, + result := &EndpointRegistry{ + data: sync.Map{}, lastSeenTimeout: o.LastSeenTimeout, now: time.Now, } -} - -func (r *EndpointRegistry) GetMetrics(key string) Metrics { - r.mu.Lock() - defer r.mu.Unlock() - - e := r.getOrInitEntryLocked(key) - copy := &entry{} - *copy = *e - return copy -} - -func (r *EndpointRegistry) SetDetectedTime(key string, detected time.Time) { - r.mu.Lock() - defer r.mu.Unlock() - e := r.getOrInitEntryLocked(key) - e.detected = detected + return result } -func (r *EndpointRegistry) IncInflightRequest(key string) { - r.mu.Lock() - defer r.mu.Unlock() - - e := r.getOrInitEntryLocked(key) - e.inflightRequests++ -} - -func (r *EndpointRegistry) DecInflightRequest(key string) { - r.mu.Lock() - defer r.mu.Unlock() - - e := r.getOrInitEntryLocked(key) - e.inflightRequests-- -} - -// getOrInitEntryLocked returns pointer to endpoint registry entry -// which contains the information about endpoint representing the -// following key. r.mu must be held while calling this function and -// using of the entry returned. In general, key represents the "host:port" -// string -func (r *EndpointRegistry) getOrInitEntryLocked(key string) *entry { - e, ok := r.data[key] - if !ok { - e = &entry{} - r.data[key] = e - } - return e +func (r *EndpointRegistry) GetMetrics(key string) Metrics { + e, _ := r.data.LoadOrStore(key, newEntry()) + return e.(*entry) } diff --git a/routing/endpointregistry_test.go b/routing/endpointregistry_test.go index 1256da4d17..f1a319c4a4 100644 --- a/routing/endpointregistry_test.go +++ b/routing/endpointregistry_test.go @@ -17,37 +17,50 @@ func TestEmptyRegistry(t *testing.T) { m := r.GetMetrics("some key") assert.Equal(t, time.Time{}, m.DetectedTime()) + assert.Equal(t, time.Time{}, m.LastSeen()) assert.Equal(t, int64(0), m.InflightRequests()) } func TestSetAndGet(t *testing.T) { + now := time.Now() r := routing.NewEndpointRegistry(routing.RegistryOptions{}) mBefore := r.GetMetrics("some key") - r.IncInflightRequest("some key") - mAfter := r.GetMetrics("some key") - + assert.Equal(t, time.Time{}, mBefore.DetectedTime()) + assert.Equal(t, time.Time{}, mBefore.LastSeen()) assert.Equal(t, int64(0), mBefore.InflightRequests()) - assert.Equal(t, int64(1), mAfter.InflightRequests()) - ts, _ := time.Parse(time.DateOnly, "2023-08-29") - mBefore = r.GetMetrics("some key") - r.SetDetectedTime("some key", ts) - mAfter = r.GetMetrics("some key") + r.GetMetrics("some key").SetDetected(now.Add(-time.Second)) + r.GetMetrics("some key").SetLastSeen(now) + r.GetMetrics("some key").IncInflightRequest() + mAfter := r.GetMetrics("some key") - assert.Equal(t, time.Time{}, mBefore.DetectedTime()) - assert.Equal(t, ts, mAfter.DetectedTime()) + assert.Equal(t, now.Add(-time.Second), mBefore.DetectedTime()) + assert.Equal(t, now, mBefore.LastSeen()) + assert.Equal(t, int64(1), mBefore.InflightRequests()) + + assert.Equal(t, now.Add(-time.Second), mAfter.DetectedTime()) + assert.Equal(t, now, mAfter.LastSeen()) + assert.Equal(t, int64(1), mAfter.InflightRequests()) } func TestSetAndGetAnotherKey(t *testing.T) { + now := time.Now() r := routing.NewEndpointRegistry(routing.RegistryOptions{}) - r.IncInflightRequest("some key") mToChange := r.GetMetrics("some key") + mToChange.IncInflightRequest() + mToChange.SetDetected(now.Add(-time.Second)) + mToChange.SetLastSeen(now) mConst := r.GetMetrics("another key") assert.Equal(t, int64(0), mConst.InflightRequests()) + assert.Equal(t, time.Time{}, mConst.DetectedTime()) + assert.Equal(t, time.Time{}, mConst.LastSeen()) + assert.Equal(t, int64(1), mToChange.InflightRequests()) + assert.Equal(t, now.Add(-time.Second), mToChange.DetectedTime()) + assert.Equal(t, now, mToChange.LastSeen()) } func TestDoRemovesOldEntries(t *testing.T) { @@ -73,9 +86,9 @@ func TestDoRemovesOldEntries(t *testing.T) { assert.Equal(t, beginTestTs, mExist.DetectedTime()) assert.Equal(t, beginTestTs, mExistYet.DetectedTime()) - r.IncInflightRequest("endpoint1.test:80") - r.IncInflightRequest("endpoint2.test:80") - r.DecInflightRequest("endpoint2.test:80") + mExist.IncInflightRequest() + mExistYet.IncInflightRequest() + mExistYet.DecInflightRequest() routing.SetNow(r, func() time.Time { return beginTestTs.Add(routing.ExportDefaultLastSeenTimeout + time.Second) @@ -101,6 +114,72 @@ func TestDoRemovesOldEntries(t *testing.T) { assert.Equal(t, int64(0), mRemoved.InflightRequests()) } +func TestRaceReadWrite(t *testing.T) { + r := routing.NewEndpointRegistry(routing.RegistryOptions{}) + duration := time.Second + + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + defer wg.Done() + stop := time.After(duration) + for { + r.GetMetrics("some key") + select { + case <-stop: + return + default: + } + } + }() + go func() { + defer wg.Done() + stop := time.After(duration) + for { + r.GetMetrics("some key").IncInflightRequest() + select { + case <-stop: + return + default: + } + } + }() + wg.Wait() +} + +func TestRaceTwoWriters(t *testing.T) { + r := routing.NewEndpointRegistry(routing.RegistryOptions{}) + duration := time.Second + + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + defer wg.Done() + stop := time.After(duration) + for { + r.GetMetrics("some key").IncInflightRequest() + select { + case <-stop: + return + default: + } + } + }() + go func() { + defer wg.Done() + stop := time.After(duration) + for { + r.GetMetrics("some key").DecInflightRequest() + select { + case <-stop: + return + default: + } + } + }() + wg.Wait() +} + func printTotalMutexWaitTime(b *testing.B) { // Name of the metric we want to read. const myMetric = "/sync/mutex/wait/total:seconds" @@ -133,20 +212,23 @@ func benchmarkIncInflightRequests(b *testing.B, name string, goroutines int) { b.Run(name, func(b *testing.B) { r := routing.NewEndpointRegistry(routing.RegistryOptions{}) + for i := 1; i < mapSize; i++ { - r.IncInflightRequest(fmt.Sprintf("foo-%d", i)) + r.GetMetrics(fmt.Sprintf("foo-%d", i)).IncInflightRequest() } - r.IncInflightRequest(key) - r.IncInflightRequest(key) + r.GetMetrics(key).IncInflightRequest() + r.GetMetrics(key).IncInflightRequest() wg := sync.WaitGroup{} b.ResetTimer() + b.ReportAllocs() for i := 0; i < goroutines; i++ { wg.Add(1) go func() { defer wg.Done() + metrics := r.GetMetrics(key) for n := 0; n < b.N/goroutines; n++ { - r.IncInflightRequest(key) + metrics.IncInflightRequest() } }() } @@ -170,20 +252,22 @@ func benchmarkGetInflightRequests(b *testing.B, name string, goroutines int) { b.Run(name, func(b *testing.B) { r := routing.NewEndpointRegistry(routing.RegistryOptions{}) for i := 1; i < mapSize; i++ { - r.IncInflightRequest(fmt.Sprintf("foo-%d", i)) + r.GetMetrics(fmt.Sprintf("foo-%d", i)).IncInflightRequest() } - r.IncInflightRequest(key) - r.IncInflightRequest(key) + r.GetMetrics(key).IncInflightRequest() + r.GetMetrics(key).IncInflightRequest() var dummy int64 wg := sync.WaitGroup{} b.ResetTimer() + b.ReportAllocs() for i := 0; i < goroutines; i++ { wg.Add(1) go func() { defer wg.Done() + metrics := r.GetMetrics(key) for n := 0; n < b.N/goroutines; n++ { - dummy = r.GetMetrics(key).InflightRequests() + dummy = metrics.InflightRequests() } }() } @@ -200,3 +284,43 @@ func BenchmarkGetInflightRequests(b *testing.B) { benchmarkGetInflightRequests(b, fmt.Sprintf("%d goroutines", goroutines), goroutines) } } + +func benchmarkGetDetectedTime(b *testing.B, name string, goroutines int) { + const key string = "some key" + const mapSize int = 10000 + + b.Run(name, func(b *testing.B) { + r := routing.NewEndpointRegistry(routing.RegistryOptions{}) + for i := 1; i < mapSize; i++ { + r.GetMetrics(fmt.Sprintf("foo-%d", i)).IncInflightRequest() + } + r.GetMetrics(key).IncInflightRequest() + r.GetMetrics(key).IncInflightRequest() + + var dummy time.Time + wg := sync.WaitGroup{} + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + metrics := r.GetMetrics(key) + for n := 0; n < b.N/goroutines; n++ { + dummy = metrics.DetectedTime() + } + }() + } + dummy = dummy.Add(time.Second) + wg.Wait() + + printTotalMutexWaitTime(b) + }) +} + +func BenchmarkGetDetectedTime(b *testing.B) { + goroutinesNums := []int{1, 2, 3, 4, 5, 6, 7, 8, 12, 16, 24, 32, 48, 64, 128, 256, 512, 768, 1024, 1536, 2048, 4096} + for _, goroutines := range goroutinesNums { + benchmarkGetDetectedTime(b, fmt.Sprintf("%d goroutines", goroutines), goroutines) + } +}