Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add context wrapping for syncer disconnections #2432

Merged
merged 1 commit into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions chain/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,24 @@ func (s *Syncer) ExistsLiveTickets(ctx context.Context, tickets []*chainhash.Has
func (s *Syncer) UsedAddresses(ctx context.Context, addrs []stdaddr.Address) (bitset.Bytes, error) {
return s.rpc.UsedAddresses(ctx, addrs)
}

func (s *Syncer) Done() <-chan struct{} {
s.doneMu.Lock()
c := s.done
s.doneMu.Unlock()
return c
}

func (s *Syncer) Err() error {
s.doneMu.Lock()
c := s.done
err := s.err
s.doneMu.Unlock()

select {
case <-c:
return err
default:
return nil
}
}
15 changes: 15 additions & 0 deletions chain/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ type Syncer struct {
relevantTxs map[chainhash.Hash][]*wire.MsgTx

cb *Callbacks

done chan struct{}
err error
doneMu sync.Mutex
}

// RPCOptions specifies the network and security settings for establishing a
Expand Down Expand Up @@ -525,6 +529,17 @@ func (s *Syncer) Run(ctx context.Context) (err error) {
}
}()

s.doneMu.Lock()
s.done = make(chan struct{})
s.err = nil
s.doneMu.Unlock()
defer func() {
s.doneMu.Lock()
close(s.done)
s.err = err
s.doneMu.Unlock()
}()

params := s.wallet.ChainParams()

