Skip to content

Commit

Permalink
chain: add unit tests
Browse files Browse the repository at this point in the history
Add TestNeutrinoClientSequentialStartStop to verify that the
neutrino client can have Start() and Stop() called sequentially.
Test case fails with -race flag enabled.

Add TestNeutrinoClientNotifyReceived to verify that calls to
NotifyReceived call the Update method of the rescan object.  Verify
that there is no race condition on any state variables of the
NeutrinoClient.  Pass test with -race flag enabled.

Add a unit test that demonstrates a segmentation fault exists when
concurrent calls to NotifyReceived, NotifyBlocks and Rescan methods of
the NeutrinoClient are executed.  The rescan property is set to nil
and a segmentation fault arises.
  • Loading branch information
MStreet3 committed Dec 2, 2022
1 parent dafb89e commit 30824ef
Show file tree
Hide file tree
Showing 2 changed files with 383 additions and 0 deletions.
171 changes: 171 additions & 0 deletions chain/mocks_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
package chain

import (
"container/list"
"errors"

"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/btcutil/gcs"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightninglabs/neutrino"
"github.com/lightninglabs/neutrino/banman"
"github.com/lightninglabs/neutrino/headerfs"
)

var (
ErrNotImplemented = errors.New("not implemented")
testBestBlock = &headerfs.BlockStamp{
Height: 42,
}
)

var (
_ rescanner = (*mockRescanner)(nil)
_ NeutrinoChainService = (*mockChainService)(nil)
)

// newMockNeutrinoClient constructs a neutrino client with a mock chain
// service implementation and mock rescanner interface implementation.
func newMockNeutrinoClient() *NeutrinoClient {
// newRescanFunc returns a mockRescanner
newRescanFunc := func(ro ...neutrino.RescanOption) rescanner {
return &mockRescanner{
updateArgs: list.New(),
}
}

return &NeutrinoClient{
CS: &mockChainService{},
newRescan: newRescanFunc,
}
}

// mockRescanner is a mock implementation of a rescanner interface for use in
// tests. Only the Update method is implemented.
type mockRescanner struct {
updateArgs *list.List
}

func (m *mockRescanner) Start() <-chan error {
errs := make(chan error)
return errs
}

func (m *mockRescanner) WaitForShutdown() {
// no-op
}

func (m *mockRescanner) Update(opts ...neutrino.UpdateOption) error {
m.updateArgs.PushBack(opts)
return nil
}

// mockChainService is a mock implementation of a chain service for use in
// tests. Only the Start, GetBlockHeader and BestBlock methods are implemented.
type mockChainService struct {
bestBlock func() (*headerfs.BlockStamp, error)
}

func (m *mockChainService) Start() error {
return nil
}

func (m *mockChainService) BestBlock() (*headerfs.BlockStamp, error) {
impl := m.getBestBlock()
return impl()
}

func (m *mockChainService) GetBlockHeader(
*chainhash.Hash) (*wire.BlockHeader, error) {

return &wire.BlockHeader{}, nil
}

// getBestBlock returns a BestBlock implementation that defaults to simply
// returning the value of testBestBlock with no error.
func (m *mockChainService) getBestBlock() func() (*headerfs.BlockStamp, error) {
if m.bestBlock == nil {
m.bestBlock = func() (*headerfs.BlockStamp, error) {
return testBestBlock, nil
}
}
return m.bestBlock
}

func (m *mockChainService) GetBlock(chainhash.Hash,
...neutrino.QueryOption) (*btcutil.Block, error) {

return nil, ErrNotImplemented
}

func (m *mockChainService) GetBlockHeight(*chainhash.Hash) (int32, error) {
return 0, ErrNotImplemented
}

func (m *mockChainService) GetBlockHash(int64) (*chainhash.Hash, error) {
return nil, ErrNotImplemented
}

func (m *mockChainService) IsCurrent() bool {
return false
}

func (m *mockChainService) SendTransaction(*wire.MsgTx) error {
return ErrNotImplemented
}

func (m *mockChainService) GetCFilter(chainhash.Hash,
wire.FilterType, ...neutrino.QueryOption) (*gcs.Filter, error) {

return nil, ErrNotImplemented
}

func (m *mockChainService) GetUtxo(
_ ...neutrino.RescanOption) (*neutrino.SpendReport, error) {

return nil, ErrNotImplemented
}

func (m *mockChainService) BanPeer(string, banman.Reason) error {
return ErrNotImplemented
}

func (m *mockChainService) IsBanned(addr string) bool {
panic(ErrNotImplemented)
}

func (m *mockChainService) AddPeer(*neutrino.ServerPeer) {
panic(ErrNotImplemented)
}

func (m *mockChainService) AddBytesSent(uint64) {
panic(ErrNotImplemented)
}

func (m *mockChainService) AddBytesReceived(uint64) {
panic(ErrNotImplemented)
}

func (m *mockChainService) NetTotals() (uint64, uint64) {
panic(ErrNotImplemented)
}

func (m *mockChainService) UpdatePeerHeights(*chainhash.Hash,
int32, *neutrino.ServerPeer,
) {
panic(ErrNotImplemented)
}

func (m *mockChainService) ChainParams() chaincfg.Params {
panic(ErrNotImplemented)
}

func (m *mockChainService) Stop() error {
panic(ErrNotImplemented)
}

func (m *mockChainService) PeerByAddr(string) *neutrino.ServerPeer {
panic(ErrNotImplemented)
}
212 changes: 212 additions & 0 deletions chain/neutrino_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
package chain

