Skip to content

Commit

Permalink
Merge pull request #822 from MStreet3/bug/rescan-data-race
Browse files Browse the repository at this point in the history
chain: fix NeutrinoClient segfault on NotifyReceived call
  • Loading branch information
guggero authored May 10, 2023
2 parents 68f7e23 + 8c31629 commit 4383930
Show file tree
Hide file tree
Showing 6 changed files with 508 additions and 12 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ vendor
.idea
coverage.txt
*.swp
.vscode
40 changes: 40 additions & 0 deletions chain/chainservice.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package chain

import (
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/btcutil/gcs"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightninglabs/neutrino"
"github.com/lightninglabs/neutrino/banman"
"github.com/lightninglabs/neutrino/headerfs"
)

// NeutrinoChainService is an interface that encapsulates all the public
// methods of a *neutrino.ChainService
type NeutrinoChainService interface {
Start() error
GetBlock(chainhash.Hash, ...neutrino.QueryOption) (*btcutil.Block, error)
GetBlockHeight(*chainhash.Hash) (int32, error)
BestBlock() (*headerfs.BlockStamp, error)
GetBlockHash(int64) (*chainhash.Hash, error)
GetBlockHeader(*chainhash.Hash) (*wire.BlockHeader, error)
IsCurrent() bool
SendTransaction(*wire.MsgTx) error
GetCFilter(chainhash.Hash, wire.FilterType,
...neutrino.QueryOption) (*gcs.Filter, error)
GetUtxo(...neutrino.RescanOption) (*neutrino.SpendReport, error)
BanPeer(string, banman.Reason) error
IsBanned(addr string) bool
AddPeer(*neutrino.ServerPeer)
AddBytesSent(uint64)
AddBytesReceived(uint64)
NetTotals() (uint64, uint64)
UpdatePeerHeights(*chainhash.Hash, int32, *neutrino.ServerPeer)
ChainParams() chaincfg.Params
Stop() error
PeerByAddr(string) *neutrino.ServerPeer
}

var _ NeutrinoChainService = (*neutrino.ChainService)(nil)
157 changes: 157 additions & 0 deletions chain/mocks_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
package chain

import (
"container/list"
"errors"

"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/btcutil/gcs"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightninglabs/neutrino"
"github.com/lightninglabs/neutrino/banman"
"github.com/lightninglabs/neutrino/headerfs"
)

var (
errNotImplemented = errors.New("not implemented")
testBestBlock = &headerfs.BlockStamp{
Height: 42,
}
)

var (
_ rescanner = (*mockRescanner)(nil)
_ NeutrinoChainService = (*mockChainService)(nil)
)

// newMockNeutrinoClient constructs a neutrino client with a mock chain
// service implementation and mock rescanner interface implementation.
func newMockNeutrinoClient() *NeutrinoClient {
// newRescanFunc returns a mockRescanner
newRescanFunc := func(ro ...neutrino.RescanOption) rescanner {
return &mockRescanner{
updateArgs: list.New(),
}
}

return &NeutrinoClient{
CS: &mockChainService{},
newRescan: newRescanFunc,
}
}

// mockRescanner is a mock implementation of a rescanner interface for use in
// tests. Only the Update method is implemented.
type mockRescanner struct {
updateArgs *list.List
}

func (m *mockRescanner) Update(opts ...neutrino.UpdateOption) error {
m.updateArgs.PushBack(opts)
return nil
}

func (m *mockRescanner) Start() <-chan error {
return nil
}

func (m *mockRescanner) WaitForShutdown() {
// no-op
}

// mockChainService is a mock implementation of a chain service for use in
// tests. Only the Start, GetBlockHeader and BestBlock methods are implemented.
type mockChainService struct {
}

func (m *mockChainService) Start() error {
return nil
}

func (m *mockChainService) BestBlock() (*headerfs.BlockStamp, error) {
return testBestBlock, nil
}