s.notifier = &notifier{
Expand Down
21 changes: 21 additions & 0 deletions spv/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -619,3 +619,24 @@ func (s *Syncer) Rescan(ctx context.Context, blockHashes []chainhash.Hash, save
func (s *Syncer) StakeDifficulty(ctx context.Context) (dcrutil.Amount, error) {
return 0, errors.E(errors.Invalid, "stake difficulty is not queryable over wire protocol")
}

func (s *Syncer) Done() <-chan struct{} {
s.doneMu.Lock()
c := s.done
s.doneMu.Unlock()
return c
}

func (s *Syncer) Err() error {
s.doneMu.Lock()
c := s.done
err := s.err
s.doneMu.Unlock()

select {
case <-c:
return err
default:
return nil
}
}
17 changes: 16 additions & 1 deletion spv/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ type Syncer struct {
// Mempool for non-wallet-relevant transactions.
mempool sync.Map // k=chainhash.Hash v=*wire.MsgTx
mempoolAdds chan *chainhash.Hash

done chan struct{}
err error
doneMu sync.Mutex
}

// Notifications struct to contain all of the upcoming callbacks that will
Expand Down Expand Up @@ -318,7 +322,18 @@ func (s *Syncer) setRequiredHeight(tipHeight int32) {

// Run synchronizes the wallet, returning when synchronization fails or the
// context is cancelled.
func (s *Syncer) Run(ctx context.Context) error {
func (s *Syncer) Run(ctx context.Context) (err error) {
s.doneMu.Lock()
s.done = make(chan struct{})
s.err = nil
s.doneMu.Unlock()
defer func() {
s.doneMu.Lock()
close(s.done)
s.err = err
s.doneMu.Unlock()
}()

tipHash, tipHeight := s.wallet.MainChainTip(ctx)
s.setRequiredHeight(tipHeight)
rescanPoint, err := s.wallet.RescanPoint(ctx)
Expand Down
2 changes: 2 additions & 0 deletions ticketbuyer/tb.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ func (tb *TB) buy(ctx context.Context, passphrase []byte, tip *wire.BlockHeader,
if err != nil {
return err
}
ctx, cancel := wallet.WrapNetworkBackendContext(n, ctx)
defer cancel()

if len(passphrase) > 0 {
// Ensure wallet is unlocked with the current passphrase. If the passphase
Expand Down
7 changes: 7 additions & 0 deletions wallet/mixing.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,13 @@ func (w *Wallet) MixOutput(ctx context.Context, output *wire.OutPoint, changeAcc
return errors.E(op, errors.Invalid, s)
}

nb, err := w.NetworkBackend()
if err != nil {
return err
}
ctx, cancel := WrapNetworkBackendContext(nb, ctx)
defer cancel()

sdiff, err := w.NextStakeDifficulty(ctx)
if err != nil {
return errors.E(op, err)
Expand Down
62 changes: 62 additions & 0 deletions wallet/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package wallet

import (
"context"
"sync"

"decred.org/dcrwallet/v5/errors"
"github.com/decred/dcrd/chaincfg/chainhash"
Expand Down Expand Up @@ -49,6 +50,12 @@ type NetworkBackend interface {
// the wallet to the underlying network, and if not, it returns the
// target height that it is attempting to sync to.
Synced(ctx context.Context) (bool, int32)

// Done return a channel that is closed after the syncer disconnects.
// The error (if any) can be returned via Err.
// These semantics match that of context.Context.
Done() <-chan struct{}
Err() error
}

// NetworkBackend returns the currently associated network backend of the
Expand All @@ -73,6 +80,47 @@ func (w *Wallet) SetNetworkBackend(n NetworkBackend) {
w.networkBackendMu.Unlock()
}

type networkContext struct {
context.Context
err error
mu sync.Mutex
}

func (c *networkContext) Err() error {
c.mu.Lock()
err := c.err
c.mu.Unlock()

if err != nil {
return err
}
return c.Context.Err()
}

// WrapNetworkBackendContext returns a derived context that is canceled when
// the NetworkBackend is disconnected. The cancel func must be called
// (e.g. using defer) otherwise a goroutine leak may occur.
func WrapNetworkBackendContext(nb NetworkBackend, ctx context.Context) (context.Context, context.CancelFunc) {
childCtx, cancel := context.WithCancel(ctx)
nbContext := &networkContext{
Context: childCtx,
}

go func() {
select {
case <-nb.Done():
err := nb.Err()
nbContext.mu.Lock()
nbContext.err = err
nbContext.mu.Unlock()
case <-childCtx.Done():
}
cancel()
}()

return nbContext, cancel
}

// Caller provides a client interface to perform remote procedure calls.
// Serialization and calling conventions are implementation-specific.
type Caller interface {
Expand Down Expand Up @@ -122,6 +170,20 @@ func (o OfflineNetworkBackend) Synced(ctx context.Context) (bool, int32) {
return true, 0
}

var closedDone = make(chan struct{})

func init() {
close(closedDone)
}

func (o OfflineNetworkBackend) Done() <-chan struct{} {
return closedDone
}

func (o OfflineNetworkBackend) Err() error {
return errors.E("offline")
}

// Compile time check to ensure OfflineNetworkBackend fulfills the
// NetworkBackend interface.
var _ NetworkBackend = OfflineNetworkBackend{}
2 changes: 2 additions & 0 deletions wallet/network_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,5 @@ func (mockNetwork) Rescan(ctx context.Context, blocks []chainhash.Hash, save fun
}
func (mockNetwork) StakeDifficulty(ctx context.Context) (dcrutil.Amount, error) { return 0, nil }
func (mockNetwork) Synced(ctx context.Context) (bool, int32) { return false, 0 }
func (mockNetwork) Done() <-chan struct{} { return nil }
func (mockNetwork) Err() error { return nil }
3 changes: 3 additions & 0 deletions wallet/wallet.go
Original file line number Diff line number Diff line change
Expand Up @@ -1588,6 +1588,9 @@ func (w *Wallet) PurchaseTickets(ctx context.Context, n NetworkBackend,

const op errors.Op = "wallet.PurchaseTickets"

ctx, cancel := WrapNetworkBackendContext(n, ctx)
defer cancel()

resp, err := w.purchaseTickets(ctx, op, n, req)
if err == nil || !errors.Is(err, errVSPFeeRequiresUTXOSplit) || req.DontSignTx {
return resp, err
Expand Down
Loading