diff --git a/.gitignore b/.gitignore index c96d00075f..b80eafa5f0 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ vendor .idea coverage.txt *.swp +.vscode diff --git a/chain/chainservice.go b/chain/chainservice.go new file mode 100644 index 0000000000..9c8612deb5 --- /dev/null +++ b/chain/chainservice.go @@ -0,0 +1,40 @@ +package chain + +import ( + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/btcutil/gcs" + "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/lightninglabs/neutrino" + "github.com/lightninglabs/neutrino/banman" + "github.com/lightninglabs/neutrino/headerfs" +) + +// NeutrinoChainService is an interface that encapsulates all the public +// methods of a *neutrino.ChainService +type NeutrinoChainService interface { + Start() error + GetBlock(chainhash.Hash, ...neutrino.QueryOption) (*btcutil.Block, error) + GetBlockHeight(*chainhash.Hash) (int32, error) + BestBlock() (*headerfs.BlockStamp, error) + GetBlockHash(int64) (*chainhash.Hash, error) + GetBlockHeader(*chainhash.Hash) (*wire.BlockHeader, error) + IsCurrent() bool + SendTransaction(*wire.MsgTx) error + GetCFilter(chainhash.Hash, wire.FilterType, + ...neutrino.QueryOption) (*gcs.Filter, error) + GetUtxo(...neutrino.RescanOption) (*neutrino.SpendReport, error) + BanPeer(string, banman.Reason) error + IsBanned(addr string) bool + AddPeer(*neutrino.ServerPeer) + AddBytesSent(uint64) + AddBytesReceived(uint64) + NetTotals() (uint64, uint64) + UpdatePeerHeights(*chainhash.Hash, int32, *neutrino.ServerPeer) + ChainParams() chaincfg.Params + Stop() error + PeerByAddr(string) *neutrino.ServerPeer +} + +var _ NeutrinoChainService = (*neutrino.ChainService)(nil) diff --git a/chain/mocks_test.go b/chain/mocks_test.go new file mode 100644 index 0000000000..52b39e88f7 --- /dev/null +++ b/chain/mocks_test.go @@ -0,0 +1,157 @@ +package chain + +import ( + "container/list" + "errors" + + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/btcutil/gcs" + "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/lightninglabs/neutrino" + "github.com/lightninglabs/neutrino/banman" + "github.com/lightninglabs/neutrino/headerfs" +) + +var ( + errNotImplemented = errors.New("not implemented") + testBestBlock = &headerfs.BlockStamp{ + Height: 42, + } +) + +var ( + _ rescanner = (*mockRescanner)(nil) + _ NeutrinoChainService = (*mockChainService)(nil) +) + +// newMockNeutrinoClient constructs a neutrino client with a mock chain +// service implementation and mock rescanner interface implementation. +func newMockNeutrinoClient() *NeutrinoClient { + // newRescanFunc returns a mockRescanner + newRescanFunc := func(ro ...neutrino.RescanOption) rescanner { + return &mockRescanner{ + updateArgs: list.New(), + } + } + + return &NeutrinoClient{ + CS: &mockChainService{}, + newRescan: newRescanFunc, + } +} + +// mockRescanner is a mock implementation of a rescanner interface for use in +// tests. Only the Update method is implemented. +type mockRescanner struct { + updateArgs *list.List +} + +func (m *mockRescanner) Update(opts ...neutrino.UpdateOption) error { + m.updateArgs.PushBack(opts) + return nil +} + +func (m *mockRescanner) Start() <-chan error { + return nil +} + +func (m *mockRescanner) WaitForShutdown() { + // no-op +} + +// mockChainService is a mock implementation of a chain service for use in +// tests. Only the Start, GetBlockHeader and BestBlock methods are implemented. +type mockChainService struct { +} + +func (m *mockChainService) Start() error { + return nil +} + +func (m *mockChainService) BestBlock() (*headerfs.BlockStamp, error) { + return testBestBlock, nil +} + +func (m *mockChainService) GetBlockHeader( + *chainhash.Hash) (*wire.BlockHeader, error) { + + return &wire.BlockHeader{}, nil +} + +func (m *mockChainService) GetBlock(chainhash.Hash, + ...neutrino.QueryOption) (*btcutil.Block, error) { + + return nil, errNotImplemented +} + +func (m *mockChainService) GetBlockHeight(*chainhash.Hash) (int32, error) { + return 0, errNotImplemented +} + +func (m *mockChainService) GetBlockHash(int64) (*chainhash.Hash, error) { + return nil, errNotImplemented +} + +func (m *mockChainService) IsCurrent() bool { + return false +} + +func (m *mockChainService) SendTransaction(*wire.MsgTx) error { + return errNotImplemented +} + +func (m *mockChainService) GetCFilter(chainhash.Hash, + wire.FilterType, ...neutrino.QueryOption) (*gcs.Filter, error) { + + return nil, errNotImplemented +} + +func (m *mockChainService) GetUtxo( + _ ...neutrino.RescanOption) (*neutrino.SpendReport, error) { + + return nil, errNotImplemented +} + +func (m *mockChainService) BanPeer(string, banman.Reason) error { + return errNotImplemented +} + +func (m *mockChainService) IsBanned(addr string) bool { + panic(errNotImplemented) +} + +func (m *mockChainService) AddPeer(*neutrino.ServerPeer) { + panic(errNotImplemented) +} + +func (m *mockChainService) AddBytesSent(uint64) { + panic(errNotImplemented) +} + +func (m *mockChainService) AddBytesReceived(uint64) { + panic(errNotImplemented) +} + +func (m *mockChainService) NetTotals() (uint64, uint64) { + panic(errNotImplemented) +} + +func (m *mockChainService) UpdatePeerHeights(*chainhash.Hash, + int32, *neutrino.ServerPeer, +) { + panic(errNotImplemented) +} + +func (m *mockChainService) ChainParams() chaincfg.Params { + panic(errNotImplemented) +} + +func (m *mockChainService) Stop() error { + panic(errNotImplemented) +} + +func (m *mockChainService) PeerByAddr(string) *neutrino.ServerPeer { + panic(errNotImplemented) +} diff --git a/chain/neutrino.go b/chain/neutrino.go index 512004e317..7079e7d1a5 100644 --- a/chain/neutrino.go +++ b/chain/neutrino.go @@ -20,14 +20,21 @@ import ( "github.com/lightninglabs/neutrino/headerfs" ) -// NeutrinoClient is an implementation of the btcwalet chain.Interface interface. +// NeutrinoClient is an implementation of the btcwallet chain.Interface interface. type NeutrinoClient struct { - CS *neutrino.ChainService + CS NeutrinoChainService chainParams *chaincfg.Params - // We currently support one rescan/notifiction goroutine per client - rescan *neutrino.Rescan + // We currently support only one rescan/notification goroutine per client. + // Therefore there can only be one instance of the rescan object and + // the rescanMtx synchronizes its access. Calls to the NotifyReceived + // and Rescan methods of the client must hold the rescan mutex lock for + // the length of their execution to ensure that all operations that + // affect the rescan object are atomic. + rescan rescanner + newRescan newRescanFunc + rescanMtx sync.Mutex enqueueNotification chan interface{} dequeueNotification chan interface{} @@ -45,6 +52,14 @@ type NeutrinoClient struct { finished bool isRescan bool + // The clientMtx synchronizes access to the state variables of the client. + // + // TODO(mstreet3): Currently the clientMtx synchronizes access to the + // rescanQuit and rescanErr channels, which cancel the current rescan + // goroutine when closed and is updated each time a new rescan goroutine + // is created, respectively. All state related to the rescan goroutine + // should ideally be synchronized by the same lock or via some other + // shared mechanism. clientMtx sync.Mutex } @@ -53,9 +68,21 @@ type NeutrinoClient struct { func NewNeutrinoClient(chainParams *chaincfg.Params, chainService *neutrino.ChainService) *NeutrinoClient { + chainSource := &neutrino.RescanChainSource{ + ChainService: chainService, + } + + // Adapt the neutrino.NewRescan constructor to satisfy the + // newRescanFunc type by closing over the chainSource and + // passing in the rescan options. + newRescan := func(ropts ...neutrino.RescanOption) rescanner { + return neutrino.NewRescan(chainSource, ropts...) + } + return &NeutrinoClient{ CS: chainService, chainParams: chainParams, + newRescan: newRescan, } } @@ -73,24 +100,36 @@ func (s *NeutrinoClient) Start() error { s.clientMtx.Lock() defer s.clientMtx.Unlock() if !s.started { + // Reset the client state. s.enqueueNotification = make(chan interface{}) s.dequeueNotification = make(chan interface{}) s.currentBlock = make(chan *waddrmgr.BlockStamp) s.quit = make(chan struct{}) s.started = true + + // Go place a ClientConnected notification onto the queue. s.wg.Add(1) go func() { + defer s.wg.Done() + select { case s.enqueueNotification <- ClientConnected{}: case <-s.quit: } }() + + // Go launch the notification handler. + s.wg.Add(1) go s.notificationHandler() } return nil } // Stop replicates the RPC client's Stop method. +// +// TODO(mstreet3): The Stop method does not cancel the long-running rescan +// goroutine. This is a memory leak. Stop should shutdown the rescan goroutine +// and reset the scanning state of the NeutrinoClient to false. func (s *NeutrinoClient) Stop() { s.clientMtx.Lock() defer s.clientMtx.Unlock() @@ -338,6 +377,10 @@ func (s *NeutrinoClient) pollCFilter(hash *chainhash.Hash) (*gcs.Filter, error) func (s *NeutrinoClient) Rescan(startHash *chainhash.Hash, addrs []btcutil.Address, outPoints map[wire.OutPoint]btcutil.Address) error { + // Obtain and hold the rescan mutex lock for the duration of the call. + s.rescanMtx.Lock() + defer s.rescanMtx.Unlock() + s.clientMtx.Lock() if !s.started { s.clientMtx.Unlock() @@ -411,10 +454,7 @@ func (s *NeutrinoClient) Rescan(startHash *chainhash.Hash, addrs []btcutil.Addre } s.clientMtx.Lock() - newRescan := neutrino.NewRescan( - &neutrino.RescanChainSource{ - ChainService: s.CS, - }, + newRescan := s.newRescan( neutrino.NotificationHandlers(rpcclient.NotificationHandlers{ OnBlockConnected: s.onBlockConnected, OnFilteredBlockConnected: s.onFilteredBlockConnected, @@ -448,6 +488,10 @@ func (s *NeutrinoClient) NotifyBlocks() error { // NotifyReceived replicates the RPC client's NotifyReceived command. func (s *NeutrinoClient) NotifyReceived(addrs []btcutil.Address) error { + // Obtain and hold the rescan mutex lock for the duration of the call. + s.rescanMtx.Lock() + defer s.rescanMtx.Unlock() + s.clientMtx.Lock() // If we have a rescan running, we just need to add the appropriate @@ -466,10 +510,7 @@ func (s *NeutrinoClient) NotifyReceived(addrs []btcutil.Address) error { s.lastFilteredBlockHeader = nil // Rescan with just the specified addresses. - newRescan := neutrino.NewRescan( - &neutrino.RescanChainSource{ - ChainService: s.CS, - }, + newRescan := s.newRescan( neutrino.NotificationHandlers(rpcclient.NotificationHandlers{ OnBlockConnected: s.onBlockConnected, OnFilteredBlockConnected: s.onFilteredBlockConnected, diff --git a/chain/neutrino_test.go b/chain/neutrino_test.go new file mode 100644 index 0000000000..bb8f8f61d8 --- /dev/null +++ b/chain/neutrino_test.go @@ -0,0 +1,218 @@ +package chain + +import ( + "fmt" + "sync" + "testing" + "time" + + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/wire" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// maxDur is the max duration a test has to execute successfully. +var maxDur = 5 * time.Second + +// TestNeutrinoClientSequentialStartStop ensures that the client +// can sequentially Start and Stop without errors or races. +func TestNeutrinoClientSequentialStartStop(t *testing.T) { + var ( + nc = newMockNeutrinoClient() + wantRestarts = 50 + ) + + // callStartStop starts the neutrino client, requires no error on + // startup, immediately stops the client and waits for shutdown. + // The returned channel is closed once shutdown is complete. + callStartStop := func() <-chan struct{} { + done := make(chan struct{}) + + go func() { + defer close(done) + + err := nc.Start() + require.NoError(t, err) + nc.Stop() + nc.WaitForShutdown() + }() + + return done + } + + // For each wanted restart, execute callStartStop and wait until the + // call is done before continuing to the next execution. Waiting for + // a read from done forces all executions of callStartStop to be done + // sequentially. + // + // The test fails if all of the wanted restarts cannot be completed + // sequentially before the timeout is reached. + timeout := time.After(maxDur) + for i := 0; i < wantRestarts; i++ { + select { + case <-timeout: + t.Fatal("timed out") + case <-callStartStop(): + } + } +} + +// TestNeutrinoClientNotifyReceived verifies that a call to NotifyReceived sets +// the client into the scanning state and that subsequent calls while scanning +// will call Update on the client's Rescanner. +func TestNeutrinoClientNotifyReceived(t *testing.T) { + var ( + nc = newMockNeutrinoClient() + wantNotifyReceivedCalls = 50 + wantUpdateCalls = wantNotifyReceivedCalls - 1 + ) + + // executeCalls calls NotifyReceived() synchronously n times without + // blocking the test and requires no error after each call. + executeCalls := func(n int) <-chan struct{} { + done := make(chan struct{}) + + go func() { + defer close(done) + + var addrs []btcutil.Address + for i := 0; i < n; i++ { + err := nc.NotifyReceived(addrs) + require.NoError(t, err) + } + }() + + return done + } + + // Wait for all calls to complete or test to time out. + timeout := time.After(maxDur) + select { + case <-timeout: + t.Fatal("timed out") + case <-executeCalls(wantNotifyReceivedCalls): + // Require that the expected number of calls to Update were made + // once done sending all NotifyReceived calls. + mockRescan := nc.rescan.(*mockRescanner) + gotUpdateCalls := mockRescan.updateArgs.Len() + require.Equal(t, wantUpdateCalls, gotUpdateCalls) + } +} + +// TestNeutrinoClientNotifyReceivedRescan verifies concurrent calls to +// NotifyBlocks, NotifyReceived and Rescan do not result in a data race +// and that there is no panic on replacing the rescan goroutine single instance. +// +// Each successful method call writes a success message to a buffered channel. +// The channel is buffered so that no concurrent reader is needed. The buffer +// size is exactly the number of goroutines launched because each goroutine +// must finish successfully or else this test will fail. Each message is read +// out of the channel to verify the number of messages received is the number +// expected (i.e., wantMsgs == gotMsgs). +func TestNeutrinoClientNotifyReceivedRescan(t *testing.T) { + var ( + addrs []btcutil.Address + nc = newMockNeutrinoClient() + wantMsgs = 100 + gotMsgs = 0 + msgCh = make(chan string, wantMsgs) + msgPrefix = "successfully called" + + // sendMsg writes a message to the buffered message channel. + sendMsg = func(s string) { + msgCh <- fmt.Sprintf("%s %s", msgPrefix, s) + } + ) + + // Define closures to wrap desired neutrino client method calls. + + // cleanup is the shared cleanup function for a closure executing + // a neutrino client method call. It sends a message and then + // decrements the wait group counter. + cleanup := func(wg *sync.WaitGroup, s string) { + defer wg.Done() + sendMsg(s) + } + + // callRescan calls the Rescan() method and asserts it completes + // with no errors. Rescan() is called with the hash of an empty header + // on each call. + startHash := new(wire.BlockHeader).BlockHash() + callRescan := func(wg *sync.WaitGroup) { + defer cleanup(wg, "rescan") + + err := nc.Rescan(&startHash, addrs, nil) + require.NoError(t, err) + } + + // callNotifyReceived calls the NotifyReceived() method and asserts it + // completes with no errors. + callNotifyReceived := func(wg *sync.WaitGroup) { + defer cleanup(wg, "notify received") + + err := nc.NotifyReceived(addrs) + require.NoError(t, err) + } + + // callNotifyBlocks calls the NotifyBlocks() method and asserts it + // completes with no errors. + callNotifyBlocks := func(wg *sync.WaitGroup) { + defer cleanup(wg, "notify blocks") + + err := nc.NotifyBlocks() + require.NoError(t, err) + } + + // executeCalls launches the wanted number of goroutines, waits + // for them to finish and signals all done by closing the returned + // channel. + executeCalls := func(n int) <-chan struct{} { + done := make(chan struct{}) + + go func() { + defer close(done) + + var wg sync.WaitGroup + defer wg.Wait() + + wg.Add(n) + for i := 0; i < n; i++ { + if i%3 == 0 { + go callRescan(&wg) + continue + } + + if i%10 == 0 { + go callNotifyBlocks(&wg) + continue + } + + go callNotifyReceived(&wg) + } + }() + + return done + } + + // Start the client. + err := nc.Start() + require.NoError(t, err) + + // Wait for all calls to complete or test to time out. + timeout := time.After(maxDur) + select { + case <-timeout: + t.Fatal("timed out") + case <-executeCalls(wantMsgs): + // Ensure that exactly wantRoutines number of calls were made + // by counting the results on the message channel. + close(msgCh) + for str := range msgCh { + assert.Contains(t, str, msgPrefix) + gotMsgs++ + } + + require.Equal(t, wantMsgs, gotMsgs) + } +} diff --git a/chain/rescan.go b/chain/rescan.go new file mode 100644 index 0000000000..2b5851d35d --- /dev/null +++ b/chain/rescan.go @@ -0,0 +1,39 @@ +package chain + +import "github.com/lightninglabs/neutrino" + +var _ rescanner = (*neutrino.Rescan)(nil) + +// rescanner is an interface that abstractly defines the public methods of +// a *neutrino.Rescan. The interface is private because it is only ever +// intended to be implemented by a *neutrino.Rescan. +type rescanner interface { + starter + updater + + // WaitForShutdown blocks until the underlying rescan object is shutdown. + // Close the quit channel before calling WaitForShutdown. + WaitForShutdown() +} + +// updater is the interface that wraps the Update method of a rescan object. +type updater interface { + // Update targets a long-running rescan/notification client with + // updateable filters. Attempts to update the filters will fail + // if either the rescan is no longer running or the shutdown signal is + // received prior to sending the update. + Update(...neutrino.UpdateOption) error +} + +// starter is the interface that wraps the Start method of a rescan object. +type starter interface { + // Start initializes the rescan goroutine, which will begin to scan the chain + // according to the specified rescan options. Start returns a channel that + // communicates any startup errors. Attempts to start a running rescan + // goroutine will error. + Start() <-chan error +} + +// newRescanFunc defines a constructor that accepts rescan options and returns +// an object that satisfies rescanner interface. +type newRescanFunc func(...neutrino.RescanOption) rescanner