From bd6158c549271a385c99ced29ccda0fafb373fff Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Mon, 11 Nov 2024 17:19:23 +0200 Subject: [PATCH 01/13] refactor trie root hash computation --- common/interface.go | 2 +- node/node.go | 2 +- testscommon/trie/trieStub.go | 6 +- trie/baseIterator.go | 21 ++-- trie/baseIterator_test.go | 35 +++--- trie/branchNode.go | 160 +++++++++++----------------- trie/branchNode_test.go | 182 +++++++++----------------------- trie/dfsIterator.go | 4 +- trie/dfsIterator_test.go | 10 +- trie/errors.go | 3 + trie/export_test.go | 4 +- trie/extensionNode.go | 77 +++++--------- trie/extensionNode_test.go | 116 +++++++------------- trie/interceptedNode_test.go | 10 +- trie/interface.go | 7 +- trie/leafNode.go | 45 ++------ trie/leafNode_test.go | 37 +------ trie/node.go | 49 ++------- trie/node_test.go | 66 +++--------- trie/patriciaMerkleTrie.go | 55 ++++++---- trie/patriciaMerkleTrie_test.go | 27 +++-- trie/sync.go | 18 +++- trie/sync_test.go | 7 +- 23 files changed, 334 insertions(+), 609 deletions(-) diff --git a/common/interface.go b/common/interface.go index 0c709f44356..19809754471 100644 --- a/common/interface.go +++ b/common/interface.go @@ -51,7 +51,7 @@ type Trie interface { GetSerializedNodes([]byte, uint64) ([][]byte, uint64, error) GetSerializedNode([]byte) ([]byte, error) GetAllLeavesOnChannel(allLeavesChan *TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder KeyBuilder, trieLeafParser TrieLeafParser) error - GetProof(key []byte) ([][]byte, []byte, error) + GetProof(key []byte, rootHash []byte) ([][]byte, []byte, error) VerifyProof(rootHash []byte, key []byte, proof [][]byte) (bool, error) GetStorageManager() StorageManager IsMigratedToLatestVersion() (bool, error) diff --git a/node/node.go b/node/node.go index 001bbd23f30..684afa774a0 100644 --- a/node/node.go +++ b/node/node.go @@ -1438,7 +1438,7 @@ func (n *Node) getProof(rootHash []byte, key []byte) (*common.GetProofResponse, return nil, err } - computedProof, value, err := tr.GetProof(key) + computedProof, value, err := tr.GetProof(key, rootHash) if err != nil { return nil, err } diff --git a/testscommon/trie/trieStub.go b/testscommon/trie/trieStub.go index 3644d5cc0cd..8ab3ab07c16 100644 --- a/testscommon/trie/trieStub.go +++ b/testscommon/trie/trieStub.go @@ -25,7 +25,7 @@ type TrieStub struct { AppendToOldHashesCalled func([][]byte) GetSerializedNodesCalled func([]byte, uint64) ([][]byte, uint64, error) GetAllLeavesOnChannelCalled func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder, trieLeafParser common.TrieLeafParser) error - GetProofCalled func(key []byte) ([][]byte, []byte, error) + GetProofCalled func(key []byte, rootHash []byte) ([][]byte, []byte, error) VerifyProofCalled func(rootHash []byte, key []byte, proof [][]byte) (bool, error) GetStorageManagerCalled func() common.StorageManager GetSerializedNodeCalled func(bytes []byte) ([]byte, error) @@ -45,9 +45,9 @@ func (ts *TrieStub) GetStorageManager() common.StorageManager { } // GetProof - -func (ts *TrieStub) GetProof(key []byte) ([][]byte, []byte, error) { +func (ts *TrieStub) GetProof(key []byte, rootHash []byte) ([][]byte, []byte, error) { if ts.GetProofCalled != nil { - return ts.GetProofCalled(key) + return ts.GetProofCalled(key, rootHash) } return nil, nil, nil diff --git a/trie/baseIterator.go b/trie/baseIterator.go index f4889c51154..049cac9462e 100644 --- a/trie/baseIterator.go +++ b/trie/baseIterator.go @@ -12,11 +12,16 @@ type baseIterator struct { } // newBaseIterator creates a new instance of trie iterator -func newBaseIterator(trie common.Trie) (*baseIterator, error) { +func newBaseIterator(trie common.Trie, rootHash []byte) (*baseIterator, error) { if check.IfNil(trie) { return nil, ErrNilTrie } + trie, err := trie.Recreate(rootHash) + if err != nil { + return nil, err + } + pmt, ok := trie.(*patriciaMerkleTrie) if !ok { return nil, ErrWrongTypeAssertion @@ -56,20 +61,10 @@ func (it *baseIterator) next() ([]node, error) { // MarshalizedNode marshalizes the current node, and then returns the serialized node func (it *baseIterator) MarshalizedNode() ([]byte, error) { - err := it.currentNode.setHash() - if err != nil { - return nil, err - } - return it.currentNode.getEncodedNode() } // GetHash returns the current node hash -func (it *baseIterator) GetHash() ([]byte, error) { - err := it.currentNode.setHash() - if err != nil { - return nil, err - } - - return it.currentNode.getHash(), nil +func (it *baseIterator) GetHash() []byte { + return it.currentNode.getHash() } diff --git a/trie/baseIterator_test.go b/trie/baseIterator_test.go index 59307440ff3..c65fc89f0cc 100644 --- a/trie/baseIterator_test.go +++ b/trie/baseIterator_test.go @@ -14,8 +14,9 @@ func TestNewBaseIterator(t *testing.T) { t.Parallel() tr := initTrie() + rootHash, _ := tr.RootHash() - it, err := trie.NewBaseIterator(tr) + it, err := trie.NewBaseIterator(tr, rootHash) assert.Nil(t, err) assert.NotNil(t, it) } @@ -25,7 +26,7 @@ func TestNewBaseIteratorNilTrieShouldErr(t *testing.T) { var tr common.Trie - it, err := trie.NewBaseIterator(tr) + it, err := trie.NewBaseIterator(tr, nil) assert.Nil(t, it) assert.Equal(t, trie.ErrNilTrie, err) } @@ -35,13 +36,15 @@ func TestBaseIterator_HasNext(t *testing.T) { tr := emptyTrie() _ = tr.Update([]byte("dog"), []byte("dog")) - trie.ExecuteUpdatesFromBatch(tr) - it, _ := trie.NewBaseIterator(tr) + _ = tr.Commit() + rootHash, _ := tr.RootHash() + it, _ := trie.NewBaseIterator(tr, rootHash) assert.False(t, it.HasNext()) _ = tr.Update([]byte("doe"), []byte("doe")) - trie.ExecuteUpdatesFromBatch(tr) - it, _ = trie.NewBaseIterator(tr) + _ = tr.Commit() + rootHash, _ = tr.RootHash() + it, _ = trie.NewBaseIterator(tr, rootHash) assert.True(t, it.HasNext()) } @@ -49,7 +52,8 @@ func TestBaseIterator_GetMarshalizedNode(t *testing.T) { t.Parallel() tr := initTrie() - it, _ := trie.NewBaseIterator(tr) + rootHash, _ := tr.RootHash() + it, _ := trie.NewBaseIterator(tr, rootHash) encNode, err := it.MarshalizedNode() assert.Nil(t, err) @@ -64,11 +68,11 @@ func TestBaseIterator_GetHash(t *testing.T) { t.Parallel() tr := initTrie() + _ = tr.Commit() rootHash, _ := tr.RootHash() - it, _ := trie.NewBaseIterator(tr) + it, _ := trie.NewBaseIterator(tr, rootHash) - hash, err := it.GetHash() - assert.Nil(t, err) + hash := it.GetHash() assert.Equal(t, rootHash, hash) } @@ -80,7 +84,7 @@ func TestIterator_Search(t *testing.T) { _ = tr.Update([]byte("dog"), []byte("puppy")) _ = tr.Update([]byte("ddog"), []byte("cat")) _ = tr.Update([]byte("ddoge"), []byte("foo")) - trie.ExecuteUpdatesFromBatch(tr) + _ = tr.Commit() expectedHashes := []string{ "ecc2304769996585131ad6276c1422265813a2b79d60392130c4baa19a9b4e06", @@ -109,9 +113,9 @@ func TestIterator_Search(t *testing.T) { expectedHashes[8], } - it, _ := trie.NewDFSIterator(tr) - nodeHash, err := it.GetHash() - require.Nil(t, err) + rootHash, _ := tr.RootHash() + it, _ := trie.NewDFSIterator(tr, rootHash) + nodeHash := it.GetHash() nodesHashes := make([]string, 0) nodesHashes = append(nodesHashes, hex.EncodeToString(nodeHash)) @@ -120,8 +124,7 @@ func TestIterator_Search(t *testing.T) { err := it.Next() require.Nil(t, err) - nodeHash, err := it.GetHash() - require.Nil(t, err) + nodeHash := it.GetHash() nodesHashes = append(nodesHashes, hex.EncodeToString(nodeHash)) } diff --git a/trie/branchNode.go b/trie/branchNode.go index ac5f90ae3b0..301a4efe47f 100644 --- a/trie/branchNode.go +++ b/trie/branchNode.go @@ -75,16 +75,8 @@ func (bn *branchNode) getCollapsedBn() (*branchNode, error) { collapsed := bn.clone() for i := range bn.children { if bn.children[i] != nil { - var ok bool - ok, err = hasValidHash(bn.children[i]) - if err != nil { - return nil, err - } - if !ok { - err = bn.children[i].setHash() - if err != nil { - return nil, err - } + if !hasValidHash(bn.children[i]) { + return nil, ErrNodeHashIsNotSet } collapsed.EncodedChildren[i] = bn.children[i].getHash() collapsed.children[i] = nil @@ -93,116 +85,81 @@ func (bn *branchNode) getCollapsedBn() (*branchNode, error) { return collapsed, nil } -func (bn *branchNode) setHash() error { - err := bn.isEmptyOrNil() - if err != nil { - return fmt.Errorf("setHash error %w", err) - } - if bn.getHash() != nil { - return nil - } - if bn.isCollapsed() { - var hash []byte - hash, err = encodeNodeAndGetHash(bn) - if err != nil { - return err - } - bn.hash = hash - return nil - } - hash, err := hashChildrenAndNode(bn) - if err != nil { - return err +func (bn *branchNode) setHash(goRoutinesManager common.TrieGoroutinesManager) { + if len(bn.hash) != 0 { + return } - bn.hash = hash - return nil -} -func (bn *branchNode) setRootHash() error { - err := bn.isEmptyOrNil() - if err != nil { - return fmt.Errorf("setRootHash error %w", err) - } - if bn.getHash() != nil { - return nil - } - if bn.isCollapsed() { - var hash []byte - hash, err = encodeNodeAndGetHash(bn) - if err != nil { - return err - } - bn.hash = hash - return nil - } + waitGroup := sync.WaitGroup{} - var wg sync.WaitGroup - errc := make(chan error, nrOfChildren) + encodedChildrenMutex := &sync.Mutex{} + encodedChildren := make([][]byte, nrOfChildren) for i := 0; i < nrOfChildren; i++ { - if bn.children[i] != nil { - wg.Add(1) - go bn.children[i].setHashConcurrent(&wg, errc) + if !goRoutinesManager.ShouldContinueProcessing() { + return } - } - wg.Wait() - if len(errc) != 0 { - for err = range errc { - return err + + if !bn.shouldSetHashForChild(i) { + continue } - } - hashed, err := bn.hashNode() - if err != nil { - return err - } + if !goRoutinesManager.CanStartGoRoutine() { + bn.children[i].setHash(goRoutinesManager) + encChild, err := encodeNodeAndGetHash(bn.children[i]) + if err != nil { + goRoutinesManager.SetError(err) + return + } - bn.hash = hashed - return nil -} + encodedChildrenMutex.Lock() + encodedChildren[i] = encChild + encodedChildrenMutex.Unlock() + continue + } -func (bn *branchNode) setHashConcurrent(wg *sync.WaitGroup, c chan error) { - defer wg.Done() - err := bn.isEmptyOrNil() - if err != nil { - c <- fmt.Errorf("setHashConcurrent error %w", err) - return - } - if bn.getHash() != nil { - return + waitGroup.Add(1) + go func(childPos int) { + bn.children[childPos].setHash(goRoutinesManager) + encChild, err := encodeNodeAndGetHash(bn.children[childPos]) + if err != nil { + goRoutinesManager.SetError(err) + return + } + encodedChildrenMutex.Lock() + encodedChildren[childPos] = encChild + encodedChildrenMutex.Unlock() + waitGroup.Done() + }(i) } - if bn.isCollapsed() { - var hash []byte - hash, err = encodeNodeAndGetHash(bn) - if err != nil { - c <- err - return + + waitGroup.Wait() + + for i := range encodedChildren { + if len(encodedChildren[i]) == 0 { + continue } - bn.hash = hash - return + + bn.EncodedChildren[i] = encodedChildren[i] } - hash, err := hashChildrenAndNode(bn) + + hash, err := encodeNodeAndGetHash(bn) if err != nil { - c <- err + goRoutinesManager.SetError(err) return } bn.hash = hash } -func (bn *branchNode) hashChildren() error { - err := bn.isEmptyOrNil() - if err != nil { - return fmt.Errorf("hashChildren error %w", err) - } - for i := 0; i < nrOfChildren; i++ { - if bn.children[i] != nil { - err = bn.children[i].setHash() - if err != nil { - return err - } - } +func (bn *branchNode) shouldSetHashForChild(childPos int) bool { + bn.childrenMutexes[childPos].RLock() + defer bn.childrenMutexes[childPos].RUnlock() + + if bn.children[childPos] != nil && bn.EncodedChildren[childPos] == nil { + return true } - return nil + + return false } func (bn *branchNode) hashNode() ([]byte, error) { @@ -599,6 +556,7 @@ func (bn *branchNode) modifyNodeAfterInsert( } bn.children[i] = newBnChildren[i] + bn.EncodedChildren[i] = nil bn.setVersionForChild(childVersion, byte(i)) } @@ -754,9 +712,9 @@ func (bn *branchNode) setNewChildren( newChildrenMap.Range(func(childPos int, newChild node) { bn.children[childPos] = newChild + bn.EncodedChildren[childPos] = nil if check.IfNil(newChild) { bn.setVersionForChild(core.NotSpecified, byte(childPos)) - bn.EncodedChildren[childPos] = nil return } diff --git a/trie/branchNode_test.go b/trie/branchNode_test.go index 3ba182cc1d4..d6ed4341de2 100644 --- a/trie/branchNode_test.go +++ b/trie/branchNode_test.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "errors" - "fmt" "github.com/multiversx/mx-chain-core-go/core/atomic" "testing" @@ -24,6 +23,12 @@ import ( "github.com/stretchr/testify/assert" ) +func getTestGoroutinesManager() common.TrieGoroutinesManager { + th, _ := throttler.NewNumGoRoutinesThrottler(1) + manager, _ := NewGoroutinesManager(th, errChan.NewErrChanWrapper(), make(chan struct{})) + return manager +} + func getTestMarshalizerAndHasher() (marshal.Marshalizer, hashing.Hasher) { marsh := &marshal.GogoProtoMarshalizer{} hash := &testscommon.KeccakMock{} @@ -99,20 +104,21 @@ func initTrie() *patriciaMerkleTrie { _ = tr.Update([]byte("doe"), []byte("reindeer")) _ = tr.Update([]byte("dog"), []byte("puppy")) _ = tr.Update([]byte("ddog"), []byte("cat")) - ExecuteUpdatesFromBatch(tr) + _ = tr.Commit() return tr } func getEncodedTrieNodesAndHashes(tr common.Trie) ([][]byte, [][]byte) { - it, _ := NewDFSIterator(tr) + rootHash, _ := tr.RootHash() + it, _ := NewDFSIterator(tr, rootHash) encNode, _ := it.MarshalizedNode() nodes := make([][]byte, 0) nodes = append(nodes, encNode) hashes := make([][]byte, 0) - hash, _ := it.GetHash() + hash := it.GetHash() hashes = append(hashes, hash) for it.HasNext() { @@ -120,7 +126,7 @@ func getEncodedTrieNodesAndHashes(tr common.Trie) ([][]byte, [][]byte) { encNode, _ = it.MarshalizedNode() nodes = append(nodes, encNode) - hash, _ = it.GetHash() + hash = it.GetHash() hashes = append(hashes, hash) } @@ -148,7 +154,9 @@ func TestBranchNode_getCollapsed(t *testing.T) { t.Parallel() bn, collapsedBn := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) + bn.setHash(getTestGoroutinesManager()) collapsedBn.dirty = true + collapsedBn.hash = bn.hash collapsed, err := bn.getCollapsed() assert.Nil(t, err) @@ -190,81 +198,22 @@ func TestBranchNode_setHash(t *testing.T) { bn, collapsedBn := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) hash, _ := encodeNodeAndGetHash(collapsedBn) + manager := getTestGoroutinesManager() - err := bn.setHash() - assert.Nil(t, err) + bn.setHash(manager) + assert.Nil(t, manager.GetError()) assert.Equal(t, hash, bn.hash) } -func TestBranchNode_setRootHash(t *testing.T) { - t.Parallel() - - marsh, hsh := getTestMarshalizerAndHasher() - - trieStorage1, _ := NewTrieStorageManager(GetDefaultTrieStorageManagerParameters()) - trieStorage2, _ := NewTrieStorageManager(GetDefaultTrieStorageManagerParameters()) - maxTrieLevelInMemory := uint(5) - - tr1, _ := NewTrie(trieStorage1, marsh, hsh, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, maxTrieLevelInMemory) - tr2, _ := NewTrie(trieStorage2, marsh, hsh, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, maxTrieLevelInMemory) - - maxIterations := 10000 - for i := 0; i < maxIterations; i++ { - val := hsh.Compute(fmt.Sprint(i)) - _ = tr1.Update(val, val) - _ = tr2.Update(val, val) - } - - ExecuteUpdatesFromBatch(tr1) - ExecuteUpdatesFromBatch(tr2) - - rootNode1 := tr1.GetRootNode() - rootNode2 := tr2.GetRootNode() - err := rootNode1.setRootHash() - _ = rootNode2.setHash() - assert.Nil(t, err) - assert.Equal(t, rootNode1.getHash(), rootNode2.getHash()) -} - -func TestBranchNode_setRootHashCollapsedNode(t *testing.T) { - t.Parallel() - - _, collapsedBn := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) - hash, _ := encodeNodeAndGetHash(collapsedBn) - - err := collapsedBn.setRootHash() - assert.Nil(t, err) - assert.Equal(t, hash, collapsedBn.hash) -} - -func TestBranchNode_setHashEmptyNode(t *testing.T) { - t.Parallel() - - bn := emptyDirtyBranchNode() - - err := bn.setHash() - assert.True(t, errors.Is(err, ErrEmptyBranchNode)) - assert.Nil(t, bn.hash) -} - -func TestBranchNode_setHashNilNode(t *testing.T) { - t.Parallel() - - var bn *branchNode - - err := bn.setHash() - assert.True(t, errors.Is(err, ErrNilBranchNode)) - assert.Nil(t, bn) -} - func TestBranchNode_setHashCollapsedNode(t *testing.T) { t.Parallel() _, collapsedBn := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) hash, _ := encodeNodeAndGetHash(collapsedBn) + manager := getTestGoroutinesManager() - err := collapsedBn.setHash() - assert.Nil(t, err) + collapsedBn.setHash(manager) + assert.Nil(t, manager.GetError()) assert.Equal(t, hash, collapsedBn.hash) } @@ -278,57 +227,6 @@ func TestBranchNode_setGivenHash(t *testing.T) { assert.Equal(t, expectedHash, bn.hash) } -func TestBranchNode_hashChildren(t *testing.T) { - t.Parallel() - - bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) - - for i := range bn.children { - if bn.children[i] != nil { - assert.Nil(t, bn.children[i].getHash()) - } - } - err := bn.hashChildren() - assert.Nil(t, err) - - for i := range bn.children { - if bn.children[i] != nil { - childHash, _ := encodeNodeAndGetHash(bn.children[i]) - assert.Equal(t, childHash, bn.children[i].getHash()) - } - } -} - -func TestBranchNode_hashChildrenEmptyNode(t *testing.T) { - t.Parallel() - - bn := emptyDirtyBranchNode() - - err := bn.hashChildren() - assert.True(t, errors.Is(err, ErrEmptyBranchNode)) -} - -func TestBranchNode_hashChildrenNilNode(t *testing.T) { - t.Parallel() - - var bn *branchNode - - err := bn.hashChildren() - assert.True(t, errors.Is(err, ErrNilBranchNode)) -} - -func TestBranchNode_hashChildrenCollapsedNode(t *testing.T) { - t.Parallel() - - _, collapsedBn := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) - - err := collapsedBn.hashChildren() - assert.Nil(t, err) - - _, collapsedBn2 := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) - assert.Equal(t, collapsedBn2, collapsedBn) -} - func TestBranchNode_hashNode(t *testing.T) { t.Parallel() @@ -368,7 +266,7 @@ func TestBranchNode_commit(t *testing.T) { bn, collapsedBn := getBnAndCollapsedBn(marsh, hasher) hash, _ := encodeNodeAndGetHash(collapsedBn) - _ = bn.setHash() + bn.setHash(getTestGoroutinesManager()) err := bn.commitDirty(0, 5, db, db) assert.Nil(t, err) @@ -431,14 +329,14 @@ func TestBranchNode_getEncodedNodeNil(t *testing.T) { assert.Nil(t, encNode) } -func TestBranchNode_resolveCollapsed(t *testing.T) { +func TestBranchNode_resolveIfCollapsed(t *testing.T) { t.Parallel() db := testscommon.NewMemDbMock() bn, collapsedBn := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) childPos := byte(2) - _ = bn.setHash() + bn.setHash(getTestGoroutinesManager()) _ = bn.commitDirty(0, 5, db, db) resolved, _ := newLeafNode(getTrieDataWithDefaultVersion("dog", "dog"), bn.marsh, bn.hasher) resolved.dirty = false @@ -527,7 +425,7 @@ func TestBranchNode_tryGetCollapsedNode(t *testing.T) { db := testscommon.NewMemDbMock() bn, collapsedBn := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) - _ = bn.setHash() + bn.setHash(getTestGoroutinesManager()) _ = bn.commitDirty(0, 5, db, db) childPos := byte(2) @@ -639,7 +537,7 @@ func TestBranchNode_insertCollapsedNode(t *testing.T) { childPos := byte(2) key := append([]byte{childPos}, []byte("dog")...) - _ = bn.setHash() + bn.setHash(getTestGoroutinesManager()) _ = bn.commitDirty(0, 5, db, db) th, _ := throttler.NewNumGoRoutinesThrottler(5) @@ -660,6 +558,7 @@ func TestBranchNode_insertInStoredBnOnExistingPos(t *testing.T) { db := testscommon.NewMemDbMock() bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) + bn.setHash(getTestGoroutinesManager()) childPos := byte(2) key := append([]byte{childPos}, []byte("dog")...) @@ -685,6 +584,7 @@ func TestBranchNode_insertInStoredBnOnNilPos(t *testing.T) { db := testscommon.NewMemDbMock() bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) + bn.setHash(getTestGoroutinesManager()) nilChildPos := byte(11) key := append([]byte{nilChildPos}, []byte("dog")...) @@ -761,8 +661,8 @@ func TestBranchNode_delete(t *testing.T) { assert.True(t, dirty) assert.Nil(t, goRoutinesManager.GetError()) - _ = expectedBn.setHash() - _ = newBn.setHash() + expectedBn.setHash(getTestGoroutinesManager()) + newBn.setHash(getTestGoroutinesManager()) assert.Equal(t, expectedBn.getHash(), newBn.getHash()) } @@ -771,6 +671,7 @@ func TestBranchNode_deleteFromStoredBn(t *testing.T) { db := testscommon.NewMemDbMock() bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) + bn.setHash(getTestGoroutinesManager()) childPos := byte(2) lnKey := append([]byte{childPos}, []byte("dog")...) @@ -849,7 +750,7 @@ func TestBranchNode_deleteCollapsedNode(t *testing.T) { db := testscommon.NewMemDbMock() bn, collapsedBn := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) - _ = bn.setHash() + bn.setHash(getTestGoroutinesManager()) _ = bn.commitDirty(0, 5, db, db) childPos := byte(2) @@ -1034,6 +935,7 @@ func TestBranchNode_getChildrenCollapsedBn(t *testing.T) { db := testscommon.NewMemDbMock() bn, collapsedBn := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) + bn.setHash(getTestGoroutinesManager()) _ = bn.commitSnapshot(db, nil, nil, context.Background(), statistics.NewTrieStatistics(), &testscommon.ProcessStatusHandlerStub{}, 0) children, err := collapsedBn.getChildren(db) @@ -1056,7 +958,7 @@ func TestBranchNode_loadChildren(t *testing.T) { marsh, hasher := getTestMarshalizerAndHasher() tr := initTrie() rootNode := tr.GetRootNode() - _ = rootNode.setRootHash() + rootNode.setHash(getTestGoroutinesManager()) nodes, _ := getEncodedTrieNodesAndHashes(tr) nodesCacher, _ := cache.NewLRUCache(100) for i := range nodes { @@ -1170,6 +1072,9 @@ func TestBranchNode_setRootHashCollapsedChildren(t *testing.T) { marsh, hasher := getTestMarshalizerAndHasher() bn := &branchNode{ + CollapsedBn: CollapsedBn{ + EncodedChildren: make([][]byte, nrOfChildren), + }, baseNode: &baseNode{ marsh: marsh, hasher: hasher, @@ -1184,15 +1089,17 @@ func TestBranchNode_setRootHashCollapsedChildren(t *testing.T) { bn.children[1] = collapsedEn bn.children[2] = collapsedLn - err := bn.setRootHash() - assert.Nil(t, err) + manager := getTestGoroutinesManager() + bn.setHash(manager) + assert.Nil(t, manager.GetError()) } func TestBranchNode_commitCollapsesTrieIfMaxTrieLevelInMemoryIsReached(t *testing.T) { t.Parallel() bn, collapsedBn := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) - _ = collapsedBn.setRootHash() + bn.setHash(getTestGoroutinesManager()) + collapsedBn.setHash(getTestGoroutinesManager()) err := bn.commitDirty(0, 1, testscommon.NewMemDbMock(), testscommon.NewMemDbMock()) assert.Nil(t, err) @@ -1224,6 +1131,8 @@ func TestBranchNode_printShouldNotPanicEvenIfNodeIsCollapsed(t *testing.T) { db := testscommon.NewMemDbMock() bn, collapsedBn := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) + bn.setHash(getTestGoroutinesManager()) + collapsedBn.setHash(getTestGoroutinesManager()) _ = bn.commitSnapshot(db, nil, nil, context.Background(), statistics.NewTrieStatistics(), &testscommon.ProcessStatusHandlerStub{}, 0) _ = collapsedBn.commitSnapshot(db, nil, nil, context.Background(), statistics.NewTrieStatistics(), &testscommon.ProcessStatusHandlerStub{}, 0) @@ -1238,6 +1147,7 @@ func TestBranchNode_getDirtyHashesFromCleanNode(t *testing.T) { db := testscommon.NewMemDbMock() bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) + bn.setHash(getTestGoroutinesManager()) _ = bn.commitDirty(0, 5, db, db) dirtyHashes := make(common.ModifiedHashes) @@ -1339,6 +1249,7 @@ func TestBranchNode_commitSnapshotChildIsMissingErr(t *testing.T) { } _, collapsedBn := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) + collapsedBn.setHash(getTestGoroutinesManager()) missingNodesChan := make(chan []byte, 10) err := collapsedBn.commitSnapshot(db, nil, missingNodesChan, context.Background(), statistics.NewTrieStatistics(), &testscommon.ProcessStatusHandlerStub{}, 0) assert.Nil(t, err) @@ -1654,6 +1565,7 @@ func TestBranchNode_insertOnNilChild(t *testing.T) { db := testscommon.NewMemDbMock() bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) + bn.setHash(getTestGoroutinesManager()) err := bn.commitDirty(0, 5, db, db) assert.Nil(t, err) assert.False(t, bn.dirty) @@ -1746,6 +1658,7 @@ func TestBranchNode_insertOnExistingChild(t *testing.T) { children[6], _ = newLeafNode(getTrieDataWithDefaultVersion("doe", "doe"), marshaller, hasher) bn, _ := newBranchNode(marshaller, hasher) bn.children = children + bn.setHash(getTestGoroutinesManager()) newData := []core.TrieData{ { Key: []byte{1, 2, 3}, @@ -1791,6 +1704,7 @@ func TestBranchNode_insertOnExistingChild(t *testing.T) { children[6], _ = newLeafNode(getTrieDataWithDefaultVersion("doe", "doe"), marshaller, hasher) bn, _ := newBranchNode(marshaller, hasher) bn.children = children + bn.setHash(getTestGoroutinesManager()) newData := []core.TrieData{ { Key: key, @@ -1826,6 +1740,7 @@ func TestBranchNode_insertBatch(t *testing.T) { children[6], _ = newLeafNode(getTrieDataWithDefaultVersion(string([]byte{7, 8, 9}), "doe"), marshaller, hasher) bn, _ := newBranchNode(marshaller, hasher) bn.children = children + bn.setHash(getTestGoroutinesManager()) newData := []core.TrieData{ { @@ -1884,6 +1799,7 @@ func getNewBn() *branchNode { bn, _ := newBranchNode(marsh, hasher) bn.children = children + bn.setHash(getTestGoroutinesManager()) _ = bn.commitDirty(0, 5, testscommon.NewMemDbMock(), testscommon.NewMemDbMock()) return bn } diff --git a/trie/dfsIterator.go b/trie/dfsIterator.go index dc4ba666b01..e83c2f127a9 100644 --- a/trie/dfsIterator.go +++ b/trie/dfsIterator.go @@ -7,8 +7,8 @@ type dfsIterator struct { } // NewDFSIterator creates a new depth first traversal iterator -func NewDFSIterator(trie common.Trie) (*dfsIterator, error) { - bit, err := newBaseIterator(trie) +func NewDFSIterator(trie common.Trie, rootHash []byte) (*dfsIterator, error) { + bit, err := newBaseIterator(trie, rootHash) if err != nil { return nil, err } diff --git a/trie/dfsIterator_test.go b/trie/dfsIterator_test.go index 5e9f653db9b..bc2f2b893fa 100644 --- a/trie/dfsIterator_test.go +++ b/trie/dfsIterator_test.go @@ -13,7 +13,7 @@ func TestNewDFSIterator(t *testing.T) { t.Run("nil trie should error", func(t *testing.T) { t.Parallel() - it, err := trie.NewDFSIterator(nil) + it, err := trie.NewDFSIterator(nil, nil) assert.Equal(t, trie.ErrNilTrie, err) assert.Nil(t, it) }) @@ -21,8 +21,10 @@ func TestNewDFSIterator(t *testing.T) { t.Parallel() tr := initTrie() + _ = tr.Commit() + rootHash, _ := tr.RootHash() - it, err := trie.NewDFSIterator(tr) + it, err := trie.NewDFSIterator(tr, rootHash) assert.Nil(t, err) assert.NotNil(t, it) }) @@ -32,8 +34,10 @@ func TestDFSIterator_Next(t *testing.T) { t.Parallel() tr := initTrie() + _ = tr.Commit() + rootHash, _ := tr.RootHash() - it, _ := trie.NewDFSIterator(tr) + it, _ := trie.NewDFSIterator(tr, rootHash) for it.HasNext() { err := it.Next() assert.Nil(t, err) diff --git a/trie/errors.go b/trie/errors.go index 107a02db264..bc7eb638432 100644 --- a/trie/errors.go +++ b/trie/errors.go @@ -135,3 +135,6 @@ var ErrNilChanClose = errors.New("nil chan close") // ErrInvalidTypeConversion signals that an invalid type conversion has been provided var ErrInvalidTypeConversion = errors.New("invalid type conversion") + +// ErrNodeHashIsNotSet signals that the node hash is not set +var ErrNodeHashIsNotSet = errors.New("node hash is not set") diff --git a/trie/export_test.go b/trie/export_test.go index 8a6cb238359..ea22af5917e 100644 --- a/trie/export_test.go +++ b/trie/export_test.go @@ -78,8 +78,8 @@ func IsTrieStorageManagerInEpoch(tsm common.StorageManager) bool { } // NewBaseIterator - -func NewBaseIterator(trie common.Trie) (*baseIterator, error) { - return newBaseIterator(trie) +func NewBaseIterator(trie common.Trie, rootHash []byte) (*baseIterator, error) { + return newBaseIterator(trie, rootHash) } // GetDefaultTrieStorageManagerParameters - diff --git a/trie/extensionNode.go b/trie/extensionNode.go index 5204acbb2ca..2108df02c1b 100644 --- a/trie/extensionNode.go +++ b/trie/extensionNode.go @@ -5,16 +5,14 @@ import ( "context" "encoding/hex" "fmt" - "io" - "math" - "sync" - "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/common" vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "io" + "math" ) var _ = node(&extensionNode{}) @@ -63,69 +61,50 @@ func (en *extensionNode) getCollapsedEn() (*extensionNode, error) { return en, nil } collapsed := en.clone() - ok, err := hasValidHash(en.child) - if err != nil { - return nil, err - } - if !ok { - err = en.child.setHash() - if err != nil { - return nil, err - } + if !hasValidHash(en.child) { + return nil, ErrNodeHashIsNotSet } collapsed.EncodedChild = en.child.getHash() collapsed.child = nil return collapsed, nil } -func (en *extensionNode) setHash() error { - err := en.isEmptyOrNil() - if err != nil { - return fmt.Errorf("setHash error %w", err) +func (en *extensionNode) setHash(goRoutinesManager common.TrieGoroutinesManager) { + if len(en.hash) != 0 { + return } - if en.getHash() != nil { - return nil + + if !goRoutinesManager.ShouldContinueProcessing() { + return } - if en.isCollapsed() { - var hash []byte - hash, err = encodeNodeAndGetHash(en) + + if en.shouldSetHashForChild() { + en.child.setHash(goRoutinesManager) + encChild, err := encodeNodeAndGetHash(en.child) if err != nil { - return err + goRoutinesManager.SetError(err) + return } - en.hash = hash - return nil + en.EncodedChild = encChild } - hash, err := hashChildrenAndNode(en) + + hash, err := encodeNodeAndGetHash(en) if err != nil { - return err + goRoutinesManager.SetError(err) + return } en.hash = hash - return nil } -func (en *extensionNode) setHashConcurrent(wg *sync.WaitGroup, c chan error) { - err := en.setHash() - if err != nil { - c <- err - } - wg.Done() -} -func (en *extensionNode) setRootHash() error { - return en.setHash() -} +func (en *extensionNode) shouldSetHashForChild() bool { + en.childMutex.RLock() + defer en.childMutex.RUnlock() -func (en *extensionNode) hashChildren() error { - err := en.isEmptyOrNil() - if err != nil { - return fmt.Errorf("hashChildren error %w", err) - } - if en.child != nil { - err = en.child.setHash() - if err != nil { - return err - } + if en.child != nil && en.EncodedChild == nil { + return true } - return nil + + return false } func (en *extensionNode) hashNode() ([]byte, error) { diff --git a/trie/extensionNode_test.go b/trie/extensionNode_test.go index 0fd7d8db6f4..31b75e24a60 100644 --- a/trie/extensionNode_test.go +++ b/trie/extensionNode_test.go @@ -72,6 +72,8 @@ func TestExtensionNode_getCollapsed(t *testing.T) { en, collapsedEn := getEnAndCollapsedEn() collapsedEn.dirty = true + en.setHash(getTestGoroutinesManager()) + collapsedEn.hash = en.hash collapsed, err := en.getCollapsed() assert.Nil(t, err) @@ -113,40 +115,22 @@ func TestExtensionNode_setHash(t *testing.T) { en, collapsedEn := getEnAndCollapsedEn() hash, _ := encodeNodeAndGetHash(collapsedEn) + manager := getTestGoroutinesManager() - err := en.setHash() - assert.Nil(t, err) + en.setHash(manager) + assert.Nil(t, manager.GetError()) assert.Equal(t, hash, en.hash) } -func TestExtensionNode_setHashEmptyNode(t *testing.T) { - t.Parallel() - - en := &extensionNode{baseNode: &baseNode{}} - - err := en.setHash() - assert.True(t, errors.Is(err, ErrEmptyExtensionNode)) - assert.Nil(t, en.hash) -} - -func TestExtensionNode_setHashNilNode(t *testing.T) { - t.Parallel() - - var en *extensionNode - - err := en.setHash() - assert.True(t, errors.Is(err, ErrNilExtensionNode)) - assert.Nil(t, en) -} - func TestExtensionNode_setHashCollapsedNode(t *testing.T) { t.Parallel() _, collapsedEn := getEnAndCollapsedEn() hash, _ := encodeNodeAndGetHash(collapsedEn) + manager := getTestGoroutinesManager() - err := collapsedEn.setHash() - assert.Nil(t, err) + collapsedEn.setHash(manager) + assert.Nil(t, manager.GetError()) assert.Equal(t, hash, collapsedEn.hash) } @@ -160,49 +144,6 @@ func TestExtensionNode_setGivenHash(t *testing.T) { assert.Equal(t, expectedHash, en.hash) } -func TestExtensionNode_hashChildren(t *testing.T) { - t.Parallel() - - en, _ := getEnAndCollapsedEn() - assert.Nil(t, en.child.getHash()) - - err := en.hashChildren() - assert.Nil(t, err) - - childHash, _ := encodeNodeAndGetHash(en.child) - assert.Equal(t, childHash, en.child.getHash()) -} - -func TestExtensionNode_hashChildrenEmptyNode(t *testing.T) { - t.Parallel() - - en := &extensionNode{} - - err := en.hashChildren() - assert.True(t, errors.Is(err, ErrEmptyExtensionNode)) -} - -func TestExtensionNode_hashChildrenNilNode(t *testing.T) { - t.Parallel() - - var en *extensionNode - - err := en.hashChildren() - assert.True(t, errors.Is(err, ErrNilExtensionNode)) -} - -func TestExtensionNode_hashChildrenCollapsedNode(t *testing.T) { - t.Parallel() - - _, collapsedEn := getEnAndCollapsedEn() - - err := collapsedEn.hashChildren() - assert.Nil(t, err) - - _, collapsedEn2 := getEnAndCollapsedEn() - assert.Equal(t, collapsedEn2, collapsedEn) -} - func TestExtensionNode_hashNode(t *testing.T) { t.Parallel() @@ -240,7 +181,7 @@ func TestExtensionNode_commit(t *testing.T) { db := testscommon.NewMemDbMock() en, collapsedEn := getEnAndCollapsedEn() hash, _ := encodeNodeAndGetHash(collapsedEn) - _ = en.setHash() + en.setHash(getTestGoroutinesManager()) err := en.commitDirty(0, 5, db, db) assert.Nil(t, err) @@ -277,7 +218,7 @@ func TestExtensionNode_commitCollapsedNode(t *testing.T) { db := testscommon.NewMemDbMock() _, collapsedEn := getEnAndCollapsedEn() hash, _ := encodeNodeAndGetHash(collapsedEn) - _ = collapsedEn.setHash() + collapsedEn.setHash(getTestGoroutinesManager()) collapsedEn.dirty = true err := collapsedEn.commitDirty(0, 5, db, db) @@ -329,7 +270,7 @@ func TestExtensionNode_resolveCollapsed(t *testing.T) { db := testscommon.NewMemDbMock() en, collapsedEn := getEnAndCollapsedEn() - _ = en.setHash() + en.setHash(getTestGoroutinesManager()) _ = en.commitDirty(0, 5, db, db) _, resolved := getBnAndCollapsedBn(en.marsh, en.hasher) @@ -403,7 +344,7 @@ func TestExtensionNode_tryGetCollapsedNode(t *testing.T) { db := testscommon.NewMemDbMock() en, collapsedEn := getEnAndCollapsedEn() - _ = en.setHash() + en.setHash(getTestGoroutinesManager()) _ = en.commitDirty(0, 5, db, db) enKey := []byte{100} @@ -476,7 +417,7 @@ func TestExtensionNode_insertCollapsedNode(t *testing.T) { en, collapsedEn := getEnAndCollapsedEn() key := []byte{100, 15, 5, 6} - _ = en.setHash() + en.setHash(getTestGoroutinesManager()) _ = en.commitDirty(0, 5, db, db) th, _ := throttler.NewNumGoRoutinesThrottler(5) @@ -499,6 +440,7 @@ func TestExtensionNode_insertInStoredEnSameKey(t *testing.T) { en, _ := getEnAndCollapsedEn() enKey := []byte{100} key := append(enKey, []byte{11, 12}...) + en.setHash(getTestGoroutinesManager()) _ = en.commitDirty(0, 5, db, db) enHash := en.getHash() @@ -525,6 +467,7 @@ func TestExtensionNode_insertInStoredEnDifferentKey(t *testing.T) { enKey := []byte{1} en, _ := newExtensionNode(enKey, bn, bn.marsh, bn.hasher) nodeKey := []byte{11, 12} + en.setHash(getTestGoroutinesManager()) _ = en.commitDirty(0, 5, db, db) expectedHashes := [][]byte{en.getHash()} @@ -614,6 +557,7 @@ func TestExtensionNode_deleteFromStoredEn(t *testing.T) { key := append(enKey, bnKey...) key = append(key, lnKey...) lnPathKey := key + en.setHash(getTestGoroutinesManager()) _ = en.commitDirty(0, 5, db, db) bn, key, _ := en.getNext(key, db) @@ -653,7 +597,7 @@ func TestExtensionNode_deleteCollapsedNode(t *testing.T) { db := testscommon.NewMemDbMock() en, collapsedEn := getEnAndCollapsedEn() - _ = en.setHash() + en.setHash(getTestGoroutinesManager()) _ = en.commitDirty(0, 5, db, db) enKey := []byte{100} @@ -743,6 +687,7 @@ func TestExtensionNode_getChildrenCollapsedEn(t *testing.T) { db := testscommon.NewMemDbMock() en, collapsedEn := getEnAndCollapsedEn() + en.setHash(getTestGoroutinesManager()) _ = en.commitDirty(0, 5, db, db) children, err := collapsedEn.getChildren(db) @@ -766,8 +711,8 @@ func TestExtensionNode_loadChildren(t *testing.T) { tr, _ := newEmptyTrie() _ = tr.Update([]byte("dog"), []byte("puppy")) _ = tr.Update([]byte("ddog"), []byte("cat")) - ExecuteUpdatesFromBatch(tr) - _ = tr.GetRootNode().setRootHash() + _ = tr.Commit() + tr.GetRootNode().setHash(getTestGoroutinesManager()) nodes, _ := getEncodedTrieNodesAndHashes(tr) nodesCacher, _ := cache.NewLRUCache(100) for i := range nodes { @@ -846,7 +791,8 @@ func TestExtensionNode_commitCollapsesTrieIfMaxTrieLevelInMemoryIsReached(t *tes t.Parallel() en, collapsedEn := getEnAndCollapsedEn() - _ = collapsedEn.setRootHash() + collapsedEn.setHash(getTestGoroutinesManager()) + en.setHash(getTestGoroutinesManager()) err := en.commitDirty(0, 1, testscommon.NewMemDbMock(), testscommon.NewMemDbMock()) assert.Nil(t, err) @@ -864,6 +810,8 @@ func TestExtensionNode_printShouldNotPanicEvenIfNodeIsCollapsed(t *testing.T) { db := testscommon.NewMemDbMock() en, collapsedEn := getEnAndCollapsedEn() + en.setHash(getTestGoroutinesManager()) + collapsedEn.setHash(getTestGoroutinesManager()) _ = en.commitDirty(0, 5, db, db) _ = collapsedEn.commitSnapshot(db, nil, nil, context.Background(), statistics.NewTrieStatistics(), &testscommon.ProcessStatusHandlerStub{}, 0) @@ -878,6 +826,7 @@ func TestExtensionNode_getDirtyHashesFromCleanNode(t *testing.T) { db := testscommon.NewMemDbMock() en, _ := getEnAndCollapsedEn() + en.setHash(getTestGoroutinesManager()) _ = en.commitSnapshot(db, nil, nil, context.Background(), statistics.NewTrieStatistics(), &testscommon.ProcessStatusHandlerStub{}, 0) dirtyHashes := make(common.ModifiedHashes) @@ -1151,6 +1100,7 @@ func TestExtensionNode_insertInSameEn(t *testing.T) { t.Parallel() en := getEn() + en.setHash(getTestGoroutinesManager()) err := en.commitDirty(0, 5, testscommon.NewMemDbMock(), testscommon.NewMemDbMock()) assert.Nil(t, err) @@ -1173,6 +1123,7 @@ func TestExtensionNode_insertInSameEn(t *testing.T) { t.Parallel() en := getEn() + en.setHash(getTestGoroutinesManager()) err := en.commitDirty(0, 5, testscommon.NewMemDbMock(), testscommon.NewMemDbMock()) assert.Nil(t, err) @@ -1207,6 +1158,7 @@ func TestExtensionNode_insertInNewBn(t *testing.T) { t.Parallel() en := getEn() + en.setHash(getTestGoroutinesManager()) err := en.commitDirty(0, 5, testscommon.NewMemDbMock(), testscommon.NewMemDbMock()) assert.Nil(t, err) @@ -1239,6 +1191,7 @@ func TestExtensionNode_insertInNewBn(t *testing.T) { t.Parallel() en := getEn() + en.setHash(getTestGoroutinesManager()) err := en.commitDirty(0, 5, testscommon.NewMemDbMock(), testscommon.NewMemDbMock()) assert.Nil(t, err) @@ -1270,6 +1223,7 @@ func TestExtensionNode_deleteBatch(t *testing.T) { t.Parallel() en := getEn() + en.setHash(getTestGoroutinesManager()) err := en.commitDirty(0, 5, testscommon.NewMemDbMock(), testscommon.NewMemDbMock()) assert.Nil(t, err) @@ -1292,6 +1246,7 @@ func TestExtensionNode_deleteBatch(t *testing.T) { t.Parallel() en := getEn() + en.setHash(getTestGoroutinesManager()) err := en.commitDirty(0, 5, testscommon.NewMemDbMock(), testscommon.NewMemDbMock()) assert.Nil(t, err) @@ -1316,6 +1271,7 @@ func TestExtensionNode_deleteBatch(t *testing.T) { t.Parallel() en := getEn() + en.setHash(getTestGoroutinesManager()) data := []core.TrieData{ getTrieDataWithDefaultVersion(string([]byte{1, 2, 4, 4, 5, 6}), "dog"), } @@ -1324,15 +1280,16 @@ func TestExtensionNode_deleteBatch(t *testing.T) { goRoutinesManager, err := NewGoroutinesManager(th, errChan.NewErrChanWrapper(), make(chan struct{})) assert.Nil(t, err) - _, _ = en.insert(data, goRoutinesManager, nil) - err = en.commitDirty(0, 5, testscommon.NewMemDbMock(), testscommon.NewMemDbMock()) + newEn, _ := en.insert(data, goRoutinesManager, nil) + newEn.setHash(getTestGoroutinesManager()) + err = newEn.commitDirty(0, 5, testscommon.NewMemDbMock(), testscommon.NewMemDbMock()) assert.Nil(t, err) dataForRemoval := []core.TrieData{ getTrieDataWithDefaultVersion(string([]byte{1, 2, 7, 7, 8, 9}), "dog"), } - dirty, newNode, modifiedHashes := en.delete(dataForRemoval, goRoutinesManager, nil) + dirty, newNode, modifiedHashes := newEn.delete(dataForRemoval, goRoutinesManager, nil) assert.True(t, dirty) assert.Nil(t, goRoutinesManager.GetError()) assert.Equal(t, 3, len(modifiedHashes)) @@ -1345,6 +1302,7 @@ func TestExtensionNode_deleteBatch(t *testing.T) { t.Parallel() en := getEn() + en.setHash(getTestGoroutinesManager()) err := en.commitDirty(0, 5, testscommon.NewMemDbMock(), testscommon.NewMemDbMock()) assert.Nil(t, err) diff --git a/trie/interceptedNode_test.go b/trie/interceptedNode_test.go index eae6b884ad0..ec413063aa6 100644 --- a/trie/interceptedNode_test.go +++ b/trie/interceptedNode_test.go @@ -14,29 +14,29 @@ import ( func getDefaultInterceptedTrieNodeParameters() ([]byte, hashing.Hasher) { tr := initTrie() + _ = tr.Commit() nodes, _ := getEncodedTrieNodesAndHashes(tr) return nodes[0], &testscommon.KeccakMock{} } func getEncodedTrieNodesAndHashes(tr common.Trie) ([][]byte, [][]byte) { - it, _ := trie.NewDFSIterator(tr) + rootHash, _ := tr.RootHash() + it, _ := trie.NewDFSIterator(tr, rootHash) encNode, _ := it.MarshalizedNode() nodes := make([][]byte, 0) nodes = append(nodes, encNode) hashes := make([][]byte, 0) - hash, _ := it.GetHash() - hashes = append(hashes, hash) + hashes = append(hashes, it.GetHash()) for it.HasNext() { _ = it.Next() encNode, _ = it.MarshalizedNode() nodes = append(nodes, encNode) - hash, _ = it.GetHash() - hashes = append(hashes, hash) + hashes = append(hashes, it.GetHash()) } return nodes, hashes diff --git a/trie/interface.go b/trie/interface.go index e7c5ed34b3b..b01c009d574 100644 --- a/trie/interface.go +++ b/trie/interface.go @@ -3,7 +3,6 @@ package trie import ( "context" "io" - "sync" "time" "github.com/multiversx/mx-chain-core-go/core" @@ -23,13 +22,11 @@ type node interface { getHasher() hashing.Hasher setHasher(hashing.Hasher) - setHash() error - setHashConcurrent(wg *sync.WaitGroup, c chan error) - setRootHash() error + setHash(goRoutinesManager common.TrieGoroutinesManager) + getCollapsed() (node, error) // a collapsed node is a node that instead of the children holds the children hashes getEncodedNode() ([]byte, error) hashNode() ([]byte, error) - hashChildren() error tryGet(key []byte, depth uint32, db common.TrieStorageInteractor) ([]byte, uint32, error) getNext(key []byte, db common.TrieStorageInteractor) (node, []byte, error) insert(newData []core.TrieData, goRoutinesManager common.TrieGoroutinesManager, db common.TrieStorageInteractor) (node, [][]byte) diff --git a/trie/leafNode.go b/trie/leafNode.go index 9ca0f276d78..a6b44a94112 100644 --- a/trie/leafNode.go +++ b/trie/leafNode.go @@ -5,10 +5,6 @@ import ( "context" "encoding/hex" "fmt" - "io" - "math" - "sync" - "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/core/keyValStorage" @@ -16,6 +12,8 @@ import ( "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/common" vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "io" + "math" ) var _ = node(&leafNode{}) @@ -50,36 +48,16 @@ func (ln *leafNode) getCollapsed() (node, error) { return ln, nil } -func (ln *leafNode) setHash() error { - err := ln.isEmptyOrNil() - if err != nil { - return fmt.Errorf("setHash error %w", err) - } - if ln.getHash() != nil { - return nil +func (ln *leafNode) setHash(goRoutinesManager common.TrieGoroutinesManager) { + if len(ln.hash) != 0 { + return } - hash, err := hashChildrenAndNode(ln) + hash, err := encodeNodeAndGetHash(ln) if err != nil { - return err + goRoutinesManager.SetError(err) + return } ln.hash = hash - return nil -} - -func (ln *leafNode) setHashConcurrent(wg *sync.WaitGroup, c chan error) { - err := ln.setHash() - if err != nil { - c <- err - } - wg.Done() -} - -func (ln *leafNode) setRootHash() error { - return ln.setHash() -} - -func (ln *leafNode) hashChildren() error { - return nil } func (ln *leafNode) hashNode() ([]byte, error) { @@ -149,12 +127,11 @@ func writeNodeOnChannel(ln *leafNode, leavesChan chan core.KeyValueHolder) error return nil } - leafHash, err := computeAndSetNodeHash(ln) - if err != nil { - return err + if len(ln.hash) == 0 { + return ErrNodeHashIsNotSet } - trieLeaf := keyValStorage.NewKeyValStorage(leafHash, ln.Value) + trieLeaf := keyValStorage.NewKeyValStorage(ln.hash, ln.Value) leavesChan <- trieLeaf return nil diff --git a/trie/leafNode_test.go b/trie/leafNode_test.go index 65a8491852f..97ff0c7e6db 100644 --- a/trie/leafNode_test.go +++ b/trie/leafNode_test.go @@ -75,32 +75,13 @@ func TestLeafNode_setHash(t *testing.T) { ln := getLn(getTestMarshalizerAndHasher()) hash, _ := encodeNodeAndGetHash(ln) + manager := getTestGoroutinesManager() - err := ln.setHash() - assert.Nil(t, err) + ln.setHash(manager) + assert.Nil(t, manager.GetError()) assert.Equal(t, hash, ln.hash) } -func TestLeafNode_setHashEmptyNode(t *testing.T) { - t.Parallel() - - ln := &leafNode{baseNode: &baseNode{}} - - err := ln.setHash() - assert.True(t, errors.Is(err, ErrEmptyLeafNode)) - assert.Nil(t, ln.hash) -} - -func TestLeafNode_setHashNilNode(t *testing.T) { - t.Parallel() - - var ln *leafNode - - err := ln.setHash() - assert.True(t, errors.Is(err, ErrNilLeafNode)) - assert.Nil(t, ln) -} - func TestLeafNode_setGivenHash(t *testing.T) { t.Parallel() @@ -111,14 +92,6 @@ func TestLeafNode_setGivenHash(t *testing.T) { assert.Equal(t, expectedHash, ln.hash) } -func TestLeafNode_hashChildren(t *testing.T) { - t.Parallel() - - ln := getLn(getTestMarshalizerAndHasher()) - - assert.Nil(t, ln.hashChildren()) -} - func TestLeafNode_hashNode(t *testing.T) { t.Parallel() @@ -156,7 +129,7 @@ func TestLeafNode_commit(t *testing.T) { db := testscommon.NewMemDbMock() ln := getLn(getTestMarshalizerAndHasher()) hash, _ := encodeNodeAndGetHash(ln) - _ = ln.setHash() + ln.setHash(getTestGoroutinesManager()) err := ln.commitDirty(0, 5, db, db) assert.Nil(t, err) @@ -699,7 +672,7 @@ func TestLeafNode_writeNodeOnChannel(t *testing.T) { t.Parallel() ln := getLn(getTestMarshalizerAndHasher()) - _ = ln.setHash() + ln.setHash(getTestGoroutinesManager()) leavesChannel := make(chan core.KeyValueHolder, 2) err := writeNodeOnChannel(ln, leavesChannel) diff --git a/trie/node.go b/trie/node.go index a8a56c4161d..a8d9b3b64e1 100644 --- a/trie/node.go +++ b/trie/node.go @@ -45,20 +45,6 @@ type leafNode struct { *baseNode } -func hashChildrenAndNode(n node) ([]byte, error) { - err := n.hashChildren() - if err != nil { - return nil, err - } - - hashed, err := n.hashNode() - if err != nil { - return nil, err - } - - return hashed, nil -} - func encodeNodeAndGetHash(n node) ([]byte, error) { encNode, err := n.getEncodedNode() if err != nil { @@ -72,9 +58,9 @@ func encodeNodeAndGetHash(n node) ([]byte, error) { // encodeNodeAndCommitToDB will encode and save provided node. It returns the node's value in bytes func encodeNodeAndCommitToDB(n node, db common.BaseStorer) (int, error) { - key, err := computeAndSetNodeHash(n) - if err != nil { - return 0, err + key := n.getHash() + if len(key) == 0 { + return 0, ErrNodeHashIsNotSet } val, err := collapseAndEncodeNode(n) @@ -98,21 +84,6 @@ func collapseAndEncodeNode(n node) ([]byte, error) { return n.getEncodedNode() } -func computeAndSetNodeHash(n node) ([]byte, error) { - key := n.getHash() - if len(key) != 0 { - return key, nil - } - - err := n.setHash() - if err != nil { - return nil, err - } - key = n.getHash() - - return key, nil -} - func getNodeFromDBAndDecode(n []byte, db common.TrieStorageInteractor, marshalizer marshal.Marshalizer, hasher hashing.Hasher) (node, error) { encChild, err := db.Get(n) if err != nil { @@ -152,19 +123,13 @@ func concat(s1 []byte, s2 ...byte) []byte { return r } -func hasValidHash(n node) (bool, error) { - err := n.isEmptyOrNil() - if err != nil { - return false, err - } - +func hasValidHash(n node) bool { childHash := n.getHash() - childIsDirty := n.isDirty() - if childHash == nil || childIsDirty { - return false, nil + if childHash == nil { + return false } - return true, nil + return true } func decodeNode(encNode []byte, marshalizer marshal.Marshalizer, hasher hashing.Hasher) (node, error) { diff --git a/trie/node_test.go b/trie/node_test.go index 6dbc6463e82..864ca002970 100644 --- a/trie/node_test.go +++ b/trie/node_test.go @@ -21,39 +21,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestNode_hashChildrenAndNodeBranchNode(t *testing.T) { - t.Parallel() - - bn, collapsedBn := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) - expectedNodeHash, _ := encodeNodeAndGetHash(collapsedBn) - - hash, err := hashChildrenAndNode(bn) - assert.Nil(t, err) - assert.Equal(t, expectedNodeHash, hash) -} - -func TestNode_hashChildrenAndNodeExtensionNode(t *testing.T) { - t.Parallel() - - en, collapsedEn := getEnAndCollapsedEn() - expectedNodeHash, _ := encodeNodeAndGetHash(collapsedEn) - - hash, err := hashChildrenAndNode(en) - assert.Nil(t, err) - assert.Equal(t, expectedNodeHash, hash) -} - -func TestNode_hashChildrenAndNodeLeafNode(t *testing.T) { - t.Parallel() - - ln := getLn(getTestMarshalizerAndHasher()) - expectedNodeHash, _ := encodeNodeAndGetHash(ln) - - hash, err := hashChildrenAndNode(ln) - assert.Nil(t, err) - assert.Equal(t, expectedNodeHash, hash) -} - func TestNode_encodeNodeAndGetHashBranchNode(t *testing.T) { t.Parallel() @@ -120,6 +87,7 @@ func TestNode_encodeNodeAndCommitToDBBranchNode(t *testing.T) { encNode, _ := collapsedBn.marsh.Marshal(collapsedBn) encNode = append(encNode, branch) nodeHash := collapsedBn.hasher.Compute(string(encNode)) + collapsedBn.hash = nodeHash _, err := encodeNodeAndCommitToDB(collapsedBn, db) assert.Nil(t, err) @@ -136,6 +104,7 @@ func TestNode_encodeNodeAndCommitToDBExtensionNode(t *testing.T) { encNode, _ := collapsedEn.marsh.Marshal(collapsedEn) encNode = append(encNode, extension) nodeHash := collapsedEn.hasher.Compute(string(encNode)) + collapsedEn.hash = nodeHash _, err := encodeNodeAndCommitToDB(collapsedEn, db) assert.Nil(t, err) @@ -152,6 +121,7 @@ func TestNode_encodeNodeAndCommitToDBLeafNode(t *testing.T) { encNode, _ := ln.marsh.Marshal(ln) encNode = append(encNode, leaf) nodeHash := ln.hasher.Compute(string(encNode)) + ln.hash = nodeHash _, err := encodeNodeAndCommitToDB(ln, db) assert.Nil(t, err) @@ -165,6 +135,7 @@ func TestNode_getNodeFromDBAndDecodeBranchNode(t *testing.T) { db := testscommon.NewMemDbMock() bn, collapsedBn := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) + bn.setHash(getTestGoroutinesManager()) _ = bn.commitDirty(0, 5, db, db) encNode, _ := bn.marsh.Marshal(collapsedBn) @@ -184,6 +155,7 @@ func TestNode_getNodeFromDBAndDecodeExtensionNode(t *testing.T) { db := testscommon.NewMemDbMock() en, collapsedEn := getEnAndCollapsedEn() + en.setHash(getTestGoroutinesManager()) _ = en.commitDirty(0, 5, db, db) encNode, _ := en.marsh.Marshal(collapsedEn) @@ -203,6 +175,7 @@ func TestNode_getNodeFromDBAndDecodeLeafNode(t *testing.T) { db := testscommon.NewMemDbMock() ln := getLn(getTestMarshalizerAndHasher()) + ln.setHash(getTestGoroutinesManager()) _ = ln.commitDirty(0, 5, db, db) encNode, _ := ln.marsh.Marshal(ln) @@ -231,27 +204,14 @@ func TestNode_hasValidHash(t *testing.T) { t.Parallel() bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) - ok, err := hasValidHash(bn) - assert.Nil(t, err) + ok := hasValidHash(bn) assert.False(t, ok) - _ = bn.setHash() - bn.dirty = false - - ok, err = hasValidHash(bn) - assert.Nil(t, err) + bn.setHash(getTestGoroutinesManager()) + ok = hasValidHash(bn) assert.True(t, ok) } -func TestNode_hasValidHashNilNode(t *testing.T) { - t.Parallel() - - var nodeInstance *branchNode - ok, err := hasValidHash(nodeInstance) - assert.Equal(t, ErrNilBranchNode, err) - assert.False(t, ok) -} - func TestNode_decodeNodeBranchNode(t *testing.T) { t.Parallel() @@ -458,9 +418,13 @@ func TestTrieGetObsoleteHashes(t *testing.T) { func TestNode_getDirtyHashes(t *testing.T) { t.Parallel() - tr := initTrie() + tr, _ := newEmptyTrie() + _ = tr.Update([]byte("doe"), []byte("reindeer")) + _ = tr.Update([]byte("dog"), []byte("puppy")) + _ = tr.Update([]byte("ddog"), []byte("cat")) + ExecuteUpdatesFromBatch(tr) - _ = tr.GetRootNode().setRootHash() + tr.GetRootNode().setHash(getTestGoroutinesManager()) hashes := make(map[string]struct{}) err := tr.GetRootNode().getDirtyHashes(hashes) diff --git a/trie/patriciaMerkleTrie.go b/trie/patriciaMerkleTrie.go index e3eb6532ba1..0dc4562547a 100644 --- a/trie/patriciaMerkleTrie.go +++ b/trie/patriciaMerkleTrie.go @@ -297,10 +297,18 @@ func (tr *patriciaMerkleTrie) getRootHash() ([]byte, error) { if hash != nil { return hash, nil } - err := rootNode.setRootHash() + + manager, err := NewGoroutinesManager(tr.goroutinesThrottler, errChan.NewErrChanWrapper(), tr.chanClose) + if err != nil { + return nil, err + } + + rootNode.setHash(manager) + err = manager.GetError() if err != nil { return nil, err } + return rootNode.getHash(), nil } @@ -326,7 +334,14 @@ func (tr *patriciaMerkleTrie) Commit() error { return nil } - err = rootNode.setRootHash() + + manager, err := NewGoroutinesManager(tr.goroutinesThrottler, errChan.NewErrChanWrapper(), tr.chanClose) + if err != nil { + return err + } + + rootNode.setHash(manager) + err = manager.GetError() if err != nil { return err } @@ -446,7 +461,13 @@ func (tr *patriciaMerkleTrie) GetDirtyHashes() (common.ModifiedHashes, error) { return nil, nil } - err = rootNode.setRootHash() + manager, err := NewGoroutinesManager(tr.goroutinesThrottler, errChan.NewErrChanWrapper(), tr.chanClose) + if err != nil { + return nil, err + } + + rootNode.setHash(manager) + err = manager.GetError() if err != nil { return nil, err } @@ -479,7 +500,6 @@ func (tr *patriciaMerkleTrie) recreateFromDb(rootHash []byte, tsm common.Storage return nil, nil, err } - newRoot.setGivenHash(rootHash) newTr.SetNewRootNode(newRoot) return newTr, newRoot, nil @@ -522,12 +542,7 @@ func (tr *patriciaMerkleTrie) GetSerializedNodes(rootHash []byte, maxBuffToSend log.Trace("GetSerializedNodes", "rootHash", rootHash) size := uint64(0) - newTr, _, err := tr.recreateFromDb(rootHash, tr.trieStorage) - if err != nil { - return nil, 0, err - } - - it, err := NewDFSIterator(newTr) + it, err := NewDFSIterator(tr, rootHash) if err != nil { return nil, 0, err } @@ -646,11 +661,18 @@ func logMapWithTrace(message string, paramName string, hashes common.ModifiedHas } // GetProof computes a Merkle proof for the node that is present at the given key -func (tr *patriciaMerkleTrie) GetProof(key []byte) ([][]byte, []byte, error) { - tr.trieOperationInProgress.SetValue(true) - defer tr.trieOperationInProgress.Reset() +func (tr *patriciaMerkleTrie) GetProof(key []byte, rootHash []byte) ([][]byte, []byte, error) { + trie, err := tr.Recreate(rootHash) + if err != nil { + return nil, nil, err + } - rootNode := tr.GetRootNode() + pmt, ok := trie.(*patriciaMerkleTrie) + if !ok { + return nil, nil, ErrWrongTypeAssertion + } + + rootNode := pmt.GetRootNode() if rootNode == nil { return nil, nil, ErrNilNode } @@ -659,11 +681,6 @@ func (tr *patriciaMerkleTrie) GetProof(key []byte) ([][]byte, []byte, error) { hexKey := keyBytesToHex(key) currentNode := rootNode - err := currentNode.setRootHash() - if err != nil { - return nil, nil, err - } - for { encodedNode, errGet := currentNode.getEncodedNode() if errGet != nil { diff --git a/trie/patriciaMerkleTrie_test.go b/trie/patriciaMerkleTrie_test.go index cc886c80a85..58ce57cad6f 100644 --- a/trie/patriciaMerkleTrie_test.go +++ b/trie/patriciaMerkleTrie_test.go @@ -73,7 +73,7 @@ func initTrieMultipleValues(nr int) (common.Trie, [][]byte) { func initTrie() common.Trie { tr := emptyTrie() addDefaultDataToTrie(tr) - trie.ExecuteUpdatesFromBatch(tr) + _ = tr.Commit() return tr } @@ -720,9 +720,10 @@ func TestPatriciaMerkleTree_Prove(t *testing.T) { t.Parallel() tr := initTrie() + _ = tr.Commit() rootHash, _ := tr.RootHash() - proof, value, err := tr.GetProof([]byte("dog")) + proof, value, err := tr.GetProof([]byte("dog"), rootHash) assert.Nil(t, err) assert.Equal(t, []byte("puppy"), value) ok, _ := tr.VerifyProof(rootHash, []byte("dog"), proof) @@ -736,7 +737,7 @@ func TestPatriciaMerkleTree_ProveCollapsedTrie(t *testing.T) { _ = tr.Commit() rootHash, _ := tr.RootHash() - proof, _, err := tr.GetProof([]byte("dog")) + proof, _, err := tr.GetProof([]byte("dog"), rootHash) assert.Nil(t, err) ok, _ := tr.VerifyProof(rootHash, []byte("dog"), proof) assert.True(t, ok) @@ -747,7 +748,7 @@ func TestPatriciaMerkleTree_ProveOnEmptyTrie(t *testing.T) { tr := emptyTrie() - proof, _, err := tr.GetProof([]byte("dog")) + proof, _, err := tr.GetProof([]byte("dog"), emptyTrieHash) assert.Nil(t, proof) assert.Equal(t, trie.ErrNilNode, err) } @@ -756,10 +757,11 @@ func TestPatriciaMerkleTree_VerifyProof(t *testing.T) { t.Parallel() tr, val := initTrieMultipleValues(50) + _ = tr.Commit() rootHash, _ := tr.RootHash() for i := range val { - proof, _, _ := tr.GetProof(val[i]) + proof, _, _ := tr.GetProof(val[i], rootHash) ok, err := tr.VerifyProof(rootHash, val[i], proof) assert.Nil(t, err) @@ -778,9 +780,10 @@ func TestPatriciaMerkleTrie_VerifyProofBranchNodeWantHashShouldWork(t *testing.T _ = tr.Update([]byte("dog"), []byte("cat")) _ = tr.Update([]byte("zebra"), []byte("horse")) + _ = tr.Commit() rootHash, _ := tr.RootHash() - proof, _, _ := tr.GetProof([]byte("dog")) + proof, _, _ := tr.GetProof([]byte("dog"), rootHash) ok, err := tr.VerifyProof(rootHash, []byte("dog"), proof) assert.True(t, ok) assert.Nil(t, err) @@ -793,9 +796,10 @@ func TestPatriciaMerkleTrie_VerifyProofExtensionNodeWantHashShouldWork(t *testin _ = tr.Update([]byte("dog"), []byte("cat")) _ = tr.Update([]byte("doe"), []byte("reindeer")) + _ = tr.Commit() rootHash, _ := tr.RootHash() - proof, _, _ := tr.GetProof([]byte("dog")) + proof, _, _ := tr.GetProof([]byte("dog"), rootHash) ok, err := tr.VerifyProof(rootHash, []byte("dog"), proof) assert.True(t, ok) assert.Nil(t, err) @@ -836,9 +840,11 @@ func TestPatriciaMerkleTrie_VerifyProofFromDifferentTrieShouldNotWork(t *testing _ = tr2.Update([]byte("doe"), []byte("reindeer")) _ = tr2.Update([]byte("dog"), []byte("puppy")) _ = tr2.Update([]byte("dogglesworth"), []byte("caterpillar")) + _ = tr2.Commit() + rootHash2, _ := tr2.RootHash() rootHash, _ := tr1.RootHash() - proof, _, _ := tr2.GetProof([]byte("dogglesworth")) + proof, _, _ := tr2.GetProof([]byte("dogglesworth"), rootHash2) ok, _ := tr1.VerifyProof(rootHash, []byte("dogglesworth"), proof) assert.False(t, ok) } @@ -860,10 +866,11 @@ func TestPatriciaMerkleTrie_GetAndVerifyProof(t *testing.T) { _ = tr.Update(values[i], values[i]) } + _ = tr.Commit() rootHash, _ := tr.RootHash() for i := 0; i < numRuns; i++ { randNum := rand.Intn(nrLeaves) - proof, _, err := tr.GetProof(values[randNum]) + proof, _, err := tr.GetProof(values[randNum], rootHash) if err != nil { dumpTrieContents(tr, values) fmt.Printf("error getting proof for %v, err = %s\n", values[randNum], err.Error()) @@ -1015,7 +1022,7 @@ func TestPatriciaMerkleTrie_ConcurrentOperations(t *testing.T) { ) assert.Nil(t, err) case 13: - _, _, _ = tr.GetProof(initialRootHash) // this might error due to concurrent operations that change the roothash + _, _, _ = tr.GetProof(initialRootHash, initialRootHash) // this might error due to concurrent operations that change the roothash case 14: // extremely hard to compute an existing hash due to concurrent changes. _, _ = tr.VerifyProof([]byte("dog"), []byte("puppy"), [][]byte{[]byte("proof1")}) // this might error due to concurrent operations that change the roothash diff --git a/trie/sync.go b/trie/sync.go index ce48f8c8e6b..6f498e58c08 100644 --- a/trie/sync.go +++ b/trie/sync.go @@ -4,6 +4,8 @@ import ( "bytes" "context" "fmt" + "github.com/multiversx/mx-chain-core-go/core/throttler" + "github.com/multiversx/mx-chain-go/common/errChan" "sync" "time" @@ -323,10 +325,6 @@ func getNodeFromCacheOrStorage( if err != nil { return nil, ErrNodeNotFound } - err = existingNode.setHash() - if err != nil { - return nil, ErrNodeNotFound - } return existingNode, nil } @@ -361,7 +359,17 @@ func trieNode( return nil, err } - err = decodedNode.setHash() + th, err := throttler.NewNumGoRoutinesThrottler(1) + if err != nil { + return nil, err + } + goRoutinesManager, err := NewGoroutinesManager(th, errChan.NewErrChanWrapper(), make(chan struct{})) + if err != nil { + return nil, err + } + + decodedNode.setHash(goRoutinesManager) + err = goRoutinesManager.GetError() if err != nil { return nil, err } diff --git a/trie/sync_test.go b/trie/sync_test.go index ab5083eb85a..f78a8d6a91a 100644 --- a/trie/sync_test.go +++ b/trie/sync_test.go @@ -210,8 +210,9 @@ func TestTrieSync_FoundInStorageShouldNotRequest(t *testing.T) { timeout := time.Second * 200 testMarshalizer, testHasher := getTestMarshalizerAndHasher() bn, _ := getBnAndCollapsedBn(testMarshalizer, testHasher) - err := bn.setHash() - require.Nil(t, err) + manager := getTestGoroutinesManager() + bn.setHash(manager) + require.Nil(t, manager.GetError()) rootHash := bn.getHash() _, trieStorage := newEmptyTrie() @@ -223,7 +224,7 @@ func TestTrieSync_FoundInStorageShouldNotRequest(t *testing.T) { }, } - err = bn.commitSnapshot(db, nil, nil, context.Background(), statistics.NewTrieStatistics(), &testscommon.ProcessStatusHandlerStub{}, 0) + err := bn.commitSnapshot(db, nil, nil, context.Background(), statistics.NewTrieStatistics(), &testscommon.ProcessStatusHandlerStub{}, 0) require.Nil(t, err) leaves, err := bn.getChildren(db) From 00f339acccfb29aa5e341030ae1651cf48a1c2bd Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Tue, 12 Nov 2024 12:22:52 +0200 Subject: [PATCH 02/13] fix failing tests --- .../state/stateTrieSync/stateTrieSync_test.go | 23 +++++++++++++++---- node/node_test.go | 6 ++--- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/integrationTests/state/stateTrieSync/stateTrieSync_test.go b/integrationTests/state/stateTrieSync/stateTrieSync_test.go index 8bfbd584a70..cb933aedba6 100644 --- a/integrationTests/state/stateTrieSync/stateTrieSync_test.go +++ b/integrationTests/state/stateTrieSync/stateTrieSync_test.go @@ -512,9 +512,26 @@ func testSyncMissingSnapshotNodes(t *testing.T, version int) { checkAllDataTriesAreSynced(t, numDataTrieLeaves, requesterTrie, dataTrieRootHashes) } +func GetAllHashes(t *testing.T, tr common.Trie, rootHash []byte) [][]byte { + iterator, err := trie.NewDFSIterator(tr, rootHash) + require.Nil(t, err) + + hashes := make([][]byte, 0) + hashes = append(hashes, iterator.GetHash()) + for iterator.HasNext() { + err = iterator.Next() + require.Nil(t, err) + + hashes = append(hashes, iterator.GetHash()) + } + + return hashes +} + func copyPartialState(t *testing.T, sourceNode, destinationNode *integrationTests.TestProcessorNode, dataTriesRootHashes [][]byte) { resolverTrie := sourceNode.TrieContainer.Get([]byte(dataRetriever.UserAccountsUnit.String())) - hashes, _ := resolverTrie.GetAllHashes() + rootHash, _ := resolverTrie.RootHash() + hashes := GetAllHashes(t, resolverTrie, rootHash) assert.NotEqual(t, 0, len(hashes)) hashes = append(hashes, getDataTriesHashes(t, resolverTrie, dataTriesRootHashes)...) @@ -531,7 +548,6 @@ func copyPartialState(t *testing.T, sourceNode, destinationNode *integrationTest err = destStorage.Put(hash, val) assert.Nil(t, err) } - } func getDataTriesHashes(t *testing.T, tr common.Trie, dataTriesRootHashes [][]byte) [][]byte { @@ -540,8 +556,7 @@ func getDataTriesHashes(t *testing.T, tr common.Trie, dataTriesRootHashes [][]by dt, err := tr.Recreate(rh) assert.Nil(t, err) - dtHashes, err := dt.GetAllHashes() - assert.Nil(t, err) + dtHashes := GetAllHashes(t, dt, rh) hashes = append(hashes, dtHashes...) } diff --git a/node/node_test.go b/node/node_test.go index 152cf98bdd7..e0c8d9d31c6 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -4004,7 +4004,7 @@ func TestNode_GetProofShouldWork(t *testing.T) { stateComponents.AccountsAPI = &stateMock.AccountsStub{ GetTrieCalled: func(_ []byte) (common.Trie, error) { return &trieMock.TrieStub{ - GetProofCalled: func(key []byte) ([][]byte, []byte, error) { + GetProofCalled: func(key []byte, _ []byte) ([][]byte, []byte, error) { assert.Equal(t, trieKey, hex.EncodeToString(key)) return proof, value, nil }, @@ -4053,7 +4053,7 @@ func TestNode_getProofErrWhenComputingProof(t *testing.T) { stateComponents.AccountsAPI = &stateMock.AccountsStub{ GetTrieCalled: func(_ []byte) (common.Trie, error) { return &trieMock.TrieStub{ - GetProofCalled: func(_ []byte) ([][]byte, []byte, error) { + GetProofCalled: func(_ []byte, _ []byte) ([][]byte, []byte, error) { return nil, nil, expectedErr }, }, nil @@ -4129,7 +4129,7 @@ func TestNode_GetProofDataTrieShouldWork(t *testing.T) { stateComponents.AccountsAPI = &stateMock.AccountsStub{ GetTrieCalled: func(_ []byte) (common.Trie, error) { return &trieMock.TrieStub{ - GetProofCalled: func(key []byte) ([][]byte, []byte, error) { + GetProofCalled: func(key []byte, _ []byte) ([][]byte, []byte, error) { if hex.EncodeToString(key) == mainTrieKey { return mainTrieProof, mainTrieValue, nil } From 3bc343874ebc3cff882ca62c8feffdcaf3dd2e2d Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Thu, 14 Nov 2024 12:45:40 +0200 Subject: [PATCH 03/13] remove unnecessary array --- trie/branchNode.go | 23 ++++++----------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/trie/branchNode.go b/trie/branchNode.go index 301a4efe47f..bd514fda6e8 100644 --- a/trie/branchNode.go +++ b/trie/branchNode.go @@ -92,9 +92,6 @@ func (bn *branchNode) setHash(goRoutinesManager common.TrieGoroutinesManager) { waitGroup := sync.WaitGroup{} - encodedChildrenMutex := &sync.Mutex{} - encodedChildren := make([][]byte, nrOfChildren) - for i := 0; i < nrOfChildren; i++ { if !goRoutinesManager.ShouldContinueProcessing() { return @@ -112,9 +109,9 @@ func (bn *branchNode) setHash(goRoutinesManager common.TrieGoroutinesManager) { return } - encodedChildrenMutex.Lock() - encodedChildren[i] = encChild - encodedChildrenMutex.Unlock() + bn.childrenMutexes[i].Lock() + bn.EncodedChildren[i] = encChild + bn.childrenMutexes[i].Unlock() continue } @@ -126,23 +123,15 @@ func (bn *branchNode) setHash(goRoutinesManager common.TrieGoroutinesManager) { goRoutinesManager.SetError(err) return } - encodedChildrenMutex.Lock() - encodedChildren[childPos] = encChild - encodedChildrenMutex.Unlock() + bn.childrenMutexes[childPos].Lock() + bn.EncodedChildren[childPos] = encChild + bn.childrenMutexes[childPos].Unlock() waitGroup.Done() }(i) } waitGroup.Wait() - for i := range encodedChildren { - if len(encodedChildren[i]) == 0 { - continue - } - - bn.EncodedChildren[i] = encodedChildren[i] - } - hash, err := encodeNodeAndGetHash(bn) if err != nil { goRoutinesManager.SetError(err) From 4edb90edbb0e61d4d461be735ec0f03095810fb6 Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Thu, 14 Nov 2024 13:03:59 +0200 Subject: [PATCH 04/13] fix setHash deadlock --- trie/branchNode.go | 9 +++++++++ trie/extensionNode.go | 4 ++++ 2 files changed, 13 insertions(+) diff --git a/trie/branchNode.go b/trie/branchNode.go index bd514fda6e8..7c4d4261bd6 100644 --- a/trie/branchNode.go +++ b/trie/branchNode.go @@ -103,6 +103,10 @@ func (bn *branchNode) setHash(goRoutinesManager common.TrieGoroutinesManager) { if !goRoutinesManager.CanStartGoRoutine() { bn.children[i].setHash(goRoutinesManager) + if !goRoutinesManager.ShouldContinueProcessing() { + return + } + encChild, err := encodeNodeAndGetHash(bn.children[i]) if err != nil { goRoutinesManager.SetError(err) @@ -118,9 +122,14 @@ func (bn *branchNode) setHash(goRoutinesManager common.TrieGoroutinesManager) { waitGroup.Add(1) go func(childPos int) { bn.children[childPos].setHash(goRoutinesManager) + if !goRoutinesManager.ShouldContinueProcessing() { + waitGroup.Done() + return + } encChild, err := encodeNodeAndGetHash(bn.children[childPos]) if err != nil { goRoutinesManager.SetError(err) + waitGroup.Done() return } bn.childrenMutexes[childPos].Lock() diff --git a/trie/extensionNode.go b/trie/extensionNode.go index 2108df02c1b..8d7a4b14e2f 100644 --- a/trie/extensionNode.go +++ b/trie/extensionNode.go @@ -80,6 +80,10 @@ func (en *extensionNode) setHash(goRoutinesManager common.TrieGoroutinesManager) if en.shouldSetHashForChild() { en.child.setHash(goRoutinesManager) + if !goRoutinesManager.ShouldContinueProcessing() { + return + } + encChild, err := encodeNodeAndGetHash(en.child) if err != nil { goRoutinesManager.SetError(err) From 98869ad52e166f4550a32272070c220cf55213ec Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Wed, 27 Nov 2024 14:07:29 +0200 Subject: [PATCH 05/13] fix after merge --- trie/branchNode_test.go | 1 - trie/extensionNode_test.go | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/trie/branchNode_test.go b/trie/branchNode_test.go index b70f7c5f17a..4e90b4663a6 100644 --- a/trie/branchNode_test.go +++ b/trie/branchNode_test.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "errors" - "fmt" "testing" "github.com/multiversx/mx-chain-core-go/core" diff --git a/trie/extensionNode_test.go b/trie/extensionNode_test.go index 67933eedb9f..22c8a53cd85 100644 --- a/trie/extensionNode_test.go +++ b/trie/extensionNode_test.go @@ -1274,7 +1274,7 @@ func TestExtensionNode_deleteBatch(t *testing.T) { goRoutinesManager, err := NewGoroutinesManager(th, errChan.NewErrChanWrapper(), make(chan struct{})) assert.Nil(t, err) - newEn, _ := en.insert(data, goRoutinesManager, common.NewModifiedHashesSlice(), nil) + newEn := en.insert(data, goRoutinesManager, common.NewModifiedHashesSlice(), nil) newEn.setHash(getTestGoroutinesManager()) err = newEn.commitDirty(0, 5, testscommon.NewMemDbMock(), testscommon.NewMemDbMock()) assert.Nil(t, err) From e8d5c1f65a76b81ca36ac16d33c03b3a973ed5c9 Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Mon, 2 Dec 2024 14:45:32 +0200 Subject: [PATCH 06/13] remove duplicated code from setHash func --- trie/branchNode.go | 48 +++++++++++++++++++--------------------------- 1 file changed, 20 insertions(+), 28 deletions(-) diff --git a/trie/branchNode.go b/trie/branchNode.go index 35ccf76228c..65d869040a3 100644 --- a/trie/branchNode.go +++ b/trie/branchNode.go @@ -102,39 +102,14 @@ func (bn *branchNode) setHash(goRoutinesManager common.TrieGoroutinesManager) { } if !goRoutinesManager.CanStartGoRoutine() { - bn.children[i].setHash(goRoutinesManager) - if !goRoutinesManager.ShouldContinueProcessing() { - return - } - - encChild, err := encodeNodeAndGetHash(bn.children[i]) - if err != nil { - goRoutinesManager.SetError(err) - return - } - - bn.childrenMutexes[i].Lock() - bn.EncodedChildren[i] = encChild - bn.childrenMutexes[i].Unlock() + bn.setHashForChild(i, goRoutinesManager) continue } waitGroup.Add(1) go func(childPos int) { - bn.children[childPos].setHash(goRoutinesManager) - if !goRoutinesManager.ShouldContinueProcessing() { - waitGroup.Done() - return - } - encChild, err := encodeNodeAndGetHash(bn.children[childPos]) - if err != nil { - goRoutinesManager.SetError(err) - waitGroup.Done() - return - } - bn.childrenMutexes[childPos].Lock() - bn.EncodedChildren[childPos] = encChild - bn.childrenMutexes[childPos].Unlock() + bn.setHashForChild(childPos, goRoutinesManager) + goRoutinesManager.EndGoRoutineProcessing() waitGroup.Done() }(i) } @@ -149,6 +124,23 @@ func (bn *branchNode) setHash(goRoutinesManager common.TrieGoroutinesManager) { bn.hash = hash } +func (bn *branchNode) setHashForChild(childPos int, goRoutinesManager common.TrieGoroutinesManager) { + bn.children[childPos].setHash(goRoutinesManager) + if !goRoutinesManager.ShouldContinueProcessing() { + return + } + + encChild, err := encodeNodeAndGetHash(bn.children[childPos]) + if err != nil { + goRoutinesManager.SetError(err) + return + } + + bn.childrenMutexes[childPos].Lock() + bn.EncodedChildren[childPos] = encChild + bn.childrenMutexes[childPos].Unlock() +} + func (bn *branchNode) shouldSetHashForChild(childPos int) bool { bn.childrenMutexes[childPos].RLock() defer bn.childrenMutexes[childPos].RUnlock() From bf0f1b6b3a9f485bd3d112cf64a960197f170209 Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Thu, 12 Dec 2024 14:05:31 +0200 Subject: [PATCH 07/13] fix after merge --- trie/baseNode.go | 6 ++++++ trie/patriciaMerkleTrie.go | 18 +++++++++--------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/trie/baseNode.go b/trie/baseNode.go index f117c3e85b7..7320a28ebdc 100644 --- a/trie/baseNode.go +++ b/trie/baseNode.go @@ -15,6 +15,9 @@ type baseNode struct { } func (bn *baseNode) getHash() []byte { + bn.mutex.RLock() + defer bn.mutex.RUnlock() + return bn.hash } @@ -23,6 +26,9 @@ func (bn *baseNode) setGivenHash(hash []byte) { } func (bn *baseNode) isDirty() bool { + bn.mutex.RLock() + defer bn.mutex.RUnlock() + return bn.dirty } diff --git a/trie/patriciaMerkleTrie.go b/trie/patriciaMerkleTrie.go index a724c08a6a3..9823b71caa7 100644 --- a/trie/patriciaMerkleTrie.go +++ b/trie/patriciaMerkleTrie.go @@ -313,13 +313,13 @@ func (tr *patriciaMerkleTrie) getRootHash() ([]byte, error) { return hash, nil } - manager, err := NewGoroutinesManager(tr.goroutinesThrottler, errChan.NewErrChanWrapper(), tr.chanClose) + err := tr.goRoutinesManager.SetNewErrorChannel(errChan.NewErrChanWrapper()) if err != nil { return nil, err } - rootNode.setHash(manager) - err = manager.GetError() + rootNode.setHash(tr.goRoutinesManager) + err = tr.goRoutinesManager.GetError() if err != nil { return nil, err } @@ -353,13 +353,13 @@ func (tr *patriciaMerkleTrie) Commit() error { return nil } - manager, err := NewGoroutinesManager(tr.goroutinesThrottler, errChan.NewErrChanWrapper(), tr.chanClose) + err = tr.goRoutinesManager.SetNewErrorChannel(errChan.NewErrChanWrapper()) if err != nil { return err } - rootNode.setHash(manager) - err = manager.GetError() + rootNode.setHash(tr.goRoutinesManager) + err = tr.goRoutinesManager.GetError() if err != nil { return err } @@ -482,13 +482,13 @@ func (tr *patriciaMerkleTrie) GetDirtyHashes() (common.ModifiedHashes, error) { return nil, nil } - manager, err := NewGoroutinesManager(tr.goroutinesThrottler, errChan.NewErrChanWrapper(), tr.chanClose) + err = tr.goRoutinesManager.SetNewErrorChannel(errChan.NewErrChanWrapper()) if err != nil { return nil, err } - rootNode.setHash(manager) - err = manager.GetError() + rootNode.setHash(tr.goRoutinesManager) + err = tr.goRoutinesManager.GetError() if err != nil { return nil, err } From a08ea8a879632081f65950aad9b6fb96fc3a3c6b Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Mon, 16 Dec 2024 15:03:02 +0200 Subject: [PATCH 08/13] add baseNode mutex --- trie/baseNode.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/trie/baseNode.go b/trie/baseNode.go index 20db3ea5971..97df070d6d8 100644 --- a/trie/baseNode.go +++ b/trie/baseNode.go @@ -16,8 +16,6 @@ type baseNode struct { } func (bn *baseNode) getHash() []byte { - //TODO add mutex protection for all methods - bn.mutex.RLock() defer bn.mutex.RUnlock() @@ -25,6 +23,9 @@ func (bn *baseNode) getHash() []byte { } func (bn *baseNode) setGivenHash(hash []byte) { + bn.mutex.Lock() + defer bn.mutex.Unlock() + bn.hash = hash } @@ -36,6 +37,9 @@ func (bn *baseNode) isDirty() bool { } func (bn *baseNode) setDirty(dirty bool) { + bn.mutex.Lock() + defer bn.mutex.Unlock() + bn.dirty = dirty } From 5a65927562932e8a3e25d4ba149fc83eb8135426 Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Mon, 16 Dec 2024 16:40:51 +0200 Subject: [PATCH 09/13] create disabledGoroutinesManager --- trie/disabledGoroutinesManager.go | 50 ++++++++++++++++++++++++++ trie/disabledGoroutinesManager_test.go | 44 +++++++++++++++++++++++ trie/sync.go | 16 ++------- 3 files changed, 97 insertions(+), 13 deletions(-) create mode 100644 trie/disabledGoroutinesManager.go create mode 100644 trie/disabledGoroutinesManager_test.go diff --git a/trie/disabledGoroutinesManager.go b/trie/disabledGoroutinesManager.go new file mode 100644 index 00000000000..9cfce44906d --- /dev/null +++ b/trie/disabledGoroutinesManager.go @@ -0,0 +1,50 @@ +package trie + +import "github.com/multiversx/mx-chain-go/common" + +type disabledGoroutinesManager struct { + err error +} + +// NewDisabledGoroutinesManager creates a new instance of disabledGoroutinesManager +func NewDisabledGoroutinesManager() *disabledGoroutinesManager { + return &disabledGoroutinesManager{} +} + +// ShouldContinueProcessing returns true if there is no error +func (d *disabledGoroutinesManager) ShouldContinueProcessing() bool { + if d.err != nil { + return false + } + + return true +} + +// CanStartGoRoutine returns false +func (d *disabledGoroutinesManager) CanStartGoRoutine() bool { + return false +} + +// EndGoRoutineProcessing does nothing +func (d *disabledGoroutinesManager) EndGoRoutineProcessing() { +} + +// SetNewErrorChannel does nothing +func (d *disabledGoroutinesManager) SetNewErrorChannel(_ common.BufferedErrChan) error { + return nil +} + +// SetError sets the given error +func (d *disabledGoroutinesManager) SetError(err error) { + d.err = err +} + +// GetError returns the error +func (d *disabledGoroutinesManager) GetError() error { + return d.err +} + +// IsInterfaceNil returns true if there is no value under the interface +func (d *disabledGoroutinesManager) IsInterfaceNil() bool { + return d == nil +} diff --git a/trie/disabledGoroutinesManager_test.go b/trie/disabledGoroutinesManager_test.go new file mode 100644 index 00000000000..c82c66ebca8 --- /dev/null +++ b/trie/disabledGoroutinesManager_test.go @@ -0,0 +1,44 @@ +package trie + +import ( + "errors" + "testing" + + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/stretchr/testify/assert" +) + +func TestNewDisabledGoroutinesManager(t *testing.T) { + t.Parallel() + + d := NewDisabledGoroutinesManager() + assert.False(t, check.IfNil(d)) +} + +func TestDisabledGoroutinesManager_ShouldContinueProcessing(t *testing.T) { + t.Parallel() + + d := NewDisabledGoroutinesManager() + assert.True(t, d.ShouldContinueProcessing()) + + d.SetError(errors.New("error")) + assert.False(t, d.ShouldContinueProcessing()) +} + +func TestDisabledGoroutinesManager_CanStartGoRoutine(t *testing.T) { + t.Parallel() + + d := NewDisabledGoroutinesManager() + assert.False(t, d.CanStartGoRoutine()) +} + +func TestDisabledGoroutinesManager_SetAndGetError(t *testing.T) { + t.Parallel() + + d := NewDisabledGoroutinesManager() + assert.Nil(t, d.GetError()) + + err := errors.New("error") + d.SetError(err) + assert.Equal(t, err, d.GetError()) +} diff --git a/trie/sync.go b/trie/sync.go index 6f498e58c08..0e362c87c40 100644 --- a/trie/sync.go +++ b/trie/sync.go @@ -4,8 +4,6 @@ import ( "bytes" "context" "fmt" - "github.com/multiversx/mx-chain-core-go/core/throttler" - "github.com/multiversx/mx-chain-go/common/errChan" "sync" "time" @@ -359,17 +357,9 @@ func trieNode( return nil, err } - th, err := throttler.NewNumGoRoutinesThrottler(1) - if err != nil { - return nil, err - } - goRoutinesManager, err := NewGoroutinesManager(th, errChan.NewErrChanWrapper(), make(chan struct{})) - if err != nil { - return nil, err - } - - decodedNode.setHash(goRoutinesManager) - err = goRoutinesManager.GetError() + manager := NewDisabledGoroutinesManager() + decodedNode.setHash(manager) + err = manager.GetError() if err != nil { return nil, err } From 0e4014e7f38811fbe4ecd52801e2d0b31cdd1b4e Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Wed, 18 Dec 2024 16:32:54 +0200 Subject: [PATCH 10/13] fix linter issues --- trie/disabledGoroutinesManager.go | 6 +----- trie/node.go | 6 +----- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/trie/disabledGoroutinesManager.go b/trie/disabledGoroutinesManager.go index 9cfce44906d..4133c5fb8e9 100644 --- a/trie/disabledGoroutinesManager.go +++ b/trie/disabledGoroutinesManager.go @@ -13,11 +13,7 @@ func NewDisabledGoroutinesManager() *disabledGoroutinesManager { // ShouldContinueProcessing returns true if there is no error func (d *disabledGoroutinesManager) ShouldContinueProcessing() bool { - if d.err != nil { - return false - } - - return true + return d.err != nil } // CanStartGoRoutine returns false diff --git a/trie/node.go b/trie/node.go index a8d9b3b64e1..448bb422371 100644 --- a/trie/node.go +++ b/trie/node.go @@ -125,11 +125,7 @@ func concat(s1 []byte, s2 ...byte) []byte { func hasValidHash(n node) bool { childHash := n.getHash() - if childHash == nil { - return false - } - - return true + return len(childHash) != 0 } func decodeNode(encNode []byte, marshalizer marshal.Marshalizer, hasher hashing.Hasher) (node, error) { From 381ecec5ae67a5b1a1c612c313e910697f49fbd7 Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Wed, 18 Dec 2024 16:52:14 +0200 Subject: [PATCH 11/13] small fix --- trie/disabledGoroutinesManager.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trie/disabledGoroutinesManager.go b/trie/disabledGoroutinesManager.go index 4133c5fb8e9..e5be63e7f9b 100644 --- a/trie/disabledGoroutinesManager.go +++ b/trie/disabledGoroutinesManager.go @@ -13,7 +13,7 @@ func NewDisabledGoroutinesManager() *disabledGoroutinesManager { // ShouldContinueProcessing returns true if there is no error func (d *disabledGoroutinesManager) ShouldContinueProcessing() bool { - return d.err != nil + return d.err == nil } // CanStartGoRoutine returns false From 0ca88fbad64c671fa019d2ef48c5f4b8a2a769e9 Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Thu, 19 Dec 2024 15:35:26 +0200 Subject: [PATCH 12/13] add setHash concurrency unit tests --- trie/patriciaMerkleTrie_test.go | 122 ++++++++++++++++++++++++++++++++ 1 file changed, 122 insertions(+) diff --git a/trie/patriciaMerkleTrie_test.go b/trie/patriciaMerkleTrie_test.go index 31c9d059571..bba864c20d3 100644 --- a/trie/patriciaMerkleTrie_test.go +++ b/trie/patriciaMerkleTrie_test.go @@ -1598,6 +1598,9 @@ func TestPatriciaMerkleTrie_AddBatchedDataToTrie(t *testing.T) { time.Sleep(time.Millisecond * 100) } }, + SetErrorCalled: func(err error) { + assert.Fail(t, "should not have called this function") + }, } trie.SetGoRoutinesManager(tr, grm) @@ -1663,6 +1666,9 @@ func TestPatriciaMerkleTrie_AddBatchedDataToTrie(t *testing.T) { time.Sleep(time.Millisecond * 100) } }, + SetErrorCalled: func(err error) { + assert.Fail(t, "should not have called this function") + }, } trie.SetGoRoutinesManager(tr, grm) @@ -1828,6 +1834,9 @@ func TestPatriciaMerkleTrie_Get(t *testing.T) { time.Sleep(time.Millisecond * 100) } }, + SetErrorCalled: func(err error) { + assert.Fail(t, "should not have called this function") + }, } trie.SetGoRoutinesManager(tr, grm) @@ -1860,6 +1869,119 @@ func TestPatriciaMerkleTrie_Get(t *testing.T) { }) } +func TestPatriciaMerkleTrie_RootHash(t *testing.T) { + t.Parallel() + + t.Run("set root hash with batched data commits batch", func(t *testing.T) { + t.Parallel() + + tr := emptyTrie() + numOperations := 1000 + for i := 0; i < numOperations; i++ { + _ = tr.Update([]byte("dog"+strconv.Itoa(i)), []byte("reindeer")) + } + + rootHash, err := tr.RootHash() + assert.Nil(t, err) + assert.NotEqual(t, emptyTrieHash, rootHash) + }) + t.Run("set root hash and update trie concurrently should serialize operations", func(t *testing.T) { + t.Parallel() + + // create trie with some data + tr := emptyTrie() + numOperations := 1000 + for i := 0; i < numOperations; i++ { + _ = tr.Update([]byte("dog"+strconv.Itoa(i)), []byte("reindeer")) + } + trie.ExecuteUpdatesFromBatch(tr) + + // compute rootHash + waitForSignal := atomic.Bool{} + waitForSignal.Store(true) + startedComputingRootHash := atomic.Bool{} + grm := &mock.GoroutinesManagerStub{ + CanStartGoRoutineCalled: func() bool { + startedComputingRootHash.Store(true) + return true + }, + EndGoRoutineProcessingCalled: func() { + for waitForSignal.Load() { + time.Sleep(time.Millisecond * 100) + } + }, + SetErrorCalled: func(err error) { + assert.Fail(t, "should not have called this function") + }, + } + trie.SetGoRoutinesManager(tr, grm) + + go func() { + rootHash1, err := tr.RootHash() + assert.Nil(t, err) + assert.NotEqual(t, emptyTrieHash, rootHash1) + }() + + // wait for start of the computation of the root hash + for !startedComputingRootHash.Load() { + time.Sleep(time.Millisecond * 100) + } + + for i := numOperations; i < numOperations*2; i++ { + _ = tr.Update([]byte("dog"+strconv.Itoa(i)), []byte("reindeer")) + } + setNewErrChanCalled := atomic.Bool{} + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + grm.SetNewErrorChannelCalled = func(common.BufferedErrChan) error { + setNewErrChanCalled.Store(true) + return nil + } + trie.ExecuteUpdatesFromBatch(tr) + wg.Done() + }() + + // commit batch to trie does not start until root hash is fully computed + time.Sleep(time.Millisecond * 500) + assert.False(t, setNewErrChanCalled.Load()) + + waitForSignal.Store(false) + wg.Wait() + assert.True(t, setNewErrChanCalled.Load()) + }) + t.Run("set root hash and get from trie concurrently", func(t *testing.T) { + t.Parallel() + + tr := emptyTrie() + numOperations := 100000 + for i := 0; i < numOperations; i++ { + _ = tr.Update([]byte("dog"+strconv.Itoa(i)), []byte("reindeer")) + } + trie.ExecuteUpdatesFromBatch(tr) + + wg := sync.WaitGroup{} + wg.Add(1) + setRootHashFinished := atomic.Bool{} + go func() { + for !setRootHashFinished.Load() { + index := rand.Intn(numOperations) + val, _, err := tr.Get([]byte("dog" + strconv.Itoa(index))) + assert.Nil(t, err) + assert.Equal(t, []byte("reindeer"), val) + } + wg.Done() + }() + + rootHash, err := tr.RootHash() + assert.Nil(t, err) + assert.NotEqual(t, emptyTrieHash, rootHash) + + setRootHashFinished.Store(true) + wg.Wait() + }) +} + func TestPatricianMerkleTrie_ConcurrentOperations(t *testing.T) { t.Parallel() From 4ac8de43e8f3a56bd4f2a9fbf0cf41ce8de5ce77 Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Wed, 8 Jan 2025 11:55:04 +0200 Subject: [PATCH 13/13] fix after review --- trie/patriciaMerkleTrie.go | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/trie/patriciaMerkleTrie.go b/trie/patriciaMerkleTrie.go index d2e6f80d7e4..59b2c33ba2e 100644 --- a/trie/patriciaMerkleTrie.go +++ b/trie/patriciaMerkleTrie.go @@ -683,19 +683,15 @@ func logMapWithTrace(message string, paramName string, hashes common.ModifiedHas // GetProof computes a Merkle proof for the node that is present at the given key func (tr *patriciaMerkleTrie) GetProof(key []byte, rootHash []byte) ([][]byte, []byte, error) { - trie, err := tr.Recreate(rootHash) - if err != nil { - return nil, nil, err - } - - pmt, ok := trie.(*patriciaMerkleTrie) - if !ok { - return nil, nil, ErrWrongTypeAssertion + //TODO refactor this function to avoid encoding the node after it is retrieved from the DB. + // The encoded node is actually the value from db, thus we can use the retrieved value directly + if len(key) == 0 || bytes.Equal(rootHash, common.EmptyTrieHash) { + return nil, nil, ErrNilNode } - rootNode := pmt.GetRootNode() - if check.IfNil(rootNode) { - return nil, nil, ErrNilNode + rootNode, err := getNodeFromDBAndDecode(rootHash, tr.trieStorage, tr.marshalizer, tr.hasher) + if err != nil { + return nil, nil, fmt.Errorf("trie get proof error: %w", err) } var proof [][]byte