Skip to content

Commit

Permalink
[p2p/streamManager] write message to stream tests
Browse files Browse the repository at this point in the history
  • Loading branch information
didaunesp committed Sep 3, 2024
1 parent 95bf42c commit 109b12b
Show file tree
Hide file tree
Showing 3 changed files with 259 additions and 25 deletions.
157 changes: 157 additions & 0 deletions p2p/mocks/mockedReporter.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 7 additions & 5 deletions p2p/node/streamManager/streamManager.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ const (
)

var (
ErrStreamNotFound = errors.New("stream not found")
ErrStreamNotFound = errors.New("stream not found")
ErrStreamMismatch = errors.New("stream mismatch")
ErrorTooManyPendingRequests = errors.New("too many pending requests")
)

type StreamManager interface {
Expand Down Expand Up @@ -202,11 +204,11 @@ func (sm *basicStreamManager) GetHost() host.Host {
func (sm *basicStreamManager) WriteMessageToStream(peerID p2p.PeerID, stream network.Stream, msg []byte, protoversion protocol.ID, reporter libp2pmetrics.Reporter) error {
wrappedStream, found := sm.streamCache.Get(peerID)
if !found {
return errors.New("stream not found")
return ErrStreamNotFound
}
if stream != wrappedStream.stream {
// Indicate an unexpected case where the stream we stored and the stream we are requested to write to are not the same.
return errors.New("stream mismatch")
return ErrStreamMismatch
}

// Attempt to acquire semaphore before proceeding
Expand All @@ -222,9 +224,9 @@ func (sm *basicStreamManager) WriteMessageToStream(peerID p2p.PeerID, stream net
}).Warn("Had to close malfunctioning stream")
// If c_maxPendingRequests have been dropped, the stream is likely in a bad state
sm.CloseStream(peerID)
return errors.New("too many pending requests")
return ErrorTooManyPendingRequests
}
return errors.New("too many pending requests")
return ErrorTooManyPendingRequests
}
defer func() {
<-wrappedStream.semaphore
Expand Down
115 changes: 95 additions & 20 deletions p2p/node/streamManager/streamManager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,27 +90,102 @@ func TestStreamManager(t *testing.T) {
require.Same(t, newMockHost, sm.host, "Expected new host to be set")
})
}

