Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem: mempool iteration is not thread safe #699

Merged
merged 8 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Ref: https://keepachangelog.com/en/1.0.0/
### Bug Fixes

* (x/bank) [#20028](https://github.com/cosmos/cosmos-sdk/pull/20028) Align query with multi denoms for send-enabled.
* (baseapp) [#699](https://github.com/crypto-org-chain/cosmos-sdk/pull/699) Fix data race in mempool iteration.

## [Unreleased-Upstream]

Expand Down
37 changes: 24 additions & 13 deletions baseapp/abci_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,14 +284,18 @@ func (h *DefaultProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHan
return &abci.ResponsePrepareProposal{Txs: h.txSelector.SelectedTxs(ctx)}, nil
}

iterator := h.mempool.Select(ctx, req.Txs)
selectedTxsSignersSeqs := make(map[string]uint64)
var selectedTxsNums int
for iterator != nil {
memTx := iterator.Tx()
var (
resError error
selectedTxsNums int
invalidTxs []sdk.Tx // invalid txs to be removed after the iteration
)
h.mempool.SelectBy(ctx, req.Txs, func(memTx mempool.Tx) bool {
signerData, err := h.signerExtAdapter.GetSigners(memTx.Tx)
if err != nil {
return nil, err
// propagate the error to the caller
resError = err
return false
}

// If the signers aren't in selectedTxsSignersSeqs then we haven't seen them before
Expand All @@ -315,8 +319,7 @@ func (h *DefaultProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHan
txSignersSeqs[signer.Signer.String()] = signer.Sequence
}
if !shouldAdd {
iterator = iterator.Next()
continue
return true
}

// NOTE: Since transaction verification was already executed in CheckTx,
Expand All @@ -325,14 +328,11 @@ func (h *DefaultProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHan
// check again.
txBz, err := h.txVerifier.PrepareProposalVerifyTx(memTx.Tx)
if err != nil {
err := h.mempool.Remove(memTx.Tx)
if err != nil && !errors.Is(err, mempool.ErrTxNotFound) {
return nil, err
}
invalidTxs = append(invalidTxs, memTx.Tx)
} else {
stop := h.txSelector.SelectTxForProposal(ctx, uint64(req.MaxTxBytes), maxBlockGas, memTx.Tx, txBz, memTx.GasWanted)
if stop {
break
return false
}

txsLen := len(h.txSelector.SelectedTxs(ctx))
Expand All @@ -353,7 +353,18 @@ func (h *DefaultProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHan
selectedTxsNums = txsLen
}

iterator = iterator.Next()
return true
})

if resError != nil {
return nil, resError
}

for _, tx := range invalidTxs {
err := h.mempool.Remove(tx)
if err != nil && !errors.Is(err, mempool.ErrTxNotFound) {
return nil, err
}
}

return &abci.ResponsePrepareProposal{Txs: h.txSelector.SelectedTxs(ctx)}, nil
Expand Down
6 changes: 4 additions & 2 deletions types/mempool/mempool.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ type Mempool interface {
InsertWithGasWanted(context.Context, sdk.Tx, uint64) error

// Select returns an Iterator over the app-side mempool. If txs are specified,
// then they shall be incorporated into the Iterator. The Iterator must
// closed by the caller.
// then they shall be incorporated into the Iterator. The Iterator is not thread-safe to use.
Select(context.Context, [][]byte) Iterator

// SelectBy use callback to iterate over the mempool.
SelectBy(context.Context, [][]byte, func(Tx) bool)

// CountTx returns the number of transactions currently in the mempool.
CountTx() int

Expand Down
1 change: 1 addition & 0 deletions types/mempool/noop.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@ type NoOpMempool struct{}
func (NoOpMempool) Insert(context.Context, sdk.Tx) error { return nil }
func (NoOpMempool) InsertWithGasWanted(context.Context, sdk.Tx, uint64) error { return nil }
func (NoOpMempool) Select(context.Context, [][]byte) Iterator { return nil }
func (NoOpMempool) SelectBy(context.Context, [][]byte, func(Tx) bool) {}
func (NoOpMempool) CountTx() int { return 0 }
func (NoOpMempool) Remove(sdk.Tx) error { return nil }
16 changes: 15 additions & 1 deletion types/mempool/priority_nonce.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,9 +361,13 @@ func (i *PriorityNonceIterator[C]) Tx() Tx {
//
// NOTE: It is not safe to use this iterator while removing transactions from
// the underlying mempool.
func (mp *PriorityNonceMempool[C]) Select(_ context.Context, _ [][]byte) Iterator {
func (mp *PriorityNonceMempool[C]) Select(ctx context.Context, txs [][]byte) Iterator {
mp.mtx.Lock()
defer mp.mtx.Unlock()
return mp.doSelect(ctx, txs)
}

func (mp *PriorityNonceMempool[C]) doSelect(_ context.Context, _ [][]byte) Iterator {
if mp.priorityIndex.Len() == 0 {
return nil
}
Expand All @@ -378,6 +382,16 @@ func (mp *PriorityNonceMempool[C]) Select(_ context.Context, _ [][]byte) Iterato
return iterator.iteratePriority()
}

func (mp *PriorityNonceMempool[C]) SelectBy(ctx context.Context, txs [][]byte, callback func(Tx) bool) {
mp.mtx.Lock()
defer mp.mtx.Unlock()

iter := mp.doSelect(ctx, txs)
for iter != nil && callback(iter.Tx()) {
iter = iter.Next()
}
}

type reorderKey[C comparable] struct {
deleteKey txMeta[C]
insertKey txMeta[C]
Expand Down
85 changes: 85 additions & 0 deletions types/mempool/priority_nonce_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package mempool_test

import (
"context"
"fmt"
"math"
"math/rand"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -396,6 +398,89 @@ func (s *MempoolTestSuite) TestIterator() {
}
}

func (s *MempoolTestSuite) TestIteratorConcurrency() {
t := s.T()
ctx := sdk.NewContext(nil, cmtproto.Header{}, false, log.NewNopLogger())
accounts := simtypes.RandomAccounts(rand.New(rand.NewSource(0)), 2)
sa := accounts[0].Address
sb := accounts[1].Address

tests := []struct {
txs []txSpec
fail bool
}{
{
txs: []txSpec{
{p: 20, n: 1, a: sa},
{p: 15, n: 1, a: sb},
{p: 6, n: 2, a: sa},
{p: 21, n: 4, a: sa},
{p: 8, n: 2, a: sb},
},
},
{
txs: []txSpec{
{p: 20, n: 1, a: sa},
{p: 15, n: 1, a: sb},
{p: 6, n: 2, a: sa},
{p: 21, n: 4, a: sa},
{p: math.MinInt64, n: 2, a: sb},
},
},
}

for i, tt := range tests {
t.Run(fmt.Sprintf("case %d", i), func(t *testing.T) {
pool := mempool.DefaultPriorityMempool()

// create test txs and insert into mempool
for i, ts := range tt.txs {
tx := testTx{id: i, priority: int64(ts.p), nonce: uint64(ts.n), address: ts.a}
c := ctx.WithPriority(tx.priority)
err := pool.Insert(c, tx)
require.NoError(t, err)
}

// iterate through txs
stdCtx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()

id := len(tt.txs)
for {
select {
case <-stdCtx.Done():
return
default:
id++
tx := testTx{id: id, priority: int64(rand.Intn(100)), nonce: uint64(id), address: sa}
c := ctx.WithPriority(tx.priority)
err := pool.Insert(c, tx)
require.NoError(t, err)
}
}
}()

var i int
pool.SelectBy(ctx, nil, func(memTx mempool.Tx) bool {
tx := memTx.Tx.(testTx)
if tx.id < len(tt.txs) {
require.Equal(t, tt.txs[tx.id].p, int(tx.priority))
require.Equal(t, tt.txs[tx.id].n, int(tx.nonce))
require.Equal(t, tt.txs[tx.id].a, tx.address)
i++
}
return i < len(tt.txs)
})
require.Equal(t, i, len(tt.txs))
cancel()
wg.Wait()
})
}
}

func (s *MempoolTestSuite) TestPriorityTies() {
ctx := sdk.NewContext(nil, cmtproto.Header{}, false, log.NewNopLogger())
accounts := simtypes.RandomAccounts(rand.New(rand.NewSource(0)), 3)
Expand Down
16 changes: 15 additions & 1 deletion types/mempool/sender_nonce.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,13 @@ func (snm *SenderNonceMempool) Insert(ctx context.Context, tx sdk.Tx) error {
//
// NOTE: It is not safe to use this iterator while removing transactions from
// the underlying mempool.
func (snm *SenderNonceMempool) Select(_ context.Context, _ [][]byte) Iterator {
func (snm *SenderNonceMempool) Select(ctx context.Context, txs [][]byte) Iterator {
snm.mtx.Lock()
defer snm.mtx.Unlock()
return snm.doSelect(ctx, txs)
}

func (snm *SenderNonceMempool) doSelect(_ context.Context, _ [][]byte) Iterator {
var senders []string

senderCursors := make(map[string]*skiplist.Element)
Expand Down Expand Up @@ -199,6 +203,16 @@ func (snm *SenderNonceMempool) Select(_ context.Context, _ [][]byte) Iterator {
return iter.Next()
}

func (snm *SenderNonceMempool) SelectBy(ctx context.Context, txs [][]byte, callback func(Tx) bool) {
snm.mtx.Lock()
defer snm.mtx.Unlock()

iter := snm.doSelect(ctx, txs)
for iter != nil && callback(iter.Tx()) {
iter = iter.Next()
}
}

// CountTx returns the total count of txs in the mempool.
func (snm *SenderNonceMempool) CountTx() int {
snm.mtx.Lock()
Expand Down
Loading