Skip to content

Commit

Permalink
Calculate and persist ATX weight in DB
Browse files Browse the repository at this point in the history
  • Loading branch information
poszu committed Jun 20, 2024
1 parent 29301a2 commit 7d0e2c6
Show file tree
Hide file tree
Showing 29 changed files with 379 additions and 214 deletions.
2 changes: 2 additions & 0 deletions activation/e2e/atx_merge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ func Test_MarryAndMerge(t *testing.T) {
require.Equal(t, totalNumUnits, atx.NumUnits)
require.Equal(t, mainID.NodeID(), atx.SmesherID)
require.Equal(t, poetProof.LeafCount/tickSize, atx.TickCount)
require.Equal(t, uint64(totalNumUnits)*atx.TickCount, atx.Weight)

posATX, err := atxs.Get(db, marriageATX.ID())
require.NoError(t, err)
Expand Down Expand Up @@ -511,6 +512,7 @@ func Test_MarryAndMerge(t *testing.T) {
require.Equal(t, totalNumUnits, atx.NumUnits)
require.Equal(t, signers[1].NodeID(), atx.SmesherID)
require.Equal(t, poetProof.LeafCount/tickSize, atx.TickCount)
require.Equal(t, uint64(totalNumUnits)*atx.TickCount, atx.Weight)

posATX, err = atxs.Get(db, mergedATX.ID())
require.NoError(t, err)
Expand Down
2 changes: 1 addition & 1 deletion activation/e2e/builds_atx_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ func TestBuilder_SwitchesToBuildV2(t *testing.T) {

require.NotZero(t, atx.BaseTickHeight)
require.NotZero(t, atx.TickCount)
require.NotZero(t, atx.GetWeight())
require.NotZero(t, atx.Weight)
require.NotZero(t, atx.TickHeight())
require.Equal(t, opts.NumUnits, atx.NumUnits)
previous = atx
Expand Down
4 changes: 2 additions & 2 deletions activation/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ func TestHandler_AtxWeight(t *testing.T) {
require.Equal(t, uint64(0), stored1.BaseTickHeight)
require.Equal(t, leaves/tickSize, stored1.TickCount)
require.Equal(t, leaves/tickSize, stored1.TickHeight())
require.Equal(t, (leaves/tickSize)*units, stored1.GetWeight())
require.Equal(t, (leaves/tickSize)*units, stored1.Weight)

atx2 := newChainedActivationTxV1(t, atx1, atx1.ID())
atx2.Sign(sig)
Expand All @@ -657,7 +657,7 @@ func TestHandler_AtxWeight(t *testing.T) {
require.Equal(t, stored1.TickHeight(), stored2.BaseTickHeight)
require.Equal(t, leaves/tickSize, stored2.TickCount)
require.Equal(t, stored1.TickHeight()+leaves/tickSize, stored2.TickHeight())
require.Equal(t, int(leaves/tickSize)*units, int(stored2.GetWeight()))
require.Equal(t, int(leaves/tickSize)*units, int(stored2.Weight))
}

func TestHandler_WrongHash(t *testing.T) {
Expand Down
6 changes: 6 additions & 0 deletions activation/handler_v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"math/bits"
"sync"
"time"

Expand Down Expand Up @@ -683,6 +684,11 @@ func (h *HandlerV1) processATX(
atx.NumUnits = effectiveNumUnits
atx.BaseTickHeight = baseTickHeight
atx.TickCount = leaves / h.tickSize
hi, weight := bits.Mul64(uint64(atx.NumUnits), atx.TickCount)
if hi != 0 {
return nil, errors.New("atx weight would overflow uint64")
}
atx.Weight = weight

proof, err = h.storeAtx(ctx, atx, watx)
if err != nil {
Expand Down
71 changes: 60 additions & 11 deletions activation/handler_v2.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package activation

import (
"cmp"
"context"
"errors"
"fmt"
"math"
"math/bits"
"slices"
"time"

Expand Down Expand Up @@ -121,9 +122,10 @@ func (h *HandlerV2) processATX(
atx := &types.ActivationTx{
PublishEpoch: watx.PublishEpoch,
Coinbase: watx.Coinbase,
NumUnits: parts.effectiveUnits,
BaseTickHeight: baseTickHeight,
TickCount: parts.leaves / h.tickSize,
NumUnits: parts.effectiveUnits,
TickCount: parts.ticks,
Weight: parts.weight,
VRFNonce: types.VRFPostIndex(watx.VRFNonce),
SmesherID: watx.SmesherID,
AtxBlob: types.AtxBlob{Blob: blob, Version: types.AtxV2},
Expand Down Expand Up @@ -487,10 +489,50 @@ func (h *HandlerV2) equivocationSet(atx *wire.ActivationTxV2) ([]types.NodeID, e
}

type atxParts struct {
leaves uint64
ticks uint64
weight uint64
effectiveUnits uint32
}

type nipostSize struct {
units uint32
ticks uint64
}

func (n *nipostSize) addUnits(units uint32) error {
sum, carry := bits.Add32(n.units, units, 0)
if carry != 0 {
return errors.New("units overflow")
}
n.units = sum
return nil
}

type nipostSizes []*nipostSize

func (n nipostSizes) minTicks() uint64 {
return slices.MinFunc(n, func(a, b *nipostSize) int { return cmp.Compare(a.ticks, b.ticks) }).ticks
}

func (n nipostSizes) sumUp() (units uint32, weight uint64, err error) {
var totalEffectiveNumUnits uint32
var totalWeight uint64
for _, ns := range n {
sum, carry := bits.Add32(totalEffectiveNumUnits, ns.units, 0)
if carry != 0 {
return 0, 0, fmt.Errorf("total units overflow (%d + %d)", totalEffectiveNumUnits, ns.units)
}
totalEffectiveNumUnits = sum

hi, weight := bits.Mul64(uint64(ns.units), ns.ticks)
if hi != 0 {
return 0, 0, fmt.Errorf("weight overflow (%d * %d)", ns.units, ns.ticks)
}
totalWeight += weight
}
return totalEffectiveNumUnits, totalWeight, nil
}

func (h *HandlerV2) verifyIncludedIDsUniqueness(atx *wire.ActivationTxV2) error {
seen := make(map[uint32]struct{})
for _, niposts := range atx.NiPosts {
Expand Down Expand Up @@ -534,8 +576,9 @@ func (h *HandlerV2) syntacticallyValidateDeps(
}

// validate previous ATXs
var totalEffectiveNumUnits uint32
for _, niposts := range atx.NiPosts {
nipostSizes := make(nipostSizes, len(atx.NiPosts))
for i, niposts := range atx.NiPosts {
nipostSizes[i] = new(nipostSize)
for _, post := range niposts.Posts {
if post.MarriageIndex >= uint32(len(equivocationSet)) {
err := fmt.Errorf("marriage index out of bounds: %d > %d", post.MarriageIndex, len(equivocationSet)-1)
Expand All @@ -551,13 +594,13 @@ func (h *HandlerV2) syntacticallyValidateDeps(
return nil, nil, fmt.Errorf("validating previous atx: %w", err)
}
}
totalEffectiveNumUnits += effectiveNumUnits
nipostSizes[i].addUnits(effectiveNumUnits)

}
}

// validate poet membership proofs
var minLeaves uint64 = math.MaxUint64
for _, niposts := range atx.NiPosts {
for i, niposts := range atx.NiPosts {
// verify PoET memberships in a single go
indexedChallenges := make(map[uint64][]byte)

Expand Down Expand Up @@ -594,7 +637,12 @@ func (h *HandlerV2) syntacticallyValidateDeps(
if err != nil {
return nil, nil, fmt.Errorf("invalid poet membership: %w", err)
}
minLeaves = min(leaves, minLeaves)
nipostSizes[i].ticks = leaves / h.tickSize
}

totalEffectiveNumUnits, totalWeight, err := nipostSizes.sumUp()
if err != nil {
return nil, nil, err
}

// validate all niposts
Expand Down Expand Up @@ -641,8 +689,9 @@ func (h *HandlerV2) syntacticallyValidateDeps(
}

parts := &atxParts{
leaves: minLeaves,
ticks: nipostSizes.minTicks(),
effectiveUnits: totalEffectiveNumUnits,
weight: totalWeight,
}

if atx.Initial == nil {
Expand Down
Loading

0 comments on commit 7d0e2c6

Please sign in to comment.