Skip to content

Commit

Permalink
atx: skip malicious atx for positioning atx (#4670)
Browse files Browse the repository at this point in the history
## Motivation
part of #4632

## Changes
- favor latest epoch over earlier
- favor own node id if specified (for positioning atx)
- skip atx from malicious id for positioning atx
- cleanup: remove unnecessary api GetPosAtxID
  • Loading branch information
countvonzero committed Jul 10, 2023
1 parent ba92340 commit fc093ca
Show file tree
Hide file tree
Showing 11 changed files with 96 additions and 102 deletions.
5 changes: 3 additions & 2 deletions activation/activation.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/spacemeshos/go-spacemesh/p2p/pubsub"
"github.com/spacemeshos/go-spacemesh/signing"
"github.com/spacemeshos/go-spacemesh/sql"
"github.com/spacemeshos/go-spacemesh/sql/atxs"
)

// PoetConfig is the configuration to interact with the poet server.
Expand Down Expand Up @@ -641,9 +642,9 @@ func (b *Builder) broadcast(ctx context.Context, atx *types.ActivationTx) (int,
return len(buf), nil
}

// GetPositioningAtx returns atx id from the newest epoch with the highest tick height.
// GetPositioningAtx returns atx id with the highest tick height.
func (b *Builder) GetPositioningAtx() (types.ATXID, error) {
id, err := b.atxHandler.GetPosAtxID()
id, err := atxs.GetIDWithMaxHeight(b.cdb, b.nodeID)
if err != nil {
if errors.Is(err, sql.ErrNotFound) {
b.log.With().Info("using golden atx as positioning atx", b.goldenATXID)
Expand Down
14 changes: 1 addition & 13 deletions activation/activation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ func publishAtx(
t.Helper()

publishEpoch := posEpoch + 1
tab.mhdlr.EXPECT().GetPosAtxID().Return(posAtxId, nil)
tab.mclock.EXPECT().LayerToTime(gomock.Any()).DoAndReturn(
func(got types.LayerID) time.Time {
// time.Now() ~= currentLayer
Expand Down Expand Up @@ -280,7 +279,6 @@ func TestBuilder_RestartSmeshing(t *testing.T) {
tab.mclock.EXPECT().AwaitLayer(gomock.Any()).Return(ch).AnyTimes()
tab.mclock.EXPECT().CurrentLayer().Return(types.LayerID(0)).AnyTimes()
tab.mclock.EXPECT().LayerToTime(gomock.Any()).Return(now).AnyTimes()
tab.mhdlr.EXPECT().GetPosAtxID().Return(types.ATXID{1, 2, 3}, nil).AnyTimes()
return tab.Builder
}

Expand Down Expand Up @@ -379,7 +377,6 @@ func TestBuilder_StopSmeshing_OnPoSTError(t *testing.T) {
tab.mclock.EXPECT().AwaitLayer(gomock.Any()).Return(ch).AnyTimes()
tab.mclock.EXPECT().CurrentLayer().Return(types.LayerID(0)).AnyTimes()
tab.mclock.EXPECT().LayerToTime(gomock.Any()).Return(now).AnyTimes()
tab.mhdlr.EXPECT().GetPosAtxID().Return(types.ATXID{1, 2, 3}, nil).AnyTimes()
tab.msync.EXPECT().RegisterForATXSynced().Return(ch).AnyTimes()
require.NoError(t, tab.StartSmeshing(tab.coinbase, PostSetupOpts{}))

Expand Down Expand Up @@ -430,7 +427,6 @@ func TestBuilder_Loop_WaitsOnStaleChallenge(t *testing.T) {
require.NoError(t, err)
require.NoError(t, atxs.Add(tab.cdb, vPrevAtx))

tab.mhdlr.EXPECT().GetPosAtxID().Return(prevAtx.ID(), nil)
tab.mclock.EXPECT().CurrentLayer().Return(currLayer).AnyTimes()
tab.mclock.EXPECT().LayerToTime(gomock.Any()).DoAndReturn(
func(got types.LayerID) time.Time {
Expand Down Expand Up @@ -468,7 +464,6 @@ func TestBuilder_PublishActivationTx_FaultyNet(t *testing.T) {

publishEpoch := posEpoch + 1
tab.mclock.EXPECT().CurrentLayer().DoAndReturn(func() types.LayerID { return currLayer }).AnyTimes()
tab.mhdlr.EXPECT().GetPosAtxID().Return(prevAtx.ID(), nil)
tab.mclock.EXPECT().LayerToTime(gomock.Any()).DoAndReturn(
func(got types.LayerID) time.Time {
// time.Now() ~= currentLayer
Expand Down Expand Up @@ -564,7 +559,6 @@ func TestBuilder_PublishActivationTx_RebuildNIPostWhenTargetEpochPassed(t *testi
func() types.LayerID {
return currLayer
}).AnyTimes()
tab.mhdlr.EXPECT().GetPosAtxID().Return(prevAtx.ID(), nil)
tab.mclock.EXPECT().LayerToTime(gomock.Any()).DoAndReturn(
func(got types.LayerID) time.Time {
// time.Now() ~= currentLayer
Expand Down Expand Up @@ -672,7 +666,7 @@ func TestBuilder_PublishActivationTx_PrevATXWithoutPrevATX(t *testing.T) {
nipost := newNIPostWithChallenge(t, types.HexToHash32("55555"), poetBytes)
posAtx := newAtx(t, otherSigner, challenge, nipost, 2, types.Address{})
SignAndFinalizeAtx(otherSigner, posAtx)
vPosAtx, err := posAtx.Verify(0, 1)
vPosAtx, err := posAtx.Verify(0, 2)
r.NoError(err)
r.NoError(atxs.Add(tab.cdb, vPosAtx))

Expand Down Expand Up @@ -721,7 +715,6 @@ func TestBuilder_PublishActivationTx_PrevATXWithoutPrevATX(t *testing.T) {

atxChan := make(chan struct{})
tab.mhdlr.EXPECT().AwaitAtx(gomock.Any()).Return(atxChan)
tab.mhdlr.EXPECT().GetPosAtxID().Return(vPosAtx.ID(), nil)
tab.mhdlr.EXPECT().UnsubscribeAtx(gomock.Any())

tab.mpub.EXPECT().Publish(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, _ string, msg []byte) error {
Expand Down Expand Up @@ -809,7 +802,6 @@ func TestBuilder_PublishActivationTx_TargetsEpochBasedOnPosAtx(t *testing.T) {

atxChan := make(chan struct{})
tab.mhdlr.EXPECT().AwaitAtx(gomock.Any()).Return(atxChan)
tab.mhdlr.EXPECT().GetPosAtxID().Return(vPosAtx.ID(), nil)
tab.mhdlr.EXPECT().UnsubscribeAtx(gomock.Any())

tab.mpub.EXPECT().Publish(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, _ string, msg []byte) error {
Expand Down Expand Up @@ -852,7 +844,6 @@ func TestBuilder_PublishActivationTx_FailsWhenNIPostBuilderFails(t *testing.T) {
require.NoError(t, atxs.Add(tab.cdb, vPosAtx))

tab.mclock.EXPECT().CurrentLayer().Return(posEpoch.FirstLayer()).AnyTimes()
tab.mhdlr.EXPECT().GetPosAtxID().Return(vPosAtx.ID(), nil)
tab.mclock.EXPECT().LayerToTime(gomock.Any()).DoAndReturn(
func(got types.LayerID) time.Time {
// time.Now() ~= currentLayer
Expand Down Expand Up @@ -921,7 +912,6 @@ func TestBuilder_NIPostPublishRecovery(t *testing.T) {

publishEpoch := posEpoch + 1
tab.mclock.EXPECT().CurrentLayer().DoAndReturn(func() types.LayerID { return currLayer }).AnyTimes()
tab.mhdlr.EXPECT().GetPosAtxID().Return(prevAtx.ID(), nil)
tab.mclock.EXPECT().LayerToTime(gomock.Any()).DoAndReturn(
func(got types.LayerID) time.Time {
// time.Now() ~= currentLayer
Expand Down Expand Up @@ -1029,7 +1019,6 @@ func TestBuilder_RetryPublishActivationTx(t *testing.T) {

currLayer := posEpoch.FirstLayer()
tab.mclock.EXPECT().CurrentLayer().Return(currLayer).AnyTimes()
tab.mhdlr.EXPECT().GetPosAtxID().Return(prevAtx.ID(), nil).AnyTimes()
tab.mclock.EXPECT().LayerToTime(gomock.Any()).DoAndReturn(
func(got types.LayerID) time.Time {
return genesis.Add(layerDuration * time.Duration(got))
Expand Down Expand Up @@ -1174,7 +1163,6 @@ func TestWaitPositioningAtx(t *testing.T) {
}).AnyTimes()

// everything else are stubs that are irrelevant for the test
tab.mhdlr.EXPECT().GetPosAtxID().Return(tab.goldenATXID, nil).AnyTimes()
tab.mpost.EXPECT().LastOpts().Return(&PostSetupOpts{}).AnyTimes()
tab.mpost.EXPECT().CommitmentAtx().Return(tab.goldenATXID, nil).AnyTimes()
index := types.VRFPostIndex(0)
Expand Down
9 changes: 0 additions & 9 deletions activation/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -457,15 +457,6 @@ func (h *Handler) GetEpochAtxs(epochID types.EpochID) (ids []types.ATXID, err er
return
}

// GetPosAtxID returns the best (highest layer id), currently known to this node, pos atx id.
func (h *Handler) GetPosAtxID() (types.ATXID, error) {
id, err := atxs.GetAtxIDWithMaxHeight(h.cdb)
if err != nil {
return types.EmptyATXID, fmt.Errorf("failed to get positioning atx: %w", err)
}
return id, nil
}

// HandleAtxData handles atxs received by sync.
func (h *Handler) HandleAtxData(ctx context.Context, peer p2p.Peer, data []byte) error {
err := h.HandleGossipAtx(ctx, peer, data)
Expand Down
43 changes: 0 additions & 43 deletions activation/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -878,49 +878,6 @@ func BenchmarkNewActivationDb(b *testing.B) {
b.Logf("\n>>> Total time: %v\n\n", time.Since(start))
}

func TestHandler_GetPosAtx(t *testing.T) {
// Arrange
r := require.New(t)

goldenATXID := types.ATXID{2, 3, 4}
atxHdlr := newTestHandler(t, goldenATXID)

currentLayer := types.LayerID(10)

sig, err := signing.NewEdSigner()
r.NoError(err)
otherSig, err := signing.NewEdSigner()
require.NoError(t, err)
coinbase := types.Address{2, 4, 5}

// Act & Assert

// ATX stored should become top ATX
atx1 := newActivationTx(t, sig, 0, types.EmptyATXID, types.EmptyATXID, nil, currentLayer.GetEpoch(), 0, 100, coinbase, 100, &types.NIPost{})
r.NoError(atxs.Add(atxHdlr.cdb, atx1))

id, err := atxHdlr.GetPosAtxID()
r.NoError(err)
r.Equal(atx1.ID(), id)

// higher-layer ATX stored should become new top ATX
atx2 := newActivationTx(t, otherSig, 0, types.EmptyATXID, types.EmptyATXID, nil, currentLayer.GetEpoch()+2, 0, 100, coinbase, 100, &types.NIPost{})
r.NoError(atxs.Add(atxHdlr.cdb, atx2))

id, err = atxHdlr.GetPosAtxID()
r.NoError(err)
r.Equal(atx2.ID(), id)

// lower-layer ATX stored should NOT become new top ATX
atx3 := newActivationTx(t, sig, 0, types.EmptyATXID, types.EmptyATXID, nil, currentLayer.GetEpoch()+1, 0, 100, coinbase, 100, &types.NIPost{})
r.NoError(atxs.Add(atxHdlr.cdb, atx3))

id, err = atxHdlr.GetPosAtxID()
r.NoError(err)
r.NotEqual(atx3.ID(), id)
r.Equal(atx2.ID(), id)
}

func TestHandler_AwaitAtx(t *testing.T) {
// Arrange
r := require.New(t)
Expand Down
1 change: 0 additions & 1 deletion activation/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ type nipostBuilder interface {
}

type atxHandler interface {
GetPosAtxID() (types.ATXID, error)
AwaitAtx(id types.ATXID) chan struct{}
UnsubscribeAtx(id types.ATXID)
}
Expand Down
15 changes: 0 additions & 15 deletions activation/mocks.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion activation/post.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ func (mgr *PostSetupManager) commitmentAtx(ctx context.Context, dataDir string)
// It will use the ATX with the highest height seen by the node and defaults to the goldenATX,
// when no ATXs have yet been published.
func (mgr *PostSetupManager) findCommitmentAtx(ctx context.Context) (types.ATXID, error) {
atx, err := atxs.GetAtxIDWithMaxHeight(mgr.db)
atx, err := atxs.GetIDWithMaxHeight(mgr.db, types.EmptyNodeID)
switch {
case errors.Is(err, sql.ErrNotFound):
mgr.logger.With().Info("using golden atx as commitment atx")
Expand Down
4 changes: 2 additions & 2 deletions checkpoint/recovery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ func validateAndPreserveData(tb testing.TB, db *sql.Database, deps []*types.Veri
mvalidator.EXPECT().NIPostChallenge(&vatx.ActivationTx.NIPostChallenge, cdb, vatx.SmesherID)
}
mvalidator.EXPECT().PositioningAtx(&vatx.PositioningATX, cdb, goldenAtx, vatx.PublishEpoch, layersPerEpoch)
mvalidator.EXPECT().NIPost(gomock.Any(), vatx.SmesherID, gomock.Any(), vatx.NIPost, gomock.Any(), vatx.NumUnits)
mvalidator.EXPECT().NIPost(gomock.Any(), vatx.SmesherID, gomock.Any(), vatx.NIPost, gomock.Any(), vatx.NumUnits).Return(uint64(1111111), nil)
mreceiver.EXPECT().OnAtx(gomock.Any())
mtrtl.EXPECT().OnAtx(gomock.Any())
require.NoError(tb, atxHandler.HandleAtxData(context.Background(), "self", encoded))
Expand Down Expand Up @@ -555,7 +555,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_Still_Initializing(t *testing.T)
require.NoError(t, poets.Add(olddb, types.PoetProofRef(vatx.GetPoetProofRef()), encoded, proofs[i].PoetServiceID, proofs[i].RoundID))
}

commitment, err := atxs.GetAtxIDWithMaxHeight(olddb)
commitment, err := atxs.GetIDWithMaxHeight(olddb, types.EmptyNodeID)
require.NoError(t, err)
require.NoError(t, olddb.Close())

Expand Down
2 changes: 1 addition & 1 deletion datastore/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ func (db *CachedDB) IdentityExists(nodeID types.NodeID) (bool, error) {
}

func (db *CachedDB) MaxHeightAtx() (types.ATXID, error) {
return atxs.GetAtxIDWithMaxHeight(db)
return atxs.GetIDWithMaxHeight(db, types.EmptyNodeID)
}

// Hint marks which DB should be queried for a certain provided hash.
Expand Down
21 changes: 16 additions & 5 deletions sql/atxs/atxs.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,12 +294,12 @@ func Add(db sql.Executor, atx *types.VerifiedActivationTx) error {
return nil
}

// GetAtxIDWithMaxHeight returns the ID of the atx from the last 2 epoch with the highest (or tied for the highest) tick height.
// GetIDWithMaxHeight returns the ID of the atx from the last 2 epoch with the highest (or tied for the highest) tick height.
// it is possible that some poet servers are faster than others and the network ends up having its highest ticked atx still in
// previous epoch and the atxs building on top of it have not been published yet. selecting from the last two epochs to strike
// a balance between being fair to honest miners while not giving unfair advantage for malicious actors who retroactively
// publish a high tick atx many epochs back.
func GetAtxIDWithMaxHeight(db sql.Executor) (types.ATXID, error) {
func GetIDWithMaxHeight(db sql.Executor, pref types.NodeID) (types.ATXID, error) {
var (
rst types.ATXID
max uint64
Expand All @@ -309,13 +309,24 @@ func GetAtxIDWithMaxHeight(db sql.Executor) (types.ATXID, error) {
stmt.ColumnBytes(0, id[:])
height := uint64(stmt.ColumnInt64(1)) + uint64(stmt.ColumnInt64(2))
if height >= max {
max = height
rst = id
var smesher types.NodeID
stmt.ColumnBytes(3, smesher[:])
if height > max {
max = height
rst = id
} else if pref != types.EmptyNodeID && smesher == pref {
// height is equal. prefer atxs from `pref`
rst = id
}
}
return true
}

if rows, err := db.Exec("select id, base_tick_height, tick_count, epoch from atxs where epoch >= (select max(epoch) from atxs)-1;", nil, dec); err != nil {
if rows, err := db.Exec(`
select id, base_tick_height, tick_count, pubkey
from atxs left join identities using(pubkey)
where identities.pubkey is null and epoch >= (select max(epoch) from atxs)-1
order by epoch desc;`, nil, dec); err != nil {
return types.ATXID{}, fmt.Errorf("select positioning atx: %w", err)
} else if rows == 0 {
return types.ATXID{}, sql.ErrNotFound
Expand Down
Loading

0 comments on commit fc093ca

Please sign in to comment.