From 72f8ec30e728cb7f68c654a1c6650fa45242a9d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bartosz=20R=C3=B3=C5=BCa=C5=84ski?= Date: Thu, 20 Jun 2024 19:01:42 +0200 Subject: [PATCH] Review feedback --- activation/e2e/atx_merge_test.go | 29 +++++++++-------- activation/handler_v2.go | 33 ++++++++++--------- activation/handler_v2_test.go | 55 +++++++++++++++++++++++++++----- activation/wire/wire_v2.go | 2 +- common/fixture/atxs.go | 18 +++++++++++ 5 files changed, 98 insertions(+), 39 deletions(-) diff --git a/activation/e2e/atx_merge_test.go b/activation/e2e/atx_merge_test.go index a0cf900b48..86c74fa2ed 100644 --- a/activation/e2e/atx_merge_test.go +++ b/activation/e2e/atx_merge_test.go @@ -66,17 +66,18 @@ type nipostData struct { func buildNipost( nb *activation.NIPostBuilder, - sig *signing.EdSigner, + signer *signing.EdSigner, publish types.EpochID, previous, positioning types.ATXID, ) (nipostData, error) { - challenge := wire.NIPostChallengeV2{ - PublishEpoch: publish, - PrevATXID: previous, - PositioningATXID: positioning, + postChallenge := &types.NIPostChallenge{ + PublishEpoch: publish, + PrevATXID: previous, + PositioningATX: positioning, } - nipost, err := nb.BuildNIPost(context.Background(), sig, challenge.PublishEpoch, challenge.Hash()) - nb.ResetState(sig.NodeID()) + challenge := wire.NIPostChallengeToWireV2(postChallenge).Hash() + nipost, err := nb.BuildNIPost(context.Background(), signer, challenge, postChallenge) + nb.ResetState(signer.NodeID()) return nipostData{previous, nipost}, err } @@ -304,6 +305,7 @@ func Test_MarryAndMerge(t *testing.T) { logger.Named("nipostBuilder"), poetCfg, clock, + validator, activation.WithPoetClients(poetClient), ) require.NoError(t, err) @@ -338,17 +340,18 @@ func Test_MarryAndMerge(t *testing.T) { eg = errgroup.Group{} for i, signer := range signers { eg.Go(func() error { - post, postInfo, err := nb.Proof(context.Background(), signer.NodeID(), types.EmptyHash32[:]) + post, postInfo, err := nb.Proof(context.Background(), signer.NodeID(), types.EmptyHash32[:], nil) if err != nil { return err } - challenge := wire.NIPostChallengeV2{ - PublishEpoch: publish, - PositioningATXID: goldenATX, - InitialPost: wire.PostToWireV1(post), + postChallenge := &types.NIPostChallenge{ + PublishEpoch: publish, + PositioningATX: goldenATX, + InitialPost: post, } - nipost, err := nb.BuildNIPost(context.Background(), signer, challenge.PublishEpoch, challenge.Hash()) + challenge := wire.NIPostChallengeToWireV2(postChallenge).Hash() + nipost, err := nb.BuildNIPost(context.Background(), signer, challenge, postChallenge) if err != nil { return err } diff --git a/activation/handler_v2.go b/activation/handler_v2.go index 20851b3f5a..72b46d7a49 100644 --- a/activation/handler_v2.go +++ b/activation/handler_v2.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "math" "math/bits" "slices" "time" @@ -435,7 +436,8 @@ func (h *HandlerV2) validateMarriages(atx *wire.ActivationTxV2) ([]types.NodeID, if len(atx.Marriages) == 0 { return nil, nil } - var marryingIDs []types.NodeID + marryingIDsSet := make(map[types.NodeID]struct{}, len(atx.Marriages)) + var marryingIDs []types.NodeID // for deterministic order for i, m := range atx.Marriages { var id types.NodeID if m.ReferenceAtx == types.EmptyATXID { @@ -451,6 +453,10 @@ func (h *HandlerV2) validateMarriages(atx *wire.ActivationTxV2) ([]types.NodeID, if !h.edVerifier.Verify(signing.MARRIAGE, id, atx.SmesherID.Bytes(), m.Signature) { return nil, fmt.Errorf("invalid marriage[%d] signature", i) } + if _, ok := marryingIDsSet[id]; ok { + return nil, fmt.Errorf("more than 1 marriage certificate for ID %s", id) + } + marryingIDsSet[id] = struct{}{} marryingIDs = append(marryingIDs, id) } return marryingIDs, nil @@ -515,14 +521,10 @@ func (n nipostSizes) minTicks() uint64 { } func (n nipostSizes) sumUp() (units uint32, weight uint64, err error) { - var totalEffectiveNumUnits uint32 + var totalUnits uint64 var totalWeight uint64 for _, ns := range n { - sum, carry := bits.Add32(totalEffectiveNumUnits, ns.units, 0) - if carry != 0 { - return 0, 0, fmt.Errorf("total units overflow (%d + %d)", totalEffectiveNumUnits, ns.units) - } - totalEffectiveNumUnits = sum + totalUnits += uint64(ns.units) hi, weight := bits.Mul64(uint64(ns.units), ns.ticks) if hi != 0 { @@ -530,7 +532,10 @@ func (n nipostSizes) sumUp() (units uint32, weight uint64, err error) { } totalWeight += weight } - return totalEffectiveNumUnits, totalWeight, nil + if totalUnits > math.MaxUint32 { + return 0, 0, fmt.Errorf("total units overflow: %d", totalUnits) + } + return uint32(totalUnits), totalWeight, nil } func (h *HandlerV2) verifyIncludedIDsUniqueness(atx *wire.ActivationTxV2) error { @@ -595,7 +600,6 @@ func (h *HandlerV2) syntacticallyValidateDeps( } } nipostSizes[i].addUnits(effectiveNumUnits) - } } @@ -708,7 +712,6 @@ func (h *HandlerV2) syntacticallyValidateDeps( } func (h *HandlerV2) checkMalicious( - ctx context.Context, tx *sql.Tx, watx *wire.ActivationTxV2, marrying []types.NodeID, @@ -721,7 +724,7 @@ func (h *HandlerV2) checkMalicious( return true, nil, nil } - proof, err := h.checkDoubleMarry(tx, watx, marrying) + proof, err := h.checkDoubleMarry(tx, marrying) if err != nil { return false, nil, fmt.Errorf("checking double marry: %w", err) } @@ -739,11 +742,7 @@ func (h *HandlerV2) checkMalicious( return false, nil, nil } -func (h *HandlerV2) checkDoubleMarry( - tx *sql.Tx, - watx *wire.ActivationTxV2, - marrying []types.NodeID, -) (*mwire.MalfeasanceProof, error) { +func (h *HandlerV2) checkDoubleMarry(tx *sql.Tx, marrying []types.NodeID) (*mwire.MalfeasanceProof, error) { for _, id := range marrying { married, err := identities.Married(tx, id) if err != nil { @@ -776,7 +775,7 @@ func (h *HandlerV2) storeAtx( ) if err := h.cdb.WithTx(ctx, func(tx *sql.Tx) error { var err error - malicious, proof, err = h.checkMalicious(ctx, tx, watx, marrying) + malicious, proof, err = h.checkMalicious(tx, watx, marrying) if err != nil { return fmt.Errorf("check malicious: %w", err) } diff --git a/activation/handler_v2_test.go b/activation/handler_v2_test.go index ffa53f3fbd..1c84c353fd 100644 --- a/activation/handler_v2_test.go +++ b/activation/handler_v2_test.go @@ -19,6 +19,7 @@ import ( "github.com/spacemeshos/go-spacemesh/activation/wire" "github.com/spacemeshos/go-spacemesh/atxsdata" "github.com/spacemeshos/go-spacemesh/codec" + "github.com/spacemeshos/go-spacemesh/common/fixture" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/datastore" mwire "github.com/spacemeshos/go-spacemesh/malfeasance/wire" @@ -121,8 +122,8 @@ func (h *handlerMocks) expectVerifyNIPoSTs( } func (h *handlerMocks) expectStoreAtxV2(atx *wire.ActivationTxV2) { - h.mbeacon.EXPECT().OnAtx(gomock.Any()) - h.mtortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any()) + h.mbeacon.EXPECT().OnAtx(fixture.MatchId(atx.ID())) + h.mtortoise.EXPECT().OnAtx(atx.PublishEpoch+1, atx.ID(), gomock.Any()) h.mValidator.EXPECT().IsVerifyingFullPost().Return(false) } @@ -183,17 +184,24 @@ func (h *v2TestHandler) createAndProcessInitial(t *testing.T, sig *signing.EdSig t.Helper() atx := newInitialATXv2(t, h.handlerMocks.goldenATXID) atx.Sign(sig) - p, err := h.processInitial(atx) + p, err := h.processInitial(t, atx) require.NoError(t, err) require.Nil(t, p) return atx } -func (h *v2TestHandler) processInitial(atx *wire.ActivationTxV2) (*mwire.MalfeasanceProof, error) { +func (h *v2TestHandler) processInitial(t *testing.T, atx *wire.ActivationTxV2) (*mwire.MalfeasanceProof, error) { + t.Helper() h.expectInitialAtxV2(atx) return h.processATX(context.Background(), peer.ID("peer"), atx, codec.MustEncode(atx), time.Now()) } +func (h *v2TestHandler) processSoloAtx(t *testing.T, atx *wire.ActivationTxV2) (*mwire.MalfeasanceProof, error) { + t.Helper() + h.expectAtxV2(atx) + return h.processATX(context.Background(), peer.ID("peer"), atx, codec.MustEncode(atx), time.Now()) +} + func TestHandlerV2_SyntacticallyValidate(t *testing.T) { t.Parallel() golden := types.RandomATXID() @@ -1234,7 +1242,7 @@ func Test_ValidateMarriages(t *testing.T) { } marriage.Sign(sig) - p, err := atxHandler.processInitial(marriage) + p, err := atxHandler.processInitial(t, marriage) require.NoError(t, err) require.Nil(t, p) @@ -1569,10 +1577,9 @@ func Test_Marriages(t *testing.T) { Signature: otherSig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), }, } - atx.Sign(sig) - p, err := atxHandler.processInitial(atx) + p, err := atxHandler.processInitial(t, atx) require.NoError(t, err) require.Nil(t, p) @@ -1588,7 +1595,39 @@ func Test_Marriages(t *testing.T) { require.NoError(t, err) require.ElementsMatch(t, []types.NodeID{sig.NodeID(), otherSig.NodeID()}, set) }) - t.Run("can't marry twice", func(t *testing.T) { + t.Run("can't marry twice in the same marriage ATX", func(t *testing.T) { + t.Parallel() + atxHandler := newV2TestHandler(t, golden) + + otherSig, err := signing.NewEdSigner() + require.NoError(t, err) + othersAtx := atxHandler.createAndProcessInitial(t, otherSig) + + othersSecondAtx := newSoloATXv2(t, othersAtx.PublishEpoch+1, othersAtx.ID(), othersAtx.ID()) + othersSecondAtx.Sign(otherSig) + _, err = atxHandler.processSoloAtx(t, othersSecondAtx) + require.NoError(t, err) + + atx := newInitialATXv2(t, golden) + atx.Marriages = []wire.MarriageCertificate{ + { + Signature: sig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + { + ReferenceAtx: othersAtx.ID(), + Signature: otherSig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + { + ReferenceAtx: othersSecondAtx.ID(), + Signature: otherSig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + } + atx.Sign(sig) + + _, err = atxHandler.validateMarriages(atx) + require.ErrorContains(t, err, "more than 1 marriage certificate for ID") + }) + t.Run("can't marry twice (separate marriages)", func(t *testing.T) { t.Parallel() atxHandler := newV2TestHandler(t, golden) diff --git a/activation/wire/wire_v2.go b/activation/wire/wire_v2.go index c844ae7809..a320351152 100644 --- a/activation/wire/wire_v2.go +++ b/activation/wire/wire_v2.go @@ -229,7 +229,7 @@ func (sp *SubPostV2) Root(prevATXs []types.ATXID) []byte { } tree.AddLeaf(prevATXs[sp.PrevATXIndex].Bytes()) - var leafIndex [8]byte + var leafIndex types.Hash32 binary.LittleEndian.PutUint64(leafIndex[:], sp.MembershipLeafIndex) tree.AddLeaf(leafIndex[:]) diff --git a/common/fixture/atxs.go b/common/fixture/atxs.go index af3ead08f2..74226aca21 100644 --- a/common/fixture/atxs.go +++ b/common/fixture/atxs.go @@ -77,3 +77,21 @@ func ToAtx(t testing.TB, watx *wire.ActivationTxV1) *types.ActivationTx { atx.TickCount = 1 return atx } + +type idMatcher types.ATXID + +func MatchId(id types.ATXID) idMatcher { + return idMatcher(id) +} + +func (m idMatcher) Matches(x any) bool { + type hasID interface { + ID() types.ATXID + } + v, ok := x.(hasID) + return ok && v.ID() == types.ATXID(m) +} + +func (m idMatcher) String() string { + return "is ATX ID " + types.ATXID(m).String() +}