From 78b58dbd8cbd94639ae4699fffd3d3157abb3db9 Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Thu, 9 Jan 2025 12:54:25 +0200 Subject: [PATCH] resolve TODOs: add maxSize check for each dfsTrieIterator --- common/interface.go | 2 +- testscommon/state/testTrie.go | 16 ++++++ .../dfsTrieIterator/dfsTrieIterator.go | 8 ++- .../dfsTrieIterator/dfsTrieIterator_test.go | 56 +++++++++++++++---- trie/leavesRetriever/leavesRetriever.go | 24 ++++++-- trie/leavesRetriever/leavesRetriever_test.go | 14 +++++ 6 files changed, 99 insertions(+), 21 deletions(-) diff --git a/common/interface.go b/common/interface.go index 696d4b0182c..efa6b5116fd 100644 --- a/common/interface.go +++ b/common/interface.go @@ -385,7 +385,7 @@ type TrieNodeData interface { // DfsIterator is used to iterate the trie nodes in a depth-first search manner type DfsIterator interface { - GetLeaves(numLeaves int, ctx context.Context) (map[string]string, error) + GetLeaves(numLeaves int, maxSize uint64, ctx context.Context) (map[string]string, error) GetIteratorId() []byte Clone() DfsIterator FinishedIteration() bool diff --git a/testscommon/state/testTrie.go b/testscommon/state/testTrie.go index 8744009aa18..bc33a5e2b6b 100644 --- a/testscommon/state/testTrie.go +++ b/testscommon/state/testTrie.go @@ -53,3 +53,19 @@ func AddDataToTrie(tr common.Trie, numLeaves int) { } _ = tr.Commit() } + +// GetTrieWithData returns a trie with some data. +// The added data builds a rootNode that is a branch with 2 leaves and 1 extension node which will have 4 leaves when traversed; +// this way the size of the iterator will be highest when the extension node is reached but 2 leaves will +// have already been retrieved +func GetTrieWithData() common.Trie { + tr := GetNewTrie() + _ = tr.Update([]byte("key1"), []byte("value1")) + _ = tr.Update([]byte("key2"), []byte("value2")) + _ = tr.Update([]byte("key13"), []byte("value3")) + _ = tr.Update([]byte("key23"), []byte("value4")) + _ = tr.Update([]byte("key33"), []byte("value4")) + _ = tr.Update([]byte("key43"), []byte("value4")) + _ = tr.Commit() + return tr +} diff --git a/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator.go b/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator.go index 5b47e2c1dd2..2224416e282 100644 --- a/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator.go +++ b/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator.go @@ -53,11 +53,14 @@ func NewIterator(rootHash []byte, db common.TrieStorageInteractor, marshaller ma } // GetLeaves retrieves leaves from the trie. It stops either when the number of leaves is reached or the context is done. -// TODO add a maxSize that will stop the iteration when the size is reached -func (it *dfsIterator) GetLeaves(numLeaves int, ctx context.Context) (map[string]string, error) { +func (it *dfsIterator) GetLeaves(numLeaves int, maxSize uint64, ctx context.Context) (map[string]string, error) { retrievedLeaves := make(map[string]string) for { nextNodes := make([]common.TrieNodeData, 0) + if it.size >= maxSize { + return retrievedLeaves, nil + } + if len(retrievedLeaves) >= numLeaves { return retrievedLeaves, nil } @@ -140,7 +143,6 @@ func (it *dfsIterator) IsInterfaceNil() bool { return it == nil } -// TODO add context nil test func checkContextDone(ctx context.Context) bool { if ctx == nil { return false diff --git a/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator_test.go b/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator_test.go index 4489a43a437..b8d71b40173 100644 --- a/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator_test.go +++ b/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator_test.go @@ -3,6 +3,7 @@ package dfsTrieIterator import ( "context" "fmt" + "math" "testing" "github.com/multiversx/mx-chain-go/testscommon" @@ -14,6 +15,8 @@ import ( "github.com/stretchr/testify/assert" ) +var maxSize = uint64(math.MaxUint64) + func TestNewIterator(t *testing.T) { t.Parallel() @@ -94,7 +97,7 @@ func TestDfsIterator_GetLeaves(t *testing.T) { _, marshaller, hasher := trieTest.GetDefaultTrieParameters() iterator, _ := NewIterator(rootHash, dbWrapper, marshaller, hasher) - trieData, err := iterator.GetLeaves(numLeaves, ctx) + trieData, err := iterator.GetLeaves(numLeaves, maxSize, ctx) assert.Nil(t, err) assert.Equal(t, expectedNumLeaves, len(trieData)) }) @@ -109,7 +112,7 @@ func TestDfsIterator_GetLeaves(t *testing.T) { _, marshaller, hasher := trieTest.GetDefaultTrieParameters() iterator, _ := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher) - trieData, err := iterator.GetLeaves(numLeaves, context.Background()) + trieData, err := iterator.GetLeaves(numLeaves, maxSize, context.Background()) assert.Nil(t, err) assert.Equal(t, numLeaves, len(trieData)) }) @@ -125,7 +128,22 @@ func TestDfsIterator_GetLeaves(t *testing.T) { _, marshaller, hasher := trieTest.GetDefaultTrieParameters() iterator, _ := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher) - trieData, err := iterator.GetLeaves(17, context.Background()) + trieData, err := iterator.GetLeaves(17, maxSize, context.Background()) + assert.Nil(t, err) + assert.Equal(t, expectedNumRetrievedLeaves, len(trieData)) + }) + t.Run("max size reached returns retrieved leaves and saves iterator context", func(t *testing.T) { + t.Parallel() + + tr := trieTest.GetTrieWithData() + expectedNumRetrievedLeaves := 2 + rootHash, _ := tr.RootHash() + + _, marshaller, hasher := trieTest.GetDefaultTrieParameters() + iterator, _ := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher) + + iteratorMaxSize := uint64(100) + trieData, err := iterator.GetLeaves(5, iteratorMaxSize, context.Background()) assert.Nil(t, err) assert.Equal(t, expectedNumRetrievedLeaves, len(trieData)) }) @@ -142,7 +160,7 @@ func TestDfsIterator_GetLeaves(t *testing.T) { numRetrievedLeaves := 0 numIterations := 0 for numRetrievedLeaves < numLeaves { - trieData, err := iterator.GetLeaves(5, context.Background()) + trieData, err := iterator.GetLeaves(5, maxSize, context.Background()) assert.Nil(t, err) numRetrievedLeaves += len(trieData) @@ -152,6 +170,22 @@ func TestDfsIterator_GetLeaves(t *testing.T) { assert.Equal(t, numLeaves, numRetrievedLeaves) assert.Equal(t, 5, numIterations) }) + t.Run("retrieve leaves with nil iterator does not panic", func(t *testing.T) { + t.Parallel() + + tr := trieTest.GetNewTrie() + numLeaves := 25 + expectedNumRetrievedLeaves := 25 + trieTest.AddDataToTrie(tr, numLeaves) + rootHash, _ := tr.RootHash() + + _, marshaller, hasher := trieTest.GetDefaultTrieParameters() + iterator, _ := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher) + + trieData, err := iterator.GetLeaves(numLeaves, maxSize, nil) + assert.Nil(t, err) + assert.Equal(t, expectedNumRetrievedLeaves, len(trieData)) + }) } func TestDfsIterator_GetIteratorId(t *testing.T) { @@ -169,7 +203,7 @@ func TestDfsIterator_GetIteratorId(t *testing.T) { iteratorId := hasher.Compute(string(append(rootHash, iterator.nextNodes[0].GetData()...))) assert.Equal(t, iteratorId, iterator.GetIteratorId()) - trieData, err := iterator.GetLeaves(5, context.Background()) + trieData, err := iterator.GetLeaves(5, maxSize, context.Background()) assert.Nil(t, err) numRetrievedLeaves += len(trieData) @@ -211,7 +245,7 @@ func TestDfsIterator_FinishedIteration(t *testing.T) { numRetrievedLeaves := 0 for numRetrievedLeaves < numLeaves { assert.False(t, iterator.FinishedIteration()) - trieData, err := iterator.GetLeaves(5, context.Background()) + trieData, err := iterator.GetLeaves(5, maxSize, context.Background()) assert.Nil(t, err) numRetrievedLeaves += len(trieData) @@ -237,23 +271,23 @@ func TestDfsIterator_Size(t *testing.T) { iterator, _ := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher) assert.Equal(t, uint64(362), iterator.Size()) // 10 branch nodes + 1 root hash - _, err := iterator.GetLeaves(5, context.Background()) + _, err := iterator.GetLeaves(5, maxSize, context.Background()) assert.Nil(t, err) assert.Equal(t, uint64(331), iterator.Size()) // 8 branch nodes + 1 leaf node + 1 root hash - _, err = iterator.GetLeaves(5, context.Background()) + _, err = iterator.GetLeaves(5, maxSize, context.Background()) assert.Nil(t, err) assert.Equal(t, uint64(300), iterator.Size()) // 6 branch nodes + 2 leaf node + 1 root hash - _, err = iterator.GetLeaves(5, context.Background()) + _, err = iterator.GetLeaves(5, maxSize, context.Background()) assert.Nil(t, err) assert.Equal(t, uint64(197), iterator.Size()) // 5 branch nodes + 1 root hash - _, err = iterator.GetLeaves(5, context.Background()) + _, err = iterator.GetLeaves(5, maxSize, context.Background()) assert.Nil(t, err) assert.Equal(t, uint64(133), iterator.Size()) // 2 branch nodes + 1 leaf node + 1 root hash - _, err = iterator.GetLeaves(5, context.Background()) + _, err = iterator.GetLeaves(5, maxSize, context.Background()) assert.Nil(t, err) assert.Equal(t, uint64(32), iterator.Size()) // 1 root hash } diff --git a/trie/leavesRetriever/leavesRetriever.go b/trie/leavesRetriever/leavesRetriever.go index 89a11569bc0..5630a3ce7e1 100644 --- a/trie/leavesRetriever/leavesRetriever.go +++ b/trie/leavesRetriever/leavesRetriever.go @@ -78,7 +78,7 @@ func (lr *leavesRetriever) getLeavesFromCheckpoint(numLeaves int, iterator commo } func (lr *leavesRetriever) getLeavesFromIterator(iterator common.DfsIterator, numLeaves int, ctx context.Context) (map[string]string, []byte, error) { - leaves, err := iterator.GetLeaves(numLeaves, ctx) + leaves, err := iterator.GetLeaves(numLeaves, lr.maxSize, ctx) if err != nil { return nil, nil, err } @@ -92,27 +92,39 @@ func (lr *leavesRetriever) getLeavesFromIterator(iterator common.DfsIterator, nu return leaves, nil, nil } - lr.manageIterators(iteratorId, iterator) + shouldReturnId := lr.manageIterators(iteratorId, iterator) + if !shouldReturnId { + return leaves, nil, nil + } return leaves, iteratorId, nil } -func (lr *leavesRetriever) manageIterators(iteratorId []byte, iterator common.DfsIterator) { +func (lr *leavesRetriever) manageIterators(iteratorId []byte, iterator common.DfsIterator) bool { lr.mutex.Lock() defer lr.mutex.Unlock() - lr.saveIterator(iteratorId, iterator) + newIteratorPresent := lr.saveIterator(iteratorId, iterator) + if !newIteratorPresent { + return false + } lr.removeIteratorsIfMaxSizeIsExceeded() + return true } -func (lr *leavesRetriever) saveIterator(iteratorId []byte, iterator common.DfsIterator) { +func (lr *leavesRetriever) saveIterator(iteratorId []byte, iterator common.DfsIterator) bool { _, isPresent := lr.iterators[string(iteratorId)] if isPresent { - return + return true + } + + if iterator.Size() >= lr.maxSize { + return false } lr.lruIteratorIDs = append(lr.lruIteratorIDs, iteratorId) lr.iterators[string(iteratorId)] = iterator lr.size += iterator.Size() + uint64(len(iteratorId)) + return true } func (lr *leavesRetriever) markIteratorAsRecentlyUsed(iteratorId []byte) { diff --git a/trie/leavesRetriever/leavesRetriever_test.go b/trie/leavesRetriever/leavesRetriever_test.go index 28dd6131475..1605aaf6fc4 100644 --- a/trie/leavesRetriever/leavesRetriever_test.go +++ b/trie/leavesRetriever/leavesRetriever_test.go @@ -202,6 +202,20 @@ func TestLeavesRetriever_GetLeaves(t *testing.T) { assert.Equal(t, 0, len(id)) assert.Equal(t, leavesRetriever.ErrIteratorNotFound, err) }) + t.Run("max size reached on the first iteration", func(t *testing.T) { + t.Parallel() + + tr := trieTest.GetTrieWithData() + rootHash, _ := tr.RootHash() + maxSize := uint64(100) + + lr, _ := leavesRetriever.NewLeavesRetriever(tr.GetStorageManager(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, maxSize) + leaves, id1, err := lr.GetLeaves(10, rootHash, []byte(""), context.Background()) + assert.Nil(t, err) + assert.Equal(t, 2, len(leaves)) + assert.Equal(t, 0, len(id1)) + assert.Equal(t, 0, len(lr.GetIterators())) + }) } func TestLeavesRetriever_Concurrency(t *testing.T) {