Skip to content

Commit

Permalink
Fix flaky TestQueued (#6097)
Browse files Browse the repository at this point in the history
## Motivation

This should once and for all fix the flaky `TestQueued` test that keeps preventing merges of PRs. 🙂
  • Loading branch information
fasmat committed Jul 4, 2024
1 parent 3c9a9be commit 61e390f
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 43 deletions.
64 changes: 43 additions & 21 deletions p2p/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
dto "github.com/prometheus/client_model/go"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
"golang.org/x/sync/semaphore"
"golang.org/x/time/rate"

"github.com/spacemeshos/go-spacemesh/codec"
Expand Down Expand Up @@ -156,6 +157,11 @@ type Server struct {
decayingTagSpec *DecayingTagSpec
decayingTag connmgr.DecayingTag

limit *rate.Limiter
sem *semaphore.Weighted
queue chan request
stopped chan struct{}

metrics *tracker // metrics can be nil

h Host
Expand All @@ -174,6 +180,9 @@ func New(h Host, proto string, handler StreamHandler, opts ...Opt) *Server {
queueSize: 1000,
requestsPerInterval: 100,
interval: time.Second,

queue: make(chan request),
stopped: make(chan struct{}),
}
for _, opt := range opts {
opt(srv)
Expand All @@ -195,6 +204,31 @@ func New(h Host, proto string, handler StreamHandler, opts ...Opt) *Server {
}
}

srv.limit = rate.NewLimiter(
rate.Every(srv.interval/time.Duration(srv.requestsPerInterval)),
srv.requestsPerInterval,
)
srv.sem = semaphore.NewWeighted(int64(srv.queueSize))
srv.h.SetStreamHandler(protocol.ID(srv.protocol), func(stream network.Stream) {
if !srv.sem.TryAcquire(1) {
if srv.metrics != nil {
srv.metrics.dropped.Inc()
}
stream.Close()
return
}
select {
case <-srv.stopped:
srv.sem.Release(1)
stream.Close()
case srv.queue <- request{stream: stream, received: time.Now()}:
// at most s.queueSize requests block here, the others are dropped with the semaphore
}
})
if srv.metrics != nil {
srv.metrics.targetQueue.Set(float64(srv.queueSize))
srv.metrics.targetRps.Set(float64(srv.limit.Limit()))
}
return srv
}

Expand All @@ -204,45 +238,29 @@ type request struct {
}

func (s *Server) Run(ctx context.Context) error {
limit := rate.NewLimiter(rate.Every(s.interval/time.Duration(s.requestsPerInterval)), s.requestsPerInterval)
queue := make(chan request, s.queueSize)
if s.metrics != nil {
s.metrics.targetQueue.Set(float64(s.queueSize))
s.metrics.targetRps.Set(float64(limit.Limit()))
}
s.h.SetStreamHandler(protocol.ID(s.protocol), func(stream network.Stream) {
select {
case queue <- request{stream: stream, received: time.Now()}:
default:
if s.metrics != nil {
s.metrics.dropped.Inc()
}
stream.Close()
}
})

var eg errgroup.Group
eg.SetLimit(s.queueSize * 2)
for {
select {
case <-ctx.Done():
close(s.stopped)
eg.Wait()
return nil
case req := <-queue:
case req := <-s.queue:
if s.metrics != nil {
s.metrics.queue.Set(float64(len(queue)))
s.metrics.queue.Set(float64(s.queueSize))
s.metrics.accepted.Inc()
}
if s.metrics != nil {
s.metrics.inQueueLatency.Observe(time.Since(req.received).Seconds())
}
if err := limit.Wait(ctx); err != nil {
if err := s.limit.Wait(ctx); err != nil {
eg.Wait()
return nil
}
ctx, cancel := context.WithCancel(ctx)
eg.Go(func() error {
<-ctx.Done()
s.sem.Release(1)
req.stream.Close()
return nil
})
Expand Down Expand Up @@ -333,6 +351,10 @@ func (s *Server) Request(ctx context.Context, pid peer.ID, req []byte, extraProt
if err := s.StreamRequest(ctx, pid, req, func(ctx context.Context, stream io.ReadWriter) error {
rd := bufio.NewReader(stream)
if _, err := codec.DecodeFrom(rd, &r); err != nil {
if errors.Is(err, io.ErrClosedPipe) && ctx.Err() != nil {
// ensure that a canceled context is returned as the right error
return ctx.Err()
}
return fmt.Errorf("peer %s: %w", pid, err)
}
if r.Error != "" {
Expand Down
93 changes: 71 additions & 22 deletions p2p/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package server
import (
"context"
"errors"
"sync/atomic"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -172,27 +172,28 @@ func TestServer(t *testing.T) {
})
}

func TestQueued(t *testing.T) {
func Test_Queued(t *testing.T) {
mesh, err := mocknet.FullMeshConnected(2)
require.NoError(t, err)

var (
total = 100
proto = "test"
success, failure atomic.Int64
wait = make(chan struct{}, total)
queueSize = 10
proto = "test"
stop = make(chan struct{})
wg sync.WaitGroup
)

wg.Add(queueSize)
client := New(wrapHost(t, mesh.Hosts()[0]), proto, nil)
srv := New(
wrapHost(t, mesh.Hosts()[1]),
proto,
WrapHandler(func(_ context.Context, msg []byte) ([]byte, error) {
wg.Done()
<-stop
return msg, nil
}),
WithQueueSize(total/3),
WithRequestsPerInterval(50, time.Second),
WithMetrics(),
WithQueueSize(queueSize),
)
var (
eg errgroup.Group
Expand All @@ -205,23 +206,71 @@ func TestQueued(t *testing.T) {
t.Cleanup(func() {
assert.NoError(t, eg.Wait())
})
for i := 0; i < total; i++ {
eg.Go(func() error {
if _, err := client.Request(ctx, mesh.Hosts()[1].ID(), []byte("ping")); err != nil {
failure.Add(1)
} else {
success.Add(1)
}
wait <- struct{}{}
var reqEq errgroup.Group
for i := 0; i < queueSize; i++ { // fill the queue with requests
reqEq.Go(func() error {
resp, err := client.Request(ctx, mesh.Hosts()[1].ID(), []byte("ping"))
require.NoError(t, err)
require.Equal(t, []byte("ping"), resp)
return nil
})
}
for i := 0; i < total; i++ {
<-wait
wg.Wait()

for i := 0; i < queueSize; i++ { // queue is full, requests fail
_, err := client.Request(ctx, mesh.Hosts()[1].ID(), []byte("ping"))
require.Error(t, err)
}

close(stop)
require.NoError(t, reqEq.Wait())
}

func Test_RequestInterval(t *testing.T) {
mesh, err := mocknet.FullMeshConnected(2)
require.NoError(t, err)

var (
maxReq = 10
maxReqTime = time.Minute
proto = "test"
)

client := New(wrapHost(t, mesh.Hosts()[0]), proto, nil)
srv := New(
wrapHost(t, mesh.Hosts()[1]),
proto,
WrapHandler(func(_ context.Context, msg []byte) ([]byte, error) {
return msg, nil
}),
WithRequestsPerInterval(maxReq, maxReqTime),
)
var (
eg errgroup.Group
ctx, cancel = context.WithCancel(context.Background())
)
defer cancel()
eg.Go(func() error {
return srv.Run(ctx)
})
t.Cleanup(func() {
assert.NoError(t, eg.Wait())
})

start := time.Now()
for i := 0; i < maxReq; i++ { // fill the interval with requests (bursts up to maxReq are allowed)
resp, err := client.Request(ctx, mesh.Hosts()[1].ID(), []byte("ping"))
require.NoError(t, err)
require.Equal(t, []byte("ping"), resp)
}
require.NotZero(t, failure.Load())
require.Greater(t, int(success.Load()), total/2)
t.Log(success.Load())

// new request will be delayed by the interval
resp, err := client.Request(context.Background(), mesh.Hosts()[1].ID(), []byte("ping"))
require.NoError(t, err)
require.Equal(t, []byte("ping"), resp)

interval := maxReqTime / time.Duration(maxReq)
require.GreaterOrEqual(t, time.Since(start), interval)
}

func FuzzResponseConsistency(f *testing.F) {
Expand Down

0 comments on commit 61e390f

Please sign in to comment.