From 5b27845b9a3d251891bd68e40dd44256f1ab2f6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bartosz=20R=C3=B3=C5=BCa=C5=84ski?= Date: Fri, 2 Feb 2024 12:31:29 +0000 Subject: [PATCH] Simplify code in `fetch` and `syncer` pkgs (#5453) ## Motivation The `requester::Request(...)` interface from the `fetch` package seems unnecessarily complicated with its callback methods. It can be made blocking instead, simplifying the code using this interface. ## Changes - make `fetch::requester::Request(...)` blocking and remove the callbacks. It returns the result directly now, - make `GetMaliciousIDs`, `GetLayerData` and `GetLayerOpinions` from `syncer::fetcher` interface blocking, - remove `fetch::poll(...)` - refactor the code and tests to use the new approach ## Test Plan existing tests pass --- fetch/fetch.go | 86 +++---- fetch/fetch_test.go | 76 +++--- fetch/interface.go | 2 +- fetch/mesh_data.go | 177 ++++--------- fetch/mesh_data_test.go | 268 +++++++------------- fetch/mocks/mocks.go | 21 +- p2p/server/server.go | 61 ++--- p2p/server/server_test.go | 100 ++------ syncer/data_fetch.go | 521 +++++++++++++------------------------- syncer/data_fetch_test.go | 268 ++++++++++++-------- syncer/interface.go | 23 +- syncer/mocks/mocks.go | 126 ++++----- 12 files changed, 668 insertions(+), 1061 deletions(-) diff --git a/fetch/fetch.go b/fetch/fetch.go index 3a411f7032..cf06cb8a86 100644 --- a/fetch/fetch.go +++ b/fetch/fetch.go @@ -209,9 +209,7 @@ type Fetch struct { // unprocessed contains requests that are not processed unprocessed map[types.Hash32]*request // ongoing contains requests that have been processed and are waiting for responses - ongoing map[types.Hash32]*request - // batched contains batched ongoing requests. - batched map[types.Hash32]*batchInfo + ongoing map[types.Hash32]*request batchTimeout *time.Ticker mu sync.Mutex onlyOnce sync.Once @@ -240,7 +238,6 @@ func NewFetch( servers: map[string]requester{}, unprocessed: make(map[types.Hash32]*request), ongoing: make(map[types.Hash32]*request), - batched: make(map[types.Hash32]*batchInfo), hashToPeers: NewHashPeersCache(cacheSize), } for _, opt := range opts { @@ -404,7 +401,7 @@ func (f *Fetch) loop() { } // receive Data from message server and call response handlers accordingly. -func (f *Fetch) receiveResponse(data []byte) { +func (f *Fetch) receiveResponse(data []byte, batch *batchInfo) { if f.stopped() { return } @@ -419,14 +416,13 @@ func (f *Fetch) receiveResponse(data []byte) { log.Stringer("batch_hash", response.ID), log.Int("num_hashes", len(response.Responses)), ) - f.mu.Lock() - batch, ok := f.batched[response.ID] - delete(f.batched, response.ID) - f.mu.Unlock() - if !ok { - f.logger.With().Warning("unknown batch response received, or already received", - log.Stringer("batch_hash", response.ID)) + if batch.ID != response.ID { + f.logger.With().Warning( + "unknown batch response received", + log.Stringer("expected", batch.ID), + log.Stringer("response", response.ID), + ) return } @@ -549,6 +545,7 @@ func (f *Fetch) send(requests []RequestMessage) { peer2batches := f.organizeRequests(requests) for peer, peerBatches := range peer2batches { + peer := peer for _, reqs := range peerBatches { batch := &batchInfo{ RequestBatch: RequestBatch{ @@ -557,7 +554,20 @@ func (f *Fetch) send(requests []RequestMessage) { peer: peer, } batch.setID() - f.sendBatch(peer, batch) + go func() { + data, err := f.sendBatch(peer, batch) + if err != nil { + f.logger.With().Warning( + "failed to send batch request", + log.Stringer("batch", batch.ID), + log.Stringer("peer", peer), + log.Err(err), + ) + f.handleHashError(batch, err) + } else { + f.receiveResponse(data, batch) + } + }() } } } @@ -621,71 +631,47 @@ func (f *Fetch) organizeRequests(requests []RequestMessage) map[p2p.Peer][][]Req } // sendBatch dispatches batched request messages to provided peer. -func (f *Fetch) sendBatch(peer p2p.Peer, batch *batchInfo) { +func (f *Fetch) sendBatch(peer p2p.Peer, batch *batchInfo) ([]byte, error) { if f.stopped() { - return + return nil, f.shutdownCtx.Err() } - f.mu.Lock() - f.batched[batch.ID] = batch - f.mu.Unlock() f.logger.With().Debug("sending batched request to peer", log.Stringer("batch_hash", batch.ID), log.Int("num_requests", len(batch.Requests)), log.Stringer("peer", peer), ) - // Request is asynchronous, + // Request is synchronous, // it will return errors only if size of the bytes buffer is large // or target peer is not connected start := time.Now() - errf := func(err error) { - f.logger.With().Warning("failed to send batch", - log.Stringer("batch_hash", peer), log.Err(err), - ) - f.peers.OnFailure(peer) - f.handleHashError(batch.ID, err) - } - err := f.servers[hashProtocol].Request( - f.shutdownCtx, - peer, - codec.MustEncode(&batch.RequestBatch), - func(buf []byte) { - f.peers.OnLatency(peer, time.Since(start)) - f.receiveResponse(buf) - }, - errf, - ) + req := codec.MustEncode(&batch.RequestBatch) + data, err := f.servers[hashProtocol].Request(f.shutdownCtx, peer, req) if err != nil { - errf(err) + f.peers.OnFailure(peer) + return nil, err } + f.peers.OnLatency(peer, time.Since(start)) + return data, nil } // handleHashError is called when an error occurred processing batches of the following hashes. -func (f *Fetch) handleHashError(batchHash types.Hash32, err error) { +func (f *Fetch) handleHashError(batch *batchInfo, err error) { f.mu.Lock() defer f.mu.Unlock() - f.logger.With().Debug("failed batch fetch", log.Stringer("batch_hash", batchHash), log.Err(err)) - batch, ok := f.batched[batchHash] - if !ok { - f.logger.With().Error("batch not found", log.Stringer("batch_hash", batchHash)) - return - } for _, br := range batch.Requests { req, ok := f.ongoing[br.Hash] if !ok { - f.logger.With(). - Warning("hash missing from ongoing requests", log.Stringer("hash", br.Hash)) + f.logger.With().Warning("hash missing from ongoing requests", log.Stringer("hash", br.Hash)) continue } - f.logger.WithContext(req.ctx).With().Warning("hash request failed", - log.Stringer("hash", req.hash), - log.Err(err)) + f.logger.WithContext(req.ctx).With(). + Warning("hash request failed", log.Stringer("hash", req.hash), log.Err(err)) req.promise.err = err peerErrors.WithLabelValues(string(req.hint)).Inc() close(req.promise.completed) delete(f.ongoing, req.hash) } - delete(f.batched, batchHash) } // getHash is the regular buffered call to get a specific hash, using provided hash, h as hint the receiving end will diff --git a/fetch/fetch_test.go b/fetch/fetch_test.go index 7d644bf829..3595291654 100644 --- a/fetch/fetch_test.go +++ b/fetch/fetch_test.go @@ -187,24 +187,22 @@ func TestFetch_RequestHashBatchFromPeers(t *testing.T) { Data: []byte("b"), } f.mHashS.EXPECT(). - Request(gomock.Any(), peer, gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn( - func(_ context.Context, _ p2p.Peer, req []byte, okFunc func([]byte), _ func(error)) error { - if tc.nErr != nil { - return tc.nErr - } - var rb RequestBatch - err := codec.Decode(req, &rb) - require.NoError(t, err) - resBatch := ResponseBatch{ - ID: rb.ID, - Responses: []ResponseMessage{res0, res1}, - } - bts, err := codec.Encode(&resBatch) - require.NoError(t, err) - okFunc(bts) - return nil - }) + Request(gomock.Any(), peer, gomock.Any()). + DoAndReturn(func(_ context.Context, _ p2p.Peer, req []byte) ([]byte, error) { + if tc.nErr != nil { + return nil, tc.nErr + } + var rb RequestBatch + err := codec.Decode(req, &rb) + require.NoError(t, err) + resBatch := ResponseBatch{ + ID: rb.ID, + Responses: []ResponseMessage{res0, res1}, + } + bts, err := codec.Encode(&resBatch) + require.NoError(t, err) + return bts, nil + }) var p0, p1 []*promise // query each hash twice @@ -254,28 +252,26 @@ func TestFetch_Loop_BatchRequestMax(t *testing.T) { h2 := types.RandomHash() h3 := types.RandomHash() f.mHashS.EXPECT(). - Request(gomock.Any(), peer, gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn( - func(_ context.Context, _ p2p.Peer, req []byte, okFunc func([]byte), _ func(error)) error { - var rb RequestBatch - err := codec.Decode(req, &rb) - require.NoError(t, err) - resps := make([]ResponseMessage, 0, len(rb.Requests)) - for _, r := range rb.Requests { - resps = append(resps, ResponseMessage{ - Hash: r.Hash, - Data: []byte("a"), - }) - } - resBatch := ResponseBatch{ - ID: rb.ID, - Responses: resps, - } - bts, err := codec.Encode(&resBatch) - require.NoError(t, err) - okFunc(bts) - return nil - }). + Request(gomock.Any(), peer, gomock.Any()). + DoAndReturn(func(_ context.Context, _ p2p.Peer, req []byte) ([]byte, error) { + var rb RequestBatch + err := codec.Decode(req, &rb) + require.NoError(t, err) + resps := make([]ResponseMessage, 0, len(rb.Requests)) + for _, r := range rb.Requests { + resps = append(resps, ResponseMessage{ + Hash: r.Hash, + Data: []byte("a"), + }) + } + resBatch := ResponseBatch{ + ID: rb.ID, + Responses: resps, + } + bts, err := codec.Encode(&resBatch) + require.NoError(t, err) + return bts, nil + }). Times(2) // 3 requests with batch size 2 -> 2 sends diff --git a/fetch/interface.go b/fetch/interface.go index 1ab00bf153..b01ebe9eab 100644 --- a/fetch/interface.go +++ b/fetch/interface.go @@ -12,7 +12,7 @@ import ( type requester interface { Run(context.Context) error - Request(context.Context, p2p.Peer, []byte, func([]byte), func(error)) error + Request(context.Context, p2p.Peer, []byte) ([]byte, error) } // The ValidatorFunc type is an adapter to allow the use of functions as diff --git a/fetch/mesh_data.go b/fetch/mesh_data.go index 47dc441162..005912c705 100644 --- a/fetch/mesh_data.go +++ b/fetch/mesh_data.go @@ -203,7 +203,7 @@ func (f *Fetch) GetPoetProof(ctx context.Context, id types.Hash32) error { return nil case errors.Is(pm.err, activation.ErrObjectExists): // PoET proofs are concurrently stored in DB in two places: - // fetcher and nipost builder. Hence it might happen that + // fetcher and nipost builder. Hence, it might happen that // a proof had been inserted into the DB while the fetcher // was fetching. return nil @@ -216,68 +216,28 @@ func (f *Fetch) GetPoetProof(ctx context.Context, id types.Hash32) error { } } -func (f *Fetch) GetMaliciousIDs( - ctx context.Context, - peers []p2p.Peer, - okCB func([]byte, p2p.Peer), - errCB func(error, p2p.Peer), -) error { - return poll(ctx, f.servers[malProtocol], peers, []byte{}, okCB, errCB) +func (f *Fetch) GetMaliciousIDs(ctx context.Context, peer p2p.Peer) ([]byte, error) { + return f.servers[malProtocol].Request(ctx, peer, []byte{}) } // GetLayerData get layer data from peers. -func (f *Fetch) GetLayerData( - ctx context.Context, - peers []p2p.Peer, - lid types.LayerID, - okCB func([]byte, p2p.Peer), - errCB func(error, p2p.Peer), -) error { +func (f *Fetch) GetLayerData(ctx context.Context, peer p2p.Peer, lid types.LayerID) ([]byte, error) { lidBytes, err := codec.Encode(&lid) if err != nil { - return err + return nil, err } - return poll(ctx, f.servers[lyrDataProtocol], peers, lidBytes, okCB, errCB) + return f.servers[lyrDataProtocol].Request(ctx, peer, lidBytes) } -func (f *Fetch) GetLayerOpinions( - ctx context.Context, - peers []p2p.Peer, - lid types.LayerID, - okCB func([]byte, p2p.Peer), - errCB func(error, p2p.Peer), -) error { +func (f *Fetch) GetLayerOpinions(ctx context.Context, peer p2p.Peer, lid types.LayerID) ([]byte, error) { req := OpinionRequest{ Layer: lid, } reqData, err := codec.Encode(&req) if err != nil { - return err - } - return poll(ctx, f.servers[OpnProtocol], peers, reqData, okCB, errCB) -} - -func poll( - ctx context.Context, - srv requester, - peers []p2p.Peer, - req []byte, - okCB func([]byte, p2p.Peer), - errCB func(error, p2p.Peer), -) error { - for _, p := range peers { - peer := p - okFunc := func(data []byte) { - okCB(data, peer) - } - errFunc := func(err error) { - errCB(err, peer) - } - if err := srv.Request(ctx, peer, req, okFunc, errFunc); err != nil { - errFunc(err) - } + return nil, err } - return nil + return f.servers[OpnProtocol].Request(ctx, peer, reqData) } // PeerEpochInfo get the epoch info published in the given epoch from the specified peer. @@ -286,34 +246,21 @@ func (f *Fetch) PeerEpochInfo(ctx context.Context, peer p2p.Peer, epoch types.Ep log.Stringer("peer", peer), log.Stringer("epoch", epoch)) - var ( - done = make(chan error, 1) - ed EpochData - ) - okCB := func(data []byte) { - done <- codec.Decode(data, &ed) - } - errCB := func(perr error) { - done <- perr - } epochBytes, err := codec.Encode(epoch) if err != nil { return nil, err } - if err := f.servers[atxProtocol].Request(ctx, peer, epochBytes, okCB, errCB); err != nil { + data, err := f.servers[atxProtocol].Request(ctx, peer, epochBytes) + if err != nil { return nil, err } - select { - case err := <-done: - if err != nil { - return nil, err - } - f.RegisterPeerHashes(peer, types.ATXIDsToHashes(ed.AtxIDs)) - return &ed, nil - case <-ctx.Done(): - f.logger.WithContext(ctx).With().Debug("context done") - return nil, ctx.Err() + + var ed EpochData + if err := codec.Decode(data, &ed); err != nil { + return nil, fmt.Errorf("decoding epoch data: %w", err) } + f.RegisterPeerHashes(peer, types.ATXIDsToHashes(ed.AtxIDs)) + return &ed, nil } func (f *Fetch) PeerMeshHashes(ctx context.Context, peer p2p.Peer, req *MeshHashRequest) (*MeshHashes, error) { @@ -322,38 +269,22 @@ func (f *Fetch) PeerMeshHashes(ctx context.Context, peer p2p.Peer, req *MeshHash log.Object("req", req), ) - var ( - done = make(chan error, 1) - hashes []types.Hash32 - reqData []byte - ) reqData, err := codec.Encode(req) if err != nil { f.logger.With().Fatal("failed to encode mesh hash request", log.Err(err)) } - okCB := func(data []byte) { - h, err := codec.DecodeSlice[types.Hash32](data) - hashes = h - done <- err - } - errCB := func(perr error) { - done <- perr - } - if err = f.servers[meshHashProtocol].Request(ctx, peer, reqData, okCB, errCB); err != nil { + data, err := f.servers[meshHashProtocol].Request(ctx, peer, reqData) + if err != nil { return nil, err } - select { - case err := <-done: - if err != nil { - return nil, err - } - return &MeshHashes{ - Hashes: hashes, - }, nil - case <-ctx.Done(): - return nil, ctx.Err() + hashes, err := codec.DecodeSlice[types.Hash32](data) + if err != nil { + return nil, fmt.Errorf("decoding hashes response: %w", err) } + return &MeshHashes{ + Hashes: hashes, + }, nil } func (f *Fetch) GetCert( @@ -368,49 +299,33 @@ func (f *Fetch) GetCert( Layer: lid, Block: &bid, } - reqData, err := codec.Encode(req) - if err != nil { - f.logger.With().Fatal("failed to encode cert request", log.Err(err)) - } + reqData := codec.MustEncode(req) - out := make(chan *types.Certificate, 1) for _, peer := range peers { - done := make(chan error, 1) - okCB := func(data []byte) { - var peerCert types.Certificate - if err = codec.Decode(data, &peerCert); err != nil { - done <- err - return - } - // for generic data fetches by hash (ID for atx/block/proposal/ballot/tx), the check on whether the returned - // data matching the hash was done on the data handlers' path. for block certificate, there is no ID associated - // with it, hence the check here. - // however, certificate doesn't go through that path. it's requested by a separate protocol because a block - // certificate doesn't have an ID. - if peerCert.BlockID != bid { - done <- fmt.Errorf("peer %v served wrong cert. want %s got %s", peer, bid.String(), peerCert.BlockID.String()) - return - } - out <- &peerCert - } - errCB := func(perr error) { - done <- perr + data, err := f.servers[OpnProtocol].Request(ctx, peer, reqData) + if err != nil { + f.logger.With().Debug("failed to get cert", log.Stringer("peer", peer), log.Err(err)) + continue } - if err := f.servers[OpnProtocol].Request(ctx, peer, reqData, okCB, errCB); err != nil { - done <- err + var peerCert types.Certificate + if err = codec.Decode(data, &peerCert); err != nil { + f.logger.With().Debug("failed to decode cert", log.Stringer("peer", peer), log.Err(err)) + continue } - select { - case err := <-done: - f.logger.With().Debug("failed to get cert from peer", + // for generic data fetches by hash (ID for atx/block/proposal/ballot/tx), the check on whether the returned + // data matching the hash was done on the data handlers' path. for block certificate, there is no ID associated + // with it, hence the check here. + // however, certificate doesn't go through that path. it's requested by a separate protocol because a block + // certificate doesn't have an ID. + if peerCert.BlockID != bid { + f.logger.With().Debug( + "peer served wrong cert", + log.Stringer("want", bid), + log.Stringer("got", peerCert.BlockID), log.Stringer("peer", peer), - log.Err(err), ) - continue - case cert := <-out: - return cert, nil - case <-ctx.Done(): - return nil, ctx.Err() } + return &peerCert, nil } - return nil, fmt.Errorf("failed to get cert %v/%s from %d peers", lid, bid.String(), len(peers)) + return nil, fmt.Errorf("failed to get cert %v/%s from %d peers: %w", lid, bid.String(), len(peers), ctx.Err()) } diff --git a/fetch/mesh_data_test.go b/fetch/mesh_data_test.go index e3b25876f0..be320bb35f 100644 --- a/fetch/mesh_data_test.go +++ b/fetch/mesh_data_test.go @@ -4,8 +4,6 @@ import ( "context" "errors" "fmt" - "os" - "sync" "testing" mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" @@ -33,15 +31,6 @@ const ( txsForProposal = iota ) -const layersPerEpoch = 3 - -func TestMain(m *testing.M) { - types.SetLayersPerEpoch(layersPerEpoch) - - res := m.Run() - os.Exit(res) -} - func (f *testFetch) withMethod(method int) *testFetch { f.method = method return f @@ -159,7 +148,7 @@ func TestFetch_getHashes(t *testing.T) { for _, peer := range peers { f.peers.Add(peer) } - f.mh.EXPECT().ID().Return(p2p.Peer("self")).AnyTimes() + f.mh.EXPECT().ID().Return("self").AnyTimes() f.RegisterPeerHashes(peers[0], hashes[:2]) f.RegisterPeerHashes(peers[1], hashes[2:]) @@ -172,9 +161,9 @@ func TestFetch_getHashes(t *testing.T) { responses[h] = res } f.mHashS.EXPECT(). - Request(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Request(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn( - func(_ context.Context, p p2p.Peer, req []byte, okFunc func([]byte), _ func(error)) error { + func(_ context.Context, p p2p.Peer, req []byte) ([]byte, error) { var rb RequestBatch err := codec.Decode(req, &rb) require.NoError(t, err) @@ -194,8 +183,8 @@ func TestFetch_getHashes(t *testing.T) { } bts, err := codec.Encode(&resBatch) require.NoError(t, err) - okFunc(bts) - return nil + + return bts, nil }). Times(len(peers)) @@ -479,134 +468,66 @@ func TestGetPoetProof(t *testing.T) { } func TestFetch_GetMaliciousIDs(t *testing.T) { - peers := []p2p.Peer{"p0", "p1", "p3", "p4"} - errUnknown := errors.New("unknown") - tt := []struct { - name string - errs []error - }{ - { - name: "all peers returns", - errs: []error{nil, nil, nil, nil}, - }, - { - name: "some peers errors", - errs: []error{nil, errUnknown, nil, errUnknown}, - }, - } - - for _, tc := range tt { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() + t.Run("success", func(t *testing.T) { + t.Parallel() + f := createFetch(t) + expectedIds := generateMaliciousIDs(t) + f.mMalS.EXPECT().Request(gomock.Any(), p2p.Peer("p0"), []byte{}).Return(expectedIds, nil) + ids, err := f.GetMaliciousIDs(context.Background(), "p0") + require.NoError(t, err) + require.Equal(t, expectedIds, ids) + }) + t.Run("failure", func(t *testing.T) { + t.Parallel() + errUnknown := errors.New("unknown") + f := createFetch(t) + f.mMalS.EXPECT().Request(gomock.Any(), p2p.Peer("p0"), []byte{}).Return(nil, errUnknown) + ids, err := f.GetMaliciousIDs(context.Background(), "p0") + require.ErrorIs(t, err, errUnknown) + require.Nil(t, ids) + }) +} - require.Equal(t, len(peers), len(tc.errs)) - f := createFetch(t) - oks := make(chan struct{}, len(peers)) - errs := make(chan struct{}, len(peers)) - var wg sync.WaitGroup - wg.Add(len(peers)) - okFunc := func([]byte, p2p.Peer) { - oks <- struct{}{} - wg.Done() - } - errFunc := func(error, p2p.Peer) { - errs <- struct{}{} - wg.Done() - } - var expOk, expErr int - for i, p := range peers { - if tc.errs[i] == nil { - expOk++ - } else { - expErr++ - } - idx := i - f.mMalS.EXPECT(). - Request(gomock.Any(), p, []byte{}, gomock.Any(), gomock.Any()). - DoAndReturn( - func(_ context.Context, _ p2p.Peer, _ []byte, okCB func([]byte), errCB func(error)) error { - if tc.errs[idx] == nil { - go okCB(generateMaliciousIDs(t)) - } else { - go errCB(tc.errs[idx]) - } - return nil - }) - } - require.NoError(t, f.GetMaliciousIDs(context.Background(), peers, okFunc, errFunc)) - wg.Wait() - require.Len(t, oks, expOk) - require.Len(t, errs, expErr) - }) - } +func TestFetch_GetLayerOpinions(t *testing.T) { + t.Run("success", func(t *testing.T) { + t.Parallel() + f := createFetch(t) + expected := generateLayerContent(t) + f.mOpn2S.EXPECT().Request(gomock.Any(), p2p.Peer("p0"), gomock.Any()).Return(expected, nil) + res, err := f.GetLayerOpinions(context.Background(), "p0", 7) + require.NoError(t, err) + require.Equal(t, expected, res) + }) + t.Run("failure", func(t *testing.T) { + t.Parallel() + errUnknown := errors.New("unknown") + f := createFetch(t) + f.mOpn2S.EXPECT().Request(gomock.Any(), p2p.Peer("p0"), gomock.Any()).Return(nil, errUnknown) + res, err := f.GetLayerOpinions(context.Background(), "p0", 7) + require.ErrorIs(t, err, errUnknown) + require.Nil(t, res) + }) } func TestFetch_GetLayerData(t *testing.T) { - peers := []p2p.Peer{"p0", "p1", "p3", "p4"} - errUnknown := errors.New("unknown") - tt := []struct { - name string - errs []error - }{ - { - name: "all peers returns", - errs: []error{nil, nil, nil, nil}, - }, - { - name: "some peers errors", - errs: []error{nil, errUnknown, nil, errUnknown}, - }, - } - - for _, tc := range tt { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - require.Equal(t, len(peers), len(tc.errs)) - f := createFetch(t) - oks := make(chan struct{}, len(peers)) - errs := make(chan struct{}, len(peers)) - var wg sync.WaitGroup - wg.Add(len(peers)) - okFunc := func(data []byte, peer p2p.Peer) { - oks <- struct{}{} - wg.Done() - } - errFunc := func(err error, peer p2p.Peer) { - errs <- struct{}{} - wg.Done() - } - var expOk, expErr int - for i, p := range peers { - if tc.errs[i] == nil { - expOk++ - } else { - expErr++ - } - idx := i - f.mLyrS.EXPECT(). - Request(gomock.Any(), p, gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn( - func(_ context.Context, _ p2p.Peer, _ []byte, okCB func([]byte), errCB func(error)) error { - if tc.errs[idx] == nil { - go okCB(generateLayerContent(t)) - } else { - go errCB(tc.errs[idx]) - } - return nil - }) - } - require.NoError( - t, - f.GetLayerData(context.Background(), peers, types.LayerID(111), okFunc, errFunc), - ) - wg.Wait() - require.Len(t, oks, expOk) - require.Len(t, errs, expErr) - }) - } + t.Run("success", func(t *testing.T) { + t.Parallel() + f := createFetch(t) + expected := generateLayerContent(t) + f.mLyrS.EXPECT().Request(gomock.Any(), p2p.Peer("p0"), gomock.Any()).Return(expected, nil) + res, err := f.GetLayerData(context.Background(), "p0", 7) + require.NoError(t, err) + require.Equal(t, expected, res) + }) + t.Run("failure", func(t *testing.T) { + t.Parallel() + errUnknown := errors.New("unknown") + f := createFetch(t) + f.mLyrS.EXPECT().Request(gomock.Any(), p2p.Peer("p0"), gomock.Any()).Return(nil, errUnknown) + res, err := f.GetLayerData(context.Background(), "p0", 7) + require.ErrorIs(t, err, errUnknown) + require.Nil(t, res) + }) } func generateEpochData(t *testing.T) (*EpochData, []byte) { @@ -641,20 +562,18 @@ func Test_PeerEpochInfo(t *testing.T) { t.Parallel() f := createFetch(t) - f.mh.EXPECT().ID().Return(p2p.Peer("self")).AnyTimes() + f.mh.EXPECT().ID().Return("self").AnyTimes() var expected *EpochData f.mAtxS.EXPECT(). - Request(gomock.Any(), peer, gomock.Any(), gomock.Any(), gomock.Any()). + Request(gomock.Any(), peer, gomock.Any()). DoAndReturn( - func(_ context.Context, _ p2p.Peer, req []byte, okCB func([]byte), errCB func(error)) error { + func(_ context.Context, _ p2p.Peer, req []byte) ([]byte, error) { if tc.err == nil { var data []byte expected, data = generateEpochData(t) - okCB(data) - } else { - errCB(tc.err) + return data, nil } - return nil + return nil, tc.err }) got, err := f.PeerEpochInfo(context.Background(), peer, types.EpochID(111)) require.ErrorIs(t, err, tc.err) @@ -708,19 +627,16 @@ func TestFetch_GetMeshHashes(t *testing.T) { reqData, err := codec.Encode(req) require.NoError(t, err) f.mMHashS.EXPECT(). - Request(gomock.Any(), peer, gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn( - func(_ context.Context, _ p2p.Peer, gotReq []byte, okCB func([]byte), errCB func(error)) error { - require.Equal(t, reqData, gotReq) - if tc.err == nil { - data, err := codec.EncodeSlice(expected.Hashes) - require.NoError(t, err) - okCB(data) - } else { - errCB(tc.err) - } - return nil - }) + Request(gomock.Any(), peer, gomock.Any()). + DoAndReturn(func(_ context.Context, _ p2p.Peer, gotReq []byte) ([]byte, error) { + require.Equal(t, reqData, gotReq) + if tc.err == nil { + data, err := codec.EncodeSlice(expected.Hashes) + require.NoError(t, err) + return data, nil + } + return nil, tc.err + }) got, err := f.PeerMeshHashes(context.Background(), peer, req) if tc.err == nil { require.NoError(t, err) @@ -738,18 +654,16 @@ func TestFetch_GetCert(t *testing.T) { tt := []struct { name string results [3]error - stop int - err bool + + err bool }{ { name: "success", results: [3]error{errUnknown, nil, nil}, - stop: 1, }, { name: "failure", results: [3]error{errUnknown, errUnknown, errUnknown}, - stop: -1, err: true, }, } @@ -770,22 +684,20 @@ func TestFetch_GetCert(t *testing.T) { reqData, err := codec.Encode(req) require.NoError(t, err) for i, peer := range peers { + p := peer ith := i f.mOpn2S.EXPECT(). - Request(gomock.Any(), peer, gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn( - func(_ context.Context, _ p2p.Peer, gotReq []byte, okCB func([]byte), errCB func(error)) error { - require.Equal(t, reqData, gotReq) - if tc.results[ith] == nil { - data, err := codec.Encode(&expected) - require.NoError(t, err) - okCB(data) - } else { - errCB(tc.results[ith]) - } - return nil - }) - if tc.stop > 0 && tc.stop == i { + Request(gomock.Any(), p, gomock.Any()). + DoAndReturn(func(_ context.Context, _ p2p.Peer, gotReq []byte) ([]byte, error) { + require.Equal(t, reqData, gotReq) + if tc.results[ith] == nil { + data, err := codec.Encode(&expected) + require.NoError(t, err) + return data, nil + } + return nil, tc.results[ith] + }) + if tc.results[ith] == nil { break } } diff --git a/fetch/mocks/mocks.go b/fetch/mocks/mocks.go index 8c60c2fd42..c7f74bde6e 100644 --- a/fetch/mocks/mocks.go +++ b/fetch/mocks/mocks.go @@ -41,17 +41,18 @@ func (m *Mockrequester) EXPECT() *MockrequesterMockRecorder { } // Request mocks base method. -func (m *Mockrequester) Request(arg0 context.Context, arg1 p2p.Peer, arg2 []byte, arg3 func([]byte), arg4 func(error)) error { +func (m *Mockrequester) Request(arg0 context.Context, arg1 p2p.Peer, arg2 []byte) ([]byte, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Request", arg0, arg1, arg2, arg3, arg4) - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "Request", arg0, arg1, arg2) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 } // Request indicates an expected call of Request. -func (mr *MockrequesterMockRecorder) Request(arg0, arg1, arg2, arg3, arg4 any) *requesterRequestCall { +func (mr *MockrequesterMockRecorder) Request(arg0, arg1, arg2 any) *requesterRequestCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Request", reflect.TypeOf((*Mockrequester)(nil).Request), arg0, arg1, arg2, arg3, arg4) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Request", reflect.TypeOf((*Mockrequester)(nil).Request), arg0, arg1, arg2) return &requesterRequestCall{Call: call} } @@ -61,19 +62,19 @@ type requesterRequestCall struct { } // Return rewrite *gomock.Call.Return -func (c *requesterRequestCall) Return(arg0 error) *requesterRequestCall { - c.Call = c.Call.Return(arg0) +func (c *requesterRequestCall) Return(arg0 []byte, arg1 error) *requesterRequestCall { + c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *requesterRequestCall) Do(f func(context.Context, p2p.Peer, []byte, func([]byte), func(error)) error) *requesterRequestCall { +func (c *requesterRequestCall) Do(f func(context.Context, p2p.Peer, []byte) ([]byte, error)) *requesterRequestCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *requesterRequestCall) DoAndReturn(f func(context.Context, p2p.Peer, []byte, func([]byte), func(error)) error) *requesterRequestCall { +func (c *requesterRequestCall) DoAndReturn(f func(context.Context, p2p.Peer, []byte) ([]byte, error)) *requesterRequestCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/p2p/server/server.go b/p2p/server/server.go index bc64bfd98b..033e95a431 100644 --- a/p2p/server/server.go +++ b/p2p/server/server.go @@ -276,49 +276,40 @@ func (s *Server) queueHandler(ctx context.Context, stream network.Stream) bool { // Request sends a binary request to the peer. Request is executed in the background, one of the callbacks // is guaranteed to be called on success/error. -func (s *Server) Request( - ctx context.Context, - pid peer.ID, - req []byte, - resp func([]byte), - failure func(error), -) error { +func (s *Server) Request(ctx context.Context, pid peer.ID, req []byte) ([]byte, error) { start := time.Now() if len(req) > s.requestLimit { - return fmt.Errorf("request length (%d) is longer than limit %d", len(req), s.requestLimit) + return nil, fmt.Errorf("request length (%d) is longer than limit %d", len(req), s.requestLimit) } if s.h.Network().Connectedness(pid) != network.Connected { - return fmt.Errorf("%w: %s", ErrNotConnected, pid) + return nil, fmt.Errorf("%w: %s", ErrNotConnected, pid) } - go func() { - data, err := s.request(ctx, pid, req) - if err != nil { - failure(err) - } else if len(data.Error) > 0 { - failure(errors.New(data.Error)) - } else { - resp(data.Data) - } - s.logger.WithContext(ctx).With().Debug("request execution time", - log.String("protocol", s.protocol), - log.Duration("duration", time.Since(start)), - log.Err(err), - ) - switch { - case s.metrics == nil: - return - case err != nil: + data, err := s.request(ctx, pid, req) + s.logger.WithContext(ctx).With().Debug("request execution time", + log.String("protocol", s.protocol), + log.Duration("duration", time.Since(start)), + log.Err(err), + ) + + took := time.Since(start).Seconds() + switch { + case err != nil: + if s.metrics != nil { s.metrics.clientFailed.Inc() - s.metrics.clientLatencyFailure.Observe(time.Since(start).Seconds()) - case len(data.Error) > 0: + s.metrics.clientLatencyFailure.Observe(took) + } + return nil, err + case len(data.Error) > 0: + if s.metrics != nil { s.metrics.clientServerError.Inc() - s.metrics.clientLatency.Observe(time.Since(start).Seconds()) - case err == nil: - s.metrics.clientSucceeded.Inc() - s.metrics.clientLatency.Observe(time.Since(start).Seconds()) + s.metrics.clientLatency.Observe(took) } - }() - return nil + return nil, errors.New(data.Error) + case s.metrics != nil: + s.metrics.clientSucceeded.Inc() + s.metrics.clientLatency.Observe(took) + } + return data.Data, nil } func (s *Server) request(ctx context.Context, pid peer.ID, req []byte) (*Response, error) { diff --git a/p2p/server/server_test.go b/p2p/server/server_test.go index 22183342be..aa4c34e807 100644 --- a/p2p/server/server_test.go +++ b/p2p/server/server_test.go @@ -9,6 +9,7 @@ import ( mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" "github.com/spacemeshos/go-scale/tester" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" @@ -23,8 +24,6 @@ func TestServer(t *testing.T) { proto := "test" request := []byte("test request") testErr := errors.New("test error") - errch := make(chan error, 1) - respch := make(chan []byte, 1) handler := func(_ context.Context, msg []byte) ([]byte, error) { return msg, nil @@ -59,79 +58,32 @@ func TestServer(t *testing.T) { cancel() eg.Wait() }) - respHandler := func(msg []byte) { - select { - case <-ctx.Done(): - case respch <- msg: - } - } - respErrHandler := func(err error) { - select { - case <-ctx.Done(): - case errch <- err: - } - } + t.Run("ReceiveMessage", func(t *testing.T) { - require.NoError( - t, - client.Request(ctx, mesh.Hosts()[1].ID(), request, respHandler, respErrHandler), - ) - select { - case <-time.After(time.Second): - require.FailNow(t, "timed out while waiting for message response") - case response := <-respch: - require.Equal(t, request, response) - require.NotEmpty(t, mesh.Hosts()[2].Network().ConnsToPeer(mesh.Hosts()[0].ID())) - } + response, err := client.Request(ctx, mesh.Hosts()[1].ID(), request) + require.NoError(t, err) + require.Equal(t, request, response) + require.NotEmpty(t, mesh.Hosts()[2].Network().ConnsToPeer(mesh.Hosts()[0].ID())) }) t.Run("ReceiveError", func(t *testing.T) { - require.NoError( - t, - client.Request(ctx, mesh.Hosts()[2].ID(), request, respHandler, respErrHandler), - ) - select { - case <-time.After(time.Second): - require.FailNow(t, "timed out while waiting for error response") - case err := <-errch: - require.Equal(t, testErr, err) - } + _, err := client.Request(ctx, mesh.Hosts()[2].ID(), request) + require.Equal(t, err, testErr) }) t.Run("DialError", func(t *testing.T) { - require.NoError( - t, - client.Request(ctx, mesh.Hosts()[3].ID(), request, respHandler, respErrHandler), - ) - select { - case <-time.After(time.Second): - require.FailNow(t, "timed out while waiting for dial error") - case err := <-errch: - require.Error(t, err) - } + _, err := client.Request(ctx, mesh.Hosts()[2].ID(), request) + require.Error(t, err) }) t.Run("NotConnected", func(t *testing.T) { - require.ErrorIs( - t, - client.Request(ctx, "unknown", request, respHandler, respErrHandler), - ErrNotConnected, - ) + _, err := client.Request(ctx, "unknown", request) + require.ErrorIs(t, err, ErrNotConnected) }) t.Run("limit overflow", func(t *testing.T) { - require.NoError( - t, - client.Request( - ctx, - mesh.Hosts()[2].ID(), - make([]byte, limit+1), - respHandler, - respErrHandler, - ), + _, err := client.Request( + ctx, + mesh.Hosts()[2].ID(), + make([]byte, limit+1), ) - select { - case <-time.After(time.Second): - require.FailNow(t, "timed out while waiting for error response") - case err := <-errch: - require.Error(t, err) - } + require.Error(t, err) }) } @@ -166,18 +118,18 @@ func TestQueued(t *testing.T) { return srv.Run(ctx) }) t.Cleanup(func() { - eg.Wait() + assert.NoError(t, eg.Wait()) }) for i := 0; i < total; i++ { - require.NoError(t, client.Request(ctx, mesh.Hosts()[1].ID(), []byte("ping"), - func(b []byte) { - success.Add(1) - wait <- struct{}{} - }, func(err error) { + eg.Go(func() error { + if _, err := client.Request(ctx, mesh.Hosts()[1].ID(), []byte("ping")); err != nil { failure.Add(1) - wait <- struct{}{} - }, - )) + } else { + success.Add(1) + } + wait <- struct{}{} + return nil + }) } for i := 0; i < total; i++ { <-wait diff --git a/syncer/data_fetch.go b/syncer/data_fetch.go index 41d07722f7..f38e274364 100644 --- a/syncer/data_fetch.go +++ b/syncer/data_fetch.go @@ -6,6 +6,9 @@ import ( "fmt" "sync" + "golang.org/x/exp/maps" + "golang.org/x/sync/errgroup" + "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/fetch" @@ -13,42 +16,7 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p" ) -var ( - errNoPeers = errors.New("no peers") - errTimeout = errors.New("request timeout") -) - -type peerResult[T any] struct { - peer p2p.Peer - data *T - err error -} - -type request[T any, R any] struct { - lid types.LayerID - peers []p2p.Peer - response R - ch chan peerResult[T] - peerResults map[p2p.Peer]peerResult[T] -} - -type dataResponse struct { - ballots map[types.BallotID]struct{} -} - -type opinionResponse struct { - opinions []*fetch.LayerOpinion -} - -type maliciousIDResponse struct { - ids map[types.NodeID]struct{} -} - -type ( - dataRequest request[fetch.LayerData, dataResponse] - opinionRequest request[fetch.LayerOpinion, opinionResponse] - maliciousIDRequest request[fetch.MaliciousIDs, maliciousIDResponse] -) +var errNoPeers = errors.New("no peers") // DataFetch contains the logic of fetching mesh data. type DataFetch struct { @@ -81,371 +49,222 @@ func NewDataFetch( } } +type threadSafeErr struct { + err error + mu sync.Mutex +} + +func (e *threadSafeErr) join(err error) { + e.mu.Lock() + defer e.mu.Unlock() + e.err = errors.Join(e.err, err) +} + // PollMaliciousProofs polls all peers for malicious NodeIDs. func (d *DataFetch) PollMaliciousProofs(ctx context.Context) error { peers := d.fetcher.SelectBestShuffled(fetch.RedundantPeers) logger := d.logger.WithContext(ctx) - req := &maliciousIDRequest{ - peers: peers, - response: maliciousIDResponse{ - ids: map[types.NodeID]struct{}{}, - }, - ch: make(chan peerResult[fetch.MaliciousIDs], len(peers)), - } - okFunc := func(data []byte, peer p2p.Peer) { - d.receiveMaliciousIDs(ctx, req, peer, data, nil) - } - errFunc := func(err error, peer p2p.Peer) { - d.receiveMaliciousIDs(ctx, req, peer, nil, err) - malPeerError.Inc() - } - if err := d.fetcher.GetMaliciousIDs(ctx, peers, okFunc, errFunc); err != nil { - return err - } - req.peerResults = map[p2p.Peer]peerResult[fetch.MaliciousIDs]{} - var ( - success bool - candidateErr error - ) - for { - select { - case res := <-req.ch: - logger.Debug("received malicious IDs") - req.peerResults[res.peer] = res - if res.err == nil { - success = true - fetchMalfeasanceProof(ctx, logger, d.ids, d.fetcher, req, res.data) - } else if candidateErr == nil { - candidateErr = res.err - } - if len(req.peerResults) < len(req.peers) { - break + maliciousIDs := make(chan fetch.MaliciousIDs, len(peers)) + var eg errgroup.Group + fetchErr := threadSafeErr{} + for _, peer := range peers { + peer := peer + eg.Go(func() error { + data, err := d.fetcher.GetMaliciousIDs(ctx, peer) + if err != nil { + malPeerError.Inc() + logger.With().Debug("failed to get malicious IDs", log.Err(err), log.Stringer("peer", peer)) + fetchErr.join(err) + return nil } - // all peer responded - if success { - candidateErr = nil + var malIDs fetch.MaliciousIDs + if err := codec.Decode(data, &malIDs); err != nil { + logger.With().Debug("failed to decode", log.Err(err)) + fetchErr.join(err) + return nil } - return candidateErr - case <-ctx.Done(): - logger.Warning("request timed out") - return errTimeout + logger.With().Debug("received malicious id from peer", log.Stringer("peer", peer)) + maliciousIDs <- malIDs + return nil + }) + } + _ = eg.Wait() + close(maliciousIDs) + + allIds := make(map[types.NodeID]struct{}) + success := false + for ids := range maliciousIDs { + success = true + for _, id := range ids.NodeIDs { + allIds[id] = struct{}{} } } + if !success { + return fetchErr.err + } + + var idsToFetch []types.NodeID + for nodeID := range allIds { + if exists, err := d.ids.IdentityExists(nodeID); err != nil { + logger.With().Error("failed to check identity", log.Err(err)) + continue + } else if !exists { + logger.With().Info("malicious identity does not exist", log.Stringer("identity", nodeID)) + continue + } + idsToFetch = append(idsToFetch, nodeID) + } + + if err := d.fetcher.GetMalfeasanceProofs(ctx, idsToFetch); err != nil { + return fmt.Errorf("getting malfeasance proofs: %w", err) + } + return nil } // PollLayerData polls all peers for data in the specified layer. func (d *DataFetch) PollLayerData(ctx context.Context, lid types.LayerID, peers ...p2p.Peer) error { if len(peers) == 0 { peers = d.fetcher.SelectBestShuffled(fetch.RedundantPeers) - } - if len(peers) == 0 { - return errNoPeers + if len(peers) == 0 { + return errNoPeers + } } logger := d.logger.WithContext(ctx).WithFields(lid) - req := &dataRequest{ - lid: lid, - peers: peers, - response: dataResponse{ - ballots: map[types.BallotID]struct{}{}, - }, - ch: make(chan peerResult[fetch.LayerData], len(peers)), - } - okFunc := func(data []byte, peer p2p.Peer) { - d.receiveData(ctx, req, peer, data, nil) - } - errFunc := func(err error, peer p2p.Peer) { - d.receiveData(ctx, req, peer, nil, err) - layerPeerError.Inc() - } - if err := d.fetcher.GetLayerData(ctx, peers, lid, okFunc, errFunc); err != nil { - return err - } - - req.peerResults = map[p2p.Peer]peerResult[fetch.LayerData]{} - var ( - success bool - candidateErr error - ) - for { - select { - case res := <-req.ch: - logger.Debug("received layer data") - req.peerResults[res.peer] = res - if res.err == nil { - success = true - logger.Debug("fetching layer data") - fetchLayerData(ctx, logger, d.fetcher, req, res.data) - logger.Debug("fetched layer data") - } else if candidateErr == nil { - candidateErr = res.err + layerData := make(chan fetch.LayerData, len(peers)) + var eg errgroup.Group + fetchErr := threadSafeErr{} + for _, peer := range peers { + peer := peer + eg.Go(func() error { + data, err := d.fetcher.GetLayerData(ctx, peer, lid) + if err != nil { + layerPeerError.Inc() + logger.With().Debug("failed to get layer data", log.Err(err), log.Stringer("peer", peer)) + fetchErr.join(err) + return nil } - if len(req.peerResults) < len(req.peers) { - break + var ld fetch.LayerData + if err := codec.Decode(data, &ld); err != nil { + logger.With().Debug("failed to decode", log.Err(err)) + fetchErr.join(err) + return nil } - // all peer responded - if success { - candidateErr = nil - } - return candidateErr - case <-ctx.Done(): - logger.Warning("request timed out") - return errTimeout + logger.With().Debug("received layer data from peer", log.Stringer("peer", peer)) + registerLayerHashes(d.fetcher, peer, &ld) + layerData <- ld + return nil + }) + } + _ = eg.Wait() + close(layerData) + + allBallots := make(map[types.BallotID]struct{}) + success := false + for ld := range layerData { + success = true + for _, id := range ld.Ballots { + allBallots[id] = struct{}{} } } -} - -func (d *DataFetch) receiveMaliciousIDs( - ctx context.Context, - req *maliciousIDRequest, - peer p2p.Peer, - data []byte, - peerErr error, -) { - logger := d.logger.WithContext(ctx).WithFields(log.Stringer("peer", peer)) - logger.Debug("received malicious id from peer") - var ( - result = peerResult[fetch.MaliciousIDs]{peer: peer, err: peerErr} - malIDs fetch.MaliciousIDs - ) - if peerErr != nil { - logger.With().Debug("received peer error for layer data", req.lid, log.Err(peerErr)) - } else if result.err = codec.Decode(data, &malIDs); result.err != nil { - logger.With().Debug("error converting bytes to LayerData", log.Err(result.err)) - } else { - result.data = &malIDs - } - select { - case req.ch <- result: - case <-ctx.Done(): - logger.Warning("request timed out") + if !success { + return fetchErr.err } -} -func (d *DataFetch) receiveData( - ctx context.Context, - req *dataRequest, - peer p2p.Peer, - data []byte, - peerErr error, -) { - logger := d.logger.WithContext(ctx).WithFields(req.lid, log.Stringer("peer", peer)) - logger.Debug("received layer data from peer") - var ( - result = peerResult[fetch.LayerData]{peer: peer, err: peerErr} - ld fetch.LayerData - ) - if peerErr != nil { - logger.With().Debug("received peer error for layer data", req.lid, log.Err(peerErr)) - } else if result.err = codec.Decode(data, &ld); result.err != nil { - logger.With().Debug("error converting bytes to LayerData", log.Err(result.err)) - } else { - result.data = &ld - registerLayerHashes(d.fetcher, peer, result.data) - } - select { - case req.ch <- result: - case <-ctx.Done(): - logger.Warning("request timed out") + if err := d.fetcher.GetBallots(ctx, maps.Keys(allBallots)); err != nil { + return fmt.Errorf("getting ballots: %w", err) } + return nil } // registerLayerHashes registers hashes with the peer that provides these hashes. func registerLayerHashes(fetcher fetcher, peer p2p.Peer, data *fetch.LayerData) { - if data == nil { + if len(data.Ballots) == 0 { return } var layerHashes []types.Hash32 for _, ballotID := range data.Ballots { layerHashes = append(layerHashes, ballotID.AsHash32()) } - if len(layerHashes) == 0 { - return - } fetcher.RegisterPeerHashes(peer, layerHashes) } -func fetchMalfeasanceProof( - ctx context.Context, - logger log.Log, - ids idProvider, - fetcher fetcher, - req *maliciousIDRequest, - data *fetch.MaliciousIDs, -) { - var idsToFetch []types.NodeID - for _, nodeID := range data.NodeIDs { - if _, ok := req.response.ids[nodeID]; !ok { - // check if the NodeID exists - if exists, err := ids.IdentityExists(nodeID); err != nil { - logger.With().Error("failed to check identity", log.Err(err)) - continue - } else if !exists { - logger.With().Warning("malicious identity does not exist", - log.String("identity", nodeID.String())) - continue - } - // not yet fetched - req.response.ids[nodeID] = struct{}{} - idsToFetch = append(idsToFetch, nodeID) - } - } - if len(idsToFetch) > 0 { - logger.With().Info("fetching malfeasance proofs", log.Int("to_fetch", len(idsToFetch))) - if err := fetcher.GetMalfeasanceProofs(ctx, idsToFetch); err != nil { - logger.With().Warning("failed fetching malfeasance proofs", - log.Array("malicious_ids", log.ArrayMarshalerFunc(func(encoder log.ArrayEncoder) error { - for _, nodeID := range idsToFetch { - encoder.AppendString(nodeID.String()) - } - return nil - })), - log.Err(err)) - } - } -} - -func fetchLayerData( - ctx context.Context, - logger log.Log, - fetcher fetcher, - req *dataRequest, - data *fetch.LayerData, -) { - var ballotsToFetch []types.BallotID - for _, ballotID := range data.Ballots { - if _, ok := req.response.ballots[ballotID]; !ok { - // not yet fetched - req.response.ballots[ballotID] = struct{}{} - ballotsToFetch = append(ballotsToFetch, ballotID) - } - } - - if len(ballotsToFetch) > 0 { - logger.With().Debug("fetching new ballots", log.Int("to_fetch", len(ballotsToFetch))) - if err := fetcher.GetBallots(ctx, ballotsToFetch); err != nil { - logger.With().Warning("failed fetching new ballots", - log.Array("ballot_ids", log.ArrayMarshalerFunc(func(encoder log.ArrayEncoder) error { - for _, bid := range ballotsToFetch { - encoder.AppendString(bid.String()) - } - return nil - })), - log.Err(err)) - - // syntactically invalid ballots are expected from malicious peers - } - } -} - func (d *DataFetch) PollLayerOpinions( ctx context.Context, lid types.LayerID, needCert bool, peers []p2p.Peer, ) ([]*fetch.LayerOpinion, []*types.Certificate, error) { - req := &opinionRequest{ - lid: lid, - peers: peers, - ch: make(chan peerResult[fetch.LayerOpinion], len(peers)), - } - okFunc := func(data []byte, peer p2p.Peer) { - d.receiveOpinions(ctx, req, peer, data, nil) - } - errFunc := func(err error, peer p2p.Peer) { - d.receiveOpinions(ctx, req, peer, nil, err) - opnsPeerError.Inc() - } - if err := d.fetcher.GetLayerOpinions(ctx, peers, lid, okFunc, errFunc); err != nil { - return nil, nil, err - } - req.peerResults = map[p2p.Peer]peerResult[fetch.LayerOpinion]{} - var ( - success bool - candidateErr error - ) - for { - select { - case res := <-req.ch: - req.peerResults[res.peer] = res - if res.err == nil { - success = true - req.response.opinions = append(req.response.opinions, res.data) - } else if candidateErr == nil { - candidateErr = res.err + logger := d.logger.WithContext(ctx).WithFields(lid) + opinions := make(chan *fetch.LayerOpinion, len(peers)) + var eg errgroup.Group + fetchErr := threadSafeErr{} + for _, peer := range peers { + peer := peer + eg.Go(func() error { + data, err := d.fetcher.GetLayerOpinions(ctx, peer, lid) + if err != nil { + opnsPeerError.Inc() + logger.With().Debug("received peer error for layer opinions", log.Err(err), log.Stringer("peer", peer)) + fetchErr.join(err) + return nil } - if len(req.peerResults) < len(req.peers) { - break + var lo fetch.LayerOpinion + if err := codec.Decode(data, &lo); err != nil { + logger.With().Debug("failed to decode layer opinion", log.Err(err)) + fetchErr.join(err) + return nil } - // all peer responded - if success { - candidateErr = nil + logger.With().Debug("received layer opinion", log.Stringer("peer", peer)) + lo.SetPeer(peer) + opinions <- &lo + return nil + }) + } + _ = eg.Wait() + close(opinions) + + var allOpinions []*fetch.LayerOpinion + success := false + for op := range opinions { + success = true + allOpinions = append(allOpinions, op) + } + if !success { + return nil, nil, fetchErr.err + } + + certs := make([]*types.Certificate, 0, len(allOpinions)) + if needCert { + peerCerts := map[types.BlockID][]p2p.Peer{} + for _, opinion := range allOpinions { + if opinion.Certified == nil { + continue } - certs := make([]*types.Certificate, 0, len(req.response.opinions)) - if needCert { - peerCerts := map[types.BlockID][]p2p.Peer{} - for _, opns := range req.response.opinions { - if opns.Certified == nil { - continue - } - if _, ok := peerCerts[*opns.Certified]; !ok { - peerCerts[*opns.Certified] = []p2p.Peer{} - } - peerCerts[*opns.Certified] = append(peerCerts[*opns.Certified], opns.Peer()) - // note that we want to fetch block certificate for types.EmptyBlockID as well - // but we don't need to register hash for the actual block fetching - if *opns.Certified != types.EmptyBlockID { - d.fetcher.RegisterPeerHashes( - opns.Peer(), - []types.Hash32{opns.Certified.AsHash32()}, - ) - } - } - for bid, bidPeers := range peerCerts { - cert, err := d.fetcher.GetCert(ctx, lid, bid, bidPeers) - if err != nil { - certPeerError.Inc() - continue - } - certs = append(certs, cert) - } + if _, ok := peerCerts[*opinion.Certified]; !ok { + peerCerts[*opinion.Certified] = []p2p.Peer{} + } + peerCerts[*opinion.Certified] = append(peerCerts[*opinion.Certified], opinion.Peer()) + // note that we want to fetch block certificate for types.EmptyBlockID as well, + // but we don't need to register hash for the actual block fetching + if *opinion.Certified != types.EmptyBlockID { + d.fetcher.RegisterPeerHashes( + opinion.Peer(), + []types.Hash32{opinion.Certified.AsHash32()}, + ) } - return req.response.opinions, certs, candidateErr - case <-ctx.Done(): - d.logger.WithContext(ctx).Debug("request timed out", lid) - return nil, nil, errTimeout + } + for bid, bidPeers := range peerCerts { + cert, err := d.fetcher.GetCert(ctx, lid, bid, bidPeers) + if err != nil { + certPeerError.Inc() + continue + } + certs = append(certs, cert) } } -} - -func (d *DataFetch) receiveOpinions( - ctx context.Context, - req *opinionRequest, - peer p2p.Peer, - data []byte, - peerErr error, -) { - logger := d.logger.WithContext(ctx).WithFields(req.lid, log.Stringer("peer", peer)) - logger.Debug("received layer opinions from peer") - - var ( - result = peerResult[fetch.LayerOpinion]{peer: peer, err: peerErr} - lo fetch.LayerOpinion - ) - if peerErr != nil { - logger.With().Debug("received peer error for layer opinions", log.Err(peerErr)) - } else if result.err = codec.Decode(data, &lo); result.err != nil { - logger.With().Debug("error decoding LayerOpinion", log.Err(result.err)) - } else { - lo.SetPeer(peer) - result.data = &lo - } - select { - case req.ch <- result: - case <-ctx.Done(): - logger.Debug("request timed out") - } + return allOpinions, certs, nil } func (d *DataFetch) pickAtxPeer(epoch types.EpochID, peers []p2p.Peer) p2p.Peer { diff --git a/syncer/data_fetch_test.go b/syncer/data_fetch_test.go index 67c1f5d421..0835075130 100644 --- a/syncer/data_fetch_test.go +++ b/syncer/data_fetch_test.go @@ -55,7 +55,7 @@ func generateMaliciousIDs(t *testing.T) ([]types.NodeID, []byte) { return malicious.NodeIDs, data } -func generateLayerOpinions2(t *testing.T, bid *types.BlockID) []byte { +func generateLayerOpinions(t *testing.T, bid *types.BlockID) []byte { t.Helper() lo := &fetch.LayerOpinion{ PrevAggHash: types.RandomHash(), @@ -101,39 +101,92 @@ func TestDataFetch_PollMaliciousIDs(t *testing.T) { numPeers := 4 peers := GenPeers(numPeers) errUnknown := errors.New("unknown") - newTestDataFetchWithMocks := func(_ *testing.T, exits bool) *testDataFetch { + newTestDataFetchWithMocks := func(_ *testing.T, exists bool) *testDataFetch { td := newTestDataFetch(t) td.mFetcher.EXPECT().SelectBestShuffled(gomock.Any()).Return(peers) - td.mFetcher.EXPECT().GetMaliciousIDs(gomock.Any(), peers, gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, _ []p2p.Peer, okCB func([]byte, p2p.Peer), errCB func(error, p2p.Peer)) error { - for _, peer := range peers { + for _, peer := range peers { + td.mFetcher.EXPECT().GetMaliciousIDs(gomock.Any(), peer).DoAndReturn( + func(_ context.Context, peer p2p.Peer) ([]byte, error) { ids, data := generateMaliciousIDs(t) for _, id := range ids { - td.mIDs.EXPECT().IdentityExists(id).Return(exits, nil) + td.mIDs.EXPECT().IdentityExists(id).Return(exists, nil) } - okCB(data, peer) - } - return nil - }) + return data, nil + }) + } return td } - t.Run("all peers have malfeasance proofs", func(t *testing.T) { + t.Run("getting malfeasance proofs success", func(t *testing.T) { t.Parallel() td := newTestDataFetchWithMocks(t, true) - td.mFetcher.EXPECT().GetMalfeasanceProofs(gomock.Any(), gomock.Any()).Return(nil).MaxTimes(numPeers) - require.NoError(t, td.PollMaliciousProofs(context.TODO())) + td.mFetcher.EXPECT().GetMalfeasanceProofs(gomock.Any(), gomock.Any()) + require.NoError(t, td.PollMaliciousProofs(context.Background())) }) - t.Run("proof failure ignored", func(t *testing.T) { + t.Run("getting proofs failure", func(t *testing.T) { t.Parallel() td := newTestDataFetchWithMocks(t, true) td.mFetcher.EXPECT().GetMalfeasanceProofs(gomock.Any(), gomock.Any()).Return(errUnknown) - td.mFetcher.EXPECT().GetMalfeasanceProofs(gomock.Any(), gomock.Any()).Return(nil).MaxTimes(numPeers - 1) - require.NoError(t, td.PollMaliciousProofs(context.TODO())) + require.ErrorIs(t, td.PollMaliciousProofs(context.Background()), errUnknown) }) t.Run("ids do not exist", func(t *testing.T) { t.Parallel() td := newTestDataFetchWithMocks(t, false) - require.NoError(t, td.PollMaliciousProofs(context.TODO())) + td.mFetcher.EXPECT().GetMalfeasanceProofs(gomock.Any(), nil) + require.NoError(t, td.PollMaliciousProofs(context.Background())) + }) +} + +func TestDataFetch_PollMaliciousIDs_PeerErrors(t *testing.T) { + t.Run("malformed data in response", func(t *testing.T) { + t.Parallel() + peers := []p2p.Peer{"p0"} + td := newTestDataFetch(t) + td.mFetcher.EXPECT().SelectBestShuffled(gomock.Any()).Return(peers) + td.mFetcher.EXPECT().GetMaliciousIDs(gomock.Any(), p2p.Peer("p0")).Return([]byte("malformed"), nil) + err := td.PollMaliciousProofs(context.Background()) + require.ErrorContains(t, err, "decode") + }) + t.Run("peer fails", func(t *testing.T) { + t.Parallel() + peers := []p2p.Peer{"p0"} + expectedErr := errors.New("peer failure") + td := newTestDataFetch(t) + td.mFetcher.EXPECT().SelectBestShuffled(gomock.Any()).Return(peers) + td.mFetcher.EXPECT().GetMaliciousIDs(gomock.Any(), p2p.Peer("p0")).Return(nil, expectedErr) + err := td.PollMaliciousProofs(context.Background()) + require.ErrorIs(t, err, expectedErr) + }) + t.Run("one peer sends malformed data (succeed anyway)", func(t *testing.T) { + t.Parallel() + peers := []p2p.Peer{"p0", "p1"} + td := newTestDataFetch(t) + maliciousIds, data := generateMaliciousIDs(t) + for _, id := range maliciousIds { + td.mIDs.EXPECT().IdentityExists(id).Return(true, nil) + } + + td.mFetcher.EXPECT().SelectBestShuffled(gomock.Any()).Return(peers) + td.mFetcher.EXPECT().GetMaliciousIDs(gomock.Any(), p2p.Peer("p0")).Return(data, nil) + td.mFetcher.EXPECT().GetMaliciousIDs(gomock.Any(), p2p.Peer("p1")).Return([]byte("malformed"), nil) + td.mFetcher.EXPECT().GetMalfeasanceProofs(gomock.Any(), gomock.Any()) + err := td.PollMaliciousProofs(context.Background()) + require.NoError(t, err) + }) + t.Run("one peer fails (succeed anyway)", func(t *testing.T) { + t.Parallel() + peers := []p2p.Peer{"p0", "p1"} + expectedErr := errors.New("peer failure") + td := newTestDataFetch(t) + maliciousIds, data := generateMaliciousIDs(t) + for _, id := range maliciousIds { + td.mIDs.EXPECT().IdentityExists(id).Return(true, nil) + } + td.mFetcher.EXPECT().SelectBestShuffled(gomock.Any()).Return(peers) + td.mFetcher.EXPECT().GetMaliciousIDs(gomock.Any(), p2p.Peer("p0")).Return(data, nil) + td.mFetcher.EXPECT().GetMaliciousIDs(gomock.Any(), p2p.Peer("p1")).Return(nil, expectedErr) + td.mFetcher.EXPECT().GetMalfeasanceProofs(gomock.Any(), gomock.Any()) + err := td.PollMaliciousProofs(context.Background()) + require.NoError(t, err) }) } @@ -145,89 +198,75 @@ func TestDataFetch_PollLayerData(t *testing.T) { newTestDataFetchWithMocks := func(*testing.T) *testDataFetch { td := newTestDataFetch(t) td.mFetcher.EXPECT().SelectBestShuffled(gomock.Any()).Return(peers) - td.mFetcher.EXPECT().GetLayerData(gomock.Any(), peers, layerID, gomock.Any(), gomock.Any()). - DoAndReturn(func( - _ context.Context, - _ []p2p.Peer, - _ types.LayerID, - okCB func([]byte, p2p.Peer), - errCB func(error, p2p.Peer), - ) error { - for _, peer := range peers { - td.mFetcher.EXPECT().RegisterPeerHashes(peer, gomock.Any()) - okCB(generateLayerContent(t), peer) - } - return nil - }) + for _, peer := range peers { + td.mFetcher.EXPECT().GetLayerData(gomock.Any(), peer, layerID).Return(generateLayerContent(t), nil) + td.mFetcher.EXPECT().RegisterPeerHashes(peer, gomock.Any()) + } return td } t.Run("all peers have layer data", func(t *testing.T) { t.Parallel() td := newTestDataFetchWithMocks(t) - td.mFetcher.EXPECT().GetBallots(gomock.Any(), gomock.Any()).Return(nil).MaxTimes(numPeers) - require.NoError(t, td.PollLayerData(context.TODO(), layerID)) + td.mFetcher.EXPECT().GetBallots(gomock.Any(), gomock.Any()) + require.NoError(t, td.PollLayerData(context.Background(), layerID)) }) - t.Run("ballots failure ignored", func(t *testing.T) { + t.Run("GetBallots failure", func(t *testing.T) { t.Parallel() td := newTestDataFetchWithMocks(t) td.mFetcher.EXPECT().GetBallots(gomock.Any(), gomock.Any()).Return(errUnknown) - td.mFetcher.EXPECT().GetBallots(gomock.Any(), gomock.Any()).Return(nil).MaxTimes(numPeers - 1) - require.NoError(t, td.PollLayerData(context.TODO(), layerID)) - }) - t.Run("blocks failure ignored", func(t *testing.T) { - t.Parallel() - td := newTestDataFetchWithMocks(t) - td.mFetcher.EXPECT().GetBallots(gomock.Any(), gomock.Any()).Return(nil).MaxTimes(numPeers) - require.NoError(t, td.PollLayerData(context.TODO(), layerID)) + require.ErrorIs(t, td.PollLayerData(context.Background(), layerID), errUnknown) }) } +func TestDataFetch_PollLayerData_FailToRequest(t *testing.T) { + t.Parallel() + peers := GenPeers(3) + expectedErr := errors.New("failed to request") + td := newTestDataFetch(t) + td.mFetcher.EXPECT().SelectBestShuffled(gomock.Any()).Return(peers) + for _, peer := range peers { + td.mFetcher.EXPECT().GetLayerData(gomock.Any(), peer, types.LayerID(7)).Return(nil, expectedErr) + } + require.ErrorIs(t, td.PollLayerData(context.Background(), 7), expectedErr) +} + func TestDataFetch_PollLayerData_PeerErrors(t *testing.T) { numPeers := 4 peers := GenPeers(numPeers) - layerID := types.LayerID(10) + lid := types.LayerID(10) t.Run("only one peer has data", func(t *testing.T) { t.Parallel() td := newTestDataFetch(t) td.mFetcher.EXPECT().SelectBestShuffled(gomock.Any()).Return(peers) - td.mFetcher.EXPECT().GetLayerData(gomock.Any(), peers, layerID, gomock.Any(), gomock.Any()). - DoAndReturn(func( - _ context.Context, - _ []p2p.Peer, - _ types.LayerID, - okCB func([]byte, p2p.Peer), - errCB func(error, p2p.Peer), - ) error { - td.mFetcher.EXPECT().RegisterPeerHashes(peers[0], gomock.Any()) - okCB(generateLayerContent(t), peers[0]) - for i := 1; i < numPeers; i++ { - errCB(errors.New("not available"), peers[i]) - } - return nil - }) - td.mFetcher.EXPECT().GetBallots(gomock.Any(), gomock.Any()).Return(nil).MaxTimes(numPeers) - td.mFetcher.EXPECT().GetBlocks(gomock.Any(), gomock.Any()).Return(nil).MaxTimes(numPeers) - require.NoError(t, td.PollLayerData(context.TODO(), layerID)) + td.mFetcher.EXPECT().RegisterPeerHashes(peers[0], gomock.Any()) + td.mFetcher.EXPECT().GetLayerData(gomock.Any(), peers[0], lid).Return(generateLayerContent(t), nil) + td.mFetcher.EXPECT().GetLayerData(gomock.Any(), gomock.Any(), lid).Return(nil, errors.New("na")).Times(numPeers - 1) + td.mFetcher.EXPECT().GetBallots(gomock.Any(), gomock.Any()) + require.NoError(t, td.PollLayerData(context.Background(), lid)) }) t.Run("only one peer has empty layer", func(t *testing.T) { t.Parallel() td := newTestDataFetch(t) td.mFetcher.EXPECT().SelectBestShuffled(gomock.Any()).Return(peers) - td.mFetcher.EXPECT().GetLayerData(gomock.Any(), peers, layerID, gomock.Any(), gomock.Any()). - DoAndReturn(func( - _ context.Context, - _ []p2p.Peer, - _ types.LayerID, - okCB func([]byte, p2p.Peer), - errCB func(error, p2p.Peer), - ) error { - okCB(generateEmptyLayer(t), peers[0]) - for i := 1; i < numPeers; i++ { - errCB(errors.New("not available"), peers[i]) - } - return nil - }) - require.NoError(t, td.PollLayerData(context.TODO(), layerID)) + td.mFetcher.EXPECT().GetLayerData(gomock.Any(), peers[0], lid).Return(generateEmptyLayer(t), nil) + for i := 1; i < numPeers; i++ { + td.mFetcher.EXPECT().RegisterPeerHashes(peers[i], gomock.Any()) + td.mFetcher.EXPECT().GetLayerData(gomock.Any(), peers[i], lid).Return(generateLayerContent(t), nil) + } + td.mFetcher.EXPECT().GetBallots(gomock.Any(), gomock.Any()) + require.NoError(t, td.PollLayerData(context.Background(), lid)) + }) + t.Run("one peer sends malformed data", func(t *testing.T) { + t.Parallel() + td := newTestDataFetch(t) + td.mFetcher.EXPECT().SelectBestShuffled(gomock.Any()).Return(peers) + td.mFetcher.EXPECT().GetLayerData(gomock.Any(), peers[0], lid).Return([]byte("malformed"), nil) + for i := 1; i < numPeers; i++ { + td.mFetcher.EXPECT().RegisterPeerHashes(peers[i], gomock.Any()) + td.mFetcher.EXPECT().GetLayerData(gomock.Any(), peers[i], lid).Return(generateLayerContent(t), nil) + } + td.mFetcher.EXPECT().GetBallots(gomock.Any(), gomock.Any()) + require.NoError(t, td.PollLayerData(context.Background(), lid)) }) } @@ -290,44 +329,34 @@ func TestDataFetch_PollLayerOpinions(t *testing.T) { t.Parallel() td := newTestDataFetch(t) - td.mFetcher.EXPECT().GetLayerOpinions(gomock.Any(), peers, lid, gomock.Any(), gomock.Any()). - DoAndReturn(func( - _ context.Context, - _ []p2p.Peer, - _ types.LayerID, - okCB func([]byte, p2p.Peer), - errCB func(error, p2p.Peer), - ) error { - for i, peer := range peers { - if tc.pErrs[i] != nil { - errCB(tc.pErrs[i], peer) - } else { - if tc.needCert && len(tc.hasCert) > 0 { - p := peer - td.mFetcher.EXPECT().RegisterPeerHashes(p, []types.Hash32{tc.hasCert[i].AsHash32()}) - } - var certified *types.BlockID - if len(tc.hasCert) > 0 { - certified = &tc.hasCert[i] - } - okCB(generateLayerOpinions2(t, certified), peer) - } + for i, peer := range peers { + if tc.pErrs[i] != nil { + td.mFetcher.EXPECT().GetLayerOpinions(gomock.Any(), peer, lid).Return(nil, tc.pErrs[i]) + } else { + if tc.needCert && len(tc.hasCert) > 0 { + td.mFetcher.EXPECT().RegisterPeerHashes(peer, []types.Hash32{tc.hasCert[i].AsHash32()}) } - for _, bid := range tc.queried { - td.mFetcher.EXPECT().GetCert(gomock.Any(), lid, bid, gomock.Any()).DoAndReturn( - func(_ context.Context, _ types.LayerID, bid types.BlockID, peers []p2p.Peer) (*types.Certificate, error) { - require.Len(t, peers, 2) - if tc.cErr == nil { - return &types.Certificate{BlockID: bid}, nil - } else { - return nil, tc.cErr - } - }) + var certified *types.BlockID + if len(tc.hasCert) > 0 { + certified = &tc.hasCert[i] } - return nil - }) + op := generateLayerOpinions(t, certified) + td.mFetcher.EXPECT().GetLayerOpinions(gomock.Any(), peer, lid).Return(op, nil) + } + } + for _, bid := range tc.queried { + td.mFetcher.EXPECT().GetCert(gomock.Any(), lid, bid, gomock.Any()).DoAndReturn( + func(_ context.Context, _ types.LayerID, bid types.BlockID, peers []p2p.Peer) (*types.Certificate, error) { + require.Len(t, peers, 2) + if tc.cErr == nil { + return &types.Certificate{BlockID: bid}, nil + } else { + return nil, tc.cErr + } + }) + } - got, certs, err := td.PollLayerOpinions(context.TODO(), lid, tc.needCert, peers) + got, certs, err := td.PollLayerOpinions(context.Background(), lid, tc.needCert, peers) require.ErrorIs(t, err, tc.err) if err == nil { require.NotEmpty(t, got) @@ -343,6 +372,23 @@ func TestDataFetch_PollLayerOpinions(t *testing.T) { } } +func TestDataFetch_PollLayerOpinions_FailToRequest(t *testing.T) { + peers := []p2p.Peer{"p0"} + td := newTestDataFetch(t) + expectedErr := errors.New("failed to request") + td.mFetcher.EXPECT().GetLayerOpinions(gomock.Any(), peers[0], types.LayerID(10)).Return(nil, expectedErr) + _, _, err := td.PollLayerOpinions(context.Background(), 10, false, peers) + require.ErrorIs(t, err, expectedErr) +} + +func TestDataFetch_PollLayerOpinions_MalformedData(t *testing.T) { + peers := []p2p.Peer{"p0"} + td := newTestDataFetch(t) + td.mFetcher.EXPECT().GetLayerOpinions(gomock.Any(), peers[0], types.LayerID(10)).Return([]byte("malformed"), nil) + _, _, err := td.PollLayerOpinions(context.Background(), 10, false, peers) + require.ErrorContains(t, err, "decode") +} + func TestDataFetch_GetEpochATXs(t *testing.T) { const numPeers = 4 peers := GenPeers(numPeers) diff --git a/syncer/interface.go b/syncer/interface.go index b9eff0770f..ef1e264bc7 100644 --- a/syncer/interface.go +++ b/syncer/interface.go @@ -42,26 +42,9 @@ type fetchLogic interface { // fetcher is the interface to the low-level fetching. type fetcher interface { - GetMaliciousIDs( - context.Context, - []p2p.Peer, - func([]byte, p2p.Peer), - func(error, p2p.Peer), - ) error - GetLayerData( - context.Context, - []p2p.Peer, - types.LayerID, - func([]byte, p2p.Peer), - func(error, p2p.Peer), - ) error - GetLayerOpinions( - context.Context, - []p2p.Peer, - types.LayerID, - func([]byte, p2p.Peer), - func(error, p2p.Peer), - ) error + GetMaliciousIDs(context.Context, p2p.Peer) ([]byte, error) + GetLayerData(context.Context, p2p.Peer, types.LayerID) ([]byte, error) + GetLayerOpinions(context.Context, p2p.Peer, types.LayerID) ([]byte, error) GetCert(context.Context, types.LayerID, types.BlockID, []p2p.Peer) (*types.Certificate, error) GetMalfeasanceProofs(context.Context, []types.NodeID) error diff --git a/syncer/mocks/mocks.go b/syncer/mocks/mocks.go index 84ee9a382d..69862ba6db 100644 --- a/syncer/mocks/mocks.go +++ b/syncer/mocks/mocks.go @@ -421,17 +421,18 @@ func (c *fetchLogicGetEpochATXsCall) DoAndReturn(f func(context.Context, types.E } // GetLayerData mocks base method. -func (m *MockfetchLogic) GetLayerData(arg0 context.Context, arg1 []p2p.Peer, arg2 types.LayerID, arg3 func([]byte, p2p.Peer), arg4 func(error, p2p.Peer)) error { +func (m *MockfetchLogic) GetLayerData(arg0 context.Context, arg1 p2p.Peer, arg2 types.LayerID) ([]byte, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetLayerData", arg0, arg1, arg2, arg3, arg4) - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "GetLayerData", arg0, arg1, arg2) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 } // GetLayerData indicates an expected call of GetLayerData. -func (mr *MockfetchLogicMockRecorder) GetLayerData(arg0, arg1, arg2, arg3, arg4 any) *fetchLogicGetLayerDataCall { +func (mr *MockfetchLogicMockRecorder) GetLayerData(arg0, arg1, arg2 any) *fetchLogicGetLayerDataCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLayerData", reflect.TypeOf((*MockfetchLogic)(nil).GetLayerData), arg0, arg1, arg2, arg3, arg4) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLayerData", reflect.TypeOf((*MockfetchLogic)(nil).GetLayerData), arg0, arg1, arg2) return &fetchLogicGetLayerDataCall{Call: call} } @@ -441,35 +442,36 @@ type fetchLogicGetLayerDataCall struct { } // Return rewrite *gomock.Call.Return -func (c *fetchLogicGetLayerDataCall) Return(arg0 error) *fetchLogicGetLayerDataCall { - c.Call = c.Call.Return(arg0) +func (c *fetchLogicGetLayerDataCall) Return(arg0 []byte, arg1 error) *fetchLogicGetLayerDataCall { + c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *fetchLogicGetLayerDataCall) Do(f func(context.Context, []p2p.Peer, types.LayerID, func([]byte, p2p.Peer), func(error, p2p.Peer)) error) *fetchLogicGetLayerDataCall { +func (c *fetchLogicGetLayerDataCall) Do(f func(context.Context, p2p.Peer, types.LayerID) ([]byte, error)) *fetchLogicGetLayerDataCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *fetchLogicGetLayerDataCall) DoAndReturn(f func(context.Context, []p2p.Peer, types.LayerID, func([]byte, p2p.Peer), func(error, p2p.Peer)) error) *fetchLogicGetLayerDataCall { +func (c *fetchLogicGetLayerDataCall) DoAndReturn(f func(context.Context, p2p.Peer, types.LayerID) ([]byte, error)) *fetchLogicGetLayerDataCall { c.Call = c.Call.DoAndReturn(f) return c } // GetLayerOpinions mocks base method. -func (m *MockfetchLogic) GetLayerOpinions(arg0 context.Context, arg1 []p2p.Peer, arg2 types.LayerID, arg3 func([]byte, p2p.Peer), arg4 func(error, p2p.Peer)) error { +func (m *MockfetchLogic) GetLayerOpinions(arg0 context.Context, arg1 p2p.Peer, arg2 types.LayerID) ([]byte, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetLayerOpinions", arg0, arg1, arg2, arg3, arg4) - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "GetLayerOpinions", arg0, arg1, arg2) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 } // GetLayerOpinions indicates an expected call of GetLayerOpinions. -func (mr *MockfetchLogicMockRecorder) GetLayerOpinions(arg0, arg1, arg2, arg3, arg4 any) *fetchLogicGetLayerOpinionsCall { +func (mr *MockfetchLogicMockRecorder) GetLayerOpinions(arg0, arg1, arg2 any) *fetchLogicGetLayerOpinionsCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLayerOpinions", reflect.TypeOf((*MockfetchLogic)(nil).GetLayerOpinions), arg0, arg1, arg2, arg3, arg4) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLayerOpinions", reflect.TypeOf((*MockfetchLogic)(nil).GetLayerOpinions), arg0, arg1, arg2) return &fetchLogicGetLayerOpinionsCall{Call: call} } @@ -479,19 +481,19 @@ type fetchLogicGetLayerOpinionsCall struct { } // Return rewrite *gomock.Call.Return -func (c *fetchLogicGetLayerOpinionsCall) Return(arg0 error) *fetchLogicGetLayerOpinionsCall { - c.Call = c.Call.Return(arg0) +func (c *fetchLogicGetLayerOpinionsCall) Return(arg0 []byte, arg1 error) *fetchLogicGetLayerOpinionsCall { + c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *fetchLogicGetLayerOpinionsCall) Do(f func(context.Context, []p2p.Peer, types.LayerID, func([]byte, p2p.Peer), func(error, p2p.Peer)) error) *fetchLogicGetLayerOpinionsCall { +func (c *fetchLogicGetLayerOpinionsCall) Do(f func(context.Context, p2p.Peer, types.LayerID) ([]byte, error)) *fetchLogicGetLayerOpinionsCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *fetchLogicGetLayerOpinionsCall) DoAndReturn(f func(context.Context, []p2p.Peer, types.LayerID, func([]byte, p2p.Peer), func(error, p2p.Peer)) error) *fetchLogicGetLayerOpinionsCall { +func (c *fetchLogicGetLayerOpinionsCall) DoAndReturn(f func(context.Context, p2p.Peer, types.LayerID) ([]byte, error)) *fetchLogicGetLayerOpinionsCall { c.Call = c.Call.DoAndReturn(f) return c } @@ -535,17 +537,18 @@ func (c *fetchLogicGetMalfeasanceProofsCall) DoAndReturn(f func(context.Context, } // GetMaliciousIDs mocks base method. -func (m *MockfetchLogic) GetMaliciousIDs(arg0 context.Context, arg1 []p2p.Peer, arg2 func([]byte, p2p.Peer), arg3 func(error, p2p.Peer)) error { +func (m *MockfetchLogic) GetMaliciousIDs(arg0 context.Context, arg1 p2p.Peer) ([]byte, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetMaliciousIDs", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "GetMaliciousIDs", arg0, arg1) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 } // GetMaliciousIDs indicates an expected call of GetMaliciousIDs. -func (mr *MockfetchLogicMockRecorder) GetMaliciousIDs(arg0, arg1, arg2, arg3 any) *fetchLogicGetMaliciousIDsCall { +func (mr *MockfetchLogicMockRecorder) GetMaliciousIDs(arg0, arg1 any) *fetchLogicGetMaliciousIDsCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMaliciousIDs", reflect.TypeOf((*MockfetchLogic)(nil).GetMaliciousIDs), arg0, arg1, arg2, arg3) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMaliciousIDs", reflect.TypeOf((*MockfetchLogic)(nil).GetMaliciousIDs), arg0, arg1) return &fetchLogicGetMaliciousIDsCall{Call: call} } @@ -555,19 +558,19 @@ type fetchLogicGetMaliciousIDsCall struct { } // Return rewrite *gomock.Call.Return -func (c *fetchLogicGetMaliciousIDsCall) Return(arg0 error) *fetchLogicGetMaliciousIDsCall { - c.Call = c.Call.Return(arg0) +func (c *fetchLogicGetMaliciousIDsCall) Return(arg0 []byte, arg1 error) *fetchLogicGetMaliciousIDsCall { + c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *fetchLogicGetMaliciousIDsCall) Do(f func(context.Context, []p2p.Peer, func([]byte, p2p.Peer), func(error, p2p.Peer)) error) *fetchLogicGetMaliciousIDsCall { +func (c *fetchLogicGetMaliciousIDsCall) Do(f func(context.Context, p2p.Peer) ([]byte, error)) *fetchLogicGetMaliciousIDsCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *fetchLogicGetMaliciousIDsCall) DoAndReturn(f func(context.Context, []p2p.Peer, func([]byte, p2p.Peer), func(error, p2p.Peer)) error) *fetchLogicGetMaliciousIDsCall { +func (c *fetchLogicGetMaliciousIDsCall) DoAndReturn(f func(context.Context, p2p.Peer) ([]byte, error)) *fetchLogicGetMaliciousIDsCall { c.Call = c.Call.DoAndReturn(f) return c } @@ -1027,17 +1030,18 @@ func (c *fetcherGetCertCall) DoAndReturn(f func(context.Context, types.LayerID, } // GetLayerData mocks base method. -func (m *Mockfetcher) GetLayerData(arg0 context.Context, arg1 []p2p.Peer, arg2 types.LayerID, arg3 func([]byte, p2p.Peer), arg4 func(error, p2p.Peer)) error { +func (m *Mockfetcher) GetLayerData(arg0 context.Context, arg1 p2p.Peer, arg2 types.LayerID) ([]byte, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetLayerData", arg0, arg1, arg2, arg3, arg4) - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "GetLayerData", arg0, arg1, arg2) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 } // GetLayerData indicates an expected call of GetLayerData. -func (mr *MockfetcherMockRecorder) GetLayerData(arg0, arg1, arg2, arg3, arg4 any) *fetcherGetLayerDataCall { +func (mr *MockfetcherMockRecorder) GetLayerData(arg0, arg1, arg2 any) *fetcherGetLayerDataCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLayerData", reflect.TypeOf((*Mockfetcher)(nil).GetLayerData), arg0, arg1, arg2, arg3, arg4) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLayerData", reflect.TypeOf((*Mockfetcher)(nil).GetLayerData), arg0, arg1, arg2) return &fetcherGetLayerDataCall{Call: call} } @@ -1047,35 +1051,36 @@ type fetcherGetLayerDataCall struct { } // Return rewrite *gomock.Call.Return -func (c *fetcherGetLayerDataCall) Return(arg0 error) *fetcherGetLayerDataCall { - c.Call = c.Call.Return(arg0) +func (c *fetcherGetLayerDataCall) Return(arg0 []byte, arg1 error) *fetcherGetLayerDataCall { + c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *fetcherGetLayerDataCall) Do(f func(context.Context, []p2p.Peer, types.LayerID, func([]byte, p2p.Peer), func(error, p2p.Peer)) error) *fetcherGetLayerDataCall { +func (c *fetcherGetLayerDataCall) Do(f func(context.Context, p2p.Peer, types.LayerID) ([]byte, error)) *fetcherGetLayerDataCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *fetcherGetLayerDataCall) DoAndReturn(f func(context.Context, []p2p.Peer, types.LayerID, func([]byte, p2p.Peer), func(error, p2p.Peer)) error) *fetcherGetLayerDataCall { +func (c *fetcherGetLayerDataCall) DoAndReturn(f func(context.Context, p2p.Peer, types.LayerID) ([]byte, error)) *fetcherGetLayerDataCall { c.Call = c.Call.DoAndReturn(f) return c } // GetLayerOpinions mocks base method. -func (m *Mockfetcher) GetLayerOpinions(arg0 context.Context, arg1 []p2p.Peer, arg2 types.LayerID, arg3 func([]byte, p2p.Peer), arg4 func(error, p2p.Peer)) error { +func (m *Mockfetcher) GetLayerOpinions(arg0 context.Context, arg1 p2p.Peer, arg2 types.LayerID) ([]byte, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetLayerOpinions", arg0, arg1, arg2, arg3, arg4) - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "GetLayerOpinions", arg0, arg1, arg2) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 } // GetLayerOpinions indicates an expected call of GetLayerOpinions. -func (mr *MockfetcherMockRecorder) GetLayerOpinions(arg0, arg1, arg2, arg3, arg4 any) *fetcherGetLayerOpinionsCall { +func (mr *MockfetcherMockRecorder) GetLayerOpinions(arg0, arg1, arg2 any) *fetcherGetLayerOpinionsCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLayerOpinions", reflect.TypeOf((*Mockfetcher)(nil).GetLayerOpinions), arg0, arg1, arg2, arg3, arg4) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLayerOpinions", reflect.TypeOf((*Mockfetcher)(nil).GetLayerOpinions), arg0, arg1, arg2) return &fetcherGetLayerOpinionsCall{Call: call} } @@ -1085,19 +1090,19 @@ type fetcherGetLayerOpinionsCall struct { } // Return rewrite *gomock.Call.Return -func (c *fetcherGetLayerOpinionsCall) Return(arg0 error) *fetcherGetLayerOpinionsCall { - c.Call = c.Call.Return(arg0) +func (c *fetcherGetLayerOpinionsCall) Return(arg0 []byte, arg1 error) *fetcherGetLayerOpinionsCall { + c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *fetcherGetLayerOpinionsCall) Do(f func(context.Context, []p2p.Peer, types.LayerID, func([]byte, p2p.Peer), func(error, p2p.Peer)) error) *fetcherGetLayerOpinionsCall { +func (c *fetcherGetLayerOpinionsCall) Do(f func(context.Context, p2p.Peer, types.LayerID) ([]byte, error)) *fetcherGetLayerOpinionsCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *fetcherGetLayerOpinionsCall) DoAndReturn(f func(context.Context, []p2p.Peer, types.LayerID, func([]byte, p2p.Peer), func(error, p2p.Peer)) error) *fetcherGetLayerOpinionsCall { +func (c *fetcherGetLayerOpinionsCall) DoAndReturn(f func(context.Context, p2p.Peer, types.LayerID) ([]byte, error)) *fetcherGetLayerOpinionsCall { c.Call = c.Call.DoAndReturn(f) return c } @@ -1141,17 +1146,18 @@ func (c *fetcherGetMalfeasanceProofsCall) DoAndReturn(f func(context.Context, [] } // GetMaliciousIDs mocks base method. -func (m *Mockfetcher) GetMaliciousIDs(arg0 context.Context, arg1 []p2p.Peer, arg2 func([]byte, p2p.Peer), arg3 func(error, p2p.Peer)) error { +func (m *Mockfetcher) GetMaliciousIDs(arg0 context.Context, arg1 p2p.Peer) ([]byte, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetMaliciousIDs", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "GetMaliciousIDs", arg0, arg1) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 } // GetMaliciousIDs indicates an expected call of GetMaliciousIDs. -func (mr *MockfetcherMockRecorder) GetMaliciousIDs(arg0, arg1, arg2, arg3 any) *fetcherGetMaliciousIDsCall { +func (mr *MockfetcherMockRecorder) GetMaliciousIDs(arg0, arg1 any) *fetcherGetMaliciousIDsCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMaliciousIDs", reflect.TypeOf((*Mockfetcher)(nil).GetMaliciousIDs), arg0, arg1, arg2, arg3) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMaliciousIDs", reflect.TypeOf((*Mockfetcher)(nil).GetMaliciousIDs), arg0, arg1) return &fetcherGetMaliciousIDsCall{Call: call} } @@ -1161,19 +1167,19 @@ type fetcherGetMaliciousIDsCall struct { } // Return rewrite *gomock.Call.Return -func (c *fetcherGetMaliciousIDsCall) Return(arg0 error) *fetcherGetMaliciousIDsCall { - c.Call = c.Call.Return(arg0) +func (c *fetcherGetMaliciousIDsCall) Return(arg0 []byte, arg1 error) *fetcherGetMaliciousIDsCall { + c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *fetcherGetMaliciousIDsCall) Do(f func(context.Context, []p2p.Peer, func([]byte, p2p.Peer), func(error, p2p.Peer)) error) *fetcherGetMaliciousIDsCall { +func (c *fetcherGetMaliciousIDsCall) Do(f func(context.Context, p2p.Peer) ([]byte, error)) *fetcherGetMaliciousIDsCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *fetcherGetMaliciousIDsCall) DoAndReturn(f func(context.Context, []p2p.Peer, func([]byte, p2p.Peer), func(error, p2p.Peer)) error) *fetcherGetMaliciousIDsCall { +func (c *fetcherGetMaliciousIDsCall) DoAndReturn(f func(context.Context, p2p.Peer) ([]byte, error)) *fetcherGetMaliciousIDsCall { c.Call = c.Call.DoAndReturn(f) return c }