Skip to content

Commit

Permalink
spv: Refactor initialSyncHeaders
Browse files Browse the repository at this point in the history
This moves all the header fetching logic from the old getHeaders
(renamed to initialSyncHeaders) to a separate getHeaders function.

In the future, this will make it easier to refactor the initial sync
header fetching to perform this stage in parallel to the other stages.

This is mostly a code moving commit.
  • Loading branch information
matheusd committed Nov 22, 2023
1 parent cee37c2 commit 51c6c57
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 130 deletions.
157 changes: 153 additions & 4 deletions spv/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,12 @@ func (s *Syncer) CFiltersV2(ctx context.Context, blockHashes []*chainhash.Hash)

// cfiltersV2FromNodes fetches cfilters for all the specified nodes from a
// remote peer.
func (s *Syncer) cfiltersV2FromNodes(ctx context.Context, cnet wire.CurrencyNet, rp *p2p.RemotePeer, nodes []*wallet.BlockNode) ([]*gcs.FilterV2, error) {
func (s *Syncer) cfiltersV2FromNodes(ctx context.Context, rp *p2p.RemotePeer, nodes []*wallet.BlockNode) error {
if len(nodes) == 0 {
return nil, nil
return nil
}

cnet := s.wallet.ChainParams().Net
g, ctx := errgroup.WithContext(ctx)
res := make([]*gcs.FilterV2, len(nodes))
for i := range nodes {
Expand Down Expand Up @@ -132,9 +133,157 @@ func (s *Syncer) cfiltersV2FromNodes(ctx context.Context, cnet wire.CurrencyNet,
}
err := g.Wait()
if err != nil {
return nil, err
return err
}
return res, nil

s.sidechainMu.Lock()
for i := range nodes {
nodes[i].FilterV2 = res[i]
}
s.sidechainMu.Unlock()
log.Debugf("Fetched %d new cfilters(s) ending at height %d from %v",
len(nodes), nodes[len(nodes)-1].Header.Height, rp)
return nil
}

// headersBatch is a batch of headers fetched during initial sync.
type headersBatch struct {
done bool
nodes []*wallet.BlockNode
bestChain []*wallet.BlockNode
rp *p2p.RemotePeer
}

// getHeaders returns a batch of headers from a remote peer for initial
// syncing.
//
// This function returns a batch with the done flag set to true when no peers
// have more recent blocks for syncing.
func (s *Syncer) getHeaders(ctx context.Context) (*headersBatch, error) {
cnet := s.wallet.ChainParams().Net

nextbatch:
for ctx.Err() == nil {
_, tipHeight := s.wallet.MainChainTip(ctx)

// Determine if there are any peers from which to request newer
// headers.
rp, err := s.waitForRemote(ctx, pickForGetHeaders(tipHeight), false)
if err != nil {
return nil, err
}
if rp == nil {
return &headersBatch{done: true}, nil
}
log.Tracef("Attempting next batch of headers from %v", rp)

// Request headers from the selected peer.
locators, locatorHeight, err := s.wallet.BlockLocators(ctx, nil)
if err != nil {
return nil, err
}
headers, err := rp.Headers(ctx, locators, &hashStop)
if err != nil {
log.Debugf("Unable to fetch headers from %v: %v", rp, err)
continue nextbatch
}

if len(headers) == 0 {
// Ensure that the peer provided headers through the
// height advertised during handshake, unless our own
// locators were up to date (in which case we actually
// do not expect any headers).
if rp.LastHeight() < rp.InitialHeight() && locatorHeight < rp.InitialHeight() {
err := errors.E(errors.Protocol, "peer did not provide "+
"headers through advertised height")
rp.Disconnect(err)
continue nextbatch
}

// Try to pick a different peer with a higher advertised
// height or check there are no such peers (thus we're
// done with fetching headers for initial sync).
log.Tracef("Skipping to next batch due to "+
"len(headers) == 0 from %v", rp)
continue nextbatch
}

nodes := make([]*wallet.BlockNode, len(headers))
for i := range headers {
// Determine the hash of the header. It is safe to use
// PrevBlock (instead of recalculating) because the
// lower p2p level already asserted the headers connect
// to each other.
var hash *chainhash.Hash
if i == len(headers)-1 {
bh := headers[i].BlockHash()
hash = &bh
} else {
hash = &headers[i+1].PrevBlock
}
nodes[i] = wallet.NewBlockNode(headers[i], hash, nil)
if wallet.BadCheckpoint(cnet, hash, int32(headers[i].Height)) {
nodes[i].BadCheckpoint()
}
}

// Verify the sidechain that includes the received headers has
// the correct difficulty.
s.sidechainMu.Lock()
fullsc, err := s.sidechains.FullSideChain(nodes)
if err != nil {
s.sidechainMu.Unlock()
return nil, err
}
_, err = s.wallet.ValidateHeaderChainDifficulties(ctx, fullsc, 0)
if err != nil {
s.sidechainMu.Unlock()
rp.Disconnect(err)
if !errors.Is(err, context.Canceled) {
log.Warnf("Disconnecting from %v due to header "+
"validation error: %v", rp, err)
}
continue nextbatch
}

// Add new headers to the sidechain forest.
var added int
for _, n := range nodes {
haveBlock, _, _ := s.wallet.BlockInMainChain(ctx, n.Hash)
if haveBlock {
continue
}
if s.sidechains.AddBlockNode(n) {
added++
}
}

// Determine if this extends the best known chain.
bestChain, err := s.wallet.EvaluateBestChain(ctx, &s.sidechains)
if err != nil {
s.sidechainMu.Unlock()
rp.Disconnect(err)
continue nextbatch
}
if len(bestChain) == 0 {
s.sidechainMu.Unlock()
continue nextbatch
}

log.Debugf("Fetched %d new header(s) ending at height %d from %v",
added, headers[len(headers)-1].Height, rp)

s.sidechainMu.Unlock()

// Batch fetched.
return &headersBatch{
nodes: nodes,
bestChain: bestChain,
rp: rp,
}, nil
}

return nil, ctx.Err()
}

func (s *Syncer) String() string {
Expand Down
146 changes: 20 additions & 126 deletions spv/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ func (s *Syncer) Run(ctx context.Context) error {
// Next: fetch headers and cfilters up to mainchain tip.
s.fetchHeadersStart()
log.Debugf("Fetching headers and CFilters...")
err = s.getHeaders(ctx)
err = s.initialSyncHeaders(ctx)
if err != nil {
return err
}
Expand Down Expand Up @@ -1318,161 +1318,54 @@ func (s *Syncer) disconnectStragglers(height int32) {
// locators.
var hashStop chainhash.Hash

// getHeaders fetches headers from peers until the wallet is up to date with
// all connected peers. This is part of the startup sync process.
func (s *Syncer) getHeaders(ctx context.Context) error {

cnet := s.wallet.ChainParams().Net

// initialSyncHeaders fetches headers and cfilters from peers until the wallet
// is up to date with all connected peers. This is part of the startup sync
// process.
func (s *Syncer) initialSyncHeaders(ctx context.Context) error {
startTime := time.Now()

nextbatch:
for ctx.Err() == nil {
tipHash, tipHeight := s.wallet.MainChainTip(ctx)

// Determine if there are any peers from which to request newer
// headers.
rp, err := s.waitForRemote(ctx, pickForGetHeaders(tipHeight), false)
// Fetch a batch of headers.
batch, err := s.getHeaders(ctx)
if err != nil {
return err
}
if rp == nil {
if batch.done {
// All done.
log.Infof("Initial sync to block %s at height %d completed in %s",
tipHash, tipHeight, time.Since(startTime).Round(time.Second))
log.Debugf("Initial sync completed in %s",
time.Since(startTime).Round(time.Second))
return nil
}
log.Tracef("Attempting next batch of headers from %v", rp)

// Request headers from the selected peer.
locators, locatorHeight, err := s.wallet.BlockLocators(ctx, nil)
if err != nil {
return err
}
headers, err := rp.Headers(ctx, locators, &hashStop)
if err != nil {
log.Debugf("Unable to fetch headers from %v: %v", rp, err)
continue nextbatch
}

if len(headers) == 0 {
// Ensure that the peer provided headers through the
// height advertised during handshake, unless our own
// locators were up to date (in which case we actually
// do not expect any headers).
if rp.LastHeight() < rp.InitialHeight() && locatorHeight < rp.InitialHeight() {
err := errors.E(errors.Protocol, "peer did not provide "+
"headers through advertised height")
rp.Disconnect(err)
continue nextbatch
}

// Try to pick a different peer with a higher advertised
// height or check there are no such peers (thus we're
// done with fetching headers for initial sync).
log.Tracef("Skipping to next batch due to "+
"len(headers) == 0 from %v", rp)
continue nextbatch
}

nodes := make([]*wallet.BlockNode, len(headers))
for i := range headers {
// Determine the hash of the header. It is safe to use
// PrevBlock (instead of recalculating) because the
// lower p2p level already asserted the headers connect
// to each other.
var hash *chainhash.Hash
if i == len(headers)-1 {
bh := headers[i].BlockHash()
hash = &bh
} else {
hash = &headers[i+1].PrevBlock
}
nodes[i] = wallet.NewBlockNode(headers[i], hash, nil)
if wallet.BadCheckpoint(cnet, hash, int32(headers[i].Height)) {
nodes[i].BadCheckpoint()
}
}
bestChain := batch.bestChain

// Verify the sidechain that includes the received headers has
// the correct difficulty.
// Determine which nodes don't have cfilters yet.
s.sidechainMu.Lock()
fullsc, err := s.sidechains.FullSideChain(nodes)
if err != nil {
s.sidechainMu.Unlock()
return err
}
_, err = s.wallet.ValidateHeaderChainDifficulties(ctx, fullsc, 0)
if err != nil {
s.sidechainMu.Unlock()
rp.Disconnect(err)
if !errors.Is(err, context.Canceled) {
log.Warnf("Disconnecting from %v due to header "+
"validation error: %v", rp, err)
}
continue nextbatch
}

// Add new headers to the sidechain forest.
var added int
for _, n := range nodes {
haveBlock, _, _ := s.wallet.BlockInMainChain(ctx, n.Hash)
if haveBlock {
continue
}
if s.sidechains.AddBlockNode(n) {
added++
}
}

// Determine if this extends the best known chain.
bestChain, err := s.wallet.EvaluateBestChain(ctx, &s.sidechains)
if err != nil {
s.sidechainMu.Unlock()
rp.Disconnect(err)
continue nextbatch
}
if len(bestChain) == 0 {
s.sidechainMu.Unlock()
continue nextbatch
}

s.fetchHeadersProgress(headers[len(headers)-1])
log.Debugf("Fetched %d new header(s) ending at height %d from %v",
added, headers[len(headers)-1].Height, rp)

// Fetch cfilters for nodes which don't yet have them.
var missingCFNodes []*wallet.BlockNode
var missingCfilter []*wallet.BlockNode
for i := range bestChain {
if bestChain[i].FilterV2 == nil {
missingCFNodes = bestChain[i:]
missingCfilter = bestChain[i:]
break
}
}
s.sidechainMu.Unlock()
filters, err := s.cfiltersV2FromNodes(ctx, cnet, rp, missingCFNodes)

// Fetch Missing CFilters.
err = s.cfiltersV2FromNodes(ctx, batch.rp, missingCfilter)
if err != nil {
log.Debugf("Unable to fetch missing cfilters from %v: %v",
rp, err)
batch.rp, err)
continue nextbatch
}
if len(missingCFNodes) > 0 {
log.Debugf("Fetched %d new cfilters(s) ending at height %d from %v",
len(missingCFNodes),
missingCFNodes[len(missingCFNodes)-1].Header.Height,
rp)
}

// Switch the best chain, now that all cfilters have been
// fetched for it.
s.sidechainMu.Lock()
for i := range missingCFNodes {
missingCFNodes[i].FilterV2 = filters[i]
}
prevChain, err := s.wallet.ChainSwitch(ctx, &s.sidechains, bestChain, nil)
if err != nil {
s.sidechainMu.Unlock()
rp.Disconnect(err)
batch.rp.Disconnect(err)
continue nextbatch
}

Expand All @@ -1490,6 +1383,7 @@ nextbatch:
log.Infof("Connected %d blocks, new tip %v, height %d, date %v",
len(bestChain), tip.Hash, tip.Header.Height, tip.Header.Timestamp)
}
s.fetchHeadersProgress(tip.Header)

s.sidechainMu.Unlock()

Expand Down

0 comments on commit 51c6c57

Please sign in to comment.