diff --git a/activation/activation.go b/activation/activation.go index fffcee3490..277fcec760 100644 --- a/activation/activation.go +++ b/activation/activation.go @@ -24,6 +24,7 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p/pubsub" "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/atxs" ) // PoetConfig is the configuration to interact with the poet server. @@ -641,9 +642,9 @@ func (b *Builder) broadcast(ctx context.Context, atx *types.ActivationTx) (int, return len(buf), nil } -// GetPositioningAtx returns atx id from the newest epoch with the highest tick height. +// GetPositioningAtx returns atx id with the highest tick height. func (b *Builder) GetPositioningAtx() (types.ATXID, error) { - id, err := b.atxHandler.GetPosAtxID() + id, err := atxs.GetIDWithMaxHeight(b.cdb, b.nodeID) if err != nil { if errors.Is(err, sql.ErrNotFound) { b.log.With().Info("using golden atx as positioning atx", b.goldenATXID) diff --git a/activation/activation_test.go b/activation/activation_test.go index f9d02b447e..51b9eb4617 100644 --- a/activation/activation_test.go +++ b/activation/activation_test.go @@ -177,7 +177,6 @@ func publishAtx( t.Helper() publishEpoch := posEpoch + 1 - tab.mhdlr.EXPECT().GetPosAtxID().Return(posAtxId, nil) tab.mclock.EXPECT().LayerToTime(gomock.Any()).DoAndReturn( func(got types.LayerID) time.Time { // time.Now() ~= currentLayer @@ -280,7 +279,6 @@ func TestBuilder_RestartSmeshing(t *testing.T) { tab.mclock.EXPECT().AwaitLayer(gomock.Any()).Return(ch).AnyTimes() tab.mclock.EXPECT().CurrentLayer().Return(types.LayerID(0)).AnyTimes() tab.mclock.EXPECT().LayerToTime(gomock.Any()).Return(now).AnyTimes() - tab.mhdlr.EXPECT().GetPosAtxID().Return(types.ATXID{1, 2, 3}, nil).AnyTimes() return tab.Builder } @@ -379,7 +377,6 @@ func TestBuilder_StopSmeshing_OnPoSTError(t *testing.T) { tab.mclock.EXPECT().AwaitLayer(gomock.Any()).Return(ch).AnyTimes() tab.mclock.EXPECT().CurrentLayer().Return(types.LayerID(0)).AnyTimes() tab.mclock.EXPECT().LayerToTime(gomock.Any()).Return(now).AnyTimes() - tab.mhdlr.EXPECT().GetPosAtxID().Return(types.ATXID{1, 2, 3}, nil).AnyTimes() tab.msync.EXPECT().RegisterForATXSynced().Return(ch).AnyTimes() require.NoError(t, tab.StartSmeshing(tab.coinbase, PostSetupOpts{})) @@ -430,7 +427,6 @@ func TestBuilder_Loop_WaitsOnStaleChallenge(t *testing.T) { require.NoError(t, err) require.NoError(t, atxs.Add(tab.cdb, vPrevAtx)) - tab.mhdlr.EXPECT().GetPosAtxID().Return(prevAtx.ID(), nil) tab.mclock.EXPECT().CurrentLayer().Return(currLayer).AnyTimes() tab.mclock.EXPECT().LayerToTime(gomock.Any()).DoAndReturn( func(got types.LayerID) time.Time { @@ -468,7 +464,6 @@ func TestBuilder_PublishActivationTx_FaultyNet(t *testing.T) { publishEpoch := posEpoch + 1 tab.mclock.EXPECT().CurrentLayer().DoAndReturn(func() types.LayerID { return currLayer }).AnyTimes() - tab.mhdlr.EXPECT().GetPosAtxID().Return(prevAtx.ID(), nil) tab.mclock.EXPECT().LayerToTime(gomock.Any()).DoAndReturn( func(got types.LayerID) time.Time { // time.Now() ~= currentLayer @@ -564,7 +559,6 @@ func TestBuilder_PublishActivationTx_RebuildNIPostWhenTargetEpochPassed(t *testi func() types.LayerID { return currLayer }).AnyTimes() - tab.mhdlr.EXPECT().GetPosAtxID().Return(prevAtx.ID(), nil) tab.mclock.EXPECT().LayerToTime(gomock.Any()).DoAndReturn( func(got types.LayerID) time.Time { // time.Now() ~= currentLayer @@ -672,7 +666,7 @@ func TestBuilder_PublishActivationTx_PrevATXWithoutPrevATX(t *testing.T) { nipost := newNIPostWithChallenge(t, types.HexToHash32("55555"), poetBytes) posAtx := newAtx(t, otherSigner, challenge, nipost, 2, types.Address{}) SignAndFinalizeAtx(otherSigner, posAtx) - vPosAtx, err := posAtx.Verify(0, 1) + vPosAtx, err := posAtx.Verify(0, 2) r.NoError(err) r.NoError(atxs.Add(tab.cdb, vPosAtx)) @@ -721,7 +715,6 @@ func TestBuilder_PublishActivationTx_PrevATXWithoutPrevATX(t *testing.T) { atxChan := make(chan struct{}) tab.mhdlr.EXPECT().AwaitAtx(gomock.Any()).Return(atxChan) - tab.mhdlr.EXPECT().GetPosAtxID().Return(vPosAtx.ID(), nil) tab.mhdlr.EXPECT().UnsubscribeAtx(gomock.Any()) tab.mpub.EXPECT().Publish(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, _ string, msg []byte) error { @@ -809,7 +802,6 @@ func TestBuilder_PublishActivationTx_TargetsEpochBasedOnPosAtx(t *testing.T) { atxChan := make(chan struct{}) tab.mhdlr.EXPECT().AwaitAtx(gomock.Any()).Return(atxChan) - tab.mhdlr.EXPECT().GetPosAtxID().Return(vPosAtx.ID(), nil) tab.mhdlr.EXPECT().UnsubscribeAtx(gomock.Any()) tab.mpub.EXPECT().Publish(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, _ string, msg []byte) error { @@ -852,7 +844,6 @@ func TestBuilder_PublishActivationTx_FailsWhenNIPostBuilderFails(t *testing.T) { require.NoError(t, atxs.Add(tab.cdb, vPosAtx)) tab.mclock.EXPECT().CurrentLayer().Return(posEpoch.FirstLayer()).AnyTimes() - tab.mhdlr.EXPECT().GetPosAtxID().Return(vPosAtx.ID(), nil) tab.mclock.EXPECT().LayerToTime(gomock.Any()).DoAndReturn( func(got types.LayerID) time.Time { // time.Now() ~= currentLayer @@ -921,7 +912,6 @@ func TestBuilder_NIPostPublishRecovery(t *testing.T) { publishEpoch := posEpoch + 1 tab.mclock.EXPECT().CurrentLayer().DoAndReturn(func() types.LayerID { return currLayer }).AnyTimes() - tab.mhdlr.EXPECT().GetPosAtxID().Return(prevAtx.ID(), nil) tab.mclock.EXPECT().LayerToTime(gomock.Any()).DoAndReturn( func(got types.LayerID) time.Time { // time.Now() ~= currentLayer @@ -1029,7 +1019,6 @@ func TestBuilder_RetryPublishActivationTx(t *testing.T) { currLayer := posEpoch.FirstLayer() tab.mclock.EXPECT().CurrentLayer().Return(currLayer).AnyTimes() - tab.mhdlr.EXPECT().GetPosAtxID().Return(prevAtx.ID(), nil).AnyTimes() tab.mclock.EXPECT().LayerToTime(gomock.Any()).DoAndReturn( func(got types.LayerID) time.Time { return genesis.Add(layerDuration * time.Duration(got)) @@ -1174,7 +1163,6 @@ func TestWaitPositioningAtx(t *testing.T) { }).AnyTimes() // everything else are stubs that are irrelevant for the test - tab.mhdlr.EXPECT().GetPosAtxID().Return(tab.goldenATXID, nil).AnyTimes() tab.mpost.EXPECT().LastOpts().Return(&PostSetupOpts{}).AnyTimes() tab.mpost.EXPECT().CommitmentAtx().Return(tab.goldenATXID, nil).AnyTimes() index := types.VRFPostIndex(0) diff --git a/activation/handler.go b/activation/handler.go index 88b88bedef..068d5c89e8 100644 --- a/activation/handler.go +++ b/activation/handler.go @@ -457,15 +457,6 @@ func (h *Handler) GetEpochAtxs(epochID types.EpochID) (ids []types.ATXID, err er return } -// GetPosAtxID returns the best (highest layer id), currently known to this node, pos atx id. -func (h *Handler) GetPosAtxID() (types.ATXID, error) { - id, err := atxs.GetAtxIDWithMaxHeight(h.cdb) - if err != nil { - return types.EmptyATXID, fmt.Errorf("failed to get positioning atx: %w", err) - } - return id, nil -} - // HandleAtxData handles atxs received by sync. func (h *Handler) HandleAtxData(ctx context.Context, peer p2p.Peer, data []byte) error { err := h.HandleGossipAtx(ctx, peer, data) diff --git a/activation/handler_test.go b/activation/handler_test.go index 246899ae12..9ade20c959 100644 --- a/activation/handler_test.go +++ b/activation/handler_test.go @@ -878,49 +878,6 @@ func BenchmarkNewActivationDb(b *testing.B) { b.Logf("\n>>> Total time: %v\n\n", time.Since(start)) } -func TestHandler_GetPosAtx(t *testing.T) { - // Arrange - r := require.New(t) - - goldenATXID := types.ATXID{2, 3, 4} - atxHdlr := newTestHandler(t, goldenATXID) - - currentLayer := types.LayerID(10) - - sig, err := signing.NewEdSigner() - r.NoError(err) - otherSig, err := signing.NewEdSigner() - require.NoError(t, err) - coinbase := types.Address{2, 4, 5} - - // Act & Assert - - // ATX stored should become top ATX - atx1 := newActivationTx(t, sig, 0, types.EmptyATXID, types.EmptyATXID, nil, currentLayer.GetEpoch(), 0, 100, coinbase, 100, &types.NIPost{}) - r.NoError(atxs.Add(atxHdlr.cdb, atx1)) - - id, err := atxHdlr.GetPosAtxID() - r.NoError(err) - r.Equal(atx1.ID(), id) - - // higher-layer ATX stored should become new top ATX - atx2 := newActivationTx(t, otherSig, 0, types.EmptyATXID, types.EmptyATXID, nil, currentLayer.GetEpoch()+2, 0, 100, coinbase, 100, &types.NIPost{}) - r.NoError(atxs.Add(atxHdlr.cdb, atx2)) - - id, err = atxHdlr.GetPosAtxID() - r.NoError(err) - r.Equal(atx2.ID(), id) - - // lower-layer ATX stored should NOT become new top ATX - atx3 := newActivationTx(t, sig, 0, types.EmptyATXID, types.EmptyATXID, nil, currentLayer.GetEpoch()+1, 0, 100, coinbase, 100, &types.NIPost{}) - r.NoError(atxs.Add(atxHdlr.cdb, atx3)) - - id, err = atxHdlr.GetPosAtxID() - r.NoError(err) - r.NotEqual(atx3.ID(), id) - r.Equal(atx2.ID(), id) -} - func TestHandler_AwaitAtx(t *testing.T) { // Arrange r := require.New(t) diff --git a/activation/interface.go b/activation/interface.go index e91007ac7e..95a526bf00 100644 --- a/activation/interface.go +++ b/activation/interface.go @@ -47,7 +47,6 @@ type nipostBuilder interface { } type atxHandler interface { - GetPosAtxID() (types.ATXID, error) AwaitAtx(id types.ATXID) chan struct{} UnsubscribeAtx(id types.ATXID) } diff --git a/activation/mocks.go b/activation/mocks.go index 167d8c6e1e..c983fc15b8 100644 --- a/activation/mocks.go +++ b/activation/mocks.go @@ -419,21 +419,6 @@ func (mr *MockatxHandlerMockRecorder) AwaitAtx(id interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AwaitAtx", reflect.TypeOf((*MockatxHandler)(nil).AwaitAtx), id) } -// GetPosAtxID mocks base method. -func (m *MockatxHandler) GetPosAtxID() (types.ATXID, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetPosAtxID") - ret0, _ := ret[0].(types.ATXID) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetPosAtxID indicates an expected call of GetPosAtxID. -func (mr *MockatxHandlerMockRecorder) GetPosAtxID() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPosAtxID", reflect.TypeOf((*MockatxHandler)(nil).GetPosAtxID)) -} - // UnsubscribeAtx mocks base method. func (m *MockatxHandler) UnsubscribeAtx(id types.ATXID) { m.ctrl.T.Helper() diff --git a/activation/post.go b/activation/post.go index e319573a02..141f5efe6c 100644 --- a/activation/post.go +++ b/activation/post.go @@ -408,7 +408,7 @@ func (mgr *PostSetupManager) commitmentAtx(ctx context.Context, dataDir string) // It will use the ATX with the highest height seen by the node and defaults to the goldenATX, // when no ATXs have yet been published. func (mgr *PostSetupManager) findCommitmentAtx(ctx context.Context) (types.ATXID, error) { - atx, err := atxs.GetAtxIDWithMaxHeight(mgr.db) + atx, err := atxs.GetIDWithMaxHeight(mgr.db, types.EmptyNodeID) switch { case errors.Is(err, sql.ErrNotFound): mgr.logger.With().Info("using golden atx as commitment atx") diff --git a/checkpoint/recovery_test.go b/checkpoint/recovery_test.go index b679d077d9..270cc87528 100644 --- a/checkpoint/recovery_test.go +++ b/checkpoint/recovery_test.go @@ -259,7 +259,7 @@ func validateAndPreserveData(tb testing.TB, db *sql.Database, deps []*types.Veri mvalidator.EXPECT().NIPostChallenge(&vatx.ActivationTx.NIPostChallenge, cdb, vatx.SmesherID) } mvalidator.EXPECT().PositioningAtx(&vatx.PositioningATX, cdb, goldenAtx, vatx.PublishEpoch, layersPerEpoch) - mvalidator.EXPECT().NIPost(gomock.Any(), vatx.SmesherID, gomock.Any(), vatx.NIPost, gomock.Any(), vatx.NumUnits) + mvalidator.EXPECT().NIPost(gomock.Any(), vatx.SmesherID, gomock.Any(), vatx.NIPost, gomock.Any(), vatx.NumUnits).Return(uint64(1111111), nil) mreceiver.EXPECT().OnAtx(gomock.Any()) mtrtl.EXPECT().OnAtx(gomock.Any()) require.NoError(tb, atxHandler.HandleAtxData(context.Background(), "self", encoded)) @@ -555,7 +555,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_Still_Initializing(t *testing.T) require.NoError(t, poets.Add(olddb, types.PoetProofRef(vatx.GetPoetProofRef()), encoded, proofs[i].PoetServiceID, proofs[i].RoundID)) } - commitment, err := atxs.GetAtxIDWithMaxHeight(olddb) + commitment, err := atxs.GetIDWithMaxHeight(olddb, types.EmptyNodeID) require.NoError(t, err) require.NoError(t, olddb.Close()) diff --git a/datastore/store.go b/datastore/store.go index 7d0aa47970..d7907a6823 100644 --- a/datastore/store.go +++ b/datastore/store.go @@ -263,7 +263,7 @@ func (db *CachedDB) IdentityExists(nodeID types.NodeID) (bool, error) { } func (db *CachedDB) MaxHeightAtx() (types.ATXID, error) { - return atxs.GetAtxIDWithMaxHeight(db) + return atxs.GetIDWithMaxHeight(db, types.EmptyNodeID) } // Hint marks which DB should be queried for a certain provided hash. diff --git a/sql/atxs/atxs.go b/sql/atxs/atxs.go index f63211fca8..c92ede107b 100644 --- a/sql/atxs/atxs.go +++ b/sql/atxs/atxs.go @@ -294,12 +294,12 @@ func Add(db sql.Executor, atx *types.VerifiedActivationTx) error { return nil } -// GetAtxIDWithMaxHeight returns the ID of the atx from the last 2 epoch with the highest (or tied for the highest) tick height. +// GetIDWithMaxHeight returns the ID of the atx from the last 2 epoch with the highest (or tied for the highest) tick height. // it is possible that some poet servers are faster than others and the network ends up having its highest ticked atx still in // previous epoch and the atxs building on top of it have not been published yet. selecting from the last two epochs to strike // a balance between being fair to honest miners while not giving unfair advantage for malicious actors who retroactively // publish a high tick atx many epochs back. -func GetAtxIDWithMaxHeight(db sql.Executor) (types.ATXID, error) { +func GetIDWithMaxHeight(db sql.Executor, pref types.NodeID) (types.ATXID, error) { var ( rst types.ATXID max uint64 @@ -309,13 +309,24 @@ func GetAtxIDWithMaxHeight(db sql.Executor) (types.ATXID, error) { stmt.ColumnBytes(0, id[:]) height := uint64(stmt.ColumnInt64(1)) + uint64(stmt.ColumnInt64(2)) if height >= max { - max = height - rst = id + var smesher types.NodeID + stmt.ColumnBytes(3, smesher[:]) + if height > max { + max = height + rst = id + } else if pref != types.EmptyNodeID && smesher == pref { + // height is equal. prefer atxs from `pref` + rst = id + } } return true } - if rows, err := db.Exec("select id, base_tick_height, tick_count, epoch from atxs where epoch >= (select max(epoch) from atxs)-1;", nil, dec); err != nil { + if rows, err := db.Exec(` + select id, base_tick_height, tick_count, pubkey + from atxs left join identities using(pubkey) + where identities.pubkey is null and epoch >= (select max(epoch) from atxs)-1 + order by epoch desc;`, nil, dec); err != nil { return types.ATXID{}, fmt.Errorf("select positioning atx: %w", err) } else if rows == 0 { return types.ATXID{}, sql.ErrNotFound diff --git a/sql/atxs/atxs_test.go b/sql/atxs/atxs_test.go index 646d10e22d..b2f41649be 100644 --- a/sql/atxs/atxs_test.go +++ b/sql/atxs/atxs_test.go @@ -13,6 +13,7 @@ import ( "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/identities" ) const layersPerEpoch = 5 @@ -577,7 +578,20 @@ func newAtx(signer *signing.EdSigner, opts ...createAtxOpt) (*types.VerifiedActi return atx.Verify(0, 1) } -func TestPositioningID(t *testing.T) { +func createIdentities(tb testing.TB, db sql.Executor, n int, midxs ...int) []*signing.EdSigner { + var sigs []*signing.EdSigner + for i := 0; i < n; i++ { + sig, err := signing.NewEdSigner() + require.NoError(tb, err) + sigs = append(sigs, sig) + } + for _, idx := range midxs { + require.NoError(tb, identities.SetMalicious(db, sigs[idx].NodeID(), []byte("bad"))) + } + return sigs +} + +func TestGetIDWithMaxHeight(t *testing.T) { type header struct { coinbase types.Address base, count uint64 @@ -586,6 +600,8 @@ func TestPositioningID(t *testing.T) { for _, tc := range []struct { desc string atxs []header + pref int + midxs []int expect int }{ { @@ -598,29 +614,74 @@ func TestPositioningID(t *testing.T) { {coinbase: types.Address{2}, base: 1, count: 2, epoch: 2}, }, expect: 1, + pref: -1, }, { desc: "highest in prev epoch", atxs: []header{ {coinbase: types.Address{1}, base: 1, count: 3, epoch: 1}, // too old - {coinbase: types.Address{1}, base: 1, count: 2, epoch: 2}, - {coinbase: types.Address{2}, base: 1, count: 1, epoch: 3}, + {coinbase: types.Address{2}, base: 1, count: 2, epoch: 2}, + {coinbase: types.Address{3}, base: 1, count: 1, epoch: 3}, + }, + pref: -1, + expect: 1, + }, + { + desc: "prefer later epoch", + atxs: []header{ + {coinbase: types.Address{1}, base: 1, count: 2, epoch: 1}, + {coinbase: types.Address{2}, base: 1, count: 2, epoch: 2}, + }, + pref: -1, + expect: 1, + }, + { + desc: "prefer node id", + atxs: []header{ + {coinbase: types.Address{1}, base: 1, count: 2, epoch: 1}, + {coinbase: types.Address{2}, base: 1, count: 2, epoch: 1}, + {coinbase: types.Address{3}, base: 1, count: 2, epoch: 2}, }, + pref: 1, expect: 1, }, + { + desc: "skip malicious id", + atxs: []header{ + {coinbase: types.Address{1}, base: 1, count: 2, epoch: 1}, + {coinbase: types.Address{2}, base: 1, count: 2, epoch: 1}, + {coinbase: types.Address{3}, base: 1, count: 1, epoch: 2}, + }, + pref: 1, + midxs: []int{0, 1}, + expect: 2, + }, + { + desc: "skip malicious id not found", + atxs: []header{ + {coinbase: types.Address{1}, base: 1, count: 2, epoch: 1}, + {coinbase: types.Address{2}, base: 1, count: 2, epoch: 1}, + {coinbase: types.Address{3}, base: 1, count: 2, epoch: 2}, + }, + pref: 1, + midxs: []int{0, 1, 2}, + expect: -1, + }, { desc: "by tick height", atxs: []header{ {coinbase: types.Address{1}, base: 1, count: 2, epoch: 1}, {coinbase: types.Address{2}, base: 1, count: 1, epoch: 1}, }, + pref: -1, expect: 0, }, } { t.Run(tc.desc, func(t *testing.T) { db := sql.InMemory() + sigs := createIdentities(t, db, len(tc.atxs), tc.midxs...) ids := []types.ATXID{} - for _, atx := range tc.atxs { + for i, atx := range tc.atxs { full := &types.ActivationTx{ InnerActivationTx: types.InnerActivationTx{ NIPostChallenge: types.NIPostChallenge{ @@ -630,10 +691,7 @@ func TestPositioningID(t *testing.T) { NumUnits: 2, }, } - - sig, err := signing.NewEdSigner() - require.NoError(t, err) - require.NoError(t, activation.SignAndFinalizeAtx(sig, full)) + require.NoError(t, activation.SignAndFinalizeAtx(sigs[i], full)) full.SetEffectiveNumUnits(full.NumUnits) full.SetReceived(time.Now()) @@ -643,8 +701,12 @@ func TestPositioningID(t *testing.T) { require.NoError(t, atxs.Add(db, vAtx)) ids = append(ids, full.ID()) } - rst, err := atxs.GetAtxIDWithMaxHeight(db) - if len(tc.atxs) == 0 { + var pref types.NodeID + if tc.pref > 0 { + pref = sigs[tc.pref].NodeID() + } + rst, err := atxs.GetIDWithMaxHeight(db, pref) + if len(tc.atxs) == 0 || tc.expect < 0 { require.ErrorIs(t, err, sql.ErrNotFound) } else { require.Equal(t, ids[tc.expect], rst)