func (m *mockChainService) GetBlockHeader(
*chainhash.Hash) (*wire.BlockHeader, error) {

return &wire.BlockHeader{}, nil
}

func (m *mockChainService) GetBlock(chainhash.Hash,
...neutrino.QueryOption) (*btcutil.Block, error) {

return nil, errNotImplemented
}

func (m *mockChainService) GetBlockHeight(*chainhash.Hash) (int32, error) {
return 0, errNotImplemented
}

func (m *mockChainService) GetBlockHash(int64) (*chainhash.Hash, error) {
return nil, errNotImplemented
}

func (m *mockChainService) IsCurrent() bool {
return false
}

func (m *mockChainService) SendTransaction(*wire.MsgTx) error {
return errNotImplemented
}

func (m *mockChainService) GetCFilter(chainhash.Hash,
wire.FilterType, ...neutrino.QueryOption) (*gcs.Filter, error) {

return nil, errNotImplemented
}

func (m *mockChainService) GetUtxo(
_ ...neutrino.RescanOption) (*neutrino.SpendReport, error) {

return nil, errNotImplemented
}

func (m *mockChainService) BanPeer(string, banman.Reason) error {
return errNotImplemented
}

func (m *mockChainService) IsBanned(addr string) bool {
panic(errNotImplemented)
}

func (m *mockChainService) AddPeer(*neutrino.ServerPeer) {
panic(errNotImplemented)
}

func (m *mockChainService) AddBytesSent(uint64) {
panic(errNotImplemented)
}

func (m *mockChainService) AddBytesReceived(uint64) {
panic(errNotImplemented)
}

func (m *mockChainService) NetTotals() (uint64, uint64) {
panic(errNotImplemented)
}

func (m *mockChainService) UpdatePeerHeights(*chainhash.Hash,
int32, *neutrino.ServerPeer,
) {
panic(errNotImplemented)
}

func (m *mockChainService) ChainParams() chaincfg.Params {
panic(errNotImplemented)
}

func (m *mockChainService) Stop() error {
panic(errNotImplemented)
}

func (m *mockChainService) PeerByAddr(string) *neutrino.ServerPeer {
panic(errNotImplemented)
}
65 changes: 53 additions & 12 deletions chain/neutrino.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,21 @@ import (
"github.com/lightninglabs/neutrino/headerfs"
)

