Skip to content

Commit

Permalink
return keys for function States (#4311)
Browse files Browse the repository at this point in the history
* return keys for States

* Update iterator.go

delete debug lines

* Update iterator.go

* address comment
  • Loading branch information
CoderZhi authored Jun 25, 2024
1 parent afa504c commit b558929
Show file tree
Hide file tree
Showing 13 changed files with 94 additions and 45 deletions.
2 changes: 1 addition & 1 deletion action/protocol/poll/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ func allBlockMetasFromDB(sr protocol.StateReader, blocksInEpoch uint64) ([]*Bloc
blockmetas := make([]*BlockMeta, 0, iter.Size())
for i := 0; i < iter.Size(); i++ {
bm := &BlockMeta{}
switch err := iter.Next(bm); errors.Cause(err) {
switch _, err := iter.Next(bm); errors.Cause(err) {
case nil:
blockmetas = append(blockmetas, bm)
case state.ErrNilValue:
Expand Down
4 changes: 2 additions & 2 deletions action/protocol/staking/candidate_statereader.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ func (c *candSR) getAllBuckets() ([]*VoteBucket, uint64, error) {
buckets := make([]*VoteBucket, 0, iter.Size())
for i := 0; i < iter.Size(); i++ {
vb := &VoteBucket{}
switch err := iter.Next(vb); errors.Cause(err) {
switch _, err := iter.Next(vb); errors.Cause(err) {
case nil:
buckets = append(buckets, vb)
case state.ErrNilValue:
Expand Down Expand Up @@ -324,7 +324,7 @@ func (c *candSR) getAllCandidates() (CandidateList, uint64, error) {
cands := make(CandidateList, 0, iter.Size())
for i := 0; i < iter.Size(); i++ {
c := &Candidate{}
if err := iter.Next(c); err != nil {
if _, err := iter.Next(c); err != nil {
return nil, height, errors.Wrapf(err, "failed to deserialize candidate")
}
cands = append(cands, c)
Expand Down
8 changes: 6 additions & 2 deletions action/protocol/staking/staking_statereader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,10 @@ func TestStakingStateReader(t *testing.T) {
testNativeTotalAmount.count++
}
var err error
keys := make([][]byte, len(testNativeBuckets))
states := make([][]byte, len(testNativeBuckets))
for i := range states {
keys[i] = byteutil.Uint64ToBytesBigEndian(uint64(i))
states[i], err = state.Serialize(testNativeBuckets[i])
r.NoError(err)
}
Expand Down Expand Up @@ -164,7 +166,8 @@ func TestStakingStateReader(t *testing.T) {
t.Run("readStateBuckets", func(t *testing.T) {
sf, _, stakeSR, ctx, r := prepare(t)
sf.EXPECT().States(gomock.Any(), gomock.Any()).DoAndReturn(func(arg0 ...protocol.StateOption) (uint64, state.Iterator, error) {
iter := state.NewIterator(states)
iter, err := state.NewIterator(keys, states)
r.NoError(err)
return uint64(1), iter, nil
}).Times(1)
sf.EXPECT().State(gomock.Any(), gomock.Any()).Return(uint64(0), state.ErrStateNotExist).Times(1)
Expand Down Expand Up @@ -193,7 +196,8 @@ func TestStakingStateReader(t *testing.T) {
t.Run("readStateBucketsWithEndorsement", func(t *testing.T) {
sf, _, stakeSR, ctx, r := prepare(t)
sf.EXPECT().States(gomock.Any(), gomock.Any()).DoAndReturn(func(arg0 ...protocol.StateOption) (uint64, state.Iterator, error) {
iter := state.NewIterator(states)
iter, err := state.NewIterator(keys, states)
r.NoError(err)
return uint64(1), iter, nil
}).Times(1)
sf.EXPECT().State(gomock.AssignableToTypeOf(&Endorsement{}), gomock.Any()).DoAndReturn(func(arg0 any, arg1 ...protocol.StateOption) (uint64, error) {
Expand Down
8 changes: 6 additions & 2 deletions state/factory/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -545,12 +545,16 @@ func (sf *factory) States(opts ...protocol.StateOption) (uint64, state.Iterator,
if cfg.Key != nil {
return sf.currentChainHeight, nil, errors.Wrap(ErrNotSupported, "Read states with key option has not been implemented yet")
}
values, err := readStates(sf.dao, cfg.Namespace, cfg.Keys)
keys, values, err := readStates(sf.dao, cfg.Namespace, cfg.Keys)
if err != nil {
return 0, nil, err
}
iter, err := state.NewIterator(keys, values)
if err != nil {
return 0, nil, err
}

return sf.currentChainHeight, state.NewIterator(values), nil
return sf.currentChainHeight, iter, nil
}

// ReadView reads the view
Expand Down
9 changes: 5 additions & 4 deletions state/factory/factory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ func testFactoryStates(sf Factory, t *testing.T) {
accounts := make([]*state.Account, 0)
for i := 0; i < iter.Size(); i++ {
c := &state.Account{}
err = iter.Next(c)
_, err = iter.Next(c)
if err != nil {
continue
}
Expand All @@ -676,7 +676,7 @@ func testFactoryStates(sf Factory, t *testing.T) {
accounts = make([]*state.Account, 0)
for i := 0; i < iter.Size(); i++ {
c := &state.Account{}
err = iter.Next(c)
_, err = iter.Next(c)
if err != nil {
continue
}
Expand All @@ -695,7 +695,7 @@ func testFactoryStates(sf Factory, t *testing.T) {
accounts = make([]*state.Account, 0)
for i := 0; i < iter.Size(); i++ {
c := &state.Account{}
err = iter.Next(c)
_, err = iter.Next(c)
if err != nil {
continue
}
Expand All @@ -716,7 +716,8 @@ func testFactoryStates(sf Factory, t *testing.T) {
accounts = make([]*state.Account, 0)
for i := 0; i < iter.Size(); i++ {
c := &state.Account{}
require.NoError(t, iter.Next(c))
_, err = iter.Next(c)
require.NoError(t, err)
accounts = append(accounts, c)
}
require.Equal(t, uint64(90), accounts[0].Balance.Uint64())
Expand Down
8 changes: 6 additions & 2 deletions state/factory/statedb.go
Original file line number Diff line number Diff line change
Expand Up @@ -383,12 +383,16 @@ func (sdb *stateDB) States(opts ...protocol.StateOption) (uint64, state.Iterator
if cfg.Key != nil {
return sdb.currentChainHeight, nil, errors.Wrap(ErrNotSupported, "Read states with key option has not been implemented yet")
}
values, err := readStates(sdb.dao, cfg.Namespace, cfg.Keys)
keys, values, err := readStates(sdb.dao, cfg.Namespace, cfg.Keys)
if err != nil {
return 0, nil, err
}
iter, err := state.NewIterator(keys, values)
if err != nil {
return 0, nil, err
}

return sdb.currentChainHeight, state.NewIterator(values), nil
return sdb.currentChainHeight, iter, nil
}

// StateAtHeight returns a confirmed state at height -- archive mode
Expand Down
21 changes: 13 additions & 8 deletions state/factory/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,30 +108,35 @@ func protocolCommit(ctx context.Context, sr protocol.StateManager) error {
return nil
}

func readStates(kvStore db.KVStore, namespace string, keys [][]byte) ([][]byte, error) {
func readStates(kvStore db.KVStore, namespace string, keys [][]byte) ([][]byte, [][]byte, error) {
var (
ks, values [][]byte
err error
)
if keys == nil {
_, values, err := kvStore.Filter(namespace, func(k, v []byte) bool { return true }, nil, nil)
ks, values, err = kvStore.Filter(namespace, func(k, v []byte) bool { return true }, nil, nil)
if err != nil {
if errors.Cause(err) == db.ErrNotExist || errors.Cause(err) == db.ErrBucketNotExist {
return nil, errors.Wrapf(state.ErrStateNotExist, "failed to get states of ns = %x", namespace)
return nil, nil, errors.Wrapf(state.ErrStateNotExist, "failed to get states of ns = %x", namespace)
}
return nil, err
return nil, nil, err
}
return values, nil
return ks, values, nil
}
var values [][]byte
for _, key := range keys {
value, err := kvStore.Get(namespace, key)
switch errors.Cause(err) {
case db.ErrNotExist, db.ErrBucketNotExist:
values = append(values, nil)
ks = append(ks, key)
case nil:
values = append(values, value)
ks = append(ks, key)
default:
return nil, err
return nil, nil, err
}
}
return values, nil
return ks, values, nil
}

func newTwoLayerTrie(ns string, dao db.KVStore, rootKey string, create bool) (trie.TwoLayerTrie, error) {
Expand Down
8 changes: 6 additions & 2 deletions state/factory/workingset.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,11 +314,15 @@ func (ws *workingSet) States(opts ...protocol.StateOption) (uint64, state.Iterat
if cfg.Key != nil {
return 0, nil, errors.Wrap(ErrNotSupported, "Read states with key option has not been implemented yet")
}
values, err := ws.store.States(cfg.Namespace, cfg.Keys)
keys, values, err := ws.store.States(cfg.Namespace, cfg.Keys)
if err != nil {
return 0, nil, err
}
return ws.height, state.NewIterator(values), nil
iter, err := state.NewIterator(keys, values)
if err != nil {
return 0, nil, err
}
return ws.height, iter, nil
}

// PutState puts a state into DB
Expand Down
20 changes: 12 additions & 8 deletions state/factory/workingsetstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type (
workingSetStore interface {
db.KVStoreBasic
Commit() error
States(string, [][]byte) ([][]byte, error)
States(string, [][]byte) ([][]byte, [][]byte, error)
Digest() hash.Hash256
Finalize(uint64) error
Snapshot() int
Expand Down Expand Up @@ -108,7 +108,7 @@ func (store *stateDBWorkingSetStore) Delete(ns string, key []byte) error {
return nil
}

func (store *stateDBWorkingSetStore) States(ns string, keys [][]byte) ([][]byte, error) {
func (store *stateDBWorkingSetStore) States(ns string, keys [][]byte) ([][]byte, [][]byte, error) {
if store.readBuffer {
return readStates(store.flusher.KVStoreWithBuffer(), ns, keys)
}
Expand Down Expand Up @@ -183,21 +183,23 @@ func (store *factoryWorkingSetStore) Delete(ns string, key []byte) error {
return err
}

func (store *factoryWorkingSetStore) States(ns string, keys [][]byte) ([][]byte, error) {
func (store *factoryWorkingSetStore) States(ns string, keys [][]byte) ([][]byte, [][]byte, error) {
ks := [][]byte{}
values := [][]byte{}
if keys == nil {
iter, err := mptrie.NewLayerTwoLeafIterator(store.tlt, namespaceKey(ns), legacyKeyLen())
if err != nil {
return nil, err
return nil, nil, err
}
for {
_, value, err := iter.Next()
key, value, err := iter.Next()
if err == trie.ErrEndOfIterator {
break
}
if err != nil {
return nil, err
return nil, nil, err
}
ks = append(ks, key)
values = append(values, value)
}
} else {
Expand All @@ -206,14 +208,16 @@ func (store *factoryWorkingSetStore) States(ns string, keys [][]byte) ([][]byte,
switch errors.Cause(err) {
case state.ErrStateNotExist:
values = append(values, nil)
ks = append(ks, key)
case nil:
values = append(values, value)
ks = append(ks, key)
default:
return nil, err
return nil, nil, err
}
}
}
return values, nil
return ks, values, nil
}
func (store *factoryWorkingSetStore) Digest() hash.Hash256 {
return hash.Hash256b(store.flusher.SerializeQueue())
Expand Down
4 changes: 2 additions & 2 deletions state/factory/workingsetstore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func TestStateDBWorkingSetStore(t *testing.T) {
valueInStore, err = store.Get(namespace, key3)
require.NoError(err)
require.True(bytes.Equal(value3, valueInStore))
valuesInStore, err := store.States(namespace, [][]byte{key1, key2, key3})
_, valuesInStore, err := store.States(namespace, [][]byte{key1, key2, key3})
require.Equal(3, len(valuesInStore))
require.True(bytes.Equal(value1, valuesInStore[0]))
require.True(bytes.Equal(value2, valuesInStore[1]))
Expand All @@ -76,7 +76,7 @@ func TestStateDBWorkingSetStore(t *testing.T) {
require.NoError(store.Delete(namespace, key1))
_, err = store.Get(namespace, key1)
require.Error(err)
valuesInStore, err = store.States(namespace, [][]byte{key1, key2, key3})
_, valuesInStore, err = store.States(namespace, [][]byte{key1, key2, key3})
require.Equal(3, len(valuesInStore))
require.Nil(valuesInStore[0])
require.True(bytes.Equal(value2, valuesInStore[1]))
Expand Down
21 changes: 14 additions & 7 deletions state/iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,36 +15,43 @@ var ErrOutOfBoundary = errors.New("index is out of boundary")
// ErrNilValue is an error when value is nil
var ErrNilValue = errors.New("value is nil")

// ErrInConsistentLength is an error when keys and states have inconsistent length
var ErrInConsistentLength = errors.New("keys and states have inconsistent length")

// Iterator defines an interator to read a set of states
type Iterator interface {
// Size returns the size of the iterator
Size() int
// Next deserializes the next state in the iterator
Next(interface{}) error
Next(interface{}) ([]byte, error)
}

type iterator struct {
keys [][]byte
states [][]byte
index int
}

// NewIterator returns an interator given a list of serialized states
func NewIterator(states [][]byte) Iterator {
return &iterator{index: 0, states: states}
func NewIterator(keys [][]byte, states [][]byte) (Iterator, error) {
if len(keys) != len(states) {
return nil, ErrInConsistentLength
}
return &iterator{index: 0, keys: keys, states: states}, nil
}

func (it *iterator) Size() int {
return len(it.states)
}

func (it *iterator) Next(s interface{}) error {
func (it *iterator) Next(s interface{}) ([]byte, error) {
i := it.index
if i >= len(it.states) {
return ErrOutOfBoundary
return nil, ErrOutOfBoundary
}
it.index = i + 1
if it.states[i] == nil {
return ErrNilValue
return nil, ErrNilValue
}
return Deserialize(s, it.states[i])
return it.keys[i], Deserialize(s, it.states[i])
}
15 changes: 12 additions & 3 deletions state/iterator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,33 @@ func TestIterator(t *testing.T) {
},
}

keys := make([][]byte, len(cands))
states := make([][]byte, len(cands))
for i, cand := range cands {
keys[i] = []byte(cand.Address)
bytes, err := cand.Serialize()
r.NoError(err)
states[i] = bytes
}

iter := NewIterator(states)
_, err := NewIterator(nil, states)
r.Equal(err, ErrInConsistentLength)
_, err = NewIterator(keys, nil)
r.Equal(err, ErrInConsistentLength)

iter, err := NewIterator(keys, states)
r.NoError(err)
r.Equal(iter.Size(), len(states))

for _, cand := range cands {
c := &Candidate{}
err := iter.Next(c)
_, err := iter.Next(c)
r.NoError(err)

r.True(c.Equal(cand))
}

var noneExistCand Candidate
r.Equal(iter.Next(&noneExistCand), ErrOutOfBoundary)
_, err = iter.Next(&noneExistCand)
r.Equal(err, ErrOutOfBoundary)
}
11 changes: 9 additions & 2 deletions testutil/testdb/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,10 @@ func NewMockStateManagerWithoutHeightFunc(ctrl *gomock.Controller) *mock_chainma
if err != nil {
return 0, nil, err
}
var fk [][]byte
var fv [][]byte
if cfg.Keys == nil {
_, fv, err = kv.Filter(cfg.Namespace, func(k, v []byte) bool {
fk, fv, err = kv.Filter(cfg.Namespace, func(k, v []byte) bool {
return true
}, nil, nil)
if err != nil {
Expand All @@ -190,14 +191,20 @@ func NewMockStateManagerWithoutHeightFunc(ctrl *gomock.Controller) *mock_chainma
switch errors.Cause(err) {
case db.ErrNotExist, db.ErrBucketNotExist:
fv = append(fv, nil)
fk = append(fk, key)
case nil:
fv = append(fv, value)
fk = append(fk, key)
default:
return 0, nil, err
}
}
}
return 0, state.NewIterator(fv), nil
iter, err := state.NewIterator(fk, fv)
if err != nil {
return 0, nil, err
}
return 0, iter, nil
},
).AnyTimes()
// sm.EXPECT().Height().Return(uint64(0), nil).AnyTimes()
Expand Down

0 comments on commit b558929

Please sign in to comment.