From 7d0e2c676e81e1eff0f7d35d7424e3981b8ec489 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bartosz=20R=C3=B3=C5=BCa=C5=84ski?= Date: Thu, 20 Jun 2024 12:38:07 +0200 Subject: [PATCH] Calculate and persist ATX weight in DB --- activation/e2e/atx_merge_test.go | 2 + activation/e2e/builds_atx_v2_test.go | 2 +- activation/handler_test.go | 4 +- activation/handler_v1.go | 6 + activation/handler_v2.go | 71 +++++-- activation/handler_v2_test.go | 242 +++++++++++++++++------ activation/post_test.go | 22 ++- api/grpcserver/grpcserver_test.go | 35 ++-- api/grpcserver/v2alpha1/activation.go | 2 +- atxsdata/data.go | 2 +- beacon/beacon.go | 2 +- beacon/beacon_test.go | 31 ++- beacon/handlers.go | 4 +- blocks/generator_test.go | 15 +- cmd/activeset/activeset.go | 2 +- common/types/activation.go | 50 +---- fetch/mesh_data_test.go | 10 +- hare3/eligibility/oracle_test.go | 9 +- hare3/hare_test.go | 1 + malfeasance/wire/malfeasance_test.go | 7 +- mesh/executor_test.go | 17 +- miner/proposal_builder_test.go | 1 + proposals/eligibility_validator_test.go | 1 + sql/atxs/atxs.go | 10 +- sql/migrations/state/0020_atx_weight.sql | 2 + tortoise/model/core.go | 17 +- tortoise/sim/generator.go | 17 +- tortoise/sim/layer.go | 4 +- tortoise/tortoise_test.go | 5 +- 29 files changed, 379 insertions(+), 214 deletions(-) create mode 100644 sql/migrations/state/0020_atx_weight.sql diff --git a/activation/e2e/atx_merge_test.go b/activation/e2e/atx_merge_test.go index a47582f8a2..a0cf900b48 100644 --- a/activation/e2e/atx_merge_test.go +++ b/activation/e2e/atx_merge_test.go @@ -463,6 +463,7 @@ func Test_MarryAndMerge(t *testing.T) { require.Equal(t, totalNumUnits, atx.NumUnits) require.Equal(t, mainID.NodeID(), atx.SmesherID) require.Equal(t, poetProof.LeafCount/tickSize, atx.TickCount) + require.Equal(t, uint64(totalNumUnits)*atx.TickCount, atx.Weight) posATX, err := atxs.Get(db, marriageATX.ID()) require.NoError(t, err) @@ -511,6 +512,7 @@ func Test_MarryAndMerge(t *testing.T) { require.Equal(t, totalNumUnits, atx.NumUnits) require.Equal(t, signers[1].NodeID(), atx.SmesherID) require.Equal(t, poetProof.LeafCount/tickSize, atx.TickCount) + require.Equal(t, uint64(totalNumUnits)*atx.TickCount, atx.Weight) posATX, err = atxs.Get(db, mergedATX.ID()) require.NoError(t, err) diff --git a/activation/e2e/builds_atx_v2_test.go b/activation/e2e/builds_atx_v2_test.go index 057b45beec..9bc896cad7 100644 --- a/activation/e2e/builds_atx_v2_test.go +++ b/activation/e2e/builds_atx_v2_test.go @@ -216,7 +216,7 @@ func TestBuilder_SwitchesToBuildV2(t *testing.T) { require.NotZero(t, atx.BaseTickHeight) require.NotZero(t, atx.TickCount) - require.NotZero(t, atx.GetWeight()) + require.NotZero(t, atx.Weight) require.NotZero(t, atx.TickHeight()) require.Equal(t, opts.NumUnits, atx.NumUnits) previous = atx diff --git a/activation/handler_test.go b/activation/handler_test.go index d5b03b9be0..30c0681dc2 100644 --- a/activation/handler_test.go +++ b/activation/handler_test.go @@ -642,7 +642,7 @@ func TestHandler_AtxWeight(t *testing.T) { require.Equal(t, uint64(0), stored1.BaseTickHeight) require.Equal(t, leaves/tickSize, stored1.TickCount) require.Equal(t, leaves/tickSize, stored1.TickHeight()) - require.Equal(t, (leaves/tickSize)*units, stored1.GetWeight()) + require.Equal(t, (leaves/tickSize)*units, stored1.Weight) atx2 := newChainedActivationTxV1(t, atx1, atx1.ID()) atx2.Sign(sig) @@ -657,7 +657,7 @@ func TestHandler_AtxWeight(t *testing.T) { require.Equal(t, stored1.TickHeight(), stored2.BaseTickHeight) require.Equal(t, leaves/tickSize, stored2.TickCount) require.Equal(t, stored1.TickHeight()+leaves/tickSize, stored2.TickHeight()) - require.Equal(t, int(leaves/tickSize)*units, int(stored2.GetWeight())) + require.Equal(t, int(leaves/tickSize)*units, int(stored2.Weight)) } func TestHandler_WrongHash(t *testing.T) { diff --git a/activation/handler_v1.go b/activation/handler_v1.go index cefff79e1b..0126ee9832 100644 --- a/activation/handler_v1.go +++ b/activation/handler_v1.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "math/bits" "sync" "time" @@ -683,6 +684,11 @@ func (h *HandlerV1) processATX( atx.NumUnits = effectiveNumUnits atx.BaseTickHeight = baseTickHeight atx.TickCount = leaves / h.tickSize + hi, weight := bits.Mul64(uint64(atx.NumUnits), atx.TickCount) + if hi != 0 { + return nil, errors.New("atx weight would overflow uint64") + } + atx.Weight = weight proof, err = h.storeAtx(ctx, atx, watx) if err != nil { diff --git a/activation/handler_v2.go b/activation/handler_v2.go index 476820acfa..20851b3f5a 100644 --- a/activation/handler_v2.go +++ b/activation/handler_v2.go @@ -1,10 +1,11 @@ package activation import ( + "cmp" "context" "errors" "fmt" - "math" + "math/bits" "slices" "time" @@ -121,9 +122,10 @@ func (h *HandlerV2) processATX( atx := &types.ActivationTx{ PublishEpoch: watx.PublishEpoch, Coinbase: watx.Coinbase, - NumUnits: parts.effectiveUnits, BaseTickHeight: baseTickHeight, - TickCount: parts.leaves / h.tickSize, + NumUnits: parts.effectiveUnits, + TickCount: parts.ticks, + Weight: parts.weight, VRFNonce: types.VRFPostIndex(watx.VRFNonce), SmesherID: watx.SmesherID, AtxBlob: types.AtxBlob{Blob: blob, Version: types.AtxV2}, @@ -487,10 +489,50 @@ func (h *HandlerV2) equivocationSet(atx *wire.ActivationTxV2) ([]types.NodeID, e } type atxParts struct { - leaves uint64 + ticks uint64 + weight uint64 effectiveUnits uint32 } +type nipostSize struct { + units uint32 + ticks uint64 +} + +func (n *nipostSize) addUnits(units uint32) error { + sum, carry := bits.Add32(n.units, units, 0) + if carry != 0 { + return errors.New("units overflow") + } + n.units = sum + return nil +} + +type nipostSizes []*nipostSize + +func (n nipostSizes) minTicks() uint64 { + return slices.MinFunc(n, func(a, b *nipostSize) int { return cmp.Compare(a.ticks, b.ticks) }).ticks +} + +func (n nipostSizes) sumUp() (units uint32, weight uint64, err error) { + var totalEffectiveNumUnits uint32 + var totalWeight uint64 + for _, ns := range n { + sum, carry := bits.Add32(totalEffectiveNumUnits, ns.units, 0) + if carry != 0 { + return 0, 0, fmt.Errorf("total units overflow (%d + %d)", totalEffectiveNumUnits, ns.units) + } + totalEffectiveNumUnits = sum + + hi, weight := bits.Mul64(uint64(ns.units), ns.ticks) + if hi != 0 { + return 0, 0, fmt.Errorf("weight overflow (%d * %d)", ns.units, ns.ticks) + } + totalWeight += weight + } + return totalEffectiveNumUnits, totalWeight, nil +} + func (h *HandlerV2) verifyIncludedIDsUniqueness(atx *wire.ActivationTxV2) error { seen := make(map[uint32]struct{}) for _, niposts := range atx.NiPosts { @@ -534,8 +576,9 @@ func (h *HandlerV2) syntacticallyValidateDeps( } // validate previous ATXs - var totalEffectiveNumUnits uint32 - for _, niposts := range atx.NiPosts { + nipostSizes := make(nipostSizes, len(atx.NiPosts)) + for i, niposts := range atx.NiPosts { + nipostSizes[i] = new(nipostSize) for _, post := range niposts.Posts { if post.MarriageIndex >= uint32(len(equivocationSet)) { err := fmt.Errorf("marriage index out of bounds: %d > %d", post.MarriageIndex, len(equivocationSet)-1) @@ -551,13 +594,13 @@ func (h *HandlerV2) syntacticallyValidateDeps( return nil, nil, fmt.Errorf("validating previous atx: %w", err) } } - totalEffectiveNumUnits += effectiveNumUnits + nipostSizes[i].addUnits(effectiveNumUnits) + } } // validate poet membership proofs - var minLeaves uint64 = math.MaxUint64 - for _, niposts := range atx.NiPosts { + for i, niposts := range atx.NiPosts { // verify PoET memberships in a single go indexedChallenges := make(map[uint64][]byte) @@ -594,7 +637,12 @@ func (h *HandlerV2) syntacticallyValidateDeps( if err != nil { return nil, nil, fmt.Errorf("invalid poet membership: %w", err) } - minLeaves = min(leaves, minLeaves) + nipostSizes[i].ticks = leaves / h.tickSize + } + + totalEffectiveNumUnits, totalWeight, err := nipostSizes.sumUp() + if err != nil { + return nil, nil, err } // validate all niposts @@ -641,8 +689,9 @@ func (h *HandlerV2) syntacticallyValidateDeps( } parts := &atxParts{ - leaves: minLeaves, + ticks: nipostSizes.minTicks(), effectiveUnits: totalEffectiveNumUnits, + weight: totalWeight, } if atx.Initial == nil { diff --git a/activation/handler_v2_test.go b/activation/handler_v2_test.go index 01bf43166f..ffa53f3fbd 100644 --- a/activation/handler_v2_test.go +++ b/activation/handler_v2_test.go @@ -4,6 +4,8 @@ import ( "context" "errors" "fmt" + "math" + "slices" "testing" "time" @@ -37,7 +39,10 @@ type marriedId struct { refAtx *wire.ActivationTxV2 } -const poetLeaves = 200 +const ( + tickSize = 20 + poetLeaves = 200 +) func newV2TestHandler(tb testing.TB, golden types.ATXID) *v2TestHandler { lg := zaptest.NewLogger(tb) @@ -50,7 +55,7 @@ func newV2TestHandler(tb testing.TB, golden types.ATXID) *v2TestHandler { atxsdata: atxsdata.New(), edVerifier: signing.NewEdVerifier(), clock: mocks.mclock, - tickSize: 1, + tickSize: tickSize, goldenATXID: golden, nipostValidator: mocks.mValidator, logger: lg, @@ -89,8 +94,12 @@ func (h *handlerMocks) expectVerifyNIPoST(atx *wire.ActivationTxV2) { ).Return(poetLeaves, nil) } -func (h *handlerMocks) expectVerifyNIPoSTs(atx *wire.ActivationTxV2, equivocationSet []types.NodeID) { - for _, nipost := range atx.NiPosts { +func (h *handlerMocks) expectVerifyNIPoSTs( + atx *wire.ActivationTxV2, + equivocationSet []types.NodeID, + poetLeaves []uint64, +) { + for i, nipost := range atx.NiPosts { for _, post := range nipost.Posts { h.mValidator.EXPECT().PostV2( gomock.Any(), @@ -107,7 +116,7 @@ func (h *handlerMocks) expectVerifyNIPoSTs(atx *wire.ActivationTxV2, equivocatio gomock.Any(), nipost.Challenge, gomock.Any(), - ) + ).Return(poetLeaves[i], nil) } } @@ -153,7 +162,11 @@ func (h *handlerMocks) expectAtxV2(atx *wire.ActivationTxV2) { h.expectStoreAtxV2(atx) } -func (h *handlerMocks) expectMergedAtxV2(atx *wire.ActivationTxV2, equivocationSet []types.NodeID) { +func (h *handlerMocks) expectMergedAtxV2( + atx *wire.ActivationTxV2, + equivocationSet []types.NodeID, + poetLeaves []uint64, +) { h.mclock.EXPECT().CurrentLayer().Return(postGenesisEpoch.FirstLayer()) h.expectFetchDeps(atx) h.mValidator.EXPECT().VRFNonceV2( @@ -162,7 +175,7 @@ func (h *handlerMocks) expectMergedAtxV2(atx *wire.ActivationTxV2, equivocationS atx.VRFNonce, atx.TotalNumUnits(), ) - h.expectVerifyNIPoSTs(atx, equivocationSet) + h.expectVerifyNIPoSTs(atx, equivocationSet, poetLeaves) h.expectStoreAtxV2(atx) } @@ -445,6 +458,7 @@ func TestHandlerV2_ProcessSoloATX(t *testing.T) { blob := codec.MustEncode(atx) atxHandler := newV2TestHandler(t, golden) + atxHandler.tickSize = tickSize atxHandler.expectInitialAtxV2(atx) proof, err := atxHandler.processATX(context.Background(), peer, atx, blob, time.Now()) @@ -456,9 +470,10 @@ func TestHandlerV2_ProcessSoloATX(t *testing.T) { require.NotNil(t, atx) require.Equal(t, atx.ID(), atxFromDb.ID()) require.Equal(t, atx.Coinbase, atxFromDb.Coinbase) - require.EqualValues(t, poetLeaves, atxFromDb.TickCount) - require.EqualValues(t, poetLeaves, atxFromDb.TickHeight()) + require.EqualValues(t, poetLeaves/tickSize, atxFromDb.TickCount) + require.EqualValues(t, 0+atxFromDb.TickCount, atxFromDb.TickHeight()) // positioning is golden require.Equal(t, atx.NiPosts[0].Posts[0].NumUnits, atxFromDb.NumUnits) + require.EqualValues(t, atx.NiPosts[0].Posts[0].NumUnits*poetLeaves/tickSize, atxFromDb.Weight) // processing ATX for the second time should skip checks proof, err = atxHandler.processATX(context.Background(), peer, atx, blob, time.Now()) @@ -471,9 +486,10 @@ func TestHandlerV2_ProcessSoloATX(t *testing.T) { prev := newInitialATXv1(t, golden) prev.Sign(sig) - atxs.Add(atxHandler.cdb, toAtx(t, prev)) + prevAtx := toAtx(t, prev) + atxs.Add(atxHandler.cdb, prevAtx) - atx := newSoloATXv2(t, prev.PublishEpoch+1, prev.ID(), golden) + atx := newSoloATXv2(t, prev.PublishEpoch+1, prev.ID(), prevAtx.ID()) atx.Sign(sig) blob := codec.MustEncode(atx) atxHandler.expectAtxV2(atx) @@ -486,34 +502,39 @@ func TestHandlerV2_ProcessSoloATX(t *testing.T) { require.NoError(t, err) require.Nil(t, atxFromDb.CommitmentATX) - // copies coinbase and VRF nonce from the previous ATX - require.Equal(t, prev.Coinbase, atxFromDb.Coinbase) - require.EqualValues(t, *prev.VRFNonce, atxFromDb.VRFNonce) + + require.Equal(t, atx.Coinbase, atxFromDb.Coinbase) + require.EqualValues(t, atx.VRFNonce, atxFromDb.VRFNonce) + require.EqualValues(t, poetLeaves/tickSize, atxFromDb.TickCount) + require.EqualValues(t, prevAtx.TickHeight(), atxFromDb.BaseTickHeight) + require.EqualValues(t, prevAtx.TickHeight()+atxFromDb.TickCount, atxFromDb.TickHeight()) + require.Equal(t, atx.NiPosts[0].Posts[0].NumUnits, atxFromDb.NumUnits) + require.EqualValues(t, atx.NiPosts[0].Posts[0].NumUnits*poetLeaves/tickSize, atxFromDb.Weight) }) t.Run("second ATX, previous V2", func(t *testing.T) { t.Parallel() atxHandler := newV2TestHandler(t, golden) - prev := newInitialATXv2(t, golden) - prev.Sign(sig) - blob := codec.MustEncode(prev) - - atxHandler.expectInitialAtxV2(prev) - proof, err := atxHandler.processATX(context.Background(), peer, prev, blob, time.Now()) - require.NoError(t, err) - require.Nil(t, proof) + prev := atxHandler.createAndProcessInitial(t, sig) - atx := newSoloATXv2(t, prev.PublishEpoch+1, prev.ID(), golden) + atx := newSoloATXv2(t, prev.PublishEpoch+1, prev.ID(), prev.ID()) atx.Sign(sig) - blob = codec.MustEncode(atx) - atxHandler.expectAtxV2(atx) + blob := codec.MustEncode(atx) - proof, err = atxHandler.processATX(context.Background(), peer, atx, blob, time.Now()) + atxHandler.expectAtxV2(atx) + proof, err := atxHandler.processATX(context.Background(), peer, atx, blob, time.Now()) require.NoError(t, err) require.Nil(t, proof) - _, err = atxs.Get(atxHandler.cdb, atx.ID()) + prevAtx, err := atxs.Get(atxHandler.cdb, prev.ID()) + require.NoError(t, err) + atxFromDb, err := atxs.Get(atxHandler.cdb, atx.ID()) require.NoError(t, err) + require.EqualValues(t, poetLeaves/tickSize, atxFromDb.TickCount) + require.EqualValues(t, prevAtx.TickHeight(), atxFromDb.BaseTickHeight) + require.EqualValues(t, prevAtx.TickHeight()+atxFromDb.TickCount, atxFromDb.TickHeight()) + require.Equal(t, atx.NiPosts[0].Posts[0].NumUnits, atxFromDb.NumUnits) + require.EqualValues(t, atx.NiPosts[0].Posts[0].NumUnits*poetLeaves/tickSize, atxFromDb.Weight) }) t.Run("second ATX, previous checkpointed", func(t *testing.T) { t.Parallel() @@ -639,46 +660,29 @@ func marryIDs( golden types.ATXID, num int, ) (marriage *wire.ActivationTxV2, other []*wire.ActivationTxV2) { - var ( - marriedIds []marriedId - equivocationSet = []types.NodeID{sig.NodeID()} - ) - for range num { - signer, err := signing.NewEdSigner() - require.NoError(t, err) - atx := atxHandler.createAndProcessInitial(t, signer) - marriedIds = append(marriedIds, marriedId{signer, atx}) - } - - var atxs []*wire.ActivationTxV2 mATX := newInitialATXv2(t, golden) mATX.Marriages = []wire.MarriageCertificate{{ Signature: sig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), }} - for _, id := range marriedIds { + + for range num { + signer, err := signing.NewEdSigner() + require.NoError(t, err) + atx := atxHandler.createAndProcessInitial(t, signer) + other = append(other, atx) mATX.Marriages = append(mATX.Marriages, wire.MarriageCertificate{ - ReferenceAtx: id.refAtx.ID(), - Signature: id.signer.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + ReferenceAtx: atx.ID(), + Signature: signer.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), }) - equivocationSet = append(equivocationSet, id.signer.NodeID()) } + mATX.Sign(sig) atxHandler.expectInitialAtxV2(mATX) p, err := atxHandler.processATX(context.Background(), "", mATX, codec.MustEncode(mATX), time.Now()) require.NoError(t, err) require.Nil(t, p) - // Other IDs publish their first ATXs. - for _, id := range marriedIds { - atx := newInitialATXv2(t, golden) - atx.Sign(id.signer) - atxHandler.expectInitialAtxV2(atx) - _, err := atxHandler.processATX(context.Background(), "", atx, codec.MustEncode(atx), time.Now()) - require.NoError(t, err) - atxs = append(atxs, atx) - } - - return mATX, atxs + return mATX, other } func TestHandlerV2_ProcessMergedATX(t *testing.T) { @@ -717,7 +721,7 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { merged.PreviousATXs = previousATXs merged.Sign(sig) - atxHandler.expectMergedAtxV2(merged, equivocationSet) + atxHandler.expectMergedAtxV2(merged, equivocationSet, []uint64{poetLeaves}) p, err := atxHandler.processATX(context.Background(), "", merged, codec.MustEncode(merged), time.Now()) require.NoError(t, err) require.Nil(t, p) @@ -726,6 +730,81 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { require.NoError(t, err) require.Equal(t, totalNumUnits, atx.NumUnits) require.Equal(t, sig.NodeID(), atx.SmesherID) + require.EqualValues(t, totalNumUnits*poetLeaves/tickSize, atx.Weight) + }) + t.Run("merged IDs on 2 poets", func(t *testing.T) { + const tickSize = 33 + atxHandler := newV2TestHandler(t, golden) + atxHandler.tickSize = tickSize + + // Marry IDs + mATX, otherATXs := marryIDs(t, atxHandler, sig, golden, 4) + previousATXs := []types.ATXID{mATX.ID()} + equivocationSet := []types.NodeID{sig.NodeID()} + for _, atx := range otherATXs { + previousATXs = append(previousATXs, atx.ID()) + equivocationSet = append(equivocationSet, atx.SmesherID) + } + + // Process a merged ATX + merged := &wire.ActivationTxV2{ + PublishEpoch: mATX.PublishEpoch + 2, + PreviousATXs: previousATXs, + PositioningATX: mATX.ID(), + Coinbase: types.GenerateAddress([]byte("aaaa")), + VRFNonce: uint64(999), + NiPosts: make([]wire.NiPostsV2, 2), + } + atxsPerPoet := [][]*wire.ActivationTxV2{ + append([]*wire.ActivationTxV2{mATX}, otherATXs[:2]...), + otherATXs[2:], + } + var totalNumUnits uint32 + unitsPerPoet := make([]uint32, 2) + var idx uint32 + for nipostId := range 2 { + for _, atx := range atxsPerPoet[nipostId] { + post := wire.SubPostV2{ + MarriageIndex: idx, + NumUnits: atx.TotalNumUnits(), + PrevATXIndex: idx, + } + unitsPerPoet[nipostId] += post.NumUnits + totalNumUnits += post.NumUnits + merged.NiPosts[nipostId].Posts = append(merged.NiPosts[nipostId].Posts, post) + idx++ + } + } + + mATXID := mATX.ID() + merged.MarriageATX = &mATXID + + merged.PreviousATXs = previousATXs + merged.Sign(sig) + + poetLeaves := []uint64{100, 500} + minPoetLeaves := slices.Min(poetLeaves) + + atxHandler.expectMergedAtxV2(merged, equivocationSet, poetLeaves) + p, err := atxHandler.processATX(context.Background(), "", merged, codec.MustEncode(merged), time.Now()) + require.NoError(t, err) + require.Nil(t, p) + + marriageATX, err := atxs.Get(atxHandler.cdb, mATX.ID()) + require.NoError(t, err) + atx, err := atxs.Get(atxHandler.cdb, merged.ID()) + require.NoError(t, err) + require.Equal(t, totalNumUnits, atx.NumUnits) + require.Equal(t, sig.NodeID(), atx.SmesherID) + require.Equal(t, minPoetLeaves/tickSize, atx.TickCount) + require.Equal(t, marriageATX.TickHeight()+atx.TickCount, atx.TickHeight()) + // the total weight is summed weight on each poet + var weight uint64 + for i := range unitsPerPoet { + ticks := poetLeaves[i] / tickSize + weight += uint64(unitsPerPoet[i]) * ticks + } + require.EqualValues(t, weight, atx.Weight) }) t.Run("signer must be included merged ATX", func(t *testing.T) { atxHandler := newV2TestHandler(t, golden) @@ -758,7 +837,7 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { atxHandler.mclock.EXPECT().CurrentLayer().Return(postGenesisEpoch.FirstLayer()) atxHandler.expectFetchDeps(merged) - atxHandler.expectVerifyNIPoSTs(merged, equivocationSet) + atxHandler.expectVerifyNIPoSTs(merged, equivocationSet, []uint64{200}) p, err := atxHandler.processATX(context.Background(), "", merged, codec.MustEncode(merged), time.Now()) require.ErrorContains(t, err, "ATX signer not present in merged ATX") @@ -1653,6 +1732,57 @@ func Test_MarryingMalicious(t *testing.T) { } } +func Test_CalculatingUnits(t *testing.T) { + t.Parallel() + t.Run("units on 1 nipost must not overflow", func(t *testing.T) { + t.Parallel() + ns := nipostSize{} + require.NoError(t, ns.addUnits(1)) + require.EqualValues(t, 1, ns.units) + require.Error(t, ns.addUnits(math.MaxUint32)) + }) + t.Run("total units on all niposts must not overflow", func(t *testing.T) { + t.Parallel() + ns := make(nipostSizes, 0) + ns = append(ns, &nipostSize{units: 11}, &nipostSize{units: math.MaxUint32 - 10}) + _, _, err := ns.sumUp() + require.Error(t, err) + }) + t.Run("units = sum of units on every nipost", func(t *testing.T) { + t.Parallel() + ns := make(nipostSizes, 0) + ns = append(ns, &nipostSize{units: 1}, &nipostSize{units: 10}) + u, _, err := ns.sumUp() + require.NoError(t, err) + require.EqualValues(t, 1+10, u) + }) +} + +func Test_CalculatingWeight(t *testing.T) { + t.Parallel() + t.Run("total weight must not overflow uint64", func(t *testing.T) { + t.Parallel() + ns := make(nipostSizes, 0) + ns = append(ns, &nipostSize{units: 1, ticks: 100}, &nipostSize{units: 10, ticks: math.MaxUint64}) + _, _, err := ns.sumUp() + require.Error(t, err) + }) + t.Run("weight = sum of weight on every nipost", func(t *testing.T) { + t.Parallel() + ns := make(nipostSizes, 0) + ns = append(ns, &nipostSize{units: 1, ticks: 100}, &nipostSize{units: 10, ticks: 1000}) + _, w, err := ns.sumUp() + require.NoError(t, err) + require.EqualValues(t, 1*100+10*1000, w) + }) +} + +func Test_CalculatingTicks(t *testing.T) { + ns := make(nipostSizes, 0) + ns = append(ns, &nipostSize{units: 1, ticks: 100}, &nipostSize{units: 10, ticks: 1000}) + require.EqualValues(t, 100, ns.minTicks()) +} + func newInitialATXv2(t testing.TB, golden types.ATXID) *wire.ActivationTxV2 { t.Helper() atx := &wire.ActivationTxV2{ diff --git a/activation/post_test.go b/activation/post_test.go index c0273a6369..de51d599df 100644 --- a/activation/post_test.go +++ b/activation/post_test.go @@ -273,15 +273,15 @@ func TestPostSetupManager_findCommitmentAtx_UsesLatestAtx(t *testing.T) { signer, err := signing.NewEdSigner() require.NoError(t, err) - challenge := types.NIPostChallenge{ + atx := &types.ActivationTx{ PublishEpoch: 1, + NumUnits: 2, + Weight: 2, + SmesherID: signer.NodeID(), + TickCount: 1, } - atx := types.NewActivationTx(challenge, types.Address{}, 2) - atx.SmesherID = signer.NodeID() atx.SetID(types.RandomATXID()) atx.SetReceived(time.Now()) - atx.TickCount = 1 - require.NoError(t, err) require.NoError(t, atxs.Add(mgr.db, atx)) mgr.atxsdata.AddFromAtx(atx, false) @@ -323,12 +323,16 @@ func TestPostSetupManager_getCommitmentAtx_getsCommitmentAtxFromInitialAtx(t *te // add an atx by the same node commitmentAtx := types.RandomATXID() - atx := types.NewActivationTx(types.NIPostChallenge{}, types.Address{}, 1) - atx.CommitmentATX = &commitmentAtx - atx.SmesherID = signer.NodeID() + atx := &types.ActivationTx{ + NumUnits: 1, + Weight: 1, + SmesherID: signer.NodeID(), + TickCount: 1, + CommitmentATX: &commitmentAtx, + } + atx.SetID(types.RandomATXID()) atx.SetReceived(time.Now()) - atx.TickCount = 1 require.NoError(t, atxs.Add(mgr.cdb, atx)) atxid, err := mgr.commitmentAtx(context.Background(), mgr.opts.DataDir, signer.NodeID()) diff --git a/api/grpcserver/grpcserver_test.go b/api/grpcserver/grpcserver_test.go index 7bb2eacf57..c434396896 100644 --- a/api/grpcserver/grpcserver_test.go +++ b/api/grpcserver/grpcserver_test.go @@ -88,8 +88,6 @@ var ( addr1 types.Address addr2 types.Address rewardSmesherID = types.RandomNodeID() - prevAtxID = types.ATXID(types.HexToHash32("44444")) - challenge = newChallenge(1, prevAtxID, prevAtxID, postGenesisEpoch) globalAtx *types.ActivationTx globalAtx2 *types.ActivationTx globalTx *types.Transaction @@ -165,12 +163,28 @@ func TestMain(m *testing.M) { addr1 = wallet.Address(signer1.PublicKey().Bytes()) addr2 = wallet.Address(signer2.PublicKey().Bytes()) - globalAtx = types.NewActivationTx(challenge, addr1, numUnits) + globalAtx = &types.ActivationTx{ + PublishEpoch: postGenesisEpoch, + Sequence: 1, + PrevATXID: types.ATXID{4, 4, 4, 4}, + Coinbase: addr1, + NumUnits: numUnits, + Weight: numUnits, + TickCount: 1, + SmesherID: signer.NodeID(), + } globalAtx.SetReceived(time.Now()) - globalAtx.SmesherID = signer.NodeID() - globalAtx.TickCount = 1 - globalAtx2 = types.NewActivationTx(challenge, addr2, numUnits) + globalAtx2 = &types.ActivationTx{ + PublishEpoch: postGenesisEpoch, + Sequence: 1, + PrevATXID: types.ATXID{5, 5, 5, 5}, + Coinbase: addr2, + NumUnits: numUnits, + Weight: numUnits, + TickCount: 1, + SmesherID: signer.NodeID(), + } globalAtx2.SetReceived(time.Now()) globalAtx2.SmesherID = signer.NodeID() globalAtx2.TickCount = 1 @@ -391,15 +405,6 @@ func NewTx(nonce uint64, recipient types.Address, signer *signing.EdSigner) *typ return &tx } -func newChallenge(sequence uint64, prevAtxID, posAtxID types.ATXID, epoch types.EpochID) types.NIPostChallenge { - return types.NIPostChallenge{ - Sequence: sequence, - PrevATXID: prevAtxID, - PublishEpoch: epoch, - PositioningATX: posAtxID, - } -} - func launchServer(tb testing.TB, services ...ServiceAPI) (Config, func()) { cfg := DefaultTestConfig() grpcService, err := NewWithServices(cfg.PublicListener, zaptest.NewLogger(tb).Named("grpc"), cfg, services) diff --git a/api/grpcserver/v2alpha1/activation.go b/api/grpcserver/v2alpha1/activation.go index fc8fd2f424..3019125996 100644 --- a/api/grpcserver/v2alpha1/activation.go +++ b/api/grpcserver/v2alpha1/activation.go @@ -149,7 +149,7 @@ func toAtx(atx *types.ActivationTx) *spacemeshv2alpha1.Activation { PublishEpoch: atx.PublishEpoch.Uint32(), PreviousAtx: atx.PrevATXID[:], Coinbase: atx.Coinbase.String(), - Weight: atx.GetWeight(), + Weight: atx.Weight, Height: atx.TickHeight(), } } diff --git a/atxsdata/data.go b/atxsdata/data.go index f8eae4794c..94c4cfc89d 100644 --- a/atxsdata/data.go +++ b/atxsdata/data.go @@ -76,7 +76,7 @@ func (d *Data) AddFromAtx(atx *types.ActivationTx, malicious bool) *ATX { atx.SmesherID, atx.Coinbase, atx.ID(), - atx.GetWeight(), + atx.Weight, atx.BaseTickHeight, atx.TickHeight(), atx.VRFNonce, diff --git a/beacon/beacon.go b/beacon/beacon.go index 8160597c0f..3a0effae27 100644 --- a/beacon/beacon.go +++ b/beacon/beacon.go @@ -604,7 +604,7 @@ func (pd *ProtocolDriver) initEpochStateIfNotPresent(logger *zap.Logger, target ) err := atxs.IterateAtxsWithMalfeasance(pd.cdb, target-1, func(atx *types.ActivationTx, malicious bool) bool { if !malicious { - epochWeight += atx.GetWeight() + epochWeight += atx.Weight } else { logger.Debug("malicious miner get 0 weight", zap.Stringer("smesher", atx.SmesherID)) } diff --git a/beacon/beacon_test.go b/beacon/beacon_test.go index c72776f17a..bdc1c54fa7 100644 --- a/beacon/beacon_test.go +++ b/beacon/beacon_test.go @@ -114,22 +114,25 @@ func createATX( numUnits uint32, received time.Time, ) types.ATXID { - nonce := types.VRFPostIndex(1) - atx := types.NewActivationTx( - types.NIPostChallenge{PublishEpoch: lid.GetEpoch()}, - types.GenerateAddress(types.RandomBytes(types.AddressLength)), - numUnits, - ) - atx.VRFNonce = nonce + tb.Helper() + atx := types.ActivationTx{ + PublishEpoch: lid.GetEpoch(), + Coinbase: types.GenerateAddress(types.RandomBytes(types.AddressLength)), + NumUnits: numUnits, + VRFNonce: 1, + TickCount: 1, + Weight: uint64(numUnits), + SmesherID: sig.NodeID(), + } + atx.SetReceived(received) - atx.SmesherID = sig.NodeID() atx.SetID(types.RandomATXID()) - atx.TickCount = 1 - require.NoError(tb, atxs.Add(db, atx)) + require.NoError(tb, atxs.Add(db, &atx)) return atx.ID() } func createRandomATXs(tb testing.TB, db *datastore.CachedDB, lid types.LayerID, num int) { + tb.Helper() for i := 0; i < num; i++ { sig, err := signing.NewEdSigner() require.NoError(tb, err) @@ -187,12 +190,8 @@ func TestBeacon_MultipleNodes(t *testing.T) { require.NoError(t, err) require.Equal(t, bootstrap, got) } - for i, node := range testNodes { - if i == 0 { - // make the first node non-smeshing node - continue - } - + // make the first node non-smeshing node + for _, node := range testNodes[1:] { for _, db := range dbs { for _, s := range node.signers { createATX(t, db, atxPublishLid, s, 1, time.Now().Add(-1*time.Second)) diff --git a/beacon/handlers.go b/beacon/handlers.go index 7234572101..89d838e985 100644 --- a/beacon/handlers.go +++ b/beacon/handlers.go @@ -331,7 +331,7 @@ func (pd *ProtocolDriver) storeFirstVotes(m FirstVotingMessage, nodeID types.Nod } voteWeight := new(big.Int) if !malicious { - voteWeight.SetUint64(atx.GetWeight()) + voteWeight.SetUint64(atx.Weight) } else { pd.logger.Debug("malicious miner get 0 weight", zap.Stringer("smesher", nodeID)) } @@ -457,7 +457,7 @@ func (pd *ProtocolDriver) storeFollowingVotes(m FollowingVotingMessage, nodeID t } voteWeight := new(big.Int) if !malicious { - voteWeight.SetUint64(atx.GetWeight()) + voteWeight.SetUint64(atx.Weight) } else { pd.logger.Debug("malicious miner get 0 weight", zap.Stringer("smesher", nodeID)) } diff --git a/blocks/generator_test.go b/blocks/generator_test.go index 0d4d206464..4145f3ff29 100644 --- a/blocks/generator_test.go +++ b/blocks/generator_test.go @@ -154,14 +154,15 @@ func createModifiedATXs( signer, err := signing.NewEdSigner() require.NoError(tb, err) signers = append(signers, signer) - address := types.GenerateAddress(signer.PublicKey().Bytes()) - atx := types.NewActivationTx( - types.NIPostChallenge{PublishEpoch: lid.GetEpoch()}, - address, - numUnit, - ) + atx := &types.ActivationTx{ + PublishEpoch: lid.GetEpoch(), + Coinbase: types.GenerateAddress(signer.PublicKey().Bytes()), + NumUnits: numUnit, + SmesherID: signer.NodeID(), + TickCount: 1, + Weight: uint64(numUnit), + } atx.SetReceived(time.Now()) - atx.SmesherID = signer.NodeID() atx.SetID(types.RandomATXID()) onAtx(atx) data.AddFromAtx(atx, false) diff --git a/cmd/activeset/activeset.go b/cmd/activeset/activeset.go index 6c3acd6d0c..1046916b02 100644 --- a/cmd/activeset/activeset.go +++ b/cmd/activeset/activeset.go @@ -39,7 +39,7 @@ Example: for _, id := range ids { atx, err := atxs.Get(db, id) must(err, "get id %v: %s\n", id, err) - weight += atx.GetWeight() + weight += atx.Weight } fmt.Printf("count = %d\nweight = %d\n", len(ids), weight) } diff --git a/common/types/activation.go b/common/types/activation.go index ca99b151c4..0bc7d65076 100644 --- a/common/types/activation.go +++ b/common/types/activation.go @@ -185,6 +185,13 @@ type ActivationTx struct { TickCount uint64 VRFNonce VRFPostIndex SmesherID NodeID + // Weight of the ATX. The total weight of the epoch is expected to fit in a uint64. + // The total ATX weight is sum(NumUnits * TickCount) for identity it holds. + // Space Units sizes are chosen such that NumUnits for all ATXs in an epoch is expected to be < 10^6. + // PoETs should produce ~10k ticks at genesis, but are expected due to technological advances + // to produce more over time. A uint64 should be large enough to hold the total weight of an epoch, + // for at least the first few years. + Weight uint64 AtxBlob @@ -194,25 +201,6 @@ type ActivationTx struct { validity Validity // whether the chain is fully verified and OK } -// NewActivationTx returns a new activation transaction. The ATXID is calculated and cached. -// NOTE: this function is deprecated and used in a few tests only. -// Create a new ActivationTx with ActivationTx{...}, setting the fields manually. -func NewActivationTx( - challenge NIPostChallenge, - coinbase Address, - numUnits uint32, -) *ActivationTx { - atx := &ActivationTx{ - PublishEpoch: challenge.PublishEpoch, - Sequence: challenge.Sequence, - PrevATXID: challenge.PrevATXID, - CommitmentATX: challenge.CommitmentATX, - Coinbase: coinbase, - NumUnits: numUnits, - } - return atx -} - // TargetEpoch returns the target epoch of the ATX. This is the epoch in which the miner is eligible // to participate thanks to the ATX. func (atx *ActivationTx) TargetEpoch() EpochID { @@ -238,16 +226,6 @@ func (atx *ActivationTx) SetGolden() { atx.golden = true } -// Weight of the ATX. The total weight of the epoch is expected to fit in a uint64 and is -// sum(atx.NumUnits * atx.TickCount for each ATX in a given epoch). -// Space Units sizes are chosen such that NumUnits for all ATXs in an epoch is expected to be < 10^6. -// PoETs should produce ~10k ticks at genesis, but are expected due to technological advances -// to produce more over time. A uint64 should be large enough to hold the total weight of an epoch, -// for at least the first few years. -func (atx *ActivationTx) GetWeight() uint64 { - return getWeight(uint64(atx.NumUnits), atx.TickCount) -} - // TickHeight returns a sum of base tick height and tick count. func (atx *ActivationTx) TickHeight() uint64 { return atx.BaseTickHeight + atx.TickCount @@ -270,7 +248,7 @@ func (atx *ActivationTx) MarshalLogObject(encoder log.ObjectEncoder) error { encoder.AddUint64("sequence_number", atx.Sequence) encoder.AddUint64("base_tick_height", atx.BaseTickHeight) encoder.AddUint64("tick_count", atx.TickCount) - encoder.AddUint64("weight", atx.GetWeight()) + encoder.AddUint64("weight", atx.Weight) encoder.AddUint64("height", atx.TickHeight()) return nil } @@ -400,15 +378,3 @@ type EpochActiveSet struct { } var MaxEpochActiveSetSize = scale.MustGetMaxElements[EpochActiveSet]("Set") - -func getWeight(numUnits, tickCount uint64) uint64 { - return safeMul(numUnits, tickCount) -} - -func safeMul(a, b uint64) uint64 { - c := a * b - if a > 1 && b > 1 && c/b != a { - panic("uint64 overflow") - } - return c -} diff --git a/fetch/mesh_data_test.go b/fetch/mesh_data_test.go index a83ef09c69..b1916a4a8a 100644 --- a/fetch/mesh_data_test.go +++ b/fetch/mesh_data_test.go @@ -458,11 +458,11 @@ func genATXs(tb testing.TB, num uint32) []*types.ActivationTx { require.NoError(tb, err) atxs := make([]*types.ActivationTx, 0, num) for i := uint32(0); i < num; i++ { - atx := types.NewActivationTx( - types.NIPostChallenge{}, - types.Address{1, 2, 3}, - i, - ) + atx := &types.ActivationTx{ + Coinbase: types.Address{1, 2, 3}, + NumUnits: i, + Weight: uint64(i), + } atx.SmesherID = sig.NodeID() atx.SetID(types.RandomATXID()) atxs = append(atxs, atx) diff --git a/hare3/eligibility/oracle_test.go b/hare3/eligibility/oracle_test.go index 60652a04a7..e299ea2f02 100644 --- a/hare3/eligibility/oracle_test.go +++ b/hare3/eligibility/oracle_test.go @@ -143,8 +143,7 @@ func (t *testOracle) createActiveSet( miners = append(miners, nodeID) atx := &types.ActivationTx{ PublishEpoch: lid.GetEpoch(), - NumUnits: uint32(i + 1), - TickCount: 1, + Weight: uint64(i + 1), SmesherID: nodeID, } atx.SetID(id) @@ -371,8 +370,7 @@ func Test_VrfSignVerify(t *testing.T) { activeSet := types.RandomActiveSet(numMiners) atx1 := &types.ActivationTx{ PublishEpoch: prevEpoch, - NumUnits: 1 * 1024, - TickCount: 1, + Weight: 1 * 1024, SmesherID: signer.NodeID(), } atx1.SetID(activeSet[0]) @@ -384,9 +382,8 @@ func Test_VrfSignVerify(t *testing.T) { atx2 := &types.ActivationTx{ PublishEpoch: prevEpoch, - NumUnits: 9 * 1024, + Weight: 9 * 1024, SmesherID: signer2.NodeID(), - TickCount: 1, } atx2.SetID(activeSet[1]) atx2.SetReceived(time.Now()) diff --git a/hare3/hare_test.go b/hare3/hare_test.go index c53f78fbe5..acdaa7f398 100644 --- a/hare3/hare_test.go +++ b/hare3/hare_test.go @@ -163,6 +163,7 @@ func (n *node) withAtx(min, max int) *node { } else { atx.NumUnits = uint32(min) } + atx.Weight = uint64(atx.NumUnits) * atx.TickCount id := types.ATXID{} n.t.rng.Read(id[:]) atx.SetID(id) diff --git a/malfeasance/wire/malfeasance_test.go b/malfeasance/wire/malfeasance_test.go index b367d24ee0..df927e2145 100644 --- a/malfeasance/wire/malfeasance_test.go +++ b/malfeasance/wire/malfeasance_test.go @@ -25,14 +25,11 @@ func TestMain(m *testing.M) { func TestCodec_MultipleATXs(t *testing.T) { epoch := types.EpochID(11) - a1 := types.NewActivationTx(types.NIPostChallenge{PublishEpoch: epoch}, types.Address{1, 2, 3}, 10) - a2 := types.NewActivationTx(types.NIPostChallenge{PublishEpoch: epoch}, types.Address{3, 2, 1}, 11) - var atxProof wire.AtxProof - for i, a := range []*types.ActivationTx{a1, a2} { + for i := range atxProof.Messages { atxProof.Messages[i] = wire.AtxProofMsg{ InnerMsg: types.ATXMetadata{ - PublishEpoch: a.PublishEpoch, + PublishEpoch: epoch, MsgHash: types.RandomHash(), }, SmesherID: types.RandomNodeID(), diff --git a/mesh/executor_test.go b/mesh/executor_test.go index 01330cfb6e..01645d640f 100644 --- a/mesh/executor_test.go +++ b/mesh/executor_test.go @@ -69,16 +69,17 @@ func makeResults(lid types.LayerID, txs ...types.Transaction) []types.Transactio func (t *testExecutor) createATX(epoch types.EpochID, cb types.Address) (types.ATXID, types.NodeID) { sig, err := signing.NewEdSigner() require.NoError(t.tb, err) - atx := types.NewActivationTx( - types.NIPostChallenge{PublishEpoch: epoch}, - cb, - 11, - ) - atx.VRFNonce = 1 + atx := &types.ActivationTx{ + PublishEpoch: epoch, + Coinbase: cb, + NumUnits: 11, + Weight: 11, + VRFNonce: 1, + TickCount: 1, + SmesherID: sig.NodeID(), + } atx.SetReceived(time.Now()) - atx.SmesherID = sig.NodeID() atx.SetID(types.RandomATXID()) - atx.TickCount = 1 require.NoError(t.tb, atxs.Add(t.db, atx)) t.atxsdata.AddFromAtx(atx, false) return atx.ID(), sig.NodeID() diff --git a/miner/proposal_builder_test.go b/miner/proposal_builder_test.go index 542927cc2c..3fd3fa2457 100644 --- a/miner/proposal_builder_test.go +++ b/miner/proposal_builder_test.go @@ -75,6 +75,7 @@ func gatx( PublishEpoch: epoch, TickCount: ticks, SmesherID: smesher, + Weight: uint64(units) * ticks, } atx.SetID(id) atx.SetReceived(time.Time{}.Add(1)) diff --git a/proposals/eligibility_validator_test.go b/proposals/eligibility_validator_test.go index 6030327d4f..acdcc9203c 100644 --- a/proposals/eligibility_validator_test.go +++ b/proposals/eligibility_validator_test.go @@ -27,6 +27,7 @@ func gatx( VRFNonce: nonce, TickCount: 100, SmesherID: smesher, + Weight: uint64(units) * 100, } atx.SetID(id) atx.SetReceived(time.Time{}.Add(1)) diff --git a/sql/atxs/atxs.go b/sql/atxs/atxs.go index 4f41ab4a68..423d28df65 100644 --- a/sql/atxs/atxs.go +++ b/sql/atxs/atxs.go @@ -22,7 +22,7 @@ const ( // filters that refer to the id column. const fieldsQuery = `select atxs.id, atxs.nonce, atxs.base_tick_height, atxs.tick_count, atxs.pubkey, atxs.effective_num_units, -atxs.received, atxs.epoch, atxs.sequence, atxs.coinbase, atxs.validity, atxs.prev_id, atxs.commitment_atx` +atxs.received, atxs.epoch, atxs.sequence, atxs.coinbase, atxs.validity, atxs.prev_id, atxs.commitment_atx, atxs.weight` const fullQuery = fieldsQuery + ` from atxs` @@ -61,6 +61,7 @@ func decoder(fn decoderCallback) sql.Decoder { a.CommitmentATX = new(types.ATXID) stmt.ColumnBytes(12, a.CommitmentATX[:]) } + a.Weight = uint64(stmt.ColumnInt64(13)) return fn(&a) } @@ -440,13 +441,14 @@ func Add(db sql.Executor, atx *types.ActivationTx) error { } else { stmt.BindNull(13) } + stmt.BindInt64(14, int64(atx.Weight)) } _, err := db.Exec(` insert into atxs (id, epoch, effective_num_units, commitment_atx, nonce, pubkey, received, base_tick_height, tick_count, sequence, coinbase, - validity, prev_id) - values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13)`, enc, nil) + validity, prev_id, weight) + values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14)`, enc, nil) if err != nil { return fmt.Errorf("insert ATX ID %v: %w", atx.ID(), err) } @@ -776,7 +778,7 @@ func IterateAtxsWithMalfeasance( func(s *sql.Statement) { s.BindInt64(1, int64(publish)) }, func(s *sql.Statement) bool { return decoder(func(atx *types.ActivationTx) bool { - return fn(atx, s.ColumnInt(13) != 0) + return fn(atx, s.ColumnInt(14) != 0) })(s) }, ) diff --git a/sql/migrations/state/0020_atx_weight.sql b/sql/migrations/state/0020_atx_weight.sql new file mode 100644 index 0000000000..4504bd4320 --- /dev/null +++ b/sql/migrations/state/0020_atx_weight.sql @@ -0,0 +1,2 @@ +ALTER TABLE atxs ADD COLUMN weight INTEGER; +INSERT INTO atxs (weight) SELECT effective_num_units * tick_count FROM atxs; diff --git a/tortoise/model/core.go b/tortoise/model/core.go index ce7022fa33..04381a2aa1 100644 --- a/tortoise/model/core.go +++ b/tortoise/model/core.go @@ -147,19 +147,20 @@ func (c *core) OnMessage(m Messenger, event Message) { return } - nipost := types.NIPostChallenge{ - PublishEpoch: ev.LayerID.GetEpoch(), + atx := &types.ActivationTx{ + PublishEpoch: ev.LayerID.GetEpoch(), + NumUnits: c.units, + Coinbase: types.GenerateAddress(c.signer.PublicKey().Bytes()), + SmesherID: c.signer.NodeID(), + BaseTickHeight: 1, + TickCount: 2, + Weight: uint64(c.units) * 2, } - addr := types.GenerateAddress(c.signer.PublicKey().Bytes()) - atx := types.NewActivationTx(nipost, addr, c.units) - atx.SmesherID = c.signer.NodeID() atx.SetID(types.RandomATXID()) atx.SetReceived(time.Now()) - atx.BaseTickHeight = 1 - atx.TickCount = 2 c.refBallot = nil c.atx = atx.ID() - c.weight = atx.GetWeight() + c.weight = atx.Weight m.Send(MessageAtx{Atx: atx}) case MessageBlock: diff --git a/tortoise/sim/generator.go b/tortoise/sim/generator.go index 3ebc5a82c4..d89a3be918 100644 --- a/tortoise/sim/generator.go +++ b/tortoise/sim/generator.go @@ -229,23 +229,24 @@ func (g *Generator) generateAtxs() { if err != nil { panic(err) } - address := types.GenerateAddress(sig.PublicKey().Bytes()) - nipost := types.NIPostChallenge{ - PublishEpoch: g.nextLayer.Sub(1).GetEpoch(), - } - atx := types.NewActivationTx(nipost, address, units) var ticks uint64 if g.ticks != nil { ticks = g.ticks[i] } else { ticks = uint64(intInRange(g.rng, g.ticksRange)) } - atx.SmesherID = sig.NodeID() + atx := &types.ActivationTx{ + PublishEpoch: g.nextLayer.Sub(1).GetEpoch(), + Coinbase: types.GenerateAddress(sig.PublicKey().Bytes()), + NumUnits: units, + SmesherID: sig.NodeID(), + BaseTickHeight: g.prevHeight[i], + TickCount: ticks, + Weight: uint64(units) * ticks, + } atx.SetID(types.RandomATXID()) atx.SetReceived(time.Now()) - atx.BaseTickHeight = g.prevHeight[i] - atx.TickCount = ticks g.prevHeight[i] += ticks g.activations[i] = atx for _, state := range g.states { diff --git a/tortoise/sim/layer.go b/tortoise/sim/layer.go index 5a1ba8a7d6..5bc6b74d2d 100644 --- a/tortoise/sim/layer.go +++ b/tortoise/sim/layer.go @@ -159,7 +159,7 @@ func (g *Generator) genLayer(cfg nextConf) types.LayerID { } var total uint64 for _, atx := range g.activations { - total += atx.GetWeight() + total += atx.Weight } miners := make([]uint32, len(g.activations)) @@ -182,7 +182,7 @@ func (g *Generator) genLayer(cfg nextConf) types.LayerID { if err != nil { g.logger.Panic("failed to get a beacon", zap.Error(err)) } - n, err := util.GetNumEligibleSlots(atx.GetWeight(), 0, total, g.conf.LayerSize, g.conf.LayersPerEpoch) + n, err := util.GetNumEligibleSlots(atx.Weight, 0, total, g.conf.LayerSize, g.conf.LayersPerEpoch) if err != nil { g.logger.Panic("eligible slots", zap.Error(err)) } diff --git a/tortoise/tortoise_test.go b/tortoise/tortoise_test.go index d8b2de1b9f..9d1ec119c0 100644 --- a/tortoise/tortoise_test.go +++ b/tortoise/tortoise_test.go @@ -475,8 +475,7 @@ func TestComputeExpectedWeight(t *testing.T) { eid := first + types.EpochID(i) atx := &types.ActivationTx{ PublishEpoch: eid - 1, - NumUnits: uint32(weight), - TickCount: 1, + Weight: weight, } atx.SetID(types.RandomATXID()) atx.SetReceived(time.Now()) @@ -500,7 +499,7 @@ func extractAtxsData(db sql.Executor, target types.EpochID) (uint64, uint64, err heights []uint64 ) if err := atxs.IterateAtxsOps(db, builder.FilterEpochOnly(target-1), func(atx *types.ActivationTx) bool { - weight += atx.GetWeight() + weight += atx.Weight heights = append(heights, atx.TickHeight()) return true }); err != nil {