From 09508f5302f1e443140c3361be554cfb8e78070c Mon Sep 17 00:00:00 2001 From: wizeguyy Date: Mon, 7 Oct 2024 10:49:51 -0500 Subject: [PATCH] Configurable subscription limit --- cmd/utils/cmd.go | 2 +- cmd/utils/flags.go | 7 ++++ quai/backend.go | 10 +++--- quai/filters/api.go | 56 ++++++++++++++++++++++-------- quai/filters/filter_system_test.go | 12 +++---- 5 files changed, 62 insertions(+), 25 deletions(-) diff --git a/cmd/utils/cmd.go b/cmd/utils/cmd.go index 394e1e4e5c..239e5d2d88 100644 --- a/cmd/utils/cmd.go +++ b/cmd/utils/cmd.go @@ -134,7 +134,7 @@ func makeFullNode(p2p quai.NetworkingAPI, nodeLocation common.Location, slicesRu // The second return value is the full node instance, which may be nil if the // node is running as a light client. func RegisterQuaiService(stack *node.Node, p2p quai.NetworkingAPI, cfg quaiconfig.Config, nodeCtx int, currentExpansionNumber uint8, startingExpansionNumber uint64, genesisBlock *types.WorkObject, logger *log.Logger) (quaiapi.Backend, error) { - backend, err := quai.New(stack, p2p, &cfg, nodeCtx, currentExpansionNumber, startingExpansionNumber, genesisBlock, logger) + backend, err := quai.New(stack, p2p, &cfg, nodeCtx, currentExpansionNumber, startingExpansionNumber, genesisBlock, logger, viper.GetInt(WSMaxSubsFlag.Name)) if err != nil { Fatalf("Failed to register the Quai service: %v", err) } diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index b9583d022c..0ac05c2b3d 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -144,6 +144,7 @@ var RPCFlags = []Flag{ HTTPPortStartFlag, WSEnabledFlag, WSListenAddrFlag, + WSMaxSubsFlag, WSApiFlag, WSAllowedOriginsFlag, WSPathPrefixFlag, @@ -605,6 +606,12 @@ var ( Usage: "WS-RPC server listening interface" + generateEnvDoc(c_RPCFlagPrefix+"ws-addr"), } + WSMaxSubsFlag = Flag{ + Name: c_RPCFlagPrefix + "ws-max-subs", + Value: 1000, + Usage: "maximum concurrent subscribers to the WS-RPC server", + } + WSApiFlag = Flag{ Name: c_RPCFlagPrefix + "ws-api", Value: "", diff --git a/quai/backend.go b/quai/backend.go index f8ac0e9a3c..bf931164dd 100644 --- a/quai/backend.go +++ b/quai/backend.go @@ -75,12 +75,13 @@ type Quai struct { lock sync.RWMutex // Protects the variadic fields (e.g. gas price and etherbase) - logger *log.Logger + logger *log.Logger + maxWsSubs int } // New creates a new Quai object (including the // initialisation of the common Quai object) -func New(stack *node.Node, p2p NetworkingAPI, config *quaiconfig.Config, nodeCtx int, currentExpansionNumber uint8, startingExpansionNumber uint64, genesisBlock *types.WorkObject, logger *log.Logger) (*Quai, error) { +func New(stack *node.Node, p2p NetworkingAPI, config *quaiconfig.Config, nodeCtx int, currentExpansionNumber uint8, startingExpansionNumber uint64, genesisBlock *types.WorkObject, logger *log.Logger, maxWsSubs int) (*Quai, error) { // Ensure configuration values are compatible and sane if config.Miner.GasPrice == nil || config.Miner.GasPrice.Cmp(common.Big0) <= 0 { logger.WithFields(log.Fields{ @@ -160,6 +161,7 @@ func New(stack *node.Node, p2p NetworkingAPI, config *quaiconfig.Config, nodeCtx primaryCoinbase: config.Miner.PrimaryCoinbase, bloomRequests: make(chan chan *bloombits.Retrieval), logger: logger, + maxWsSubs: maxWsSubs, } // Copy the chainConfig @@ -293,12 +295,12 @@ func (s *Quai) APIs() []rpc.API { }, { Namespace: "eth", Version: "1.0", - Service: filters.NewPublicFilterAPI(s.APIBackend, 5*time.Minute), + Service: filters.NewPublicFilterAPI(s.APIBackend, 5*time.Minute, s.maxWsSubs), Public: true, }, { Namespace: "quai", Version: "1.0", - Service: filters.NewPublicFilterAPI(s.APIBackend, 5*time.Minute), + Service: filters.NewPublicFilterAPI(s.APIBackend, 5*time.Minute, s.maxWsSubs), Public: true, }, { Namespace: "admin", diff --git a/quai/filters/api.go b/quai/filters/api.go index 4502ac289f..29ae4890ba 100644 --- a/quai/filters/api.go +++ b/quai/filters/api.go @@ -55,24 +55,28 @@ type filter struct { // PublicFilterAPI offers support to create and manage filters. This will allow external clients to retrieve various // information related to the Quai protocol such as blocks, transactions and logs. type PublicFilterAPI struct { - backend Backend - mux *event.TypeMux - quit chan struct{} - chainDb ethdb.Database - events *EventSystem - filtersMu sync.Mutex - filters map[rpc.ID]*filter - timeout time.Duration + backend Backend + mux *event.TypeMux + quit chan struct{} + chainDb ethdb.Database + events *EventSystem + filtersMu sync.Mutex + filters map[rpc.ID]*filter + timeout time.Duration + subscriptionLimit int + activeSubscriptions int } // NewPublicFilterAPI returns a new PublicFilterAPI instance. -func NewPublicFilterAPI(backend Backend, timeout time.Duration) *PublicFilterAPI { +func NewPublicFilterAPI(backend Backend, timeout time.Duration, subscriptionLimit int) *PublicFilterAPI { api := &PublicFilterAPI{ - backend: backend, - chainDb: backend.ChainDb(), - events: NewEventSystem(backend), - filters: make(map[rpc.ID]*filter), - timeout: timeout, + backend: backend, + chainDb: backend.ChainDb(), + events: NewEventSystem(backend), + filters: make(map[rpc.ID]*filter), + timeout: timeout, + subscriptionLimit: subscriptionLimit, + activeSubscriptions: 0, } go api.timeoutLoop(timeout) @@ -166,6 +170,10 @@ func (api *PublicFilterAPI) NewPendingTransactionFilter() rpc.ID { // NewPendingTransactions creates a subscription that is triggered each time a transaction // enters the transaction pool and was signed from one of the transactions this nodes manages. func (api *PublicFilterAPI) NewPendingTransactions(ctx context.Context) (*rpc.Subscription, error) { + if api.activeSubscriptions >= api.subscriptionLimit { + return &rpc.Subscription{}, errors.New("too many subscribers") + } + notifier, supported := rpc.NotifierFromContext(ctx) if !supported { return &rpc.Subscription{}, rpc.ErrNotificationsUnsupported @@ -181,7 +189,9 @@ func (api *PublicFilterAPI) NewPendingTransactions(ctx context.Context) (*rpc.Su "stacktrace": string(debug.Stack()), }).Fatal("Go-Quai Panicked") } + api.activeSubscriptions -= 1 }() + api.activeSubscriptions += 1 txHashes := make(chan []common.Hash, 128) pendingTxSub := api.events.SubscribePendingTxs(txHashes) @@ -251,6 +261,10 @@ func (api *PublicFilterAPI) NewBlockFilter() rpc.ID { // NewHeads send a notification each time a new (header) block is appended to the chain. func (api *PublicFilterAPI) NewHeads(ctx context.Context) (*rpc.Subscription, error) { + if api.activeSubscriptions >= api.subscriptionLimit { + return &rpc.Subscription{}, errors.New("too many subscribers") + } + notifier, supported := rpc.NotifierFromContext(ctx) if !supported { return &rpc.Subscription{}, rpc.ErrNotificationsUnsupported @@ -266,7 +280,9 @@ func (api *PublicFilterAPI) NewHeads(ctx context.Context) (*rpc.Subscription, er "stacktrace": string(debug.Stack()), }).Fatal("Go-Quai Panicked") } + api.activeSubscriptions -= 1 }() + api.activeSubscriptions += 1 headers := make(chan *types.WorkObject) headersSub := api.events.SubscribeNewHeads(headers) @@ -291,6 +307,10 @@ func (api *PublicFilterAPI) NewHeads(ctx context.Context) (*rpc.Subscription, er // Accesses send a notification each time the specified address is accessed func (api *PublicFilterAPI) Accesses(ctx context.Context, addr common.Address) (*rpc.Subscription, error) { + if api.activeSubscriptions >= api.subscriptionLimit { + return &rpc.Subscription{}, errors.New("too many subscribers") + } + notifier, supported := rpc.NotifierFromContext(ctx) if !supported { return &rpc.Subscription{}, rpc.ErrNotificationsUnsupported @@ -306,7 +326,9 @@ func (api *PublicFilterAPI) Accesses(ctx context.Context, addr common.Address) ( "stacktrace": string(debug.Stack()), }).Fatal("Go-Quai Panicked") } + api.activeSubscriptions -= 1 }() + api.activeSubscriptions += 1 headers := make(chan *types.WorkObject) headersSub := api.events.SubscribeNewHeads(headers) @@ -348,6 +370,10 @@ func (api *PublicFilterAPI) Accesses(ctx context.Context, addr common.Address) ( // Logs creates a subscription that fires for all new log that match the given filter criteria. func (api *PublicFilterAPI) Logs(ctx context.Context, crit FilterCriteria) (*rpc.Subscription, error) { + if api.activeSubscriptions >= api.subscriptionLimit { + return &rpc.Subscription{}, errors.New("too many subscribers") + } + notifier, supported := rpc.NotifierFromContext(ctx) if !supported { return &rpc.Subscription{}, rpc.ErrNotificationsUnsupported @@ -371,7 +397,9 @@ func (api *PublicFilterAPI) Logs(ctx context.Context, crit FilterCriteria) (*rpc "stacktrace": string(debug.Stack()), }).Fatal("Go-Quai Panicked") } + api.activeSubscriptions -= 1 }() + api.activeSubscriptions += 1 for { select { case logs := <-matchedLogs: diff --git a/quai/filters/filter_system_test.go b/quai/filters/filter_system_test.go index 9fa83b0c32..cf490b9306 100644 --- a/quai/filters/filter_system_test.go +++ b/quai/filters/filter_system_test.go @@ -189,7 +189,7 @@ func TestPendingTxFilter(t *testing.T) { var ( db = rawdb.NewMemoryDatabase(log.Global) backend = &testBackend{db: db} - api = NewPublicFilterAPI(backend, deadline) + api = NewPublicFilterAPI(backend, deadline, 1) to = common.HexToAddress("0x0094f5ea0ba39494ce83a213fffba74279579268", common.Location{0, 0}) @@ -311,7 +311,7 @@ func TestLogFilterCreation(t *testing.T) { var ( db = rawdb.NewMemoryDatabase(log.Global) backend = &testBackend{db: db} - api = NewPublicFilterAPI(backend, deadline) + api = NewPublicFilterAPI(backend, deadline, 1) testCases = []struct { crit FilterCriteria @@ -354,7 +354,7 @@ func TestInvalidLogFilterCreation(t *testing.T) { var ( db = rawdb.NewMemoryDatabase(log.Global) backend = &testBackend{db: db} - api = NewPublicFilterAPI(backend, deadline) + api = NewPublicFilterAPI(backend, deadline, 1) ) // different situations where log filter creation should fail. @@ -376,7 +376,7 @@ func TestInvalidGetLogsRequest(t *testing.T) { var ( db = rawdb.NewMemoryDatabase(log.Global) backend = &testBackend{db: db} - api = NewPublicFilterAPI(backend, deadline) + api = NewPublicFilterAPI(backend, deadline, 1) blockHash = common.HexToHash("0x1111111111111111111111111111111111111111111111111111111111111111") ) @@ -401,7 +401,7 @@ func TestLogFilter(t *testing.T) { var ( db = rawdb.NewMemoryDatabase(log.Global) backend = &testBackend{db: db} - api = NewPublicFilterAPI(backend, deadline) + api = NewPublicFilterAPI(backend, deadline, 1) firstAddr = common.HexToAddressBytes("0x0011111111111111111111111111111111111111") secondAddr = common.HexToAddressBytes("0x0022222222222222222222222222222222222222") @@ -514,7 +514,7 @@ func TestPendingLogsSubscription(t *testing.T) { var ( db = rawdb.NewMemoryDatabase(log.Global) backend = &testBackend{db: db} - api = NewPublicFilterAPI(backend, deadline) + api = NewPublicFilterAPI(backend, deadline, 1) firstAddr = common.HexToAddressBytes("0x0011111111111111111111111111111111111111") secondAddr = common.HexToAddressBytes("0x0022222222222222222222222222222222222222")