Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Merged by Bors] - Fix flaky TestQueued #6097

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 43 additions & 21 deletions p2p/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
dto "github.com/prometheus/client_model/go"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
"golang.org/x/sync/semaphore"
"golang.org/x/time/rate"

"github.com/spacemeshos/go-spacemesh/codec"
Expand Down Expand Up @@ -156,6 +157,11 @@
decayingTagSpec *DecayingTagSpec
decayingTag connmgr.DecayingTag

limit *rate.Limiter
sem *semaphore.Weighted
queue chan request
stopped chan struct{}

metrics *tracker // metrics can be nil

h Host
Expand All @@ -174,6 +180,9 @@
queueSize: 1000,
requestsPerInterval: 100,
interval: time.Second,

queue: make(chan request),
stopped: make(chan struct{}),
}
for _, opt := range opts {
opt(srv)
Expand All @@ -195,6 +204,31 @@
}
}

srv.limit = rate.NewLimiter(
rate.Every(srv.interval/time.Duration(srv.requestsPerInterval)),
srv.requestsPerInterval,
)
srv.sem = semaphore.NewWeighted(int64(srv.queueSize))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about using a goroutine pool like https://github.com/panjf2000/ants? It can automatically control how many tasks are executed concurrently, reuse goroutines instead of creating a new one for each request (which is cheap but not free), scale the workers down if not used, etc.

Copy link
Member Author

@fasmat fasmat Jul 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a quick look at ants but I'm not sure it is a good fit here. We have 2 situations that we want to limit:

  • if the queue is full, we want to drop incoming requests immediately. I don't quite see how to make this possible with Pool.Submit() - it would have to be WithNonBlocking(true) but then we don't know if the submit actually started a routine or not (because if it didn't we need to close the stream and update metrics accordingly)
  • a maximum rate of incoming requests: the ant.Pool will always process as many requests concurrently as it has capacity, but we actually want to limit to at most x requests per timeframe. If requests come in at a higher rate those should block for the time necessary until the requests per timeframe is below the targeted rate

I don't see ants as a good fit for these requirements - as far as I understand the package it just ensures that no more than a certain number of requests are inflight at the same time while any new incoming request either blocks until a running one is finished or immediately fails without the more fine-grained behaviour we have now 😕

EDIT: probably the given requirements can be met with ants and some extra code, but we would then only replace semaphore with it while still needing the rate limiter.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking of a goroutine pool only to achieve 2 goals:

  • to reuse goroutines
  • to control how many streams are processed concurrently (via pool size)

We would need to use the pool in the blocking mode + use the rate-limiter and queue (for buffering waiting requests) as we do now.

srv.h.SetStreamHandler(protocol.ID(srv.protocol), func(stream network.Stream) {
if !srv.sem.TryAcquire(1) {
if srv.metrics != nil {
srv.metrics.dropped.Inc()

Check warning on line 215 in p2p/server/server.go

View check run for this annotation

Codecov / codecov/patch

p2p/server/server.go#L215

Added line #L215 was not covered by tests
}
stream.Close()
return
}
select {
case <-srv.stopped:
srv.sem.Release(1)
stream.Close()

Check warning on line 223 in p2p/server/server.go

View check run for this annotation

Codecov / codecov/patch

p2p/server/server.go#L221-L223

Added lines #L221 - L223 were not covered by tests
case srv.queue <- request{stream: stream, received: time.Now()}:
// at most s.queueSize requests block here, the others are dropped with the semaphore
}
})
if srv.metrics != nil {
srv.metrics.targetQueue.Set(float64(srv.queueSize))
srv.metrics.targetRps.Set(float64(srv.limit.Limit()))
}
return srv
}

Expand All @@ -204,45 +238,29 @@
}