import (
"fmt"
"sync"
"testing"
"time"

"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/wire"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

var (
// maxDur is the max duration a test has to execute successfully.
maxDur = 5 * time.Second
timeout = time.After(maxDur)
nc = newMockNeutrinoClient()
addrs []btcutil.Address
)

// TestNeutrinoClientSequentialStartStop ensures that the client
// can sequentially Start and Stop without errors or races.
func TestNeutrinoClientSequentialStartStop(t *testing.T) {
wantRestarts := 50

// callStartStop starts the neutrino client, requires no error on
// startup, immediately stops the client and waits for shutdown.
// The returned channel is closed once shutdown is complete.
callStartStop := func() <-chan struct{} {
done := make(chan struct{})

go func() {
defer close(done)

err := nc.Start()
require.NoError(t, err)
nc.Stop()
nc.WaitForShutdown()
}()

return done
}

// For each wanted restart, execute callStartStop and wait until the
// call is done before continuing to the next execution. Waiting for
// a read from done forces all executions of callStartStop to be done
// sequentially.
//
// The test fails if all of the wanted restarts cannot be completed
// sequentially before the timeout is reached.
for i := 0; i < wantRestarts; i++ {
select {
case <-timeout:
t.Fatal("timed out")
case <-callStartStop():
}
}
}

// TestNeutrinoClientNotifyReceived verifies that a call to NotifyReceived sets
// the client into the scanning state and that subsequent calls while scanning
// will call Update on the client's Rescanner.
func TestNeutrinoClientNotifyReceived(t *testing.T) {
wantNotifyReceivedCalls := 50
wantUpdateCalls := wantNotifyReceivedCalls - 1

// executeCalls calls NotifyReceived() synchronously n times without
// blocking the test and requires no error after each call.
executeCalls := func(n int) <-chan struct{} {
done := make(chan struct{})

go func() {
defer close(done)

for i := 0; i < n; i++ {
err := nc.NotifyReceived(addrs)
require.NoError(t, err)
}
}()

return done
}

// Wait for all calls to complete or test to time out.
select {
case <-timeout:
t.Fatal("timed out")
case <-executeCalls(wantNotifyReceivedCalls):
// Require that the expected number of calls to Update were made
// once done sending all NotifyReceived calls.
mockRescan := nc.rescan.(*mockRescanner)
gotUpdateCalls := mockRescan.updateArgs.Len()
require.Equal(t, wantUpdateCalls, gotUpdateCalls)
}
}

// TestNeutrinoClientNotifyReceivedRescan verifies concurrent calls to
// NotifyBlocks, NotifyReceived and Rescan do not result in a data race
// and that there is no panic on replacing the rescan goroutine single instance.
//
// Each successful method call writes a success message to a buffered channel.
// The channel is buffered so that no concurrent reader is needed. The buffer
// size is exactly the number of goroutines launched because each goroutine
// must finish successfully or else this test will fail. Each message is read
// out of the channel to verify the number of messages received is the number
// expected (i.e., wantMsgs == gotMsgs).
func TestNeutrinoClientNotifyReceivedRescan(t *testing.T) {
var (
wantMsgs = 100
gotMsgs = 0
msgCh = make(chan string, wantMsgs)
msgPrefix = "successfully called"

// sendMsg writes a message to the buffered message channel.
sendMsg = func(s string) {
msgCh <- fmt.Sprintf("%s %s", msgPrefix, s)
}
)

// Define closures to wrap desired neutrino client method calls.

// cleanup is the shared cleanup function for a closure executing
// a neutrino client method call. It sends a message and then
// decrements the wait group counter.
cleanup := func(wg *sync.WaitGroup, s string) {
defer wg.Done()
sendMsg(s)
}

// callRescan calls the Rescan() method and asserts it completes
// with no errors. Rescan() is called with the hash of an empty header
// on each call.
startHash := new(wire.BlockHeader).BlockHash()
callRescan := func(wg *sync.WaitGroup) {
defer cleanup(wg, "rescan")

err := nc.Rescan(&startHash, addrs, nil)
require.NoError(t, err)
}

// callNotifyReceived calls the NotifyReceived() method and asserts it
// completes with no errors.
callNotifyReceived := func(wg *sync.WaitGroup) {
defer cleanup(wg, "notify received")

err := nc.NotifyReceived(addrs)
require.NoError(t, err)
}

// callNotifyBlocks calls the NotifyBlocks() method and asserts it
// completes with no errors.
callNotifyBlocks := func(wg *sync.WaitGroup) {
defer cleanup(wg, "notify blocks")

err := nc.NotifyBlocks()
require.NoError(t, err)
}

// executeCalls launches the wanted number of goroutines, waits
// for them to finish and signals all done by closing the returned
// channel.
executeCalls := func(n int) <-chan struct{} {
done := make(chan struct{})

go func() {
defer close(done)

var wg sync.WaitGroup
defer wg.Wait()

wg.Add(n)
for i := 0; i < n; i++ {
if i%3 == 0 {
go callRescan(&wg)
continue
}

if i%10 == 0 {
go callNotifyBlocks(&wg)
continue
}

go callNotifyReceived(&wg)
}
}()

return done
}

// Start the client.
err := nc.Start()
require.NoError(t, err)

// Wait for all calls to complete or test to time out.
select {
case <-timeout:
t.Fatal("timed out")
case <-executeCalls(wantMsgs):
// Ensure that exactly wantRoutines number of calls were made
// by counting the results on the message channel.
close(msgCh)
for str := range msgCh {
assert.NotNil(t, str)
assert.Contains(t, str, msgPrefix)
gotMsgs++
}

require.Equal(t, wantMsgs, gotMsgs)
}
}

0 comments on commit 30824ef

Please sign in to comment.