From 91e063802bed767f35cc679e3383b01cc5da60d3 Mon Sep 17 00:00:00 2001 From: Roman Zavodskikh Date: Thu, 30 Nov 2023 16:57:17 +0100 Subject: [PATCH] Use sync.Map in the endpointregistry Signed-off-by: Roman Zavodskikh --- loadbalancer/algorithm_test.go | 4 +- loadbalancer/fadein_test.go | 10 ++- routing/endpointregistry.go | 126 ++++++++++++++----------------- routing/endpointregistry_test.go | 78 +++++++++++++------ 4 files changed, 122 insertions(+), 96 deletions(-) diff --git a/loadbalancer/algorithm_test.go b/loadbalancer/algorithm_test.go index 3db8e1fc19..1eb9bb310a 100644 --- a/loadbalancer/algorithm_test.go +++ b/loadbalancer/algorithm_test.go @@ -419,7 +419,7 @@ func TestConsistentHashBoundedLoadDistribution(t *testing.T) { t.Errorf("Expected in-flight requests for each endpoint to be less than %d. In-flight request counts: %d, %d, %d", limit, ifr0, ifr1, ifr2) } ep.Metrics.IncInflightRequest() - ctx.Registry.IncInflightRequest(ep.Host) + ctx.Registry.GetMetrics(ep.Host).IncInflightRequest() } } @@ -441,7 +441,7 @@ func TestConsistentHashKeyDistribution(t *testing.T) { func addInflightRequests(registry *routing.EndpointRegistry, endpoint routing.LBEndpoint, count int) { for i := 0; i < count; i++ { endpoint.Metrics.IncInflightRequest() - registry.IncInflightRequest(endpoint.Host) + registry.GetMetrics(endpoint.Host).IncInflightRequest() } } diff --git a/loadbalancer/fadein_test.go b/loadbalancer/fadein_test.go index c05311fe23..a06532295c 100644 --- a/loadbalancer/fadein_test.go +++ b/loadbalancer/fadein_test.go @@ -62,17 +62,18 @@ func initializeEndpoints(endpointAges []time.Duration, fadeInDuration time.Durat LBFadeInDuration: fadeInDuration, LBFadeInExponent: 1, }, - Registry: routing.NewEndpointRegistry(routing.RegistryOptions{}), } + detected := map[string]time.Time{} for i := range eps { ctx.Route.LBEndpoints = append(ctx.Route.LBEndpoints, routing.LBEndpoint{ Host: eps[i], Detected: detectionTimes[i], }) - ctx.Registry.SetDetectedTime(eps[i], detectionTimes[i]) + detected[eps[i]] = detectionTimes[i] } ctx.LBEndpoints = ctx.Route.LBEndpoints + ctx.Registry = routing.NewEndpointRegistry(routing.RegistryOptions{Detected: detected}) return ctx, eps } @@ -326,14 +327,15 @@ func benchmarkFadeIn( LBFadeInDuration: fadeInDuration, LBFadeInExponent: 1, } - registry := routing.NewEndpointRegistry(routing.RegistryOptions{}) + detected := map[string]time.Time{} for i := range eps { route.LBEndpoints = append(route.LBEndpoints, routing.LBEndpoint{ Host: eps[i], Detected: detectionTimes[i], }) - registry.SetDetectedTime(eps[i], detectionTimes[i]) + detected[eps[i]] = detectionTimes[i] } + registry := routing.NewEndpointRegistry(routing.RegistryOptions{Detected: detected}) var wg sync.WaitGroup diff --git a/routing/endpointregistry.go b/routing/endpointregistry.go index 2d1679ffa9..106935e63b 100644 --- a/routing/endpointregistry.go +++ b/routing/endpointregistry.go @@ -2,6 +2,7 @@ package routing import ( "sync" + "sync/atomic" "time" "github.com/zalando/skipper/eskip" @@ -14,37 +15,62 @@ const defaultLastSeenTimeout = 1 * time.Minute type Metrics interface { DetectedTime() time.Time InflightRequests() int64 + IncInflightRequest() + DecInflightRequest() } type entry struct { - detected time.Time - inflightRequests int64 + detected atomic.Value // time.Time + lastSeen atomic.Value // time.Time + inflightRequests atomic.Int64 } var _ Metrics = &entry{} func (e *entry) DetectedTime() time.Time { - return e.detected + return e.detected.Load().(time.Time) } func (e *entry) InflightRequests() int64 { - return e.inflightRequests + return e.inflightRequests.Load() +} + +func (e *entry) IncInflightRequest() { + e.inflightRequests.Add(1) +} + +func (e *entry) DecInflightRequest() { + e.inflightRequests.Add(-1) +} + +func (e *entry) setDetectedTime(detected time.Time) { + e.detected.CompareAndSwap(time.Time{}, detected) +} + +func (e *entry) setLastSeenTime(ts time.Time) { + e.lastSeen.Store(ts) +} + +func newEntry() (result *entry) { + result = &entry{} + result.detected.Store(time.Time{}) + result.lastSeen.Store(time.Time{}) + return } type EndpointRegistry struct { - lastSeen map[string]time.Time lastSeenTimeout time.Duration now func() time.Time - mu sync.Mutex - - data map[string]*entry + // map[string]*entry + data sync.Map } var _ PostProcessor = &EndpointRegistry{} type RegistryOptions struct { LastSeenTimeout time.Duration + Detected map[string]time.Time } func (r *EndpointRegistry) Do(routes []*Route) []*Route { @@ -53,26 +79,23 @@ func (r *EndpointRegistry) Do(routes []*Route) []*Route { for _, route := range routes { if route.BackendType == eskip.LBBackend { for _, epi := range route.LBEndpoints { - metrics := r.GetMetrics(epi.Host) - if metrics.DetectedTime().IsZero() { - r.SetDetectedTime(epi.Host, now) - } + e, _ := r.data.LoadOrStore(epi.Host, newEntry()) - r.lastSeen[epi.Host] = now + e.(*entry).setDetectedTime(now) + e.(*entry).setLastSeenTime(now) } } } - for host, ts := range r.lastSeen { - if ts.Add(r.lastSeenTimeout).Before(now) { - r.mu.Lock() - if r.data[host].inflightRequests == 0 { - delete(r.lastSeen, host) - delete(r.data, host) - } - r.mu.Unlock() + removeOlder := now.Add(-r.lastSeenTimeout) + r.data.Range(func(key, value any) bool { + e := value.(*entry) + if e.lastSeen.Load().(time.Time).Before(removeOlder) { + r.data.Delete(key) } - } + + return true + }) return routes } @@ -82,58 +105,23 @@ func NewEndpointRegistry(o RegistryOptions) *EndpointRegistry { o.LastSeenTimeout = defaultLastSeenTimeout } - return &EndpointRegistry{ - data: map[string]*entry{}, - lastSeen: map[string]time.Time{}, + result := &EndpointRegistry{ + data: sync.Map{}, lastSeenTimeout: o.LastSeenTimeout, now: time.Now, } -} - -func (r *EndpointRegistry) GetMetrics(key string) Metrics { - r.mu.Lock() - defer r.mu.Unlock() - - e := r.getOrInitEntryLocked(key) - copy := &entry{} - *copy = *e - return copy -} - -func (r *EndpointRegistry) SetDetectedTime(key string, detected time.Time) { - r.mu.Lock() - defer r.mu.Unlock() - - e := r.getOrInitEntryLocked(key) - e.detected = detected -} - -func (r *EndpointRegistry) IncInflightRequest(key string) { - r.mu.Lock() - defer r.mu.Unlock() - - e := r.getOrInitEntryLocked(key) - e.inflightRequests++ -} -func (r *EndpointRegistry) DecInflightRequest(key string) { - r.mu.Lock() - defer r.mu.Unlock() + for host, detected := range o.Detected { + e := &entry{} + e.detected.Store(detected) + e.lastSeen.Store(result.now()) + result.data.Store(host, e) + } - e := r.getOrInitEntryLocked(key) - e.inflightRequests-- + return result } -// getOrInitEntryLocked returns pointer to endpoint registry entry -// which contains the information about endpoint representing the -// following key. r.mu must be held while calling this function and -// using of the entry returned. In general, key represents the "host:port" -// string -func (r *EndpointRegistry) getOrInitEntryLocked(key string) *entry { - e, ok := r.data[key] - if !ok { - e = &entry{} - r.data[key] = e - } - return e +func (r *EndpointRegistry) GetMetrics(key string) Metrics { + e, _ := r.data.LoadOrStore(key, newEntry()) + return e.(*entry) } diff --git a/routing/endpointregistry_test.go b/routing/endpointregistry_test.go index 1256da4d17..1aa9df9749 100644 --- a/routing/endpointregistry_test.go +++ b/routing/endpointregistry_test.go @@ -20,30 +20,33 @@ func TestEmptyRegistry(t *testing.T) { assert.Equal(t, int64(0), m.InflightRequests()) } +func TestRegistryWithInitData(t *testing.T) { + now := time.Now() + r := routing.NewEndpointRegistry(routing.RegistryOptions{Detected: map[string]time.Time{"some key": now}}) + m := r.GetMetrics("some key") + + assert.Equal(t, now, m.DetectedTime()) + assert.Equal(t, int64(0), m.InflightRequests()) +} + func TestSetAndGet(t *testing.T) { r := routing.NewEndpointRegistry(routing.RegistryOptions{}) mBefore := r.GetMetrics("some key") - r.IncInflightRequest("some key") - mAfter := r.GetMetrics("some key") - assert.Equal(t, int64(0), mBefore.InflightRequests()) - assert.Equal(t, int64(1), mAfter.InflightRequests()) - ts, _ := time.Parse(time.DateOnly, "2023-08-29") - mBefore = r.GetMetrics("some key") - r.SetDetectedTime("some key", ts) - mAfter = r.GetMetrics("some key") + r.GetMetrics("some key").IncInflightRequest() + mAfter := r.GetMetrics("some key") - assert.Equal(t, time.Time{}, mBefore.DetectedTime()) - assert.Equal(t, ts, mAfter.DetectedTime()) + assert.Equal(t, int64(1), mBefore.InflightRequests()) + assert.Equal(t, int64(1), mAfter.InflightRequests()) } func TestSetAndGetAnotherKey(t *testing.T) { r := routing.NewEndpointRegistry(routing.RegistryOptions{}) - r.IncInflightRequest("some key") mToChange := r.GetMetrics("some key") + mToChange.IncInflightRequest() mConst := r.GetMetrics("another key") assert.Equal(t, int64(0), mConst.InflightRequests()) @@ -73,9 +76,9 @@ func TestDoRemovesOldEntries(t *testing.T) { assert.Equal(t, beginTestTs, mExist.DetectedTime()) assert.Equal(t, beginTestTs, mExistYet.DetectedTime()) - r.IncInflightRequest("endpoint1.test:80") - r.IncInflightRequest("endpoint2.test:80") - r.DecInflightRequest("endpoint2.test:80") + mExist.IncInflightRequest() + mExistYet.IncInflightRequest() + mExistYet.DecInflightRequest() routing.SetNow(r, func() time.Time { return beginTestTs.Add(routing.ExportDefaultLastSeenTimeout + time.Second) @@ -101,6 +104,38 @@ func TestDoRemovesOldEntries(t *testing.T) { assert.Equal(t, int64(0), mRemoved.InflightRequests()) } +func TestRaceReadWrite(t *testing.T) { + r := routing.NewEndpointRegistry(routing.RegistryOptions{}) + + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + defer wg.Done() + r.GetMetrics("some key") + }() + go func() { + defer wg.Done() + r.GetMetrics("some key").IncInflightRequest() + }() + wg.Wait() +} + +func TestRaceTwoWriters(t *testing.T) { + r := routing.NewEndpointRegistry(routing.RegistryOptions{}) + + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + defer wg.Done() + r.GetMetrics("some key").IncInflightRequest() + }() + go func() { + defer wg.Done() + r.GetMetrics("some key").DecInflightRequest() + }() + wg.Wait() +} + func printTotalMutexWaitTime(b *testing.B) { // Name of the metric we want to read. const myMetric = "/sync/mutex/wait/total:seconds" @@ -133,11 +168,12 @@ func benchmarkIncInflightRequests(b *testing.B, name string, goroutines int) { b.Run(name, func(b *testing.B) { r := routing.NewEndpointRegistry(routing.RegistryOptions{}) + for i := 1; i < mapSize; i++ { - r.IncInflightRequest(fmt.Sprintf("foo-%d", i)) + r.GetMetrics(fmt.Sprintf("foo-%d", i)).IncInflightRequest() } - r.IncInflightRequest(key) - r.IncInflightRequest(key) + r.GetMetrics(key).IncInflightRequest() + r.GetMetrics(key).IncInflightRequest() wg := sync.WaitGroup{} b.ResetTimer() @@ -146,7 +182,7 @@ func benchmarkIncInflightRequests(b *testing.B, name string, goroutines int) { go func() { defer wg.Done() for n := 0; n < b.N/goroutines; n++ { - r.IncInflightRequest(key) + r.GetMetrics(key).IncInflightRequest() } }() } @@ -170,10 +206,10 @@ func benchmarkGetInflightRequests(b *testing.B, name string, goroutines int) { b.Run(name, func(b *testing.B) { r := routing.NewEndpointRegistry(routing.RegistryOptions{}) for i := 1; i < mapSize; i++ { - r.IncInflightRequest(fmt.Sprintf("foo-%d", i)) + r.GetMetrics(fmt.Sprintf("foo-%d", i)).IncInflightRequest() } - r.IncInflightRequest(key) - r.IncInflightRequest(key) + r.GetMetrics(key).IncInflightRequest() + r.GetMetrics(key).IncInflightRequest() var dummy int64 wg := sync.WaitGroup{}