Skip to content

Commit

Permalink
resolve TODOs: add maxSize check for each dfsTrieIterator
Browse files Browse the repository at this point in the history
  • Loading branch information
BeniaminDrasovean committed Jan 9, 2025
1 parent b1a8ac4 commit 78b58db
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 21 deletions.
2 changes: 1 addition & 1 deletion common/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ type TrieNodeData interface {

// DfsIterator is used to iterate the trie nodes in a depth-first search manner
type DfsIterator interface {
GetLeaves(numLeaves int, ctx context.Context) (map[string]string, error)
GetLeaves(numLeaves int, maxSize uint64, ctx context.Context) (map[string]string, error)
GetIteratorId() []byte
Clone() DfsIterator
FinishedIteration() bool
Expand Down
16 changes: 16 additions & 0 deletions testscommon/state/testTrie.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,19 @@ func AddDataToTrie(tr common.Trie, numLeaves int) {
}
_ = tr.Commit()
}

// GetTrieWithData returns a trie with some data.
// The added data builds a rootNode that is a branch with 2 leaves and 1 extension node which will have 4 leaves when traversed;
// this way the size of the iterator will be highest when the extension node is reached but 2 leaves will
// have already been retrieved
func GetTrieWithData() common.Trie {
tr := GetNewTrie()
_ = tr.Update([]byte("key1"), []byte("value1"))
_ = tr.Update([]byte("key2"), []byte("value2"))
_ = tr.Update([]byte("key13"), []byte("value3"))
_ = tr.Update([]byte("key23"), []byte("value4"))
_ = tr.Update([]byte("key33"), []byte("value4"))
_ = tr.Update([]byte("key43"), []byte("value4"))
_ = tr.Commit()
return tr
}
8 changes: 5 additions & 3 deletions trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,14 @@ func NewIterator(rootHash []byte, db common.TrieStorageInteractor, marshaller ma
}

// GetLeaves retrieves leaves from the trie. It stops either when the number of leaves is reached or the context is done.
// TODO add a maxSize that will stop the iteration when the size is reached
func (it *dfsIterator) GetLeaves(numLeaves int, ctx context.Context) (map[string]string, error) {
func (it *dfsIterator) GetLeaves(numLeaves int, maxSize uint64, ctx context.Context) (map[string]string, error) {
retrievedLeaves := make(map[string]string)
for {
nextNodes := make([]common.TrieNodeData, 0)
if it.size >= maxSize {
return retrievedLeaves, nil
}

if len(retrievedLeaves) >= numLeaves {
return retrievedLeaves, nil
}
Expand Down Expand Up @@ -140,7 +143,6 @@ func (it *dfsIterator) IsInterfaceNil() bool {
return it == nil
}

// TODO add context nil test
func checkContextDone(ctx context.Context) bool {
if ctx == nil {
return false
Expand Down
56 changes: 45 additions & 11 deletions trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package dfsTrieIterator
import (
"context"
"fmt"
"math"
"testing"

"github.com/multiversx/mx-chain-go/testscommon"
Expand All @@ -14,6 +15,8 @@ import (
"github.com/stretchr/testify/assert"
)

var maxSize = uint64(math.MaxUint64)

func TestNewIterator(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -94,7 +97,7 @@ func TestDfsIterator_GetLeaves(t *testing.T) {
_, marshaller, hasher := trieTest.GetDefaultTrieParameters()
iterator, _ := NewIterator(rootHash, dbWrapper, marshaller, hasher)

trieData, err := iterator.GetLeaves(numLeaves, ctx)
trieData, err := iterator.GetLeaves(numLeaves, maxSize, ctx)
assert.Nil(t, err)
assert.Equal(t, expectedNumLeaves, len(trieData))
})
Expand All @@ -109,7 +112,7 @@ func TestDfsIterator_GetLeaves(t *testing.T) {
_, marshaller, hasher := trieTest.GetDefaultTrieParameters()
iterator, _ := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher)

trieData, err := iterator.GetLeaves(numLeaves, context.Background())
trieData, err := iterator.GetLeaves(numLeaves, maxSize, context.Background())
assert.Nil(t, err)
assert.Equal(t, numLeaves, len(trieData))
})
Expand All @@ -125,7 +128,22 @@ func TestDfsIterator_GetLeaves(t *testing.T) {
_, marshaller, hasher := trieTest.GetDefaultTrieParameters()
iterator, _ := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher)

trieData, err := iterator.GetLeaves(17, context.Background())
trieData, err := iterator.GetLeaves(17, maxSize, context.Background())
assert.Nil(t, err)
assert.Equal(t, expectedNumRetrievedLeaves, len(trieData))
})
t.Run("max size reached returns retrieved leaves and saves iterator context", func(t *testing.T) {
t.Parallel()

tr := trieTest.GetTrieWithData()
expectedNumRetrievedLeaves := 2
rootHash, _ := tr.RootHash()

_, marshaller, hasher := trieTest.GetDefaultTrieParameters()
iterator, _ := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher)

iteratorMaxSize := uint64(100)
trieData, err := iterator.GetLeaves(5, iteratorMaxSize, context.Background())
assert.Nil(t, err)
assert.Equal(t, expectedNumRetrievedLeaves, len(trieData))
})
Expand All @@ -142,7 +160,7 @@ func TestDfsIterator_GetLeaves(t *testing.T) {
numRetrievedLeaves := 0
numIterations := 0
for numRetrievedLeaves < numLeaves {
trieData, err := iterator.GetLeaves(5, context.Background())
trieData, err := iterator.GetLeaves(5, maxSize, context.Background())
assert.Nil(t, err)

numRetrievedLeaves += len(trieData)
Expand All @@ -152,6 +170,22 @@ func TestDfsIterator_GetLeaves(t *testing.T) {
assert.Equal(t, numLeaves, numRetrievedLeaves)
assert.Equal(t, 5, numIterations)
})
t.Run("retrieve leaves with nil iterator does not panic", func(t *testing.T) {
t.Parallel()

tr := trieTest.GetNewTrie()
numLeaves := 25
expectedNumRetrievedLeaves := 25
trieTest.AddDataToTrie(tr, numLeaves)
rootHash, _ := tr.RootHash()

_, marshaller, hasher := trieTest.GetDefaultTrieParameters()
iterator, _ := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher)

trieData, err := iterator.GetLeaves(numLeaves, maxSize, nil)
assert.Nil(t, err)
assert.Equal(t, expectedNumRetrievedLeaves, len(trieData))
})
}

func TestDfsIterator_GetIteratorId(t *testing.T) {
Expand All @@ -169,7 +203,7 @@ func TestDfsIterator_GetIteratorId(t *testing.T) {
iteratorId := hasher.Compute(string(append(rootHash, iterator.nextNodes[0].GetData()...)))
assert.Equal(t, iteratorId, iterator.GetIteratorId())

trieData, err := iterator.GetLeaves(5, context.Background())
trieData, err := iterator.GetLeaves(5, maxSize, context.Background())
assert.Nil(t, err)

numRetrievedLeaves += len(trieData)
Expand Down Expand Up @@ -211,7 +245,7 @@ func TestDfsIterator_FinishedIteration(t *testing.T) {
numRetrievedLeaves := 0
for numRetrievedLeaves < numLeaves {
assert.False(t, iterator.FinishedIteration())
trieData, err := iterator.GetLeaves(5, context.Background())
trieData, err := iterator.GetLeaves(5, maxSize, context.Background())
assert.Nil(t, err)

numRetrievedLeaves += len(trieData)
Expand All @@ -237,23 +271,23 @@ func TestDfsIterator_Size(t *testing.T) {
iterator, _ := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher)
assert.Equal(t, uint64(362), iterator.Size()) // 10 branch nodes + 1 root hash

_, err := iterator.GetLeaves(5, context.Background())
_, err := iterator.GetLeaves(5, maxSize, context.Background())
assert.Nil(t, err)
assert.Equal(t, uint64(331), iterator.Size()) // 8 branch nodes + 1 leaf node + 1 root hash

_, err = iterator.GetLeaves(5, context.Background())
_, err = iterator.GetLeaves(5, maxSize, context.Background())
assert.Nil(t, err)
assert.Equal(t, uint64(300), iterator.Size()) // 6 branch nodes + 2 leaf node + 1 root hash

_, err = iterator.GetLeaves(5, context.Background())
_, err = iterator.GetLeaves(5, maxSize, context.Background())
assert.Nil(t, err)
assert.Equal(t, uint64(197), iterator.Size()) // 5 branch nodes + 1 root hash

_, err = iterator.GetLeaves(5, context.Background())
_, err = iterator.GetLeaves(5, maxSize, context.Background())
assert.Nil(t, err)
assert.Equal(t, uint64(133), iterator.Size()) // 2 branch nodes + 1 leaf node + 1 root hash

_, err = iterator.GetLeaves(5, context.Background())
_, err = iterator.GetLeaves(5, maxSize, context.Background())
assert.Nil(t, err)
assert.Equal(t, uint64(32), iterator.Size()) // 1 root hash
}
24 changes: 18 additions & 6 deletions trie/leavesRetriever/leavesRetriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func (lr *leavesRetriever) getLeavesFromCheckpoint(numLeaves int, iterator commo
}

func (lr *leavesRetriever) getLeavesFromIterator(iterator common.DfsIterator, numLeaves int, ctx context.Context) (map[string]string, []byte, error) {
leaves, err := iterator.GetLeaves(numLeaves, ctx)
leaves, err := iterator.GetLeaves(numLeaves, lr.maxSize, ctx)
if err != nil {
return nil, nil, err
}
Expand All @@ -92,27 +92,39 @@ func (lr *leavesRetriever) getLeavesFromIterator(iterator common.DfsIterator, nu
return leaves, nil, nil
}

lr.manageIterators(iteratorId, iterator)
shouldReturnId := lr.manageIterators(iteratorId, iterator)
if !shouldReturnId {
return leaves, nil, nil
}
return leaves, iteratorId, nil
}

func (lr *leavesRetriever) manageIterators(iteratorId []byte, iterator common.DfsIterator) {
func (lr *leavesRetriever) manageIterators(iteratorId []byte, iterator common.DfsIterator) bool {
lr.mutex.Lock()
defer lr.mutex.Unlock()

lr.saveIterator(iteratorId, iterator)
newIteratorPresent := lr.saveIterator(iteratorId, iterator)
if !newIteratorPresent {
return false
}
lr.removeIteratorsIfMaxSizeIsExceeded()
return true
}

func (lr *leavesRetriever) saveIterator(iteratorId []byte, iterator common.DfsIterator) {
func (lr *leavesRetriever) saveIterator(iteratorId []byte, iterator common.DfsIterator) bool {
_, isPresent := lr.iterators[string(iteratorId)]
if isPresent {
return
return true
}

if iterator.Size() >= lr.maxSize {
return false
}

lr.lruIteratorIDs = append(lr.lruIteratorIDs, iteratorId)
lr.iterators[string(iteratorId)] = iterator
lr.size += iterator.Size() + uint64(len(iteratorId))
return true
}

func (lr *leavesRetriever) markIteratorAsRecentlyUsed(iteratorId []byte) {
Expand Down
14 changes: 14 additions & 0 deletions trie/leavesRetriever/leavesRetriever_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,20 @@ func TestLeavesRetriever_GetLeaves(t *testing.T) {
assert.Equal(t, 0, len(id))
assert.Equal(t, leavesRetriever.ErrIteratorNotFound, err)
})
t.Run("max size reached on the first iteration", func(t *testing.T) {
t.Parallel()

tr := trieTest.GetTrieWithData()
rootHash, _ := tr.RootHash()
maxSize := uint64(100)

lr, _ := leavesRetriever.NewLeavesRetriever(tr.GetStorageManager(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, maxSize)
leaves, id1, err := lr.GetLeaves(10, rootHash, []byte(""), context.Background())
assert.Nil(t, err)
assert.Equal(t, 2, len(leaves))
assert.Equal(t, 0, len(id1))
assert.Equal(t, 0, len(lr.GetIterators()))
})
}

func TestLeavesRetriever_Concurrency(t *testing.T) {
Expand Down

0 comments on commit 78b58db

Please sign in to comment.