diff --git a/p2p/server/server.go b/p2p/server/server.go index 51a3a84168..38d52dabbf 100644 --- a/p2p/server/server.go +++ b/p2p/server/server.go @@ -17,6 +17,7 @@ import ( dto "github.com/prometheus/client_model/go" "go.uber.org/zap" "golang.org/x/sync/errgroup" + "golang.org/x/sync/semaphore" "golang.org/x/time/rate" "github.com/spacemeshos/go-spacemesh/codec" @@ -156,6 +157,11 @@ type Server struct { decayingTagSpec *DecayingTagSpec decayingTag connmgr.DecayingTag + limit *rate.Limiter + sem *semaphore.Weighted + queue chan request + stopped chan struct{} + metrics *tracker // metrics can be nil h Host @@ -174,6 +180,9 @@ func New(h Host, proto string, handler StreamHandler, opts ...Opt) *Server { queueSize: 1000, requestsPerInterval: 100, interval: time.Second, + + queue: make(chan request), + stopped: make(chan struct{}), } for _, opt := range opts { opt(srv) @@ -195,6 +204,31 @@ func New(h Host, proto string, handler StreamHandler, opts ...Opt) *Server { } } + srv.limit = rate.NewLimiter( + rate.Every(srv.interval/time.Duration(srv.requestsPerInterval)), + srv.requestsPerInterval, + ) + srv.sem = semaphore.NewWeighted(int64(srv.queueSize)) + srv.h.SetStreamHandler(protocol.ID(srv.protocol), func(stream network.Stream) { + if !srv.sem.TryAcquire(1) { + if srv.metrics != nil { + srv.metrics.dropped.Inc() + } + stream.Close() + return + } + select { + case <-srv.stopped: + srv.sem.Release(1) + stream.Close() + case srv.queue <- request{stream: stream, received: time.Now()}: + // at most s.queueSize requests block here, the others are dropped with the semaphore + } + }) + if srv.metrics != nil { + srv.metrics.targetQueue.Set(float64(srv.queueSize)) + srv.metrics.targetRps.Set(float64(srv.limit.Limit())) + } return srv } @@ -204,45 +238,29 @@ type request struct { } func (s *Server) Run(ctx context.Context) error { - limit := rate.NewLimiter(rate.Every(s.interval/time.Duration(s.requestsPerInterval)), s.requestsPerInterval) - queue := make(chan request, s.queueSize) - if s.metrics != nil { - s.metrics.targetQueue.Set(float64(s.queueSize)) - s.metrics.targetRps.Set(float64(limit.Limit())) - } - s.h.SetStreamHandler(protocol.ID(s.protocol), func(stream network.Stream) { - select { - case queue <- request{stream: stream, received: time.Now()}: - default: - if s.metrics != nil { - s.metrics.dropped.Inc() - } - stream.Close() - } - }) - var eg errgroup.Group - eg.SetLimit(s.queueSize * 2) for { select { case <-ctx.Done(): + close(s.stopped) eg.Wait() return nil - case req := <-queue: + case req := <-s.queue: if s.metrics != nil { - s.metrics.queue.Set(float64(len(queue))) + s.metrics.queue.Set(float64(s.queueSize)) s.metrics.accepted.Inc() } if s.metrics != nil { s.metrics.inQueueLatency.Observe(time.Since(req.received).Seconds()) } - if err := limit.Wait(ctx); err != nil { + if err := s.limit.Wait(ctx); err != nil { eg.Wait() return nil } ctx, cancel := context.WithCancel(ctx) eg.Go(func() error { <-ctx.Done() + s.sem.Release(1) req.stream.Close() return nil }) @@ -333,6 +351,10 @@ func (s *Server) Request(ctx context.Context, pid peer.ID, req []byte, extraProt if err := s.StreamRequest(ctx, pid, req, func(ctx context.Context, stream io.ReadWriter) error { rd := bufio.NewReader(stream) if _, err := codec.DecodeFrom(rd, &r); err != nil { + if errors.Is(err, io.ErrClosedPipe) && ctx.Err() != nil { + // ensure that a canceled context is returned as the right error + return ctx.Err() + } return fmt.Errorf("peer %s: %w", pid, err) } if r.Error != "" { diff --git a/p2p/server/server_test.go b/p2p/server/server_test.go index b8d5a07732..ac8c3e8a50 100644 --- a/p2p/server/server_test.go +++ b/p2p/server/server_test.go @@ -3,7 +3,7 @@ package server import ( "context" "errors" - "sync/atomic" + "sync" "testing" "time" @@ -172,27 +172,28 @@ func TestServer(t *testing.T) { }) } -func TestQueued(t *testing.T) { +func Test_Queued(t *testing.T) { mesh, err := mocknet.FullMeshConnected(2) require.NoError(t, err) var ( - total = 100 - proto = "test" - success, failure atomic.Int64 - wait = make(chan struct{}, total) + queueSize = 10 + proto = "test" + stop = make(chan struct{}) + wg sync.WaitGroup ) + wg.Add(queueSize) client := New(wrapHost(t, mesh.Hosts()[0]), proto, nil) srv := New( wrapHost(t, mesh.Hosts()[1]), proto, WrapHandler(func(_ context.Context, msg []byte) ([]byte, error) { + wg.Done() + <-stop return msg, nil }), - WithQueueSize(total/3), - WithRequestsPerInterval(50, time.Second), - WithMetrics(), + WithQueueSize(queueSize), ) var ( eg errgroup.Group @@ -205,23 +206,71 @@ func TestQueued(t *testing.T) { t.Cleanup(func() { assert.NoError(t, eg.Wait()) }) - for i := 0; i < total; i++ { - eg.Go(func() error { - if _, err := client.Request(ctx, mesh.Hosts()[1].ID(), []byte("ping")); err != nil { - failure.Add(1) - } else { - success.Add(1) - } - wait <- struct{}{} + var reqEq errgroup.Group + for i := 0; i < queueSize; i++ { // fill the queue with requests + reqEq.Go(func() error { + resp, err := client.Request(ctx, mesh.Hosts()[1].ID(), []byte("ping")) + require.NoError(t, err) + require.Equal(t, []byte("ping"), resp) return nil }) } - for i := 0; i < total; i++ { - <-wait + wg.Wait() + + for i := 0; i < queueSize; i++ { // queue is full, requests fail + _, err := client.Request(ctx, mesh.Hosts()[1].ID(), []byte("ping")) + require.Error(t, err) + } + + close(stop) + require.NoError(t, reqEq.Wait()) +} + +func Test_RequestInterval(t *testing.T) { + mesh, err := mocknet.FullMeshConnected(2) + require.NoError(t, err) + + var ( + maxReq = 10 + maxReqTime = time.Minute + proto = "test" + ) + + client := New(wrapHost(t, mesh.Hosts()[0]), proto, nil) + srv := New( + wrapHost(t, mesh.Hosts()[1]), + proto, + WrapHandler(func(_ context.Context, msg []byte) ([]byte, error) { + return msg, nil + }), + WithRequestsPerInterval(maxReq, maxReqTime), + ) + var ( + eg errgroup.Group + ctx, cancel = context.WithCancel(context.Background()) + ) + defer cancel() + eg.Go(func() error { + return srv.Run(ctx) + }) + t.Cleanup(func() { + assert.NoError(t, eg.Wait()) + }) + + start := time.Now() + for i := 0; i < maxReq; i++ { // fill the interval with requests (bursts up to maxReq are allowed) + resp, err := client.Request(ctx, mesh.Hosts()[1].ID(), []byte("ping")) + require.NoError(t, err) + require.Equal(t, []byte("ping"), resp) } - require.NotZero(t, failure.Load()) - require.Greater(t, int(success.Load()), total/2) - t.Log(success.Load()) + + // new request will be delayed by the interval + resp, err := client.Request(context.Background(), mesh.Hosts()[1].ID(), []byte("ping")) + require.NoError(t, err) + require.Equal(t, []byte("ping"), resp) + + interval := maxReqTime / time.Duration(maxReq) + require.GreaterOrEqual(t, time.Since(start), interval) } func FuzzResponseConsistency(f *testing.F) {