From 9aceb849c60ac32c3f0c4fd6d4575945a706a566 Mon Sep 17 00:00:00 2001 From: Matthias <5011972+fasmat@users.noreply.github.com> Date: Thu, 4 Jul 2024 00:04:01 +0000 Subject: [PATCH] Review 2: electric boogaloo --- p2p/server/server.go | 41 ++++++++++++++++++--------------------- p2p/server/server_test.go | 2 -- 2 files changed, 19 insertions(+), 24 deletions(-) diff --git a/p2p/server/server.go b/p2p/server/server.go index 159f5eb8ef6..3de5e8c5706 100644 --- a/p2p/server/server.go +++ b/p2p/server/server.go @@ -160,7 +160,7 @@ type Server struct { limit *rate.Limiter sem *semaphore.Weighted queue chan request - started chan struct{} + stopped chan struct{} metrics *tracker // metrics can be nil @@ -182,7 +182,7 @@ func New(h Host, proto string, handler StreamHandler, opts ...Opt) *Server { interval: time.Second, queue: make(chan request), - started: make(chan struct{}), + stopped: make(chan struct{}), } for _, opt := range opts { opt(srv) @@ -204,6 +204,22 @@ func New(h Host, proto string, handler StreamHandler, opts ...Opt) *Server { } } + srv.h.SetStreamHandler(protocol.ID(srv.protocol), func(stream network.Stream) { + if !srv.sem.TryAcquire(1) { + if srv.metrics != nil { + srv.metrics.dropped.Inc() + } + stream.Close() + return + } + select { + case <-srv.stopped: + srv.sem.Release(1) + stream.Close() + case srv.queue <- request{stream: stream, received: time.Now()}: + // at most s.queueSize requests block here, the others are dropped with the semaphore + } + }) srv.limit = rate.NewLimiter( rate.Every(srv.interval/time.Duration(srv.requestsPerInterval)), srv.requestsPerInterval, @@ -221,31 +237,12 @@ type request struct { received time.Time } -func (s *Server) Ready() <-chan struct{} { - return s.started -} - func (s *Server) Run(ctx context.Context) error { - s.h.SetStreamHandler(protocol.ID(s.protocol), func(stream network.Stream) { - if !s.sem.TryAcquire(1) { - if s.metrics != nil { - s.metrics.dropped.Inc() - } - stream.Close() - return - } - select { - case <-ctx.Done(): - case s.queue <- request{stream: stream, received: time.Now()}: - // at most s.queueSize requests block here, the others are dropped with the semaphore - } - }) - close(s.started) - var eg errgroup.Group for { select { case <-ctx.Done(): + close(s.stopped) eg.Wait() return nil case req := <-s.queue: diff --git a/p2p/server/server_test.go b/p2p/server/server_test.go index 92fdbd4446d..ac8c3e8a50f 100644 --- a/p2p/server/server_test.go +++ b/p2p/server/server_test.go @@ -203,7 +203,6 @@ func Test_Queued(t *testing.T) { eg.Go(func() error { return srv.Run(ctx) }) - <-srv.Ready() t.Cleanup(func() { assert.NoError(t, eg.Wait()) }) @@ -254,7 +253,6 @@ func Test_RequestInterval(t *testing.T) { eg.Go(func() error { return srv.Run(ctx) }) - <-srv.Ready() t.Cleanup(func() { assert.NoError(t, eg.Wait()) })