diff --git a/go.mod b/go.mod index 34b260bd7..67d78342d 100644 --- a/go.mod +++ b/go.mod @@ -34,7 +34,7 @@ require ( github.com/spf13/cobra v1.8.1 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.9.0 - github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 // Don't upgrade it! due to memory leak issue. + github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 github.com/tyler-smith/go-bip39 v1.1.0 go.nanomsg.org/mangos/v3 v3.4.2 golang.org/x/crypto v0.27.0 diff --git a/network/config.go b/network/config.go index 9efb473a6..8444dd953 100644 --- a/network/config.go +++ b/network/config.go @@ -2,6 +2,7 @@ package network import ( "fmt" + "time" lp2pcore "github.com/libp2p/go-libp2p/core" lp2ppeer "github.com/libp2p/go-libp2p/core/peer" @@ -25,11 +26,12 @@ type Config struct { ForcePrivateNetwork bool `toml:"force_private_network"` // Private configs - NetworkName string `toml:"-"` - DefaultPort int `toml:"-"` - DefaultBootstrapAddrStrings []string `toml:"-"` - IsBootstrapper bool `toml:"-"` - PeerStorePath string `toml:"-"` + NetworkName string `toml:"-"` + DefaultPort int `toml:"-"` + DefaultBootstrapAddrStrings []string `toml:"-"` + IsBootstrapper bool `toml:"-"` + PeerStorePath string `toml:"-"` + StreamTimeout time.Duration `toml:"-"` } func DefaultConfig() *Config { @@ -50,6 +52,7 @@ func DefaultConfig() *Config { DefaultPort: 0, IsBootstrapper: false, PeerStorePath: "peers.json", + StreamTimeout: 20 * time.Second, } } diff --git a/network/gossip.go b/network/gossip.go index cd05934cc..d8cf71145 100644 --- a/network/gossip.go +++ b/network/gossip.go @@ -27,8 +27,8 @@ type gossipService struct { logger *logger.SubLogger } -func newGossipService(ctx context.Context, host lp2phost.Host, eventCh chan Event, - conf *Config, log *logger.SubLogger, +func newGossipService(ctx context.Context, host lp2phost.Host, conf *Config, + eventCh chan Event, log *logger.SubLogger, ) *gossipService { opts := []lp2pps.Option{ lp2pps.WithFloodPublish(true), diff --git a/network/network.go b/network/network.go index c842a236a..a84264137 100644 --- a/network/network.go +++ b/network/network.go @@ -254,11 +254,11 @@ func makeNetwork(conf *Config, log *logger.SubLogger, opts []lp2p.Option) (*netw self.mdns = newMdnsService(ctx, self.host, self.logger) } - self.dht = newDHTService(self.ctx, self.host, kadProtocolID, conf, self.logger) self.peerMgr = newPeerMgr(ctx, host, conf, self.logger) - self.stream = newStreamService(ctx, self.host, streamProtocolID, self.eventChannel, self.logger) - self.gossip = newGossipService(ctx, self.host, self.eventChannel, conf, self.logger) - self.notifee = newNotifeeService(ctx, self.host, self.eventChannel, self.peerMgr, streamProtocolID, self.logger) + self.dht = newDHTService(ctx, host, kadProtocolID, conf, self.logger) + self.stream = newStreamService(ctx, host, conf, streamProtocolID, self.eventChannel, self.logger) + self.gossip = newGossipService(ctx, host, conf, self.eventChannel, self.logger) + self.notifee = newNotifeeService(ctx, host, self.eventChannel, self.peerMgr, streamProtocolID, self.logger) self.logger.Info("network setup", "id", self.host.ID(), "name", conf.NetworkName, @@ -372,7 +372,7 @@ func (n *network) Protect(pid lp2pcore.PeerID, tag string) { // It uses a goroutine to ensure that if sending is blocked, receiving messages won't be blocked. func (n *network) SendTo(msg []byte, pid lp2pcore.PeerID) { go func() { - err := n.stream.SendRequest(msg, pid) + _, err := n.stream.SendRequest(msg, pid) if err != nil { n.logger.Warn("error on sending msg", "pid", pid, "error", err) } diff --git a/network/network_test.go b/network/network_test.go index 21988317d..085cc8525 100644 --- a/network/network_test.go +++ b/network/network_test.go @@ -49,19 +49,20 @@ func testConfig() *Config { EnableMdns: false, ForcePrivateNetwork: true, NetworkName: "test", - DefaultPort: 12345, + DefaultPort: FindFreePort(), PeerStorePath: util.TempFilePath(), + StreamTimeout: 10 * time.Second, } } func shouldReceiveEvent(t *testing.T, net *network, eventType EventType) Event { t.Helper() - timeout := time.NewTimer(10 * time.Second) + timer := time.NewTimer(10 * time.Second) for { select { - case <-timeout.C: + case <-timer.C: require.NoError(t, fmt.Errorf("shouldReceiveEvent Timeout, test: %v id:%s", t.Name(), net.SelfID().String())) return nil @@ -77,11 +78,11 @@ func shouldReceiveEvent(t *testing.T, net *network, eventType EventType) Event { func shouldNotReceiveEvent(t *testing.T, net *network) { t.Helper() - timeout := time.NewTimer(100 * time.Millisecond) + timer := time.NewTimer(100 * time.Millisecond) for { select { - case <-timeout.C: + case <-timer.C: return case <-net.EventChannel(): @@ -131,20 +132,17 @@ func TestStoppingNetwork(t *testing.T) { func TestNetwork(t *testing.T) { ts := testsuite.NewTestSuite(t) - bootstrapPort := ts.RandInt32(9999) + 10000 - publicPort := ts.RandInt32(9999) + 10000 - // Bootstrap node confB := testConfig() confB.ListenAddrStrings = []string{ - fmt.Sprintf("/ip4/127.0.0.1/tcp/%v", bootstrapPort), + fmt.Sprintf("/ip4/127.0.0.1/tcp/%v", confB.DefaultPort), } fmt.Println("Starting Bootstrap node") networkB := makeTestNetwork(t, confB, []lp2p.Option{ lp2p.ForceReachabilityPublic(), }) bootstrapAddresses := []string{ - fmt.Sprintf("/ip4/127.0.0.1/tcp/%v/p2p/%v", bootstrapPort, networkB.SelfID().String()), + fmt.Sprintf("/ip4/127.0.0.1/tcp/%v/p2p/%v", confB.DefaultPort, networkB.SelfID().String()), } // Public and relay node @@ -153,14 +151,14 @@ func TestNetwork(t *testing.T) { confP.EnableRelay = false confP.EnableRelayService = true confP.ListenAddrStrings = []string{ - fmt.Sprintf("/ip4/127.0.0.1/tcp/%v", publicPort), + fmt.Sprintf("/ip4/127.0.0.1/tcp/%v", confP.DefaultPort), } fmt.Println("Starting Public node") networkP := makeTestNetwork(t, confP, []lp2p.Option{ lp2p.ForceReachabilityPublic(), }) publicAddrInfo, _ := lp2ppeer.AddrInfoFromString( - fmt.Sprintf("/ip4/127.0.0.1/tcp/%v/p2p/%s", publicPort, networkP.SelfID())) + fmt.Sprintf("/ip4/127.0.0.1/tcp/%v/p2p/%s", confP.DefaultPort, networkP.SelfID())) // Private node M confM := testConfig() @@ -215,57 +213,57 @@ func TestNetwork(t *testing.T) { t.Run("Supported Protocols", func(t *testing.T) { fmt.Printf("Running %s\n", t.Name()) - require.EventuallyWithT(t, func(_ *assert.CollectT) { + require.EventuallyWithT(t, func(c *assert.CollectT) { protos := networkM.Protocols() - assert.Contains(t, protos, lp2pproto.ProtoIDv2Stop) - assert.NotContains(t, protos, lp2pproto.ProtoIDv2Hop) + assert.Contains(c, protos, lp2pproto.ProtoIDv2Stop) + assert.NotContains(c, protos, lp2pproto.ProtoIDv2Hop) }, time.Second, 100*time.Millisecond) - require.EventuallyWithT(t, func(_ *assert.CollectT) { + require.EventuallyWithT(t, func(c *assert.CollectT) { protos := networkN.Protocols() - assert.Contains(t, protos, lp2pproto.ProtoIDv2Stop) - assert.NotContains(t, protos, lp2pproto.ProtoIDv2Hop) + assert.Contains(c, protos, lp2pproto.ProtoIDv2Stop) + assert.NotContains(c, protos, lp2pproto.ProtoIDv2Hop) }, time.Second, 100*time.Millisecond) - require.EventuallyWithT(t, func(_ *assert.CollectT) { + require.EventuallyWithT(t, func(c *assert.CollectT) { protos := networkP.Protocols() - assert.NotContains(t, protos, lp2pproto.ProtoIDv2Stop) - assert.Contains(t, protos, lp2pproto.ProtoIDv2Hop) + assert.NotContains(c, protos, lp2pproto.ProtoIDv2Stop) + assert.Contains(c, protos, lp2pproto.ProtoIDv2Hop) }, time.Second, 100*time.Millisecond) - require.EventuallyWithT(t, func(_ *assert.CollectT) { + require.EventuallyWithT(t, func(c *assert.CollectT) { protos := networkX.Protocols() - assert.NotContains(t, protos, lp2pproto.ProtoIDv2Stop) - assert.NotContains(t, protos, lp2pproto.ProtoIDv2Hop) + assert.NotContains(c, protos, lp2pproto.ProtoIDv2Stop) + assert.NotContains(c, protos, lp2pproto.ProtoIDv2Hop) }, time.Second, 100*time.Millisecond) }) t.Run("Reachability", func(t *testing.T) { fmt.Printf("Running %s\n", t.Name()) - require.EventuallyWithT(t, func(_ *assert.CollectT) { + require.EventuallyWithT(t, func(c *assert.CollectT) { reachability := networkB.ReachabilityStatus() - assert.Equal(t, "Public", reachability) + assert.Equal(c, "Public", reachability) }, time.Second, 100*time.Millisecond) - require.EventuallyWithT(t, func(_ *assert.CollectT) { + require.EventuallyWithT(t, func(c *assert.CollectT) { reachability := networkM.ReachabilityStatus() - assert.Equal(t, "Private", reachability) + assert.Equal(c, "Private", reachability) }, time.Second, 100*time.Millisecond) - require.EventuallyWithT(t, func(_ *assert.CollectT) { + require.EventuallyWithT(t, func(c *assert.CollectT) { reachability := networkN.ReachabilityStatus() - assert.Equal(t, "Private", reachability) + assert.Equal(c, "Private", reachability) }, time.Second, 100*time.Millisecond) - require.EventuallyWithT(t, func(_ *assert.CollectT) { + require.EventuallyWithT(t, func(c *assert.CollectT) { reachability := networkP.ReachabilityStatus() - assert.Equal(t, "Public", reachability) + assert.Equal(c, "Public", reachability) }, time.Second, 100*time.Millisecond) - require.EventuallyWithT(t, func(_ *assert.CollectT) { + require.EventuallyWithT(t, func(c *assert.CollectT) { reachability := networkP.ReachabilityStatus() - assert.Equal(t, "Public", reachability) + assert.Equal(c, "Public", reachability) }, time.Second, 100*time.Millisecond) }) @@ -421,23 +419,20 @@ func TestNetwork(t *testing.T) { func TestConnections(t *testing.T) { t.Parallel() // run the tests in parallel - ts := testsuite.NewTestSuite(t) - tests := []struct { bootstrapAddr string peerAddr string }{ {"/ip4/127.0.0.1/tcp/%d", "/ip4/127.0.0.1/tcp/0"}, - {"/ip4/127.0.0.1/udp/%d/quic-v1", "/ip4/127.0.0.1/udp/0/quic-v1"}, {"/ip6/::1/tcp/%d", "/ip6/::1/tcp/0"}, + {"/ip4/127.0.0.1/udp/%d/quic-v1", "/ip4/127.0.0.1/udp/0/quic-v1"}, {"/ip6/::1/udp/%d/quic-v1", "/ip6/::1/udp/0/quic-v1"}, } for i, test := range tests { // Bootstrap node confB := testConfig() - bootstrapPort := ts.RandInt32(9999) + 10000 - bootstrapAddr := fmt.Sprintf(test.bootstrapAddr, bootstrapPort) + bootstrapAddr := fmt.Sprintf(test.bootstrapAddr, confB.DefaultPort) confB.ListenAddrStrings = []string{bootstrapAddr} fmt.Println("Starting Bootstrap node") networkB := makeTestNetwork(t, confB, []lp2p.Option{ @@ -456,7 +451,7 @@ func TestConnections(t *testing.T) { }) t.Run(fmt.Sprintf("Running test %d: %s <-> %s ... ", - i, test.bootstrapAddr, test.peerAddr), func(t *testing.T) { + i, bootstrapAddr, test.peerAddr), func(t *testing.T) { t.Parallel() // run the tests in parallel testConnection(t, networkP, networkB) @@ -467,20 +462,12 @@ func TestConnections(t *testing.T) { func testConnection(t *testing.T, networkP, networkB *network) { t.Helper() - // Ensure that peers are connected to each other - for i := 0; i < 20; i++ { - if networkP.NumConnectedPeers() >= 1 && - networkB.NumConnectedPeers() >= 1 { - break - } - time.Sleep(100 * time.Millisecond) - } - - assert.Equal(t, 1, networkB.NumConnectedPeers()) - assert.Equal(t, 1, networkP.NumConnectedPeers()) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.GreaterOrEqual(c, networkP.NumConnectedPeers(), 1) + assert.GreaterOrEqual(c, networkB.NumConnectedPeers(), 1) + }, 5*time.Second, 100*time.Millisecond) msg := []byte("test-msg") - networkP.SendTo(msg, networkB.SelfID()) e := shouldReceiveEvent(t, networkB, EventTypeStream).(*StreamMessage) assert.Equal(t, networkP.SelfID(), e.From) diff --git a/network/notifee.go b/network/notifee.go index d54637926..883547b04 100644 --- a/network/notifee.go +++ b/network/notifee.go @@ -116,7 +116,7 @@ func (s *NotifeeService) Listen(_ lp2pnetwork.Network, ma multiaddr.Multiaddr) { s.logger.Debug("notifee Listen event emitted", "addr", ma.String()) } -// ListenClose is called when your node stops listening on an address. +// ListenClose is called when the peer stops listening on an address. func (s *NotifeeService) ListenClose(_ lp2pnetwork.Network, ma multiaddr.Multiaddr) { // Handle listen close event if needed. s.logger.Debug("notifee ListenClose event emitted", "addr", ma.String()) diff --git a/network/stream.go b/network/stream.go index 4203f7444..883fe6ad4 100644 --- a/network/stream.go +++ b/network/stream.go @@ -15,17 +15,19 @@ type streamService struct { ctx context.Context host lp2phost.Host protocolID lp2pcore.ProtocolID + timeout time.Duration eventCh chan Event logger *logger.SubLogger } -func newStreamService(ctx context.Context, host lp2phost.Host, +func newStreamService(ctx context.Context, host lp2phost.Host, conf *Config, protocolID lp2pcore.ProtocolID, eventCh chan Event, log *logger.SubLogger, ) *streamService { s := &streamService{ ctx: ctx, host: host, protocolID: protocolID, + timeout: conf.StreamTimeout, eventCh: eventCh, logger: log, } @@ -42,7 +44,7 @@ func (*streamService) Stop() {} func (s *streamService) handleStream(stream lp2pnetwork.Stream) { from := stream.Conn().RemotePeer() - s.logger.Trace("receiving stream", "from", from) + s.logger.Debug("receiving stream", "from", from) event := &StreamMessage{ From: from, Reader: stream, @@ -51,41 +53,65 @@ func (s *streamService) handleStream(stream lp2pnetwork.Stream) { s.eventCh <- event } -// SendRequest sends a message to a specific peer. -// If a direct connection can't be established, it attempts to connect via a relay node. -// Returns an error if the sending process fails. -func (s *streamService) SendRequest(msg []byte, pid lp2peer.ID) error { +// SendRequest sends a message to a specific peer, assuming there is already a direct connection. +// +// For simplicity, we do not use bi-directional streams. +// Each time a peer wants to send a message, it creates a new stream. +// +// For more details on stream multiplexing, refer to: https://docs.libp2p.io/concepts/multiplex/overview/ +func (s *streamService) SendRequest(msg []byte, pid lp2peer.ID) (lp2pnetwork.Stream, error) { s.logger.Trace("sending stream", "to", pid) _, err := s.host.Peerstore().SupportsProtocols(pid, s.protocolID) if err != nil { - return LibP2PError{Err: err} + return nil, LibP2PError{Err: err} } // To prevent a broken stream from being open forever. - ctxWithTimeout, cancel := context.WithTimeout(s.ctx, 20*time.Second) + ctxWithTimeout, cancel := context.WithTimeout(s.ctx, 5*time.Second) defer cancel() - // Attempt to open a new stream to the target peer assuming there's already direct a connection + // Attempt to open a new stream to the peer, assuming there's already a direct connection. stream, err := s.host.NewStream( lp2pnetwork.WithNoDial(ctxWithTimeout, "should already have connection"), pid, s.protocolID) if err != nil { - return LibP2PError{Err: err} + return nil, LibP2PError{Err: err} } - deadline, _ := ctxWithTimeout.Deadline() - _ = stream.SetDeadline(deadline) - _, err = stream.Write(msg) if err != nil { _ = stream.Reset() - return LibP2PError{Err: err} + return nil, LibP2PError{Err: err} } err = stream.CloseWrite() if err != nil { - return LibP2PError{Err: err} + return nil, LibP2PError{Err: err} } - return nil + // We need to close the stream once it is read by the receiver. + // If, for any reason, the receiver doesn't close the stream, we need to close it after a timeout. + go func() { + timer := time.NewTimer(s.timeout) + closed := make(chan bool) + + go func() { + // We need only one byte to read the EOF. + buf := make([]byte, 1) + _, _ = stream.Read(buf) + closed <- true + }() + + select { + case <-timer.C: + s.logger.Warn("stream timeout", "to", pid) + _ = stream.Close() + + case <-closed: + s.logger.Debug("stream closed", "to", pid) + _ = stream.Close() + } + }() + + return stream, nil } diff --git a/network/stream_test.go b/network/stream_test.go new file mode 100644 index 000000000..37c3731ac --- /dev/null +++ b/network/stream_test.go @@ -0,0 +1,68 @@ +package network + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCloseStream(t *testing.T) { + confA := testConfig() + confA.StreamTimeout = 1 * time.Second // Reduce timeout for testing + confA.EnableMdns = true + networkA := makeTestNetwork(t, confA, nil) + + confB := testConfig() + confB.EnableMdns = true + networkB := makeTestNetwork(t, confB, nil) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + e := <-networkA.EventChannel() + assert.Equal(c, EventTypeConnect, e.Type()) + }, 5*time.Second, 100*time.Millisecond) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + e := <-networkB.EventChannel() + assert.Equal(c, EventTypeConnect, e.Type()) + }, 5*time.Second, 100*time.Millisecond) + + t.Run("Stream timeout", func(t *testing.T) { + stream, err := networkA.stream.SendRequest([]byte("test-1"), networkB.SelfID()) + require.NoError(t, err) + + // NetworkB doesn't close the stream. + assert.EventuallyWithT(t, func(c *assert.CollectT) { + e := <-networkB.EventChannel() + _, ok := e.(*StreamMessage) + assert.True(c, ok) + }, 5*time.Second, 100*time.Millisecond) + + // Wait fot the steam timeout. + time.Sleep(2 * confA.StreamTimeout) + + _, err = stream.Write([]byte("should-be-closed")) + assert.ErrorContains(t, err, "write on closed stream") + }) + + t.Run("Stream closed", func(t *testing.T) { + stream, err := networkA.stream.SendRequest([]byte("test-2"), networkB.SelfID()) + require.NoError(t, err) + + // NetworkB close the stream. + assert.EventuallyWithT(t, func(c *assert.CollectT) { + e := <-networkB.EventChannel() + s, ok := e.(*StreamMessage) + assert.True(c, ok) + + if ok { + err := s.Reader.Close() + assert.NoError(t, err) + } + }, 5*time.Second, 100*time.Millisecond) + + _, err = stream.Write([]byte("should-be-closed")) + assert.ErrorContains(t, err, "write on closed stream") + }) +} diff --git a/network/utils.go b/network/utils.go index da7b2e090..1bc36c263 100644 --- a/network/utils.go +++ b/network/utils.go @@ -139,3 +139,12 @@ func MessageIDFunc(m *lp2pspb.Message) string { return string(h[:20]) } + +func FindFreePort() int { + listener, _ := net.Listen("tcp", "localhost:0") + defer func() { + _ = listener.Close() + }() + + return listener.Addr().(*net.TCPAddr).Port +} diff --git a/sync/sync_test.go b/sync/sync_test.go index d524e1f5d..5ff2ff76a 100644 --- a/sync/sync_test.go +++ b/sync/sync_test.go @@ -103,11 +103,11 @@ func setup(t *testing.T, config *Config) *testData { func shouldPublishMessageWithThisType(t *testing.T, net *network.MockNetwork, msgType message.Type) *bundle.Bundle { t.Helper() - timeout := time.NewTimer(3 * time.Second) + timer := time.NewTimer(3 * time.Second) for { select { - case <-timeout.C: + case <-timer.C: require.NoError(t, fmt.Errorf("shouldPublishMessageWithThisType %v: Timeout, test: %v", msgType, t.Name())) return nil @@ -157,11 +157,11 @@ func (td *testData) shouldPublishMessageWithThisType(t *testing.T, msgType messa func shouldNotPublishMessageWithThisType(t *testing.T, net *network.MockNetwork, msgType message.Type) { t.Helper() - timeout := time.NewTimer(3 * time.Millisecond) + timer := time.NewTimer(3 * time.Millisecond) for { select { - case <-timeout.C: + case <-timer.C: return case b := <-net.PublishCh: diff --git a/txpool/txpool_test.go b/txpool/txpool_test.go index 14e3ef72d..beda5952c 100644 --- a/txpool/txpool_test.go +++ b/txpool/txpool_test.go @@ -56,11 +56,11 @@ func setup(t *testing.T) *testData { func (td *testData) shouldPublishTransaction(t *testing.T, id tx.ID) { t.Helper() - timeout := time.NewTimer(1 * time.Second) + timer := time.NewTimer(1 * time.Second) for { select { - case <-timeout.C: + case <-timer.C: require.NoError(t, fmt.Errorf("Timeout")) return