Skip to content

Commit

Permalink
Add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mininny committed Sep 19, 2024
1 parent 4b88e18 commit 89460e6
Showing 1 changed file with 85 additions and 44 deletions.
129 changes: 85 additions & 44 deletions rvgo/fast/radix.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,38 @@ import (
"math/bits"
)

// RadixNode is an interface defining the operations for a node in a radix trie.
type RadixNode interface {
// InvalidateNode invalidates the hash cache along the path to the specified address.
InvalidateNode(addr uint64)
// GenerateProof generates the Merkle proof for the given address.
GenerateProof(addr uint64) [][32]byte
// MerkleizeNode computes the Merkle root hash for the node at the given generalized index.
MerkleizeNode(addr, gindex uint64) [32]byte
}

// SmallRadixNode is a radix trie node with a branching factor of 4 bits.
type SmallRadixNode[C RadixNode] struct {
Children [1 << 4]*C
Hashes [1 << 4][32]byte
HashExists uint16
HashValid uint16
Depth uint16
Children [1 << 4]*C // Array of child nodes, indexed by 4-bit keys.
Hashes [1 << 4][32]byte // Cached hashes for each child node.
ChildExists uint16 // Bitmask indicating which children exist (1 bit per child).
HashValid uint16 // Bitmask indicating which hashes are valid (1 bit per child).
Depth uint64 // The depth of this node in the trie (number of bits from the root).
}

// LargeRadixNode is a radix trie node with a branching factor of 8 bits.
type LargeRadixNode[C RadixNode] struct {
Children [1 << 8]*C
Hashes [1 << 8][32]byte
HashExists [(1 << 8) / 64]uint64
HashValid [(1 << 8) / 64]uint64
Depth uint16
Children [1 << 8]*C // Array of child nodes, indexed by 8-bit keys.
Hashes [1 << 8][32]byte
ChildExists [(1 << 8) / 64]uint64
HashValid [(1 << 8) / 64]uint64
Depth uint64
}

// Define a sequence of radix trie node types (L1 to L11) representing different levels in the trie.
// Each level corresponds to a node type, where L1 is the root node and L11 is the leaf level pointing to Memory.
// The cumulative bit-lengths of the addresses represented by the nodes from L1 to L11 add up to 52 bits.

type L1 = SmallRadixNode[L2]
type L2 = *SmallRadixNode[L3]
type L3 = *SmallRadixNode[L4]
Expand All @@ -38,14 +48,18 @@ type L9 = *LargeRadixNode[L10]
type L10 = *LargeRadixNode[L11]
type L11 = *Memory

// InvalidateNode invalidates the hash cache along the path to the specified address.
// It marks the necessary child hashes as invalid, forcing them to be recomputed when needed.
func (n *SmallRadixNode[C]) InvalidateNode(addr uint64) {
childIdx := addressToRadixPath(addr, n.Depth, 4)
childIdx := addressToRadixPath(addr, n.Depth, 4) // Get the 4-bit child index at the current depth.

branchIdx := (childIdx + 1<<4) / 2 // Compute the index for the hash tree traversal.

branchIdx := (childIdx + 1<<4) / 2
// Traverse up the hash tree, invalidating hashes along the way.
for index := branchIdx; index > 0; index >>= 1 {
hashBit := index & 15
n.HashExists |= 1 << hashBit
n.HashValid &= ^(1 << hashBit)
hashBit := index & 15 // Get the relevant bit position (0-15).
n.ChildExists |= 1 << hashBit // Mark the child as existing.
n.HashValid &= ^(1 << hashBit) // Invalidate the hash at this position.
}
}

Expand All @@ -57,7 +71,7 @@ func (n *LargeRadixNode[C]) InvalidateNode(addr uint64) {
for index := branchIdx; index > 0; index >>= 1 {
hashIndex := index >> 6
hashBit := index & 63
n.HashExists[hashIndex] |= 1 << hashBit
n.ChildExists[hashIndex] |= 1 << hashBit
n.HashValid[hashIndex] &= ^(1 << hashBit)
}
}
Expand All @@ -68,17 +82,23 @@ func (m *Memory) InvalidateNode(addr uint64) {
}
}