// NeutrinoClient is an implementation of the btcwalet chain.Interface interface.
// NeutrinoClient is an implementation of the btcwallet chain.Interface interface.
type NeutrinoClient struct {
CS *neutrino.ChainService
CS NeutrinoChainService

chainParams *chaincfg.Params

// We currently support one rescan/notifiction goroutine per client
rescan *neutrino.Rescan
// We currently support only one rescan/notification goroutine per client.
// Therefore there can only be one instance of the rescan object and
// the rescanMtx synchronizes its access. Calls to the NotifyReceived
// and Rescan methods of the client must hold the rescan mutex lock for
// the length of their execution to ensure that all operations that
// affect the rescan object are atomic.
rescan rescanner
newRescan newRescanFunc
rescanMtx sync.Mutex

enqueueNotification chan interface{}
dequeueNotification chan interface{}
Expand All @@ -45,6 +52,14 @@ type NeutrinoClient struct {
finished bool
isRescan bool

// The clientMtx synchronizes access to the state variables of the client.
//
// TODO(mstreet3): Currently the clientMtx synchronizes access to the
// rescanQuit and rescanErr channels, which cancel the current rescan
// goroutine when closed and is updated each time a new rescan goroutine
// is created, respectively. All state related to the rescan goroutine
// should ideally be synchronized by the same lock or via some other
// shared mechanism.
clientMtx sync.Mutex
}

Expand All @@ -53,9 +68,21 @@ type NeutrinoClient struct {
func NewNeutrinoClient(chainParams *chaincfg.Params,
chainService *neutrino.ChainService) *NeutrinoClient {

chainSource := &neutrino.RescanChainSource{
ChainService: chainService,
}

// Adapt the neutrino.NewRescan constructor to satisfy the
// newRescanFunc type by closing over the chainSource and
// passing in the rescan options.
newRescan := func(ropts ...neutrino.RescanOption) rescanner {
return neutrino.NewRescan(chainSource, ropts...)
}

return &NeutrinoClient{
CS: chainService,
chainParams: chainParams,
newRescan: newRescan,
}
}

Expand All @@ -73,24 +100,36 @@ func (s *NeutrinoClient) Start() error {
s.clientMtx.Lock()
defer s.clientMtx.Unlock()
if !s.started {
// Reset the client state.
s.enqueueNotification = make(chan interface{})
s.dequeueNotification = make(chan interface{})
s.currentBlock = make(chan *waddrmgr.BlockStamp)
s.quit = make(chan struct{})
s.started = true

// Go place a ClientConnected notification onto the queue.
s.wg.Add(1)
go func() {
defer s.wg.Done()

select {
case s.enqueueNotification <- ClientConnected{}:
case <-s.quit:
}
}()

// Go launch the notification handler.
s.wg.Add(1)
go s.notificationHandler()
}
return nil
}

// Stop replicates the RPC client's Stop method.
//
// TODO(mstreet3): The Stop method does not cancel the long-running rescan
// goroutine. This is a memory leak. Stop should shutdown the rescan goroutine
// and reset the scanning state of the NeutrinoClient to false.
func (s *NeutrinoClient) Stop() {
s.clientMtx.Lock()
defer s.clientMtx.Unlock()
Expand Down Expand Up @@ -338,6 +377,10 @@ func (s *NeutrinoClient) pollCFilter(hash *chainhash.Hash) (*gcs.Filter, error)
func (s *NeutrinoClient) Rescan(startHash *chainhash.Hash, addrs []btcutil.Address,
outPoints map[wire.OutPoint]btcutil.Address) error {

// Obtain and hold the rescan mutex lock for the duration of the call.
s.rescanMtx.Lock()
defer s.rescanMtx.Unlock()

s.clientMtx.Lock()
if !s.started {
s.clientMtx.Unlock()
Expand Down Expand Up @@ -411,10 +454,7 @@ func (s *NeutrinoClient) Rescan(startHash *chainhash.Hash, addrs []btcutil.Addre
}

s.clientMtx.Lock()
newRescan := neutrino.NewRescan(
&neutrino.RescanChainSource{
ChainService: s.CS,
},
newRescan := s.newRescan(
neutrino.NotificationHandlers(rpcclient.NotificationHandlers{
OnBlockConnected: s.onBlockConnected,
OnFilteredBlockConnected: s.onFilteredBlockConnected,
Expand Down Expand Up @@ -448,6 +488,10 @@ func (s *NeutrinoClient) NotifyBlocks() error {

// NotifyReceived replicates the RPC client's NotifyReceived command.
func (s *NeutrinoClient) NotifyReceived(addrs []btcutil.Address) error {
// Obtain and hold the rescan mutex lock for the duration of the call.
s.rescanMtx.Lock()
defer s.rescanMtx.Unlock()

s.clientMtx.Lock()

// If we have a rescan running, we just need to add the appropriate
Expand All @@ -466,10 +510,7 @@ func (s *NeutrinoClient) NotifyReceived(addrs []btcutil.Address) error {
s.lastFilteredBlockHeader = nil

// Rescan with just the specified addresses.
newRescan := neutrino.NewRescan(
&neutrino.RescanChainSource{
ChainService: s.CS,
},
newRescan := s.newRescan(
neutrino.NotificationHandlers(rpcclient.NotificationHandlers{
OnBlockConnected: s.onBlockConnected,
OnFilteredBlockConnected: s.onFilteredBlockConnected,
Expand Down
Loading

0 comments on commit 4383930

Please sign in to comment.