Skip to content

Commit

Permalink
Simplify code in fetch and syncer pkgs (#5453)
Browse files Browse the repository at this point in the history
## Motivation
The `requester::Request(...)` interface from the `fetch` package seems unnecessarily complicated with its callback methods. It can be made blocking instead, simplifying the code using this interface.

## Changes
- make `fetch::requester::Request(...)` blocking and remove the callbacks. It returns the result directly now,
- make `GetMaliciousIDs`, `GetLayerData` and `GetLayerOpinions` from `syncer::fetcher` interface blocking,
- remove `fetch::poll(...)`
- refactor the code and tests to use the new approach

## Test Plan
existing tests pass
  • Loading branch information
poszu committed Feb 2, 2024
1 parent 76ac1d0 commit 5b27845
Show file tree
Hide file tree
Showing 12 changed files with 668 additions and 1,061 deletions.
86 changes: 36 additions & 50 deletions fetch/fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,7 @@ type Fetch struct {
// unprocessed contains requests that are not processed
unprocessed map[types.Hash32]*request
// ongoing contains requests that have been processed and are waiting for responses
ongoing map[types.Hash32]*request
// batched contains batched ongoing requests.
batched map[types.Hash32]*batchInfo
ongoing map[types.Hash32]*request
batchTimeout *time.Ticker
mu sync.Mutex
onlyOnce sync.Once
Expand Down Expand Up @@ -240,7 +238,6 @@ func NewFetch(
servers: map[string]requester{},
unprocessed: make(map[types.Hash32]*request),
ongoing: make(map[types.Hash32]*request),
batched: make(map[types.Hash32]*batchInfo),
hashToPeers: NewHashPeersCache(cacheSize),
}
for _, opt := range opts {
Expand Down Expand Up @@ -404,7 +401,7 @@ func (f *Fetch) loop() {
}

// receive Data from message server and call response handlers accordingly.
func (f *Fetch) receiveResponse(data []byte) {
func (f *Fetch) receiveResponse(data []byte, batch *batchInfo) {
if f.stopped() {
return
}
Expand All @@ -419,14 +416,13 @@ func (f *Fetch) receiveResponse(data []byte) {
log.Stringer("batch_hash", response.ID),
log.Int("num_hashes", len(response.Responses)),
)
f.mu.Lock()
batch, ok := f.batched[response.ID]
delete(f.batched, response.ID)
f.mu.Unlock()

if !ok {
f.logger.With().Warning("unknown batch response received, or already received",
log.Stringer("batch_hash", response.ID))
if batch.ID != response.ID {
f.logger.With().Warning(
"unknown batch response received",
log.Stringer("expected", batch.ID),
log.Stringer("response", response.ID),
)
return
}

Expand Down Expand Up @@ -549,6 +545,7 @@ func (f *Fetch) send(requests []RequestMessage) {

peer2batches := f.organizeRequests(requests)
for peer, peerBatches := range peer2batches {
peer := peer
for _, reqs := range peerBatches {
batch := &batchInfo{
RequestBatch: RequestBatch{
Expand All @@ -557,7 +554,20 @@ func (f *Fetch) send(requests []RequestMessage) {
peer: peer,
}
batch.setID()
f.sendBatch(peer, batch)
go func() {
data, err := f.sendBatch(peer, batch)
if err != nil {
f.logger.With().Warning(
"failed to send batch request",
log.Stringer("batch", batch.ID),
log.Stringer("peer", peer),
log.Err(err),
)
f.handleHashError(batch, err)
} else {
f.receiveResponse(data, batch)
}
}()
}
}
}
Expand Down Expand Up @@ -621,71 +631,47 @@ func (f *Fetch) organizeRequests(requests []RequestMessage) map[p2p.Peer][][]Req
}

// sendBatch dispatches batched request messages to provided peer.
func (f *Fetch) sendBatch(peer p2p.Peer, batch *batchInfo) {
func (f *Fetch) sendBatch(peer p2p.Peer, batch *batchInfo) ([]byte, error) {
if f.stopped() {
return
return nil, f.shutdownCtx.Err()
}
f.mu.Lock()
f.batched[batch.ID] = batch
f.mu.Unlock()
f.logger.With().Debug("sending batched request to peer",
log.Stringer("batch_hash", batch.ID),
log.Int("num_requests", len(batch.Requests)),
log.Stringer("peer", peer),
)
// Request is asynchronous,
// Request is synchronous,
// it will return errors only if size of the bytes buffer is large
// or target peer is not connected
start := time.Now()
errf := func(err error) {
f.logger.With().Warning("failed to send batch",
log.Stringer("batch_hash", peer), log.Err(err),
)
f.peers.OnFailure(peer)
f.handleHashError(batch.ID, err)
}
err := f.servers[hashProtocol].Request(
f.shutdownCtx,
peer,
codec.MustEncode(&batch.RequestBatch),
func(buf []byte) {
f.peers.OnLatency(peer, time.Since(start))
f.receiveResponse(buf)
},
errf,
)
req := codec.MustEncode(&batch.RequestBatch)
data, err := f.servers[hashProtocol].Request(f.shutdownCtx, peer, req)
if err != nil {
errf(err)
f.peers.OnFailure(peer)
return nil, err
}
f.peers.OnLatency(peer, time.Since(start))
return data, nil
}

// handleHashError is called when an error occurred processing batches of the following hashes.
func (f *Fetch) handleHashError(batchHash types.Hash32, err error) {
func (f *Fetch) handleHashError(batch *batchInfo, err error) {
f.mu.Lock()
defer f.mu.Unlock()

f.logger.With().Debug("failed batch fetch", log.Stringer("batch_hash", batchHash), log.Err(err))
batch, ok := f.batched[batchHash]
if !ok {
f.logger.With().Error("batch not found", log.Stringer("batch_hash", batchHash))
return
}
for _, br := range batch.Requests {
req, ok := f.ongoing[br.Hash]
if !ok {
f.logger.With().
Warning("hash missing from ongoing requests", log.Stringer("hash", br.Hash))
f.logger.With().Warning("hash missing from ongoing requests", log.Stringer("hash", br.Hash))
continue
}
f.logger.WithContext(req.ctx).With().Warning("hash request failed",
log.Stringer("hash", req.hash),
log.Err(err))
f.logger.WithContext(req.ctx).With().
Warning("hash request failed", log.Stringer("hash", req.hash), log.Err(err))
req.promise.err = err
peerErrors.WithLabelValues(string(req.hint)).Inc()
close(req.promise.completed)
delete(f.ongoing, req.hash)
}
delete(f.batched, batchHash)
}

// getHash is the regular buffered call to get a specific hash, using provided hash, h as hint the receiving end will
Expand Down
76 changes: 36 additions & 40 deletions fetch/fetch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,24 +187,22 @@ func TestFetch_RequestHashBatchFromPeers(t *testing.T) {
Data: []byte("b"),
}
f.mHashS.EXPECT().
Request(gomock.Any(), peer, gomock.Any(), gomock.Any(), gomock.Any()).
DoAndReturn(
func(_ context.Context, _ p2p.Peer, req []byte, okFunc func([]byte), _ func(error)) error {
if tc.nErr != nil {
return tc.nErr
}
var rb RequestBatch
err := codec.Decode(req, &rb)
require.NoError(t, err)
resBatch := ResponseBatch{
ID: rb.ID,
Responses: []ResponseMessage{res0, res1},
}
bts, err := codec.Encode(&resBatch)
require.NoError(t, err)
okFunc(bts)
return nil
})
Request(gomock.Any(), peer, gomock.Any()).
DoAndReturn(func(_ context.Context, _ p2p.Peer, req []byte) ([]byte, error) {
if tc.nErr != nil {
return nil, tc.nErr
}
var rb RequestBatch
err := codec.Decode(req, &rb)
require.NoError(t, err)
resBatch := ResponseBatch{
ID: rb.ID,
Responses: []ResponseMessage{res0, res1},
}
bts, err := codec.Encode(&resBatch)
require.NoError(t, err)
return bts, nil
})

var p0, p1 []*promise
// query each hash twice
Expand Down Expand Up @@ -254,28 +252,26 @@ func TestFetch_Loop_BatchRequestMax(t *testing.T) {
h2 := types.RandomHash()
h3 := types.RandomHash()
f.mHashS.EXPECT().
Request(gomock.Any(), peer, gomock.Any(), gomock.Any(), gomock.Any()).
DoAndReturn(
func(_ context.Context, _ p2p.Peer, req []byte, okFunc func([]byte), _ func(error)) error {
var rb RequestBatch
err := codec.Decode(req, &rb)
require.NoError(t, err)
resps := make([]ResponseMessage, 0, len(rb.Requests))
for _, r := range rb.Requests {
resps = append(resps, ResponseMessage{
Hash: r.Hash,
Data: []byte("a"),
})
}
resBatch := ResponseBatch{
ID: rb.ID,
Responses: resps,
}
bts, err := codec.Encode(&resBatch)
require.NoError(t, err)
okFunc(bts)
return nil
}).
Request(gomock.Any(), peer, gomock.Any()).
DoAndReturn(func(_ context.Context, _ p2p.Peer, req []byte) ([]byte, error) {
var rb RequestBatch
err := codec.Decode(req, &rb)
require.NoError(t, err)
resps := make([]ResponseMessage, 0, len(rb.Requests))
for _, r := range rb.Requests {
resps = append(resps, ResponseMessage{
Hash: r.Hash,
Data: []byte("a"),
})
}
resBatch := ResponseBatch{
ID: rb.ID,
Responses: resps,
}
bts, err := codec.Encode(&resBatch)
require.NoError(t, err)
return bts, nil
}).
Times(2)
// 3 requests with batch size 2 -> 2 sends

Expand Down
2 changes: 1 addition & 1 deletion fetch/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (

type requester interface {
Run(context.Context) error
Request(context.Context, p2p.Peer, []byte, func([]byte), func(error)) error
Request(context.Context, p2p.Peer, []byte) ([]byte, error)
}

// The ValidatorFunc type is an adapter to allow the use of functions as
Expand Down
Loading

0 comments on commit 5b27845

Please sign in to comment.