Skip to content

Commit

Permalink
Merge pull request btcsuite#822 from MStreet3/bug/rescan-data-race
Browse files Browse the repository at this point in the history
chain: fix NeutrinoClient segfault on NotifyReceived call
  • Loading branch information
guggero authored and buck54321 committed Apr 21, 2024
1 parent 8d116ac commit 4a9290e
Show file tree
Hide file tree
Showing 6 changed files with 508 additions and 12 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ vendor
.idea
coverage.txt
*.swp
.vscode
40 changes: 40 additions & 0 deletions chain/chainservice.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package chain

import (
"github.com/dcrlabs/ltcwallet/spv"
"github.com/dcrlabs/ltcwallet/spv/banman"
"github.com/dcrlabs/ltcwallet/spv/headerfs"
"github.com/ltcsuite/ltcd/chaincfg"
"github.com/ltcsuite/ltcd/chaincfg/chainhash"
"github.com/ltcsuite/ltcd/ltcutil"
"github.com/ltcsuite/ltcd/ltcutil/gcs"
"github.com/ltcsuite/ltcd/wire"
)

// NeutrinoChainService is an interface that encapsulates all the public
// methods of a *neutrino.ChainService
type NeutrinoChainService interface {
Start() error
GetBlock(chainhash.Hash, ...spv.QueryOption) (*ltcutil.Block, error)
GetBlockHeight(*chainhash.Hash) (int32, error)
BestBlock() (*headerfs.BlockStamp, error)
GetBlockHash(int64) (*chainhash.Hash, error)
GetBlockHeader(*chainhash.Hash) (*wire.BlockHeader, error)
IsCurrent() bool
SendTransaction(*wire.MsgTx) error
GetCFilter(chainhash.Hash, wire.FilterType,
...spv.QueryOption) (*gcs.Filter, error)
GetUtxo(...spv.RescanOption) (*spv.SpendReport, error)
BanPeer(string, banman.Reason) error
IsBanned(addr string) bool
AddPeer(*spv.ServerPeer)
AddBytesSent(uint64)
AddBytesReceived(uint64)
NetTotals() (uint64, uint64)
UpdatePeerHeights(*chainhash.Hash, int32, *spv.ServerPeer)
ChainParams() chaincfg.Params
Stop() error
PeerByAddr(string) *spv.ServerPeer
}

var _ NeutrinoChainService = (*spv.ChainService)(nil)
157 changes: 157 additions & 0 deletions chain/mocks_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
package chain

import (
"container/list"
"errors"

"github.com/dcrlabs/ltcwallet/spv"
"github.com/dcrlabs/ltcwallet/spv/banman"
"github.com/dcrlabs/ltcwallet/spv/headerfs"
"github.com/ltcsuite/ltcd/chaincfg"
"github.com/ltcsuite/ltcd/chaincfg/chainhash"
"github.com/ltcsuite/ltcd/ltcutil"
"github.com/ltcsuite/ltcd/ltcutil/gcs"
"github.com/ltcsuite/ltcd/wire"
)

var (
errNotImplemented = errors.New("not implemented")
testBestBlock = &headerfs.BlockStamp{
Height: 42,
}
)

var (
_ rescanner = (*mockRescanner)(nil)
_ NeutrinoChainService = (*mockChainService)(nil)
)

// newMockNeutrinoClient constructs a neutrino client with a mock chain
// service implementation and mock rescanner interface implementation.
func newMockNeutrinoClient() *NeutrinoClient {
// newRescanFunc returns a mockRescanner
newRescanFunc := func(ro ...spv.RescanOption) rescanner {
return &mockRescanner{
updateArgs: list.New(),
}
}

return &NeutrinoClient{
CS: &mockChainService{},
newRescan: newRescanFunc,
}
}

// mockRescanner is a mock implementation of a rescanner interface for use in
// tests. Only the Update method is implemented.
type mockRescanner struct {
updateArgs *list.List
}

func (m *mockRescanner) Update(opts ...spv.UpdateOption) error {
m.updateArgs.PushBack(opts)
return nil
}

func (m *mockRescanner) Start() <-chan error {
return nil
}