func TestWriteMessageToStream(t *testing.T) {
ctrl, mockNode, mockHost, sm := setup(t)
defer ctrl.Finish()

peerID := peer.ID("mockPeerID")
mockHost.EXPECT().ID().Return(peerID).Times(2)

t.Run("WriteMessageToStream - Stream not found", func(t *testing.T) {
err := sm.WriteMessageToStream(peerID, nil, nil, protocol.ProtocolVersion, nil)
require.ErrorIs(t, err, ErrStreamNotFound, "Expected error when stream not found")
})

mockLibp2pStream := mock_p2p.NewMockStream(ctrl)
mockHost.EXPECT().NewStream(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockLibp2pStream, nil).AnyTimes()

MockConn := mock_p2p.NewMockConn(ctrl)

mockNode.EXPECT().GetBandwidthCounter().Return(nil).AnyTimes()
MockLibp2pStream.EXPECT().Close().Return(nil).AnyTimes()
MockLibp2pStream.EXPECT().Conn().Return(MockConn).Times(1)
MockLibp2pStream.EXPECT().Protocol().Return(protocol.ProtocolVersion).Times(1)
MockLibp2pStream.EXPECT().Read(gomock.Any()).Return(0, nil).AnyTimes()
MockConn.EXPECT().RemotePeer().Return(peerID)
mockHost.EXPECT().NewStream(gomock.Any(), gomock.Any(), gomock.Any()).Return(MockLibp2pStream, nil).Times(1)

// GetStream Error
entry, err := sm.GetStream(peerID)
require.Error(t, err)
require.Nil(t, entry)

err = sm.OpenStream(mockHost2.ID())

require.NoError(t, err)

// Get Stream Success
entry, err = sm.GetStream(peerID)
require.NoError(t, err)
require.Equal(t, MockLibp2pStream, entry)
mockLibp2pStream.EXPECT().Close().Return(nil).AnyTimes()
mockLibp2pStream.EXPECT().Conn().Return(MockConn).AnyTimes()
mockLibp2pStream.EXPECT().Protocol().Return(protocol.ProtocolVersion).AnyTimes()
mockLibp2pStream.EXPECT().Read(gomock.Any()).Return(0, nil).AnyTimes()
MockConn.EXPECT().RemotePeer().Return(peerID).AnyTimes()

err := sm.OpenStream(mockHost.ID())
require.NoError(t, err, "Expected no error when opening stream")

t.Run("Stream mismatch", func(t *testing.T) {

anotherStream := mock_p2p.NewMockStream(ctrl)
err = sm.WriteMessageToStream(peerID, anotherStream, []byte("message"), protocol.ProtocolVersion, nil)
require.ErrorIs(t, err, ErrStreamMismatch, "Expected error when stream mismatch")
})

t.Run("Too many pending requests", func(t *testing.T) {
// small semaphore to block the stream
wrappedStream := streamWrapper{
stream: mockLibp2pStream,
semaphore: make(chan struct{}, 1),
errCount: 0,
}
// block semaphore
wrappedStream.semaphore <- struct{}{}
sm.streamCache.Add(peerID, wrappedStream)
err := sm.WriteMessageToStream(peerID, mockLibp2pStream, []byte("message"), protocol.ProtocolVersion, nil)
require.ErrorIs(t, err, ErrorTooManyPendingRequests, "Expected error when too many pending requests")

// errCount is maxed out so it should close stream
wrappedStream.errCount = c_maxPendingRequests
sm.streamCache.Add(peerID, wrappedStream)

// make sure stream exists
entry, err := sm.GetStream(peerID)
require.NoError(t, err, "Expected no error when getting stream")
require.Equal(t, mockLibp2pStream, entry, "Expected correct stream entry")

err = sm.WriteMessageToStream(peerID, mockLibp2pStream, []byte("message"), protocol.ProtocolVersion, nil)
require.ErrorIs(t, err, ErrorTooManyPendingRequests, "Expected error when too many pending requests")

// check if stream was closed
entry, err = sm.GetStream(peerID)
require.Nil(t, entry, "Expected nil entry")
require.ErrorIs(t, err, ErrStreamNotFound, "Expected error when stream not found")
})

t.Run("Failed to set write deadline", func(t *testing.T) {
err := sm.OpenStream(mockHost.ID())
require.NoError(t, err, "Expected no error when opening stream")

mockLibp2pStream.EXPECT().SetWriteDeadline(gomock.Any()).Return(errors.New("mock error")).Times(1)

err = sm.WriteMessageToStream(peerID, mockLibp2pStream, []byte("message"), protocol.ProtocolVersion, nil)
require.Error(t, err, "Expected error when failed to set write deadline")
})

t.Run("Failed to write message to stream", func(t *testing.T) {
mockLibp2pStream.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil).Times(1)
mockLibp2pStream.EXPECT().Write(gomock.Any()).Return(0, errors.New("mock error")).Times(1)

err = sm.WriteMessageToStream(peerID, mockLibp2pStream, []byte("message"), protocol.ProtocolVersion, nil)
require.Error(t, err, "Expected error when failed to write message to stream")
})

t.Run("Succes write message to stream", func(t *testing.T) {
mockLibp2pStream.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil).Times(2)
mockLibp2pStream.EXPECT().Write(gomock.Any()).Return(10, nil).Times(2)

// without reporter
err = sm.WriteMessageToStream(peerID, mockLibp2pStream, []byte("message"), protocol.ProtocolVersion, nil)
require.NoError(t, err, "Expected no error when writing message to stream")

//with reporter
mockReporter := mock_p2p.NewMockReporter(ctrl)
mockReporter.EXPECT().LogSentMessageStream(gomock.Any(), gomock.Any(), gomock.Any()).Times(1)
err = sm.WriteMessageToStream(peerID, mockLibp2pStream, []byte("message"), protocol.ProtocolVersion, mockReporter)
require.NoError(t, err, "Expected no error when writing message to stream")
})
}

0 comments on commit 109b12b

Please sign in to comment.