Skip to content

Commit

Permalink
Use sync.Map in the endpointregistry
Browse files Browse the repository at this point in the history
Signed-off-by: Roman Zavodskikh <roman.zavodskikh@zalando.de>
  • Loading branch information
Roman Zavodskikh committed Dec 13, 2023
1 parent 5f706e5 commit 91e0638
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 96 deletions.
4 changes: 2 additions & 2 deletions loadbalancer/algorithm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ func TestConsistentHashBoundedLoadDistribution(t *testing.T) {
t.Errorf("Expected in-flight requests for each endpoint to be less than %d. In-flight request counts: %d, %d, %d", limit, ifr0, ifr1, ifr2)
}
ep.Metrics.IncInflightRequest()
ctx.Registry.IncInflightRequest(ep.Host)
ctx.Registry.GetMetrics(ep.Host).IncInflightRequest()
}
}

Expand All @@ -441,7 +441,7 @@ func TestConsistentHashKeyDistribution(t *testing.T) {
func addInflightRequests(registry *routing.EndpointRegistry, endpoint routing.LBEndpoint, count int) {
for i := 0; i < count; i++ {
endpoint.Metrics.IncInflightRequest()
registry.IncInflightRequest(endpoint.Host)
registry.GetMetrics(endpoint.Host).IncInflightRequest()
}
}

Expand Down
10 changes: 6 additions & 4 deletions loadbalancer/fadein_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,18 @@ func initializeEndpoints(endpointAges []time.Duration, fadeInDuration time.Durat
LBFadeInDuration: fadeInDuration,
LBFadeInExponent: 1,
},
Registry: routing.NewEndpointRegistry(routing.RegistryOptions{}),
}

detected := map[string]time.Time{}
for i := range eps {
ctx.Route.LBEndpoints = append(ctx.Route.LBEndpoints, routing.LBEndpoint{
Host: eps[i],
Detected: detectionTimes[i],
})
ctx.Registry.SetDetectedTime(eps[i], detectionTimes[i])
detected[eps[i]] = detectionTimes[i]
}
ctx.LBEndpoints = ctx.Route.LBEndpoints
ctx.Registry = routing.NewEndpointRegistry(routing.RegistryOptions{Detected: detected})

return ctx, eps
}
Expand Down Expand Up @@ -326,14 +327,15 @@ func benchmarkFadeIn(
LBFadeInDuration: fadeInDuration,
LBFadeInExponent: 1,
}
registry := routing.NewEndpointRegistry(routing.RegistryOptions{})
detected := map[string]time.Time{}
for i := range eps {
route.LBEndpoints = append(route.LBEndpoints, routing.LBEndpoint{
Host: eps[i],
Detected: detectionTimes[i],
})
registry.SetDetectedTime(eps[i], detectionTimes[i])
detected[eps[i]] = detectionTimes[i]
}
registry := routing.NewEndpointRegistry(routing.RegistryOptions{Detected: detected})

var wg sync.WaitGroup

Expand Down
126 changes: 57 additions & 69 deletions routing/endpointregistry.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package routing