func (m *mockRescanner) WaitForShutdown() {
// no-op
}

// mockChainService is a mock implementation of a chain service for use in
// tests. Only the Start, GetBlockHeader and BestBlock methods are implemented.
type mockChainService struct {
}

func (m *mockChainService) Start() error {
return nil
}

func (m *mockChainService) BestBlock() (*headerfs.BlockStamp, error) {
return testBestBlock, nil
}

func (m *mockChainService) GetBlockHeader(
*chainhash.Hash) (*wire.BlockHeader, error) {

return &wire.BlockHeader{}, nil
}

func (m *mockChainService) GetBlock(chainhash.Hash,
...spv.QueryOption) (*ltcutil.Block, error) {

return nil, errNotImplemented
}

func (m *mockChainService) GetBlockHeight(*chainhash.Hash) (int32, error) {
return 0, errNotImplemented
}

func (m *mockChainService) GetBlockHash(int64) (*chainhash.Hash, error) {
return nil, errNotImplemented
}

func (m *mockChainService) IsCurrent() bool {
return false
}

func (m *mockChainService) SendTransaction(*wire.MsgTx) error {
return errNotImplemented
}

func (m *mockChainService) GetCFilter(chainhash.Hash,
wire.FilterType, ...spv.QueryOption) (*gcs.Filter, error) {

return nil, errNotImplemented
}

func (m *mockChainService) GetUtxo(
_ ...spv.RescanOption) (*spv.SpendReport, error) {

return nil, errNotImplemented
}

func (m *mockChainService) BanPeer(string, banman.Reason) error {
return errNotImplemented
}

func (m *mockChainService) IsBanned(addr string) bool {
panic(errNotImplemented)
}

func (m *mockChainService) AddPeer(*spv.ServerPeer) {
panic(errNotImplemented)
}

func (m *mockChainService) AddBytesSent(uint64) {
panic(errNotImplemented)
}

func (m *mockChainService) AddBytesReceived(uint64) {
panic(errNotImplemented)
}

func (m *mockChainService) NetTotals() (uint64, uint64) {
panic(errNotImplemented)
}

func (m *mockChainService) UpdatePeerHeights(*chainhash.Hash,
int32, *spv.ServerPeer,
) {
panic(errNotImplemented)
}

func (m *mockChainService) ChainParams() chaincfg.Params {
panic(errNotImplemented)
}

func (m *mockChainService) Stop() error {
panic(errNotImplemented)
}

func (m *mockChainService) PeerByAddr(string) *spv.ServerPeer {
panic(errNotImplemented)
}
65 changes: 53 additions & 12 deletions chain/neutrino.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,21 @@ import (
"github.com/ltcsuite/ltcd/wire"
)

