diff --git a/p2p/server/server.go b/p2p/server/server.go index 159f5eb8ef..38d52dabbf 100644 --- a/p2p/server/server.go +++ b/p2p/server/server.go @@ -160,7 +160,7 @@ type Server struct { limit *rate.Limiter sem *semaphore.Weighted queue chan request - started chan struct{} + stopped chan struct{} metrics *tracker // metrics can be nil @@ -182,7 +182,7 @@ func New(h Host, proto string, handler StreamHandler, opts ...Opt) *Server { interval: time.Second, queue: make(chan request), - started: make(chan struct{}), + stopped: make(chan struct{}), } for _, opt := range opts { opt(srv) @@ -209,6 +209,22 @@ func New(h Host, proto string, handler StreamHandler, opts ...Opt) *Server { srv.requestsPerInterval, ) srv.sem = semaphore.NewWeighted(int64(srv.queueSize)) + srv.h.SetStreamHandler(protocol.ID(srv.protocol), func(stream network.Stream) { + if !srv.sem.TryAcquire(1) { + if srv.metrics != nil { + srv.metrics.dropped.Inc() + } + stream.Close() + return + } + select { + case <-srv.stopped: + srv.sem.Release(1) + stream.Close() + case srv.queue <- request{stream: stream, received: time.Now()}: + // at most s.queueSize requests block here, the others are dropped with the semaphore + } + }) if srv.metrics != nil { srv.metrics.targetQueue.Set(float64(srv.queueSize)) srv.metrics.targetRps.Set(float64(srv.limit.Limit())) @@ -221,31 +237,12 @@ type request struct { received time.Time } -func (s *Server) Ready() <-chan struct{} { - return s.started -} - func (s *Server) Run(ctx context.Context) error { - s.h.SetStreamHandler(protocol.ID(s.protocol), func(stream network.Stream) { - if !s.sem.TryAcquire(1) { - if s.metrics != nil { - s.metrics.dropped.Inc() - } - stream.Close() - return - } - select { - case <-ctx.Done(): - case s.queue <- request{stream: stream, received: time.Now()}: - // at most s.queueSize requests block here, the others are dropped with the semaphore - } - }) - close(s.started) - var eg errgroup.Group for { select { case <-ctx.Done(): + close(s.stopped) eg.Wait() return nil case req := <-s.queue: diff --git a/p2p/server/server_test.go b/p2p/server/server_test.go index 92fdbd4446..ac8c3e8a50 100644 --- a/p2p/server/server_test.go +++ b/p2p/server/server_test.go @@ -203,7 +203,6 @@ func Test_Queued(t *testing.T) { eg.Go(func() error { return srv.Run(ctx) }) - <-srv.Ready() t.Cleanup(func() { assert.NoError(t, eg.Wait()) }) @@ -254,7 +253,6 @@ func Test_RequestInterval(t *testing.T) { eg.Go(func() error { return srv.Run(ctx) }) - <-srv.Ready() t.Cleanup(func() { assert.NoError(t, eg.Wait()) })