func (s *Server) Run(ctx context.Context) error {
limit := rate.NewLimiter(rate.Every(s.interval/time.Duration(s.requestsPerInterval)), s.requestsPerInterval)
queue := make(chan request, s.queueSize)
if s.metrics != nil {
s.metrics.targetQueue.Set(float64(s.queueSize))
s.metrics.targetRps.Set(float64(limit.Limit()))
}
s.h.SetStreamHandler(protocol.ID(s.protocol), func(stream network.Stream) {
select {
case queue <- request{stream: stream, received: time.Now()}:
default:
if s.metrics != nil {
s.metrics.dropped.Inc()
}
stream.Close()
}
})

var eg errgroup.Group
eg.SetLimit(s.queueSize * 2)
for {
select {
case <-ctx.Done():
close(s.stopped)
eg.Wait()
return nil
case req := <-queue:
case req := <-s.queue:
if s.metrics != nil {
s.metrics.queue.Set(float64(len(queue)))
s.metrics.queue.Set(float64(s.queueSize))
s.metrics.accepted.Inc()
}
if s.metrics != nil {
s.metrics.inQueueLatency.Observe(time.Since(req.received).Seconds())
}
if err := limit.Wait(ctx); err != nil {
if err := s.limit.Wait(ctx); err != nil {
eg.Wait()
return nil
}
ctx, cancel := context.WithCancel(ctx)
eg.Go(func() error {
<-ctx.Done()
s.sem.Release(1)
req.stream.Close()
return nil
})
Expand Down Expand Up @@ -333,6 +351,10 @@
if err := s.StreamRequest(ctx, pid, req, func(ctx context.Context, stream io.ReadWriter) error {
rd := bufio.NewReader(stream)
if _, err := codec.DecodeFrom(rd, &r); err != nil {
if errors.Is(err, io.ErrClosedPipe) && ctx.Err() != nil {
// ensure that a canceled context is returned as the right error
return ctx.Err()

Check warning on line 356 in p2p/server/server.go

View check run for this annotation

Codecov / codecov/patch

p2p/server/server.go#L356

Added line #L356 was not covered by tests
}
return fmt.Errorf("peer %s: %w", pid, err)
}
if r.Error != "" {
Expand Down
93 changes: 71 additions & 22 deletions p2p/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package server
import (
"context"
"errors"
"sync/atomic"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -172,27 +172,28 @@ func TestServer(t *testing.T) {
})
}

func TestQueued(t *testing.T) {
func Test_Queued(t *testing.T) {
mesh, err := mocknet.FullMeshConnected(2)
require.NoError(t, err)

var (
total = 100
proto = "test"
success, failure atomic.Int64
wait = make(chan struct{}, total)
queueSize = 10
proto = "test"
stop = make(chan struct{})
wg sync.WaitGroup
)

wg.Add(queueSize)
client := New(wrapHost(t, mesh.Hosts()[0]), proto, nil)
srv := New(
wrapHost(t, mesh.Hosts()[1]),
proto,
WrapHandler(func(_ context.Context, msg []byte) ([]byte, error) {
wg.Done()
<-stop
return msg, nil
}),
WithQueueSize(total/3),
WithRequestsPerInterval(50, time.Second),
WithMetrics(),
WithQueueSize(queueSize),
)
var (
eg errgroup.Group
Expand All @@ -205,23 +206,71 @@ func TestQueued(t *testing.T) {
t.Cleanup(func() {
assert.NoError(t, eg.Wait())
})
for i := 0; i < total; i++ {
eg.Go(func() error {
if _, err := client.Request(ctx, mesh.Hosts()[1].ID(), []byte("ping")); err != nil {
failure.Add(1)
} else {
success.Add(1)
}
wait <- struct{}{}
var reqEq errgroup.Group
for i := 0; i < queueSize; i++ { // fill the queue with requests
reqEq.Go(func() error {
resp, err := client.Request(ctx, mesh.Hosts()[1].ID(), []byte("ping"))
require.NoError(t, err)
require.Equal(t, []byte("ping"), resp)
return nil
})
}
for i := 0; i < total; i++ {
<-wait
wg.Wait()

for i := 0; i < queueSize; i++ { // queue is full, requests fail
_, err := client.Request(ctx, mesh.Hosts()[1].ID(), []byte("ping"))
require.Error(t, err)
}

close(stop)
require.NoError(t, reqEq.Wait())
}

func Test_RequestInterval(t *testing.T) {
mesh, err := mocknet.FullMeshConnected(2)
require.NoError(t, err)

var (
maxReq = 10
maxReqTime = time.Minute
proto = "test"
)

client := New(wrapHost(t, mesh.Hosts()[0]), proto, nil)
srv := New(
wrapHost(t, mesh.Hosts()[1]),
proto,
WrapHandler(func(_ context.Context, msg []byte) ([]byte, error) {
return msg, nil
}),
WithRequestsPerInterval(maxReq, maxReqTime),
)
var (
eg errgroup.Group
ctx, cancel = context.WithCancel(context.Background())
)
defer cancel()
eg.Go(func() error {
return srv.Run(ctx)
})
t.Cleanup(func() {
assert.NoError(t, eg.Wait())
})

start := time.Now()
for i := 0; i < maxReq; i++ { // fill the interval with requests (bursts up to maxReq are allowed)
resp, err := client.Request(ctx, mesh.Hosts()[1].ID(), []byte("ping"))
require.NoError(t, err)
require.Equal(t, []byte("ping"), resp)
}
require.NotZero(t, failure.Load())
require.Greater(t, int(success.Load()), total/2)
t.Log(success.Load())

// new request will be delayed by the interval
resp, err := client.Request(context.Background(), mesh.Hosts()[1].ID(), []byte("ping"))
require.NoError(t, err)
require.Equal(t, []byte("ping"), resp)

interval := maxReqTime / time.Duration(maxReq)
require.GreaterOrEqual(t, time.Since(start), interval)
}

func FuzzResponseConsistency(f *testing.F) {
Expand Down
Loading