// NeutrinoClient is an implementation of the ltcwallet chain.Interface interface.
// NeutrinoClient is an implementation of the btcwallet chain.Interface interface.
type NeutrinoClient struct {
CS *spv.ChainService
CS NeutrinoChainService

chainParams *chaincfg.Params

// We currently support one rescan/notifiction goroutine per client
rescan *spv.Rescan
// We currently support only one rescan/notification goroutine per client.
// Therefore there can only be one instance of the rescan object and
// the rescanMtx synchronizes its access. Calls to the NotifyReceived
// and Rescan methods of the client must hold the rescan mutex lock for
// the length of their execution to ensure that all operations that
// affect the rescan object are atomic.
rescan rescanner
newRescan newRescanFunc
rescanMtx sync.Mutex

enqueueNotification chan interface{}
dequeueNotification chan interface{}
Expand All @@ -45,6 +52,14 @@ type NeutrinoClient struct {
finished bool
isRescan bool

// The clientMtx synchronizes access to the state variables of the client.
//
// TODO(mstreet3): Currently the clientMtx synchronizes access to the
// rescanQuit and rescanErr channels, which cancel the current rescan
// goroutine when closed and is updated each time a new rescan goroutine
// is created, respectively. All state related to the rescan goroutine
// should ideally be synchronized by the same lock or via some other
// shared mechanism.
clientMtx sync.Mutex
}

Expand All @@ -53,9 +68,21 @@ type NeutrinoClient struct {
func NewNeutrinoClient(chainParams *chaincfg.Params,
chainService *spv.ChainService) *NeutrinoClient {

chainSource := &spv.RescanChainSource{
ChainService: chainService,
}

// Adapt the spv.NewRescan constructor to satisfy the
// newRescanFunc type by closing over the chainSource and
// passing in the rescan options.
newRescan := func(ropts ...spv.RescanOption) rescanner {
return spv.NewRescan(chainSource, ropts...)
}

return &NeutrinoClient{
CS: chainService,
chainParams: chainParams,
newRescan: newRescan,
}
}

Expand All @@ -73,24 +100,36 @@ func (s *NeutrinoClient) Start() error {
s.clientMtx.Lock()
defer s.clientMtx.Unlock()
if !s.started {
// Reset the client state.
s.enqueueNotification = make(chan interface{})
s.dequeueNotification = make(chan interface{})
s.currentBlock = make(chan *waddrmgr.BlockStamp)
s.quit = make(chan struct{})
s.started = true

// Go place a ClientConnected notification onto the queue.
s.wg.Add(1)
go func() {
defer s.wg.Done()

select {
case s.enqueueNotification <- ClientConnected{}:
case <-s.quit:
}
}()

// Go launch the notification handler.
s.wg.Add(1)
go s.notificationHandler()
}
return nil
}

// Stop replicates the RPC client's Stop method.
//
// TODO(mstreet3): The Stop method does not cancel the long-running rescan
// goroutine. This is a memory leak. Stop should shutdown the rescan goroutine
// and reset the scanning state of the NeutrinoClient to false.
func (s *NeutrinoClient) Stop() {
s.clientMtx.Lock()
defer s.clientMtx.Unlock()
Expand Down Expand Up @@ -338,6 +377,10 @@ func (s *NeutrinoClient) pollCFilter(hash *chainhash.Hash) (*gcs.Filter, error)
func (s *NeutrinoClient) Rescan(startHash *chainhash.Hash, addrs []ltcutil.Address,
outPoints map[wire.OutPoint]ltcutil.Address) error {

// Obtain and hold the rescan mutex lock for the duration of the call.
s.rescanMtx.Lock()
defer s.rescanMtx.Unlock()

s.clientMtx.Lock()
if !s.started {
s.clientMtx.Unlock()
Expand Down Expand Up @@ -411,10 +454,7 @@ func (s *NeutrinoClient) Rescan(startHash *chainhash.Hash, addrs []ltcutil.Addre
}

s.clientMtx.Lock()
newRescan := spv.NewRescan(
&spv.RescanChainSource{
ChainService: s.CS,
},
newRescan := s.newRescan(
spv.NotificationHandlers(rpcclient.NotificationHandlers{
OnBlockConnected: s.onBlockConnected,
OnFilteredBlockConnected: s.onFilteredBlockConnected,
Expand Down Expand Up @@ -449,6 +489,10 @@ func (s *NeutrinoClient) NotifyBlocks() error {

// NotifyReceived replicates the RPC client's NotifyReceived command.
func (s *NeutrinoClient) NotifyReceived(addrs []ltcutil.Address) error {
// Obtain and hold the rescan mutex lock for the duration of the call.
s.rescanMtx.Lock()
defer s.rescanMtx.Unlock()

s.clientMtx.Lock()

// If we have a rescan running, we just need to add the appropriate
Expand All @@ -467,10 +511,7 @@ func (s *NeutrinoClient) NotifyReceived(addrs []ltcutil.Address) error {
s.lastFilteredBlockHeader = nil

// Rescan with just the specified addresses.
newRescan := spv.NewRescan(
&spv.RescanChainSource{
ChainService: s.CS,
},
newRescan := s.newRescan(
spv.NotificationHandlers(rpcclient.NotificationHandlers{
OnBlockConnected: s.onBlockConnected,
OnFilteredBlockConnected: s.onFilteredBlockConnected,
Expand Down
Loading

0 comments on commit 4a9290e

Please sign in to comment.