Skip to content

Commit

Permalink
fixes after review and unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
BeniaminDrasovean committed Jan 8, 2025
1 parent 354812c commit dc26e41
Show file tree
Hide file tree
Showing 13 changed files with 287 additions and 43 deletions.
1 change: 1 addition & 0 deletions common/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -413,4 +413,5 @@ type TrieHashesCollector interface {
GetDirtyHashes() ModifiedHashes
AddObsoleteHashes(oldRootHash []byte, oldHashes [][]byte)
GetCollectedData() ([]byte, ModifiedHashes, ModifiedHashes)
IsInterfaceNil() bool
}
6 changes: 5 additions & 1 deletion state/accountsDB.go
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,11 @@ func (adb *AccountsDB) commit() ([]byte, error) {

// Step 2. commit main trie
if adb.mainTrie.GetStorageManager().IsPruningEnabled() {
hc = hashesCollector.NewHashesCollector(hc)
wrappedHc, err := hashesCollector.NewHashesCollector(hc)
if err != nil {
return nil, err
}
hc = wrappedHc
}
err := adb.mainTrie.Commit(hc)
if err != nil {
Expand Down
5 changes: 5 additions & 0 deletions state/hashesCollector/dataTrieHashesCollector.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,8 @@ func (hc *dataTrieHashesCollector) GetCollectedData() ([]byte, common.ModifiedHa

return nil, hc.oldHashes, hc.newHashes
}

// IsInterfaceNil returns true if there is no value under the interface
func (hc *dataTrieHashesCollector) IsInterfaceNil() bool {
return hc == nil
}
102 changes: 102 additions & 0 deletions state/hashesCollector/dataTrieHashesCollector_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package hashesCollector

import (
"strconv"
"sync"
"testing"

"github.com/multiversx/mx-chain-core-go/core/check"
"github.com/stretchr/testify/assert"
)

func TestNewDataTrieHashesCollector(t *testing.T) {
t.Parallel()

hc := NewDataTrieHashesCollector()
assert.False(t, check.IfNil(hc))
assert.NotNil(t, hc.oldHashes)
assert.NotNil(t, hc.newHashes)
}

func TestDataTrieHashesCollector_AddDirtyHash(t *testing.T) {
t.Parallel()

dthc := NewDataTrieHashesCollector()
numHashes := 1000
wg := &sync.WaitGroup{}
wg.Add(numHashes)
for i := 0; i < numHashes; i++ {
go func(index int) {
dthc.AddDirtyHash([]byte(strconv.Itoa(index)))
wg.Done()
}(i)
}
wg.Wait()

for i := 0; i < numHashes; i++ {
_, ok := dthc.newHashes[strconv.Itoa(i)]
assert.True(t, ok)
}
}

func TestDataTrieHashesCollector_GetDirtyHashes(t *testing.T) {
t.Parallel()

dthc := NewDataTrieHashesCollector()
numHashes := 1000
for i := 0; i < numHashes; i++ {
dthc.AddDirtyHash([]byte(strconv.Itoa(i)))
}

dirtyHashes := dthc.GetDirtyHashes()
assert.Equal(t, numHashes, len(dirtyHashes))
for i := 0; i < numHashes; i++ {
_, ok := dirtyHashes[strconv.Itoa(i)]
assert.True(t, ok)
}
}

func TestDataTrieHashesCollector_AddObsoleteHashes(t *testing.T) {
t.Parallel()

dthc := NewDataTrieHashesCollector()
numHashes := 1000
dirtyHashes := make([][]byte, numHashes)

for i := 0; i < numHashes; i++ {
dirtyHashes[i] = []byte(strconv.Itoa(i))
}

dthc.AddObsoleteHashes(nil, dirtyHashes)

assert.Equal(t, numHashes, len(dthc.oldHashes))
for i := 0; i < numHashes; i++ {
_, ok := dthc.oldHashes[strconv.Itoa(i)]
assert.True(t, ok)
}
}

func TestDataTriehashesCollector_GetCollectedData(t *testing.T) {
t.Parallel()

dthc := NewDataTrieHashesCollector()
numHashes := 1000
dirtyHashes := make([][]byte, numHashes)

for i := 0; i < numHashes; i++ {
dirtyHashes[i] = []byte(strconv.Itoa(i))
dthc.AddDirtyHash(dirtyHashes[i])
}
dthc.AddObsoleteHashes(nil, dirtyHashes)

oldRootHash, oldHashes, newHashes := dthc.GetCollectedData()
assert.Nil(t, oldRootHash)
assert.Equal(t, numHashes, len(oldHashes))
assert.Equal(t, numHashes, len(newHashes))
for i := 0; i < numHashes; i++ {
_, ok := oldHashes[strconv.Itoa(i)]
assert.True(t, ok)
_, ok = newHashes[strconv.Itoa(i)]
assert.True(t, ok)
}
}
5 changes: 5 additions & 0 deletions state/hashesCollector/disabledHashesCollector.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,8 @@ func (hc *disabledHashesCollector) AddObsoleteHashes(_ []byte, _ [][]byte) {
func (hc *disabledHashesCollector) GetCollectedData() ([]byte, common.ModifiedHashes, common.ModifiedHashes) {
return nil, nil, nil
}

// IsInterfaceNil returns true if there is no value under the interface
func (hc *disabledHashesCollector) IsInterfaceNil() bool {
return hc == nil
}
6 changes: 6 additions & 0 deletions state/hashesCollector/export_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package hashesCollector

// GetOldRootHash -
func (hc *hashesCollector) GetOldRootHash() []byte {
return hc.oldRootHash
}
18 changes: 16 additions & 2 deletions state/hashesCollector/hashesCollector.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package hashesCollector

import (
"errors"

"github.com/multiversx/mx-chain-core-go/core/check"
"github.com/multiversx/mx-chain-go/common"
)

Expand All @@ -10,13 +13,19 @@ type hashesCollector struct {
oldRootHash []byte
}

// ErrNilTrieHashesCollector is returned when the trie hashes collector is nil.
var ErrNilTrieHashesCollector = errors.New("nil trie hashes collector")

// NewHashesCollector creates a new instance of hashesCollector.
// This collector is used to collect hashes related to the main trie.
func NewHashesCollector(collector common.TrieHashesCollector) *hashesCollector {
func NewHashesCollector(collector common.TrieHashesCollector) (*hashesCollector, error) {
if check.IfNil(collector) {
return nil, ErrNilTrieHashesCollector
}
return &hashesCollector{
TrieHashesCollector: collector,
oldRootHash: nil,
}
}, nil
}

// AddObsoleteHashes adds the old root hash and the old hashes to the collector.
Expand All @@ -30,3 +39,8 @@ func (hc *hashesCollector) GetCollectedData() ([]byte, common.ModifiedHashes, co
_, oldHashes, newHashes := hc.TrieHashesCollector.GetCollectedData()
return hc.oldRootHash, oldHashes, newHashes
}

// IsInterfaceNil returns true if there is no value under the interface
func (hc *hashesCollector) IsInterfaceNil() bool {
return hc == nil
}
71 changes: 71 additions & 0 deletions state/hashesCollector/hashesCollector_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package hashesCollector_test

import (
"testing"

"github.com/multiversx/mx-chain-core-go/core/check"
"github.com/multiversx/mx-chain-go/common"
"github.com/multiversx/mx-chain-go/state/hashesCollector"
"github.com/multiversx/mx-chain-go/testscommon/trie"
"github.com/stretchr/testify/assert"
)

func TestNewHashesCollector(t *testing.T) {
t.Parallel()

hc, err := hashesCollector.NewHashesCollector(nil)
assert.True(t, check.IfNil(hc))
assert.Equal(t, hashesCollector.ErrNilTrieHashesCollector, err)

hc, err = hashesCollector.NewHashesCollector(&trie.TrieHashesCollectorStub{})
assert.False(t, check.IfNil(hc))
assert.Nil(t, err)
assert.Nil(t, hc.GetOldRootHash())
}

func TestHashesCollector_AddObsoleteHashes(t *testing.T) {
t.Parallel()

addObsoleteHashesCalled := false
oldRootHash := []byte("oldRootHash")
oldHashes := [][]byte{[]byte("oldHash1"), []byte("oldHash2")}
hc := &trie.TrieHashesCollectorStub{
AddObsoleteHashesCalled: func(oldRootHash []byte, oldHashes [][]byte) {
assert.Equal(t, oldRootHash, oldRootHash)
assert.Equal(t, oldHashes, oldHashes)
addObsoleteHashesCalled = true
},
}
wrappedHc, _ := hashesCollector.NewHashesCollector(hc)

wrappedHc.AddObsoleteHashes(oldRootHash, oldHashes)
assert.True(t, addObsoleteHashesCalled)
assert.Equal(t, oldRootHash, wrappedHc.GetOldRootHash())
}

func TestHashesCollector_GetCollectedData(t *testing.T) {
t.Parallel()

getCollectedDataCalled := false
oldHashes := common.ModifiedHashes{"oldHash1": {}, "oldHash2": {}}
newHashes := common.ModifiedHashes{"newHash1": {}, "newHash2": {}}
hc := &trie.TrieHashesCollectorStub{
GetCollectedDataCalled: func() ([]byte, common.ModifiedHashes, common.ModifiedHashes) {
getCollectedDataCalled = true
return []byte("oldRootHash"), oldHashes, newHashes
},
}
wrappedHc, _ := hashesCollector.NewHashesCollector(hc)

oldRootHash, collectedOldHashes, collectedNewHashes := wrappedHc.GetCollectedData()
assert.True(t, getCollectedDataCalled)
assert.Nil(t, oldRootHash)
assert.Equal(t, oldHashes, collectedOldHashes)
assert.Equal(t, newHashes, collectedNewHashes)

wrappedHc.AddObsoleteHashes([]byte("oldRootHash1"), [][]byte{[]byte("oldHash1"), []byte("oldHash2")})
oldRootHash, collectedOldHashes, collectedNewHashes = wrappedHc.GetCollectedData()
assert.Equal(t, []byte("oldRootHash1"), oldRootHash)
assert.Equal(t, oldHashes, collectedOldHashes)
assert.Equal(t, newHashes, collectedNewHashes)
}
46 changes: 46 additions & 0 deletions testscommon/trie/trieHashesCollectorStub.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package trie

import "github.com/multiversx/mx-chain-go/common"

// TrieHashesCollectorStub is a stub for the TrieHashesCollector interface.
type TrieHashesCollectorStub struct {
AddDirtyHashCalled func(hash []byte)
GetDirtyHashesCalled func() common.ModifiedHashes
AddObsoleteHashesCalled func(oldRootHash []byte, oldHashes [][]byte)
GetCollectedDataCalled func() ([]byte, common.ModifiedHashes, common.ModifiedHashes)
}

// AddDirtyHash -
func (h *TrieHashesCollectorStub) AddDirtyHash(hash []byte) {
if h.AddDirtyHashCalled != nil {
h.AddDirtyHashCalled(hash)
}
}

// GetDirtyHashes -
func (h *TrieHashesCollectorStub) GetDirtyHashes() common.ModifiedHashes {
if h.GetDirtyHashesCalled != nil {
return h.GetDirtyHashesCalled()
}
return nil
}

// AddObsoleteHashes -
func (h *TrieHashesCollectorStub) AddObsoleteHashes(oldRootHash []byte, oldHashes [][]byte) {
if h.AddObsoleteHashesCalled != nil {
h.AddObsoleteHashesCalled(oldRootHash, oldHashes)
}
}

// GetCollectedData -
func (h *TrieHashesCollectorStub) GetCollectedData() ([]byte, common.ModifiedHashes, common.ModifiedHashes) {
if h.GetCollectedDataCalled != nil {
return h.GetCollectedDataCalled()
}
return nil, nil, nil
}

// IsInterfaceNil -
func (h *TrieHashesCollectorStub) IsInterfaceNil() bool {
return h == nil
}
15 changes: 2 additions & 13 deletions trie/branchNode.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,19 +233,8 @@ func (bn *branchNode) commitDirty(

waitGroup.Wait()

bn.dirty = false
encNode, err := bn.getEncodedNode()
if err != nil {
goRoutinesManager.SetError(err)
return
}
hash := bn.hasher.Compute(string(encNode))
bn.hash = hash
hashesCollector.AddDirtyHash(hash)

err = targetDb.Put(hash, encNode)
if err != nil {
goRoutinesManager.SetError(err)
ok := saveDirtyNodeToStorage(bn, goRoutinesManager, hashesCollector, targetDb, bn.hasher)
if !ok {
return
}

Expand Down
15 changes: 2 additions & 13 deletions trie/extensionNode.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,19 +158,8 @@ func (en *extensionNode) commitDirty(
en.EncodedChild = child.getHash()
}

en.dirty = false
encNode, err := en.getEncodedNode()
if err != nil {
goRoutinesManager.SetError(err)
return
}
hash := en.hasher.Compute(string(encNode))
en.hash = hash
hashesCollector.AddDirtyHash(hash)

err = targetDb.Put(hash, encNode)
if err != nil {
goRoutinesManager.SetError(err)
ok := saveDirtyNodeToStorage(en, goRoutinesManager, hashesCollector, targetDb, en.hasher)
if !ok {
return
}

Expand Down
15 changes: 1 addition & 14 deletions trie/leafNode.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,20 +80,7 @@ func (ln *leafNode) commitDirty(
return
}

ln.dirty = false
encNode, err := ln.getEncodedNode()
if err != nil {
goRoutinesManager.SetError(err)
return
}
hash := ln.hasher.Compute(string(encNode))
ln.hash = hash
hashesCollector.AddDirtyHash(hash)

err = targetDb.Put(hash, encNode)
if err != nil {
goRoutinesManager.SetError(err)
}
saveDirtyNodeToStorage(ln, goRoutinesManager, hashesCollector, targetDb, ln.hasher)
}

func (ln *leafNode) commitSnapshot(
Expand Down
Loading

0 comments on commit dc26e41

Please sign in to comment.