import (
"sync"
"sync/atomic"
"time"

"github.com/zalando/skipper/eskip"
Expand All @@ -14,37 +15,62 @@ const defaultLastSeenTimeout = 1 * time.Minute
type Metrics interface {
DetectedTime() time.Time
InflightRequests() int64
IncInflightRequest()
DecInflightRequest()
}

type entry struct {
detected time.Time
inflightRequests int64
detected atomic.Value // time.Time
lastSeen atomic.Value // time.Time
inflightRequests atomic.Int64
}

var _ Metrics = &entry{}

func (e *entry) DetectedTime() time.Time {
return e.detected
return e.detected.Load().(time.Time)
}

func (e *entry) InflightRequests() int64 {
return e.inflightRequests
return e.inflightRequests.Load()
}

func (e *entry) IncInflightRequest() {
e.inflightRequests.Add(1)
}

func (e *entry) DecInflightRequest() {
e.inflightRequests.Add(-1)
}

func (e *entry) setDetectedTime(detected time.Time) {
e.detected.CompareAndSwap(time.Time{}, detected)
}

func (e *entry) setLastSeenTime(ts time.Time) {
e.lastSeen.Store(ts)
}

func newEntry() (result *entry) {
result = &entry{}
result.detected.Store(time.Time{})
result.lastSeen.Store(time.Time{})
return
}

type EndpointRegistry struct {
lastSeen map[string]time.Time
lastSeenTimeout time.Duration
now func() time.Time

mu sync.Mutex

data map[string]*entry
// map[string]*entry
data sync.Map
}

var _ PostProcessor = &EndpointRegistry{}

type RegistryOptions struct {
LastSeenTimeout time.Duration
Detected map[string]time.Time
}

func (r *EndpointRegistry) Do(routes []*Route) []*Route {
Expand All @@ -53,26 +79,23 @@ func (r *EndpointRegistry) Do(routes []*Route) []*Route {
for _, route := range routes {
if route.BackendType == eskip.LBBackend {
for _, epi := range route.LBEndpoints {
metrics := r.GetMetrics(epi.Host)
if metrics.DetectedTime().IsZero() {
r.SetDetectedTime(epi.Host, now)
}
e, _ := r.data.LoadOrStore(epi.Host, newEntry())

r.lastSeen[epi.Host] = now
e.(*entry).setDetectedTime(now)
e.(*entry).setLastSeenTime(now)
}
}
}

for host, ts := range r.lastSeen {
if ts.Add(r.lastSeenTimeout).Before(now) {
r.mu.Lock()
if r.data[host].inflightRequests == 0 {
delete(r.lastSeen, host)
delete(r.data, host)
}
r.mu.Unlock()
removeOlder := now.Add(-r.lastSeenTimeout)
r.data.Range(func(key, value any) bool {
e := value.(*entry)
if e.lastSeen.Load().(time.Time).Before(removeOlder) {
r.data.Delete(key)
}
}

return true
})

return routes
}
Expand All @@ -82,58 +105,23 @@ func NewEndpointRegistry(o RegistryOptions) *EndpointRegistry {
o.LastSeenTimeout = defaultLastSeenTimeout
}

return &EndpointRegistry{
data: map[string]*entry{},
lastSeen: map[string]time.Time{},
result := &EndpointRegistry{
data: sync.Map{},
lastSeenTimeout: o.LastSeenTimeout,
now: time.Now,
}
}

func (r *EndpointRegistry) GetMetrics(key string) Metrics {
r.mu.Lock()
defer r.mu.Unlock()

e := r.getOrInitEntryLocked(key)
copy := &entry{}
*copy = *e
return copy
}

func (r *EndpointRegistry) SetDetectedTime(key string, detected time.Time) {
r.mu.Lock()
defer r.mu.Unlock()

e := r.getOrInitEntryLocked(key)
e.detected = detected
}

func (r *EndpointRegistry) IncInflightRequest(key string) {
r.mu.Lock()
defer r.mu.Unlock()

e := r.getOrInitEntryLocked(key)
e.inflightRequests++
}

func (r *EndpointRegistry) DecInflightRequest(key string) {
r.mu.Lock()
defer r.mu.Unlock()
for host, detected := range o.Detected {
e := &entry{}
e.detected.Store(detected)
e.lastSeen.Store(result.now())
result.data.Store(host, e)
}

e := r.getOrInitEntryLocked(key)
e.inflightRequests--
return result
}

// getOrInitEntryLocked returns pointer to endpoint registry entry
// which contains the information about endpoint representing the
// following key. r.mu must be held while calling this function and
// using of the entry returned. In general, key represents the "host:port"
// string
func (r *EndpointRegistry) getOrInitEntryLocked(key string) *entry {
e, ok := r.data[key]
if !ok {
e = &entry{}
r.data[key] = e
}
return e
func (r *EndpointRegistry) GetMetrics(key string) Metrics {
e, _ := r.data.LoadOrStore(key, newEntry())
return e.(*entry)
}
78 changes: 57 additions & 21 deletions routing/endpointregistry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,33 @@ func TestEmptyRegistry(t *testing.T) {
assert.Equal(t, int64(0), m.InflightRequests())
}

func TestRegistryWithInitData(t *testing.T) {
now := time.Now()
r := routing.NewEndpointRegistry(routing.RegistryOptions{Detected: map[string]time.Time{"some key": now}})
m := r.GetMetrics("some key")

assert.Equal(t, now, m.DetectedTime())
assert.Equal(t, int64(0), m.InflightRequests())
}

func TestSetAndGet(t *testing.T) {
r := routing.NewEndpointRegistry(routing.RegistryOptions{})

mBefore := r.GetMetrics("some key")
r.IncInflightRequest("some key")
mAfter := r.GetMetrics("some key")

assert.Equal(t, int64(0), mBefore.InflightRequests())
assert.Equal(t, int64(1), mAfter.InflightRequests())

ts, _ := time.Parse(time.DateOnly, "2023-08-29")
mBefore = r.GetMetrics("some key")
r.SetDetectedTime("some key", ts)
mAfter = r.GetMetrics("some key")
r.GetMetrics("some key").IncInflightRequest()
mAfter := r.GetMetrics("some key")

assert.Equal(t, time.Time{}, mBefore.DetectedTime())
assert.Equal(t, ts, mAfter.DetectedTime())
assert.Equal(t, int64(1), mBefore.InflightRequests())
assert.Equal(t, int64(1), mAfter.InflightRequests())
}

func TestSetAndGetAnotherKey(t *testing.T) {
r := routing.NewEndpointRegistry(routing.RegistryOptions{})

r.IncInflightRequest("some key")
mToChange := r.GetMetrics("some key")
mToChange.IncInflightRequest()
mConst := r.GetMetrics("another key")

assert.Equal(t, int64(0), mConst.InflightRequests())
Expand Down Expand Up @@ -73,9 +76,9 @@ func TestDoRemovesOldEntries(t *testing.T) {
assert.Equal(t, beginTestTs, mExist.DetectedTime())
assert.Equal(t, beginTestTs, mExistYet.DetectedTime())

r.IncInflightRequest("endpoint1.test:80")
r.IncInflightRequest("endpoint2.test:80")
r.DecInflightRequest("endpoint2.test:80")
mExist.IncInflightRequest()
mExistYet.IncInflightRequest()
mExistYet.DecInflightRequest()

routing.SetNow(r, func() time.Time {
return beginTestTs.Add(routing.ExportDefaultLastSeenTimeout + time.Second)
Expand All @@ -101,6 +104,38 @@ func TestDoRemovesOldEntries(t *testing.T) {
assert.Equal(t, int64(0), mRemoved.InflightRequests())
}

func TestRaceReadWrite(t *testing.T) {
r := routing.NewEndpointRegistry(routing.RegistryOptions{})

wg := sync.WaitGroup{}
wg.Add(2)
go func() {
defer wg.Done()
r.GetMetrics("some key")
}()
go func() {
defer wg.Done()
r.GetMetrics("some key").IncInflightRequest()
}()
wg.Wait()
}

func TestRaceTwoWriters(t *testing.T) {
r := routing.NewEndpointRegistry(routing.RegistryOptions{})

wg := sync.WaitGroup{}
wg.Add(2)
go func() {
defer wg.Done()
r.GetMetrics("some key").IncInflightRequest()
}()
go func() {
defer wg.Done()
r.GetMetrics("some key").DecInflightRequest()
}()
wg.Wait()
}

func printTotalMutexWaitTime(b *testing.B) {
// Name of the metric we want to read.
const myMetric = "/sync/mutex/wait/total:seconds"
Expand Down Expand Up @@ -133,11 +168,12 @@ func benchmarkIncInflightRequests(b *testing.B, name string, goroutines int) {

b.Run(name, func(b *testing.B) {
r := routing.NewEndpointRegistry(routing.RegistryOptions{})

for i := 1; i < mapSize; i++ {
r.IncInflightRequest(fmt.Sprintf("foo-%d", i))
r.GetMetrics(fmt.Sprintf("foo-%d", i)).IncInflightRequest()
}
r.IncInflightRequest(key)
r.IncInflightRequest(key)
r.GetMetrics(key).IncInflightRequest()
r.GetMetrics(key).IncInflightRequest()

wg := sync.WaitGroup{}
b.ResetTimer()
Expand All @@ -146,7 +182,7 @@ func benchmarkIncInflightRequests(b *testing.B, name string, goroutines int) {
go func() {
defer wg.Done()
for n := 0; n < b.N/goroutines; n++ {
r.IncInflightRequest(key)
r.GetMetrics(key).IncInflightRequest()
}
}()
}
Expand All @@ -170,10 +206,10 @@ func benchmarkGetInflightRequests(b *testing.B, name string, goroutines int) {
b.Run(name, func(b *testing.B) {
r := routing.NewEndpointRegistry(routing.RegistryOptions{})
for i := 1; i < mapSize; i++ {
r.IncInflightRequest(fmt.Sprintf("foo-%d", i))
r.GetMetrics(fmt.Sprintf("foo-%d", i)).IncInflightRequest()
}
r.IncInflightRequest(key)
r.IncInflightRequest(key)
r.GetMetrics(key).IncInflightRequest()
r.GetMetrics(key).IncInflightRequest()

var dummy int64
wg := sync.WaitGroup{}
Expand Down

0 comments on commit 91e0638

Please sign in to comment.