diff --git a/proposals/handler.go b/proposals/handler.go index 65dd60c1a8..387708593f 100644 --- a/proposals/handler.go +++ b/proposals/handler.go @@ -234,6 +234,11 @@ func (h *Handler) checkBallotSyntacticValidity(ctx context.Context, b *types.Bal return err } + if err := h.checkVotesConsistency(ctx, b); err != nil { + h.logger.WithContext(ctx).With().Warning("ballot votes consistency check failed", log.Err(err)) + return err + } + if eligible, err := h.validator.CheckEligibility(ctx, b); err != nil || !eligible { h.logger.WithContext(ctx).With().Warning("ballot eligibility check failed", log.Err(err)) return errNotEligible @@ -271,14 +276,8 @@ func (h *Handler) checkBallotDataIntegrity(ctx context.Context, b *types.Ballot) } set[atx] = struct{}{} } - } else { - if b.EpochData != nil { - return errUnexpectedEpochData - } - } - - if err := h.checkVotesConsistency(ctx, b); err != nil { - return err + } else if b.EpochData != nil { + return errUnexpectedEpochData } return nil } diff --git a/proposals/handler_test.go b/proposals/handler_test.go index d3a1e8044f..adc33c5f09 100644 --- a/proposals/handler_test.go +++ b/proposals/handler_test.go @@ -231,6 +231,9 @@ func TestBallot_BallotDoubleVotedWithinHdist(t *testing.T) { require.GreaterOrEqual(t, 2, len(b.Votes.Support)) data := encodeBallot(t, b) th.mm.EXPECT().HasBallot(b.ID()).Return(false).Times(1) + th.mf.EXPECT().GetBallots(gomock.Any(), []types.BallotID{b.Votes.Base, b.RefBallot}).Return(nil).Times(1) + th.mf.EXPECT().GetAtxs(gomock.Any(), types.ATXIDList{b.AtxID}).Return(nil).Times(1) + th.mf.EXPECT().GetBlocks(gomock.Any(), b.Votes.Support).Return(nil).Times(1) cutoff := b.LayerIndex.Sub(th.cfg.Hdist) th.mm.EXPECT().GetBlockLayer(b.Votes.Support[0]).Return(cutoff.Add(1), nil) th.mm.EXPECT().GetBlockLayer(b.Votes.Support[1]).Return(cutoff.Add(1), nil) @@ -245,6 +248,9 @@ func TestBallot_BallotDoubleVotedWithinHdist_LyrBfrHdist(t *testing.T) { require.GreaterOrEqual(t, 2, len(b.Votes.Support)) data := encodeBallot(t, b) th.mm.EXPECT().HasBallot(b.ID()).Return(false).Times(1) + th.mf.EXPECT().GetBallots(gomock.Any(), []types.BallotID{b.Votes.Base, b.RefBallot}).Return(nil).Times(1) + th.mf.EXPECT().GetAtxs(gomock.Any(), types.ATXIDList{b.AtxID}).Return(nil).Times(1) + th.mf.EXPECT().GetBlocks(gomock.Any(), b.Votes.Support).Return(nil).Times(1) th.mm.EXPECT().GetBlockLayer(b.Votes.Support[0]).Return(b.LayerIndex.Sub(1), nil) th.mm.EXPECT().GetBlockLayer(b.Votes.Support[1]).Return(b.LayerIndex.Sub(1), nil) th.mm.EXPECT().SetIdentityMalicious(b.SmesherID()).Return(nil) @@ -257,6 +263,9 @@ func TestBallot_BallotDoubleVotedWithinHdist_SetMaliciousError(t *testing.T) { require.GreaterOrEqual(t, 2, len(b.Votes.Support)) data := encodeBallot(t, b) th.mm.EXPECT().HasBallot(b.ID()).Return(false).Times(1) + th.mf.EXPECT().GetBallots(gomock.Any(), []types.BallotID{b.Votes.Base, b.RefBallot}).Return(nil).Times(1) + th.mf.EXPECT().GetAtxs(gomock.Any(), types.ATXIDList{b.AtxID}).Return(nil).Times(1) + th.mf.EXPECT().GetBlocks(gomock.Any(), b.Votes.Support).Return(nil).Times(1) cutoff := b.LayerIndex.Sub(th.cfg.Hdist) th.mm.EXPECT().GetBlockLayer(b.Votes.Support[0]).Return(cutoff, nil) th.mm.EXPECT().GetBlockLayer(b.Votes.Support[1]).Return(cutoff, nil) @@ -270,13 +279,13 @@ func TestBallot_BallotDoubleVotedOutsideHdist(t *testing.T) { b := createBallot(t) data := encodeBallot(t, b) th.mm.EXPECT().HasBallot(b.ID()).Return(false).Times(1) + th.mf.EXPECT().GetBallots(gomock.Any(), []types.BallotID{b.Votes.Base, b.RefBallot}).Return(nil).Times(1) + th.mf.EXPECT().GetAtxs(gomock.Any(), types.ATXIDList{b.AtxID}).Return(nil).Times(1) + th.mf.EXPECT().GetBlocks(gomock.Any(), b.Votes.Support).Return(nil).Times(1) cutoff := b.LayerIndex.Sub(th.cfg.Hdist) for _, bid := range b.Votes.Support { th.mm.EXPECT().GetBlockLayer(bid).Return(cutoff.Sub(1), nil) } - th.mf.EXPECT().GetBallots(gomock.Any(), []types.BallotID{b.Votes.Base, b.RefBallot}).Return(nil).Times(1) - th.mf.EXPECT().GetAtxs(gomock.Any(), types.ATXIDList{b.AtxID}).Return(nil).Times(1) - th.mf.EXPECT().GetBlocks(gomock.Any(), b.Votes.Support).Return(nil).Times(1) th.mv.EXPECT().CheckEligibility(gomock.Any(), gomock.Any()).DoAndReturn( func(_ context.Context, ballot *types.Ballot) (bool, error) { require.Equal(t, b.ID(), ballot.ID()) @@ -297,6 +306,9 @@ func TestBallot_ConflictingForAndAgainst(t *testing.T) { b = signAndInit(t, b) data := encodeBallot(t, b) th.mm.EXPECT().HasBallot(b.ID()).Return(false).Times(1) + th.mf.EXPECT().GetBallots(gomock.Any(), []types.BallotID{b.Votes.Base, b.RefBallot}).Return(nil).Times(1) + th.mf.EXPECT().GetAtxs(gomock.Any(), types.ATXIDList{b.AtxID}).Return(nil).Times(1) + th.mf.EXPECT().GetBlocks(gomock.Any(), append(b.Votes.Support, b.Votes.Against...)).Return(nil).Times(1) for i, bid := range b.Votes.Support { th.mm.EXPECT().GetBlockLayer(bid).Return(b.LayerIndex.Sub(uint32(i+1)), nil) } @@ -310,6 +322,9 @@ func TestBallot_ConflictingForAndAbstain(t *testing.T) { b = signAndInit(t, b) data := encodeBallot(t, b) th.mm.EXPECT().HasBallot(b.ID()).Return(false).Times(1) + th.mf.EXPECT().GetBallots(gomock.Any(), []types.BallotID{b.Votes.Base, b.RefBallot}).Return(nil).Times(1) + th.mf.EXPECT().GetAtxs(gomock.Any(), types.ATXIDList{b.AtxID}).Return(nil).Times(1) + th.mf.EXPECT().GetBlocks(gomock.Any(), b.Votes.Support).Return(nil).Times(1) for i, bid := range b.Votes.Support { th.mm.EXPECT().GetBlockLayer(bid).Return(b.LayerIndex.Sub(uint32(i+1)), nil) } @@ -326,6 +341,9 @@ func TestBallot_ConflictingAgainstAndAbstain(t *testing.T) { b = signAndInit(t, b) data := encodeBallot(t, b) th.mm.EXPECT().HasBallot(b.ID()).Return(false).Times(1) + th.mf.EXPECT().GetBallots(gomock.Any(), []types.BallotID{b.Votes.Base, b.RefBallot}).Return(nil).Times(1) + th.mf.EXPECT().GetAtxs(gomock.Any(), types.ATXIDList{b.AtxID}).Return(nil).Times(1) + th.mf.EXPECT().GetBlocks(gomock.Any(), append(b.Votes.Support, b.Votes.Against...)).Return(nil).Times(1) for i, bid := range b.Votes.Against { th.mm.EXPECT().GetBlockLayer(bid).Return(b.LayerIndex.Sub(uint32(i+1)), nil) } @@ -340,6 +358,9 @@ func TestBallot_ExceedMaxExceptions(t *testing.T) { b = signAndInit(t, b) data := encodeBallot(t, b) th.mm.EXPECT().HasBallot(b.ID()).Return(false).Times(1) + th.mf.EXPECT().GetBallots(gomock.Any(), []types.BallotID{b.Votes.Base, b.RefBallot}).Return(nil).Times(1) + th.mf.EXPECT().GetAtxs(gomock.Any(), types.ATXIDList{b.AtxID}).Return(nil).Times(1) + th.mf.EXPECT().GetBlocks(gomock.Any(), b.Votes.Support).Return(nil).Times(1) for i, bid := range b.Votes.Support { th.mm.EXPECT().GetBlockLayer(bid).Return(b.LayerIndex.Sub(uint32(i+1)), nil) } @@ -351,9 +372,6 @@ func TestBallot_BallotsNotAvailable(t *testing.T) { b := createBallot(t) data := encodeBallot(t, b) th.mm.EXPECT().HasBallot(b.ID()).Return(false).Times(1) - for i, bid := range b.Votes.Support { - th.mm.EXPECT().GetBlockLayer(bid).Return(b.LayerIndex.Sub(uint32(i+1)), nil) - } errUnknown := errors.New("unknown") th.mf.EXPECT().GetBallots(gomock.Any(), []types.BallotID{b.Votes.Base, b.RefBallot}).Return(errUnknown).Times(1) require.ErrorIs(t, th.HandleBallotData(context.TODO(), data), errUnknown) @@ -364,9 +382,6 @@ func TestBallot_ATXsNotAvailable(t *testing.T) { b := createBallot(t) data := encodeBallot(t, b) th.mm.EXPECT().HasBallot(b.ID()).Return(false).Times(1) - for i, bid := range b.Votes.Support { - th.mm.EXPECT().GetBlockLayer(bid).Return(b.LayerIndex.Sub(uint32(i+1)), nil) - } th.mf.EXPECT().GetBallots(gomock.Any(), []types.BallotID{b.Votes.Base, b.RefBallot}).Return(nil).Times(1) errUnknown := errors.New("unknown") th.mf.EXPECT().GetAtxs(gomock.Any(), types.ATXIDList{b.AtxID}).Return(errUnknown).Times(1) @@ -378,9 +393,6 @@ func TestBallot_BlocksNotAvailable(t *testing.T) { b := createBallot(t) data := encodeBallot(t, b) th.mm.EXPECT().HasBallot(b.ID()).Return(false).Times(1) - for i, bid := range b.Votes.Support { - th.mm.EXPECT().GetBlockLayer(bid).Return(b.LayerIndex.Sub(uint32(i+1)), nil) - } th.mf.EXPECT().GetBallots(gomock.Any(), []types.BallotID{b.Votes.Base, b.RefBallot}).Return(nil).Times(1) th.mf.EXPECT().GetAtxs(gomock.Any(), types.ATXIDList{b.AtxID}).Return(nil).Times(1) errUnknown := errors.New("unknown")