// GenerateProof generates the Merkle proof for the given address.
// It collects the necessary sibling hashes along the path to reconstruct the Merkle proof.
func (n *SmallRadixNode[C]) GenerateProof(addr uint64) [][32]byte {
var proofs [][32]byte
path := addressToRadixPath(addr, n.Depth, 4)

if n.Children[path] == nil {
// When no child exists at this path, the rest of the proofs are zero hashes.
proofs = zeroHashRange(0, 60-n.Depth-4)
} else {
// Recursively generate proofs from the child node.
proofs = (*n.Children[path]).GenerateProof(addr)
}

// Collect sibling hashes along the path for the proof.
for idx := path + 1<<4; idx > 1; idx >>= 1 {
sibling := idx ^ 1
sibling := idx ^ 1 // Get the sibling index.
proofs = append(proofs, n.MerkleizeNode(addr>>(64-n.Depth), sibling))
}

Expand Down Expand Up @@ -106,31 +126,37 @@ func (m *Memory) GenerateProof(addr uint64) [][32]byte {
pageIndex := addr >> PageAddrSize

if p, ok := m.pages[pageIndex]; ok {
return p.GenerateProof(addr)
return p.GenerateProof(addr) // Generate proof from the page.
} else {
return zeroHashRange(0, 8)
return zeroHashRange(0, 8) // Return zero hashes if the page does not exist.

Check warning on line 131 in rvgo/fast/radix.go

View check run for this annotation

Codecov / codecov/patch

rvgo/fast/radix.go#L131

Added line #L131 was not covered by tests
}
}

// MerkleizeNode computes the Merkle root hash for the node at the given generalized index.
// It recursively computes the hash of the subtree rooted at the given index.
// Note: The 'addr' parameter represents the partial address accumulated up to this node, not the full address. It represents the path taken in the trie to reach this node.
func (n *SmallRadixNode[C]) MerkleizeNode(addr, gindex uint64) [32]byte {
depth := uint16(bits.Len64(gindex))
depth := uint64(bits.Len64(gindex)) // Get the depth of the current gindex.

if depth <= 4 {
hashBit := gindex & 15

if (n.HashExists & (1 << hashBit)) != 0 {
if (n.ChildExists & (1 << hashBit)) != 0 {
if (n.HashValid & (1 << hashBit)) != 0 {
// Return the cached hash if valid.
return n.Hashes[gindex]
} else {
left := n.MerkleizeNode(addr, gindex<<1)
right := n.MerkleizeNode(addr, (gindex<<1)|1)

// Hash the pair and cache the result.
r := HashPair(left, right)
n.Hashes[gindex] = r
n.HashValid |= 1 << hashBit
return r
}
} else {
// Return zero hash for non-existent child.
return zeroHashes[64-5+1-(depth+n.Depth)]
}
}
Expand All @@ -140,21 +166,26 @@ func (n *SmallRadixNode[C]) MerkleizeNode(addr, gindex uint64) [32]byte {
}

childIndex := gindex - 1<<4

if n.Children[childIndex] == nil {
// Return zero hash if child does not exist.
return zeroHashes[64-5+1-(depth+n.Depth)]
}

// Update the partial address by appending the child index bits.
// This accumulates the address as we traverse deeper into the trie.
addr <<= 4
addr |= childIndex
return (*n.Children[childIndex]).MerkleizeNode(addr, 1)
}

func (n *LargeRadixNode[C]) MerkleizeNode(addr, gindex uint64) [32]byte {
depth := uint16(bits.Len64(gindex))
depth := uint64(bits.Len64(gindex))

if depth <= 8 {
hashIndex := gindex >> 6
hashBit := gindex & 63
if (n.HashExists[hashIndex] & (1 << hashBit)) != 0 {
if (n.ChildExists[hashIndex] & (1 << hashBit)) != 0 {
if (n.HashValid[hashIndex] & (1 << hashBit)) != 0 {
return n.Hashes[gindex]
} else {
Expand All @@ -171,7 +202,7 @@ func (n *LargeRadixNode[C]) MerkleizeNode(addr, gindex uint64) [32]byte {
}
}

if depth > 8<<1 {
if depth > 16 {
panic("gindex too deep")

Check warning on line 206 in rvgo/fast/radix.go

View check run for this annotation

Codecov / codecov/patch

rvgo/fast/radix.go#L206

Added line #L206 was not covered by tests
}

Expand All @@ -196,17 +227,19 @@ func (m *Memory) MerkleizeNode(addr, gindex uint64) [32]byte {
}
}

// MerkleRoot computes the Merkle root hash of the entire memory.
func (m *Memory) MerkleRoot() [32]byte {
return (*m.radix).MerkleizeNode(0, 1)
}

// MerkleProof generates the Merkle proof for the specified address in memory.
func (m *Memory) MerkleProof(addr uint64) [ProofLen * 32]byte {
proofs := m.radix.GenerateProof(addr)

return encodeProofs(proofs)
}

func zeroHashRange(start, end uint16) [][32]byte {
// zeroHashRange returns a slice of zero hashes from start to end.
func zeroHashRange(start, end uint64) [][32]byte {
proofs := make([][32]byte, end-start)
if start == 0 {
proofs[0] = zeroHashes[0]
Expand All @@ -218,6 +251,7 @@ func zeroHashRange(start, end uint16) [][32]byte {
return proofs
}

// encodeProofs encodes the list of proof hashes into a byte array.
func encodeProofs(proofs [][32]byte) [ProofLen * 32]byte {
var out [ProofLen * 32]byte
for i := 0; i < ProofLen; i++ {
Expand All @@ -226,37 +260,41 @@ func encodeProofs(proofs [][32]byte) [ProofLen * 32]byte {
return out
}

func addressToRadixPath(addr uint64, position, count uint16) uint64 {
// Calculate the total shift amount
totalShift := PageAddrSize + 52 - position - count
// addressToRadixPath extracts a segment of bits from an address, starting from 'position' with 'count' bits.
// It returns the extracted bits as a uint64.
func addressToRadixPath(addr, position, count uint64) uint64 {
// Calculate the total shift amount.
totalShift := 64 - position - count

// Shift the address to bring the desired bits to the LSB
// Shift the address to bring the desired bits to the LSB.
addr >>= totalShift

// Extract the desired bits using a mask
// Extract the desired bits using a mask.
return addr & ((1 << count) - 1)
}

func (m *Memory) addressToBranchPath(addr uint64) []uint64 {
addr >>= PageAddrSize

// addressToRadixPaths converts an address into a slice of radix path indices based on the branch factors.
func (m *Memory) addressToRadixPaths(addr uint64) []uint64 {
path := make([]uint64, len(m.branchFactors))
for i := len(m.branchFactors) - 1; i >= 0; i-- {
bits := m.branchFactors[i]
mask := (1 << bits) - 1 // Create a mask for the current segment
path[i] = addr & uint64(mask) // Extract the segment using the mask
addr >>= bits // Shift the gindex to the right by the number of bits processed
var position uint64

for index, branchFactor := range m.branchFactors {
path[index] = addressToRadixPath(addr, position, branchFactor)
position += branchFactor
}

return path
}

// AllocPage allocates a new page at the specified page index in memory.
func (m *Memory) AllocPage(pageIndex uint64) *CachedPage {
p := &CachedPage{Data: new(Page)}
m.pages[pageIndex] = p

addr := pageIndex << PageAddrSize
branchPaths := m.addressToBranchPath(addr)
branchPaths := m.addressToRadixPaths(addr)

// Build the radix trie path to the new page, creating nodes as necessary.
radixLevel1 := m.radix
if (*radixLevel1).Children[branchPaths[0]] == nil {
node := &SmallRadixNode[L3]{Depth: 4}
Expand Down Expand Up @@ -329,18 +367,21 @@ func (m *Memory) AllocPage(pageIndex uint64) *CachedPage {
return p
}

// Invalidate invalidates the cache along the path from the specified address up to the root.
// It ensures that any cached hashes are recomputed when needed.
func (m *Memory) Invalidate(addr uint64) {
// find page, and invalidate addr within it
// Find the page and invalidate the address within it.
if p, ok := m.pageLookup(addr >> PageAddrSize); ok {
prevValid := p.Ok[1]
if !prevValid { // if the page was already invalid before, then nodes to mem-root will also still be.
if !prevValid {
// If the page was already invalid, the nodes up to the root are also invalid.
return
}
} else { // no page? nothing to invalidate
} else {
return

Check warning on line 381 in rvgo/fast/radix.go

View check run for this annotation

Codecov / codecov/patch

rvgo/fast/radix.go#L380-L381

Added lines #L380 - L381 were not covered by tests
}

branchPaths := m.addressToBranchPath(addr)
branchPaths := m.addressToRadixPaths(addr)

currentLevel1 := m.radix
currentLevel1.InvalidateNode(addr)
Expand Down

0 comments on commit 89460e6

Please sign in to comment.