Skip to content

Commit

Permalink
Review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
poszu committed Jun 21, 2024
1 parent 7d0e2c6 commit 72f8ec3
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 39 deletions.
29 changes: 16 additions & 13 deletions activation/e2e/atx_merge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,18 @@ type nipostData struct {

func buildNipost(
nb *activation.NIPostBuilder,
sig *signing.EdSigner,
signer *signing.EdSigner,
publish types.EpochID,
previous, positioning types.ATXID,
) (nipostData, error) {
challenge := wire.NIPostChallengeV2{
PublishEpoch: publish,
PrevATXID: previous,
PositioningATXID: positioning,
postChallenge := &types.NIPostChallenge{
PublishEpoch: publish,
PrevATXID: previous,
PositioningATX: positioning,
}
nipost, err := nb.BuildNIPost(context.Background(), sig, challenge.PublishEpoch, challenge.Hash())
nb.ResetState(sig.NodeID())
challenge := wire.NIPostChallengeToWireV2(postChallenge).Hash()
nipost, err := nb.BuildNIPost(context.Background(), signer, challenge, postChallenge)
nb.ResetState(signer.NodeID())
return nipostData{previous, nipost}, err
}

Expand Down Expand Up @@ -304,6 +305,7 @@ func Test_MarryAndMerge(t *testing.T) {
logger.Named("nipostBuilder"),
poetCfg,
clock,
validator,
activation.WithPoetClients(poetClient),
)
require.NoError(t, err)
Expand Down Expand Up @@ -338,17 +340,18 @@ func Test_MarryAndMerge(t *testing.T) {
eg = errgroup.Group{}
for i, signer := range signers {
eg.Go(func() error {
post, postInfo, err := nb.Proof(context.Background(), signer.NodeID(), types.EmptyHash32[:])
post, postInfo, err := nb.Proof(context.Background(), signer.NodeID(), types.EmptyHash32[:], nil)
if err != nil {
return err
}

challenge := wire.NIPostChallengeV2{
PublishEpoch: publish,
PositioningATXID: goldenATX,
InitialPost: wire.PostToWireV1(post),
postChallenge := &types.NIPostChallenge{
PublishEpoch: publish,
PositioningATX: goldenATX,
InitialPost: post,
}
nipost, err := nb.BuildNIPost(context.Background(), signer, challenge.PublishEpoch, challenge.Hash())
challenge := wire.NIPostChallengeToWireV2(postChallenge).Hash()
nipost, err := nb.BuildNIPost(context.Background(), signer, challenge, postChallenge)
if err != nil {
return err
}
Expand Down
33 changes: 16 additions & 17 deletions activation/handler_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"errors"
"fmt"
"math"
"math/bits"
"slices"
"time"
Expand Down Expand Up @@ -435,7 +436,8 @@ func (h *HandlerV2) validateMarriages(atx *wire.ActivationTxV2) ([]types.NodeID,
if len(atx.Marriages) == 0 {
return nil, nil
}
var marryingIDs []types.NodeID
marryingIDsSet := make(map[types.NodeID]struct{}, len(atx.Marriages))
var marryingIDs []types.NodeID // for deterministic order
for i, m := range atx.Marriages {
var id types.NodeID
if m.ReferenceAtx == types.EmptyATXID {
Expand All @@ -451,6 +453,10 @@ func (h *HandlerV2) validateMarriages(atx *wire.ActivationTxV2) ([]types.NodeID,
if !h.edVerifier.Verify(signing.MARRIAGE, id, atx.SmesherID.Bytes(), m.Signature) {
return nil, fmt.Errorf("invalid marriage[%d] signature", i)
}
if _, ok := marryingIDsSet[id]; ok {
return nil, fmt.Errorf("more than 1 marriage certificate for ID %s", id)
}
marryingIDsSet[id] = struct{}{}
marryingIDs = append(marryingIDs, id)
}
return marryingIDs, nil
Expand Down Expand Up @@ -515,22 +521,21 @@ func (n nipostSizes) minTicks() uint64 {
}

func (n nipostSizes) sumUp() (units uint32, weight uint64, err error) {
var totalEffectiveNumUnits uint32
var totalUnits uint64
var totalWeight uint64
for _, ns := range n {
sum, carry := bits.Add32(totalEffectiveNumUnits, ns.units, 0)
if carry != 0 {
return 0, 0, fmt.Errorf("total units overflow (%d + %d)", totalEffectiveNumUnits, ns.units)
}
totalEffectiveNumUnits = sum
totalUnits += uint64(ns.units)

hi, weight := bits.Mul64(uint64(ns.units), ns.ticks)
if hi != 0 {
return 0, 0, fmt.Errorf("weight overflow (%d * %d)", ns.units, ns.ticks)
}
totalWeight += weight
}
return totalEffectiveNumUnits, totalWeight, nil
if totalUnits > math.MaxUint32 {
return 0, 0, fmt.Errorf("total units overflow: %d", totalUnits)
}
return uint32(totalUnits), totalWeight, nil
}

func (h *HandlerV2) verifyIncludedIDsUniqueness(atx *wire.ActivationTxV2) error {
Expand Down Expand Up @@ -595,7 +600,6 @@ func (h *HandlerV2) syntacticallyValidateDeps(
}
}
nipostSizes[i].addUnits(effectiveNumUnits)

}
}

Expand Down Expand Up @@ -708,7 +712,6 @@ func (h *HandlerV2) syntacticallyValidateDeps(
}

func (h *HandlerV2) checkMalicious(
ctx context.Context,
tx *sql.Tx,
watx *wire.ActivationTxV2,
marrying []types.NodeID,
Expand All @@ -721,7 +724,7 @@ func (h *HandlerV2) checkMalicious(
return true, nil, nil
}

proof, err := h.checkDoubleMarry(tx, watx, marrying)
proof, err := h.checkDoubleMarry(tx, marrying)
if err != nil {
return false, nil, fmt.Errorf("checking double marry: %w", err)
}
Expand All @@ -739,11 +742,7 @@ func (h *HandlerV2) checkMalicious(
return false, nil, nil
}

func (h *HandlerV2) checkDoubleMarry(
tx *sql.Tx,
watx *wire.ActivationTxV2,
marrying []types.NodeID,
) (*mwire.MalfeasanceProof, error) {
func (h *HandlerV2) checkDoubleMarry(tx *sql.Tx, marrying []types.NodeID) (*mwire.MalfeasanceProof, error) {
for _, id := range marrying {
married, err := identities.Married(tx, id)
if err != nil {
Expand Down Expand Up @@ -776,7 +775,7 @@ func (h *HandlerV2) storeAtx(
)
if err := h.cdb.WithTx(ctx, func(tx *sql.Tx) error {
var err error
malicious, proof, err = h.checkMalicious(ctx, tx, watx, marrying)
malicious, proof, err = h.checkMalicious(tx, watx, marrying)
if err != nil {
return fmt.Errorf("check malicious: %w", err)
}
Expand Down
55 changes: 47 additions & 8 deletions activation/handler_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/spacemeshos/go-spacemesh/activation/wire"
"github.com/spacemeshos/go-spacemesh/atxsdata"
"github.com/spacemeshos/go-spacemesh/codec"
"github.com/spacemeshos/go-spacemesh/common/fixture"
"github.com/spacemeshos/go-spacemesh/common/types"
"github.com/spacemeshos/go-spacemesh/datastore"
mwire "github.com/spacemeshos/go-spacemesh/malfeasance/wire"
Expand Down Expand Up @@ -121,8 +122,8 @@ func (h *handlerMocks) expectVerifyNIPoSTs(
}

func (h *handlerMocks) expectStoreAtxV2(atx *wire.ActivationTxV2) {
h.mbeacon.EXPECT().OnAtx(gomock.Any())
h.mtortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any())
h.mbeacon.EXPECT().OnAtx(fixture.MatchId(atx.ID()))
h.mtortoise.EXPECT().OnAtx(atx.PublishEpoch+1, atx.ID(), gomock.Any())
h.mValidator.EXPECT().IsVerifyingFullPost().Return(false)
}

Expand Down Expand Up @@ -183,17 +184,24 @@ func (h *v2TestHandler) createAndProcessInitial(t *testing.T, sig *signing.EdSig
t.Helper()
atx := newInitialATXv2(t, h.handlerMocks.goldenATXID)
atx.Sign(sig)
p, err := h.processInitial(atx)
p, err := h.processInitial(t, atx)
require.NoError(t, err)
require.Nil(t, p)
return atx
}

func (h *v2TestHandler) processInitial(atx *wire.ActivationTxV2) (*mwire.MalfeasanceProof, error) {
func (h *v2TestHandler) processInitial(t *testing.T, atx *wire.ActivationTxV2) (*mwire.MalfeasanceProof, error) {
t.Helper()
h.expectInitialAtxV2(atx)
return h.processATX(context.Background(), peer.ID("peer"), atx, codec.MustEncode(atx), time.Now())
}

func (h *v2TestHandler) processSoloAtx(t *testing.T, atx *wire.ActivationTxV2) (*mwire.MalfeasanceProof, error) {
t.Helper()
h.expectAtxV2(atx)
return h.processATX(context.Background(), peer.ID("peer"), atx, codec.MustEncode(atx), time.Now())
}

func TestHandlerV2_SyntacticallyValidate(t *testing.T) {
t.Parallel()
golden := types.RandomATXID()
Expand Down Expand Up @@ -1234,7 +1242,7 @@ func Test_ValidateMarriages(t *testing.T) {
}
marriage.Sign(sig)

p, err := atxHandler.processInitial(marriage)
p, err := atxHandler.processInitial(t, marriage)
require.NoError(t, err)
require.Nil(t, p)

Expand Down Expand Up @@ -1569,10 +1577,9 @@ func Test_Marriages(t *testing.T) {
Signature: otherSig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()),
},
}

atx.Sign(sig)

p, err := atxHandler.processInitial(atx)
p, err := atxHandler.processInitial(t, atx)
require.NoError(t, err)
require.Nil(t, p)

Expand All @@ -1588,7 +1595,39 @@ func Test_Marriages(t *testing.T) {
require.NoError(t, err)
require.ElementsMatch(t, []types.NodeID{sig.NodeID(), otherSig.NodeID()}, set)
})
t.Run("can't marry twice", func(t *testing.T) {
t.Run("can't marry twice in the same marriage ATX", func(t *testing.T) {
t.Parallel()
atxHandler := newV2TestHandler(t, golden)

otherSig, err := signing.NewEdSigner()
require.NoError(t, err)
othersAtx := atxHandler.createAndProcessInitial(t, otherSig)

othersSecondAtx := newSoloATXv2(t, othersAtx.PublishEpoch+1, othersAtx.ID(), othersAtx.ID())
othersSecondAtx.Sign(otherSig)
_, err = atxHandler.processSoloAtx(t, othersSecondAtx)
require.NoError(t, err)

atx := newInitialATXv2(t, golden)
atx.Marriages = []wire.MarriageCertificate{
{
Signature: sig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()),
},
{
ReferenceAtx: othersAtx.ID(),
Signature: otherSig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()),
},
{
ReferenceAtx: othersSecondAtx.ID(),
Signature: otherSig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()),
},
}
atx.Sign(sig)

_, err = atxHandler.validateMarriages(atx)
require.ErrorContains(t, err, "more than 1 marriage certificate for ID")
})
t.Run("can't marry twice (separate marriages)", func(t *testing.T) {
t.Parallel()
atxHandler := newV2TestHandler(t, golden)

Expand Down
2 changes: 1 addition & 1 deletion activation/wire/wire_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ func (sp *SubPostV2) Root(prevATXs []types.ATXID) []byte {
}
tree.AddLeaf(prevATXs[sp.PrevATXIndex].Bytes())

var leafIndex [8]byte
var leafIndex types.Hash32
binary.LittleEndian.PutUint64(leafIndex[:], sp.MembershipLeafIndex)
tree.AddLeaf(leafIndex[:])

Expand Down
18 changes: 18 additions & 0 deletions common/fixture/atxs.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,21 @@ func ToAtx(t testing.TB, watx *wire.ActivationTxV1) *types.ActivationTx {
atx.TickCount = 1
return atx
}

type idMatcher types.ATXID

func MatchId(id types.ATXID) idMatcher {
return idMatcher(id)
}

func (m idMatcher) Matches(x any) bool {
type hasID interface {
ID() types.ATXID
}
v, ok := x.(hasID)
return ok && v.ID() == types.ATXID(m)
}

func (m idMatcher) String() string {
return "is ATX ID " + types.ATXID(m).String()
}

0 comments on commit 72f8ec3

Please sign in to comment.