diff --git a/chain/mocks_test.go b/chain/mocks_test.go new file mode 100644 index 0000000000..6a53cdefb3 --- /dev/null +++ b/chain/mocks_test.go @@ -0,0 +1,171 @@ +package chain + +import ( + "container/list" + "errors" + + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/btcutil/gcs" + "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/lightninglabs/neutrino" + "github.com/lightninglabs/neutrino/banman" + "github.com/lightninglabs/neutrino/headerfs" +) + +var ( + ErrNotImplemented = errors.New("not implemented") + testBestBlock = &headerfs.BlockStamp{ + Height: 42, + } +) + +var ( + _ rescanner = (*mockRescanner)(nil) + _ NeutrinoChainService = (*mockChainService)(nil) +) + +// newMockNeutrinoClient constructs a neutrino client with a mock chain +// service implementation and mock rescanner interface implementation. +func newMockNeutrinoClient() *NeutrinoClient { + // newRescanFunc returns a mockRescanner + newRescanFunc := func(ro ...neutrino.RescanOption) rescanner { + return &mockRescanner{ + updateArgs: list.New(), + } + } + + return &NeutrinoClient{ + CS: &mockChainService{}, + newRescan: newRescanFunc, + } +} + +// mockRescanner is a mock implementation of a rescanner interface for use in +// tests. Only the Update method is implemented. +type mockRescanner struct { + updateArgs *list.List +} + +func (m *mockRescanner) Start() <-chan error { + errs := make(chan error) + return errs +} + +func (m *mockRescanner) WaitForShutdown() { + // no-op +} + +func (m *mockRescanner) Update(opts ...neutrino.UpdateOption) error { + m.updateArgs.PushBack(opts) + return nil +} + +// mockChainService is a mock implementation of a chain service for use in +// tests. Only the Start, GetBlockHeader and BestBlock methods are implemented. +type mockChainService struct { + bestBlock func() (*headerfs.BlockStamp, error) +} + +func (m *mockChainService) Start() error { + return nil +} + +func (m *mockChainService) BestBlock() (*headerfs.BlockStamp, error) { + impl := m.getBestBlock() + return impl() +} + +func (m *mockChainService) GetBlockHeader( + *chainhash.Hash) (*wire.BlockHeader, error) { + + return &wire.BlockHeader{}, nil +} + +// getBestBlock returns a BestBlock implementation that defaults to simply +// returning the value of testBestBlock with no error. +func (m *mockChainService) getBestBlock() func() (*headerfs.BlockStamp, error) { + if m.bestBlock == nil { + m.bestBlock = func() (*headerfs.BlockStamp, error) { + return testBestBlock, nil + } + } + return m.bestBlock +} + +func (m *mockChainService) GetBlock(chainhash.Hash, + ...neutrino.QueryOption) (*btcutil.Block, error) { + + return nil, ErrNotImplemented +} + +func (m *mockChainService) GetBlockHeight(*chainhash.Hash) (int32, error) { + return 0, ErrNotImplemented +} + +func (m *mockChainService) GetBlockHash(int64) (*chainhash.Hash, error) { + return nil, ErrNotImplemented +} + +func (m *mockChainService) IsCurrent() bool { + return false +} + +func (m *mockChainService) SendTransaction(*wire.MsgTx) error { + return ErrNotImplemented +} + +func (m *mockChainService) GetCFilter(chainhash.Hash, + wire.FilterType, ...neutrino.QueryOption) (*gcs.Filter, error) { + + return nil, ErrNotImplemented +} + +func (m *mockChainService) GetUtxo( + _ ...neutrino.RescanOption) (*neutrino.SpendReport, error) { + + return nil, ErrNotImplemented +} + +func (m *mockChainService) BanPeer(string, banman.Reason) error { + return ErrNotImplemented +} + +func (m *mockChainService) IsBanned(addr string) bool { + panic(ErrNotImplemented) +} + +func (m *mockChainService) AddPeer(*neutrino.ServerPeer) { + panic(ErrNotImplemented) +} + +func (m *mockChainService) AddBytesSent(uint64) { + panic(ErrNotImplemented) +} + +func (m *mockChainService) AddBytesReceived(uint64) { + panic(ErrNotImplemented) +} + +func (m *mockChainService) NetTotals() (uint64, uint64) { + panic(ErrNotImplemented) +} + +func (m *mockChainService) UpdatePeerHeights(*chainhash.Hash, + int32, *neutrino.ServerPeer, +) { + panic(ErrNotImplemented) +} + +func (m *mockChainService) ChainParams() chaincfg.Params { + panic(ErrNotImplemented) +} + +func (m *mockChainService) Stop() error { + panic(ErrNotImplemented) +} + +func (m *mockChainService) PeerByAddr(string) *neutrino.ServerPeer { + panic(ErrNotImplemented) +} diff --git a/chain/neutrino_test.go b/chain/neutrino_test.go new file mode 100644 index 0000000000..316871b4bb --- /dev/null +++ b/chain/neutrino_test.go @@ -0,0 +1,212 @@ +package chain + +import ( + "fmt" + "sync" + "testing" + "time" + + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/wire" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + // maxDur is the max duration a test has to execute successfully. + maxDur = 5 * time.Second + timeout = time.After(maxDur) + nc = newMockNeutrinoClient() + addrs []btcutil.Address +) + +// TestNeutrinoClientSequentialStartStop ensures that the client +// can sequentially Start and Stop without errors or races. +func TestNeutrinoClientSequentialStartStop(t *testing.T) { + wantRestarts := 50 + + // callStartStop starts the neutrino client, requires no error on + // startup, immediately stops the client and waits for shutdown. + // The returned channel is closed once shutdown is complete. + callStartStop := func() <-chan struct{} { + done := make(chan struct{}) + + go func() { + defer close(done) + + err := nc.Start() + require.NoError(t, err) + nc.Stop() + nc.WaitForShutdown() + }() + + return done + } + + // For each wanted restart, execute callStartStop and wait until the + // call is done before continuing to the next execution. Waiting for + // a read from done forces all executions of callStartStop to be done + // sequentially. + // + // The test fails if all of the wanted restarts cannot be completed + // sequentially before the timeout is reached. + for i := 0; i < wantRestarts; i++ { + select { + case <-timeout: + t.Fatal("timed out") + case <-callStartStop(): + } + } +} + +// TestNeutrinoClientNotifyReceived verifies that a call to NotifyReceived sets +// the client into the scanning state and that subsequent calls while scanning +// will call Update on the client's Rescanner. +func TestNeutrinoClientNotifyReceived(t *testing.T) { + wantNotifyReceivedCalls := 50 + wantUpdateCalls := wantNotifyReceivedCalls - 1 + + // executeCalls calls NotifyReceived() synchronously n times without + // blocking the test and requires no error after each call. + executeCalls := func(n int) <-chan struct{} { + done := make(chan struct{}) + + go func() { + defer close(done) + + for i := 0; i < n; i++ { + err := nc.NotifyReceived(addrs) + require.NoError(t, err) + } + }() + + return done + } + + // Wait for all calls to complete or test to time out. + select { + case <-timeout: + t.Fatal("timed out") + case <-executeCalls(wantNotifyReceivedCalls): + // Require that the expected number of calls to Update were made + // once done sending all NotifyReceived calls. + mockRescan := nc.rescan.(*mockRescanner) + gotUpdateCalls := mockRescan.updateArgs.Len() + require.Equal(t, wantUpdateCalls, gotUpdateCalls) + } +} + +// TestNeutrinoClientNotifyReceivedRescan verifies concurrent calls to +// NotifyBlocks, NotifyReceived and Rescan do not result in a data race +// and that there is no panic on replacing the rescan goroutine single instance. +// +// Each successful method call writes a success message to a buffered channel. +// The channel is buffered so that no concurrent reader is needed. The buffer +// size is exactly the number of goroutines launched because each goroutine +// must finish successfully or else this test will fail. Each message is read +// out of the channel to verify the number of messages received is the number +// expected (i.e., wantMsgs == gotMsgs). +func TestNeutrinoClientNotifyReceivedRescan(t *testing.T) { + var ( + wantMsgs = 100 + gotMsgs = 0 + msgCh = make(chan string, wantMsgs) + msgPrefix = "successfully called" + + // sendMsg writes a message to the buffered message channel. + sendMsg = func(s string) { + msgCh <- fmt.Sprintf("%s %s", msgPrefix, s) + } + ) + + // Define closures to wrap desired neutrino client method calls. + + // cleanup is the shared cleanup function for a closure executing + // a neutrino client method call. It sends a message and then + // decrements the wait group counter. + cleanup := func(wg *sync.WaitGroup, s string) { + defer wg.Done() + sendMsg(s) + } + + // callRescan calls the Rescan() method and asserts it completes + // with no errors. Rescan() is called with the hash of an empty header + // on each call. + startHash := new(wire.BlockHeader).BlockHash() + callRescan := func(wg *sync.WaitGroup) { + defer cleanup(wg, "rescan") + + err := nc.Rescan(&startHash, addrs, nil) + require.NoError(t, err) + } + + // callNotifyReceived calls the NotifyReceived() method and asserts it + // completes with no errors. + callNotifyReceived := func(wg *sync.WaitGroup) { + defer cleanup(wg, "notify received") + + err := nc.NotifyReceived(addrs) + require.NoError(t, err) + } + + // callNotifyBlocks calls the NotifyBlocks() method and asserts it + // completes with no errors. + callNotifyBlocks := func(wg *sync.WaitGroup) { + defer cleanup(wg, "notify blocks") + + err := nc.NotifyBlocks() + require.NoError(t, err) + } + + // executeCalls launches the wanted number of goroutines, waits + // for them to finish and signals all done by closing the returned + // channel. + executeCalls := func(n int) <-chan struct{} { + done := make(chan struct{}) + + go func() { + defer close(done) + + var wg sync.WaitGroup + defer wg.Wait() + + wg.Add(n) + for i := 0; i < n; i++ { + if i%3 == 0 { + go callRescan(&wg) + continue + } + + if i%10 == 0 { + go callNotifyBlocks(&wg) + continue + } + + go callNotifyReceived(&wg) + } + }() + + return done + } + + // Start the client. + err := nc.Start() + require.NoError(t, err) + + // Wait for all calls to complete or test to time out. + select { + case <-timeout: + t.Fatal("timed out") + case <-executeCalls(wantMsgs): + // Ensure that exactly wantRoutines number of calls were made + // by counting the results on the message channel. + close(msgCh) + for str := range msgCh { + assert.NotNil(t, str) + assert.Contains(t, str, msgPrefix) + gotMsgs++ + } + + require.Equal(t, wantMsgs, gotMsgs) + } +}