From 89460e6cac7648b25db406f8db1298772cd60277 Mon Sep 17 00:00:00 2001 From: Minhyuk Kim Date: Thu, 19 Sep 2024 10:43:05 -0600 Subject: [PATCH] Add comments --- rvgo/fast/radix.go | 129 +++++++++++++++++++++++++++++---------------- 1 file changed, 85 insertions(+), 44 deletions(-) diff --git a/rvgo/fast/radix.go b/rvgo/fast/radix.go index 5ef5012..8ebcb1a 100644 --- a/rvgo/fast/radix.go +++ b/rvgo/fast/radix.go @@ -4,28 +4,38 @@ import ( "math/bits" ) +// RadixNode is an interface defining the operations for a node in a radix trie. type RadixNode interface { + // InvalidateNode invalidates the hash cache along the path to the specified address. InvalidateNode(addr uint64) + // GenerateProof generates the Merkle proof for the given address. GenerateProof(addr uint64) [][32]byte + // MerkleizeNode computes the Merkle root hash for the node at the given generalized index. MerkleizeNode(addr, gindex uint64) [32]byte } +// SmallRadixNode is a radix trie node with a branching factor of 4 bits. type SmallRadixNode[C RadixNode] struct { - Children [1 << 4]*C - Hashes [1 << 4][32]byte - HashExists uint16 - HashValid uint16 - Depth uint16 + Children [1 << 4]*C // Array of child nodes, indexed by 4-bit keys. + Hashes [1 << 4][32]byte // Cached hashes for each child node. + ChildExists uint16 // Bitmask indicating which children exist (1 bit per child). + HashValid uint16 // Bitmask indicating which hashes are valid (1 bit per child). + Depth uint64 // The depth of this node in the trie (number of bits from the root). } +// LargeRadixNode is a radix trie node with a branching factor of 8 bits. type LargeRadixNode[C RadixNode] struct { - Children [1 << 8]*C - Hashes [1 << 8][32]byte - HashExists [(1 << 8) / 64]uint64 - HashValid [(1 << 8) / 64]uint64 - Depth uint16 + Children [1 << 8]*C // Array of child nodes, indexed by 8-bit keys. + Hashes [1 << 8][32]byte + ChildExists [(1 << 8) / 64]uint64 + HashValid [(1 << 8) / 64]uint64 + Depth uint64 } +// Define a sequence of radix trie node types (L1 to L11) representing different levels in the trie. +// Each level corresponds to a node type, where L1 is the root node and L11 is the leaf level pointing to Memory. +// The cumulative bit-lengths of the addresses represented by the nodes from L1 to L11 add up to 52 bits. + type L1 = SmallRadixNode[L2] type L2 = *SmallRadixNode[L3] type L3 = *SmallRadixNode[L4] @@ -38,14 +48,18 @@ type L9 = *LargeRadixNode[L10] type L10 = *LargeRadixNode[L11] type L11 = *Memory +// InvalidateNode invalidates the hash cache along the path to the specified address. +// It marks the necessary child hashes as invalid, forcing them to be recomputed when needed. func (n *SmallRadixNode[C]) InvalidateNode(addr uint64) { - childIdx := addressToRadixPath(addr, n.Depth, 4) + childIdx := addressToRadixPath(addr, n.Depth, 4) // Get the 4-bit child index at the current depth. + + branchIdx := (childIdx + 1<<4) / 2 // Compute the index for the hash tree traversal. - branchIdx := (childIdx + 1<<4) / 2 + // Traverse up the hash tree, invalidating hashes along the way. for index := branchIdx; index > 0; index >>= 1 { - hashBit := index & 15 - n.HashExists |= 1 << hashBit - n.HashValid &= ^(1 << hashBit) + hashBit := index & 15 // Get the relevant bit position (0-15). + n.ChildExists |= 1 << hashBit // Mark the child as existing. + n.HashValid &= ^(1 << hashBit) // Invalidate the hash at this position. } } @@ -57,7 +71,7 @@ func (n *LargeRadixNode[C]) InvalidateNode(addr uint64) { for index := branchIdx; index > 0; index >>= 1 { hashIndex := index >> 6 hashBit := index & 63 - n.HashExists[hashIndex] |= 1 << hashBit + n.ChildExists[hashIndex] |= 1 << hashBit n.HashValid[hashIndex] &= ^(1 << hashBit) } } @@ -68,17 +82,23 @@ func (m *Memory) InvalidateNode(addr uint64) { } } +// GenerateProof generates the Merkle proof for the given address. +// It collects the necessary sibling hashes along the path to reconstruct the Merkle proof. func (n *SmallRadixNode[C]) GenerateProof(addr uint64) [][32]byte { var proofs [][32]byte path := addressToRadixPath(addr, n.Depth, 4) if n.Children[path] == nil { + // When no child exists at this path, the rest of the proofs are zero hashes. proofs = zeroHashRange(0, 60-n.Depth-4) } else { + // Recursively generate proofs from the child node. proofs = (*n.Children[path]).GenerateProof(addr) } + + // Collect sibling hashes along the path for the proof. for idx := path + 1<<4; idx > 1; idx >>= 1 { - sibling := idx ^ 1 + sibling := idx ^ 1 // Get the sibling index. proofs = append(proofs, n.MerkleizeNode(addr>>(64-n.Depth), sibling)) } @@ -106,31 +126,37 @@ func (m *Memory) GenerateProof(addr uint64) [][32]byte { pageIndex := addr >> PageAddrSize if p, ok := m.pages[pageIndex]; ok { - return p.GenerateProof(addr) + return p.GenerateProof(addr) // Generate proof from the page. } else { - return zeroHashRange(0, 8) + return zeroHashRange(0, 8) // Return zero hashes if the page does not exist. } } +// MerkleizeNode computes the Merkle root hash for the node at the given generalized index. +// It recursively computes the hash of the subtree rooted at the given index. +// Note: The 'addr' parameter represents the partial address accumulated up to this node, not the full address. It represents the path taken in the trie to reach this node. func (n *SmallRadixNode[C]) MerkleizeNode(addr, gindex uint64) [32]byte { - depth := uint16(bits.Len64(gindex)) + depth := uint64(bits.Len64(gindex)) // Get the depth of the current gindex. if depth <= 4 { hashBit := gindex & 15 - if (n.HashExists & (1 << hashBit)) != 0 { + if (n.ChildExists & (1 << hashBit)) != 0 { if (n.HashValid & (1 << hashBit)) != 0 { + // Return the cached hash if valid. return n.Hashes[gindex] } else { left := n.MerkleizeNode(addr, gindex<<1) right := n.MerkleizeNode(addr, (gindex<<1)|1) + // Hash the pair and cache the result. r := HashPair(left, right) n.Hashes[gindex] = r n.HashValid |= 1 << hashBit return r } } else { + // Return zero hash for non-existent child. return zeroHashes[64-5+1-(depth+n.Depth)] } } @@ -140,21 +166,26 @@ func (n *SmallRadixNode[C]) MerkleizeNode(addr, gindex uint64) [32]byte { } childIndex := gindex - 1<<4 + if n.Children[childIndex] == nil { + // Return zero hash if child does not exist. return zeroHashes[64-5+1-(depth+n.Depth)] } + + // Update the partial address by appending the child index bits. + // This accumulates the address as we traverse deeper into the trie. addr <<= 4 addr |= childIndex return (*n.Children[childIndex]).MerkleizeNode(addr, 1) } func (n *LargeRadixNode[C]) MerkleizeNode(addr, gindex uint64) [32]byte { - depth := uint16(bits.Len64(gindex)) + depth := uint64(bits.Len64(gindex)) if depth <= 8 { hashIndex := gindex >> 6 hashBit := gindex & 63 - if (n.HashExists[hashIndex] & (1 << hashBit)) != 0 { + if (n.ChildExists[hashIndex] & (1 << hashBit)) != 0 { if (n.HashValid[hashIndex] & (1 << hashBit)) != 0 { return n.Hashes[gindex] } else { @@ -171,7 +202,7 @@ func (n *LargeRadixNode[C]) MerkleizeNode(addr, gindex uint64) [32]byte { } } - if depth > 8<<1 { + if depth > 16 { panic("gindex too deep") } @@ -196,17 +227,19 @@ func (m *Memory) MerkleizeNode(addr, gindex uint64) [32]byte { } } +// MerkleRoot computes the Merkle root hash of the entire memory. func (m *Memory) MerkleRoot() [32]byte { return (*m.radix).MerkleizeNode(0, 1) } +// MerkleProof generates the Merkle proof for the specified address in memory. func (m *Memory) MerkleProof(addr uint64) [ProofLen * 32]byte { proofs := m.radix.GenerateProof(addr) - return encodeProofs(proofs) } -func zeroHashRange(start, end uint16) [][32]byte { +// zeroHashRange returns a slice of zero hashes from start to end. +func zeroHashRange(start, end uint64) [][32]byte { proofs := make([][32]byte, end-start) if start == 0 { proofs[0] = zeroHashes[0] @@ -218,6 +251,7 @@ func zeroHashRange(start, end uint16) [][32]byte { return proofs } +// encodeProofs encodes the list of proof hashes into a byte array. func encodeProofs(proofs [][32]byte) [ProofLen * 32]byte { var out [ProofLen * 32]byte for i := 0; i < ProofLen; i++ { @@ -226,37 +260,41 @@ func encodeProofs(proofs [][32]byte) [ProofLen * 32]byte { return out } -func addressToRadixPath(addr uint64, position, count uint16) uint64 { - // Calculate the total shift amount - totalShift := PageAddrSize + 52 - position - count +// addressToRadixPath extracts a segment of bits from an address, starting from 'position' with 'count' bits. +// It returns the extracted bits as a uint64. +func addressToRadixPath(addr, position, count uint64) uint64 { + // Calculate the total shift amount. + totalShift := 64 - position - count - // Shift the address to bring the desired bits to the LSB + // Shift the address to bring the desired bits to the LSB. addr >>= totalShift - // Extract the desired bits using a mask + // Extract the desired bits using a mask. return addr & ((1 << count) - 1) } -func (m *Memory) addressToBranchPath(addr uint64) []uint64 { - addr >>= PageAddrSize - +// addressToRadixPaths converts an address into a slice of radix path indices based on the branch factors. +func (m *Memory) addressToRadixPaths(addr uint64) []uint64 { path := make([]uint64, len(m.branchFactors)) - for i := len(m.branchFactors) - 1; i >= 0; i-- { - bits := m.branchFactors[i] - mask := (1 << bits) - 1 // Create a mask for the current segment - path[i] = addr & uint64(mask) // Extract the segment using the mask - addr >>= bits // Shift the gindex to the right by the number of bits processed + var position uint64 + + for index, branchFactor := range m.branchFactors { + path[index] = addressToRadixPath(addr, position, branchFactor) + position += branchFactor } + return path } +// AllocPage allocates a new page at the specified page index in memory. func (m *Memory) AllocPage(pageIndex uint64) *CachedPage { p := &CachedPage{Data: new(Page)} m.pages[pageIndex] = p addr := pageIndex << PageAddrSize - branchPaths := m.addressToBranchPath(addr) + branchPaths := m.addressToRadixPaths(addr) + // Build the radix trie path to the new page, creating nodes as necessary. radixLevel1 := m.radix if (*radixLevel1).Children[branchPaths[0]] == nil { node := &SmallRadixNode[L3]{Depth: 4} @@ -329,18 +367,21 @@ func (m *Memory) AllocPage(pageIndex uint64) *CachedPage { return p } +// Invalidate invalidates the cache along the path from the specified address up to the root. +// It ensures that any cached hashes are recomputed when needed. func (m *Memory) Invalidate(addr uint64) { - // find page, and invalidate addr within it + // Find the page and invalidate the address within it. if p, ok := m.pageLookup(addr >> PageAddrSize); ok { prevValid := p.Ok[1] - if !prevValid { // if the page was already invalid before, then nodes to mem-root will also still be. + if !prevValid { + // If the page was already invalid, the nodes up to the root are also invalid. return } - } else { // no page? nothing to invalidate + } else { return } - branchPaths := m.addressToBranchPath(addr) + branchPaths := m.addressToRadixPaths(addr) currentLevel1 := m.radix currentLevel1.InvalidateNode(addr)