Skip to content

Commit

Permalink
[API] increase api listener limit (#4374)
Browse files Browse the repository at this point in the history
  • Loading branch information
envestcc authored Sep 2, 2024
1 parent 66ef5cd commit 778b8ee
Show file tree
Hide file tree
Showing 12 changed files with 126 additions and 26 deletions.
3 changes: 3 additions & 0 deletions api/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ type Config struct {
BatchRequestLimit int `yaml:"batchRequestLimit"`
// WebsocketRateLimit is the maximum number of messages per second per client.
WebsocketRateLimit int `yaml:"websocketRateLimit"`
// ListenerLimit is the maximum number of listeners.
ListenerLimit int `yaml:"listenerLimit"`
}

// DefaultConfig is the default config
Expand All @@ -38,4 +40,5 @@ var DefaultConfig = Config{
RangeQueryLimit: 1000,
BatchRequestLimit: _defaultBatchRequestLimit,
WebsocketRateLimit: 5,
ListenerLimit: 5000,
}
48 changes: 48 additions & 0 deletions api/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package api

import (
"context"
"sync"
)

type (
streamContextKey struct{}

StreamContext struct {
listenerIDs map[string]struct{}
mutex sync.Mutex
}
)

func (sc *StreamContext) AddListener(id string) {
sc.mutex.Lock()
defer sc.mutex.Unlock()
sc.listenerIDs[id] = struct{}{}
}

func (sc *StreamContext) RemoveListener(id string) {
sc.mutex.Lock()
defer sc.mutex.Unlock()
delete(sc.listenerIDs, id)
}

func (sc *StreamContext) ListenerIDs() []string {
sc.mutex.Lock()
defer sc.mutex.Unlock()
ids := make([]string, 0, len(sc.listenerIDs))
for id := range sc.listenerIDs {
ids = append(ids, id)
}
return ids
}

func WithStreamContext(ctx context.Context) context.Context {
return context.WithValue(ctx, streamContextKey{}, &StreamContext{
listenerIDs: make(map[string]struct{}),
})
}

func StreamFromContext(ctx context.Context) (*StreamContext, bool) {
sc, ok := ctx.Value(streamContextKey{}).(*StreamContext)
return sc, ok
}
2 changes: 1 addition & 1 deletion api/coreservice.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ func newCoreService(
ap: actPool,
cfg: cfg,
registry: registry,
chainListener: NewChainListener(500),
chainListener: NewChainListener(cfg.ListenerLimit),
gs: gasstation.NewGasStation(chain, dao, cfg.GasStation),
readCache: NewReadCache(),
getBlockTime: getBlockTime,
Expand Down
16 changes: 10 additions & 6 deletions api/grpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -573,15 +573,17 @@ func (svr *gRPCHandler) StreamBlocks(_ *iotexapi.StreamBlocksRequest, stream iot
errChan := make(chan error)
defer close(errChan)
chainListener := svr.coreService.ChainListener()
if _, err := chainListener.AddResponder(NewGRPCBlockListener(
id, err := chainListener.AddResponder(NewGRPCBlockListener(
func(resp interface{}) (int, error) {
return 0, stream.Send(resp.(*iotexapi.StreamBlocksResponse))
},
errChan,
)); err != nil {
))
if err != nil {
return status.Error(codes.Internal, err.Error())
}
err := <-errChan
err = <-errChan
chainListener.RemoveResponder(id)
if err != nil {
return status.Error(codes.Aborted, err.Error())
}
Expand All @@ -596,16 +598,18 @@ func (svr *gRPCHandler) StreamLogs(in *iotexapi.StreamLogsRequest, stream iotexa
errChan := make(chan error)
defer close(errChan)
chainListener := svr.coreService.ChainListener()
if _, err := chainListener.AddResponder(NewGRPCLogListener(
id, err := chainListener.AddResponder(NewGRPCLogListener(
logfilter.NewLogFilter(in.GetFilter()),
func(in interface{}) (int, error) {
return 0, stream.Send(in.(*iotexapi.StreamLogsResponse))
},
errChan,
)); err != nil {
))
if err != nil {
return status.Error(codes.Internal, err.Error())
}
err := <-errChan
err = <-errChan
chainListener.RemoveResponder(id)
if err != nil {
return status.Error(codes.Aborted, err.Error())
}
Expand Down
6 changes: 6 additions & 0 deletions api/grpcserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,9 @@ func TestGrpcServer_StreamBlocks(t *testing.T) {
}()
return "", nil
})
listener.EXPECT().RemoveResponder(gomock.Any()).DoAndReturn(func(string) (bool, error) {
return true, nil
})
core.EXPECT().ChainListener().Return(listener)
err := grpcSvr.StreamBlocks(&iotexapi.StreamBlocksRequest{}, nil)
require.NoError(err)
Expand Down Expand Up @@ -390,6 +393,9 @@ func TestGrpcServer_StreamLogs(t *testing.T) {
}()
return "", nil
})
listener.EXPECT().RemoveResponder(gomock.Any()).DoAndReturn(func(string) (bool, error) {
return true, nil
})
core.EXPECT().ChainListener().Return(listener)
err := grpcSvr.StreamLogs(&iotexapi.StreamLogsRequest{Filter: &iotexapi.LogsFilter{}}, nil)
require.NoError(err)
Expand Down
3 changes: 3 additions & 0 deletions api/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ func (cl *chainListener) Stop() error {
return nil
})
cl.streamMap.Reset()
apiLimitMtcs.WithLabelValues("listener").Set(float64(cl.streamMap.Count()))
return nil
}

Expand Down Expand Up @@ -105,6 +106,7 @@ func (cl *chainListener) AddResponder(responder apitypes.Responder) (string, err
}

cl.streamMap.Set(listenerID, responder)
apiLimitMtcs.WithLabelValues("listener").Set(float64(cl.streamMap.Count()))
return listenerID, nil
}

Expand All @@ -122,6 +124,7 @@ func (cl *chainListener) RemoveResponder(listenerID string) (bool, error) {
return false, errListenerNotFound
}
r.Exit()
apiLimitMtcs.WithLabelValues("listener").Set(float64(cl.streamMap.Count() - 1))
return cl.streamMap.Delete(listenerID), nil
}

Expand Down
14 changes: 14 additions & 0 deletions api/metrics.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package api

import "github.com/prometheus/client_golang/prometheus"

var (
apiLimitMtcs = prometheus.NewGaugeVec(prometheus.GaugeOpts{
Name: "iotex_api_limit_metrics",
Help: "api limit metrics.",
}, []string{"limit"})
)

func init() {
prometheus.MustRegister(apiLimitMtcs)
}
2 changes: 1 addition & 1 deletion api/serverV2.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func NewServerV2(
wrappedWeb3Handler := otelhttp.NewHandler(newHTTPHandler(web3Handler), "web3.jsonrpc")

limiter := rate.NewLimiter(rate.Limit(cfg.WebsocketRateLimit), 1)
wrappedWebsocketHandler := otelhttp.NewHandler(NewWebsocketHandler(web3Handler, limiter), "web3.websocket")
wrappedWebsocketHandler := otelhttp.NewHandler(NewWebsocketHandler(coreAPI, web3Handler, limiter), "web3.websocket")

return &ServerV2{
core: coreAPI,
Expand Down
2 changes: 1 addition & 1 deletion api/serverV2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func TestServerV2(t *testing.T) {
core: core,
grpcServer: NewGRPCServer(core, testutil.RandomPort()),
httpSvr: NewHTTPServer("", testutil.RandomPort(), newHTTPHandler(web3Handler)),
websocketSvr: NewHTTPServer("", testutil.RandomPort(), NewWebsocketHandler(web3Handler, nil)),
websocketSvr: NewHTTPServer("", testutil.RandomPort(), NewWebsocketHandler(core, web3Handler, nil)),
}
ctx := context.Background()

Expand Down
19 changes: 13 additions & 6 deletions api/web3server.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ var (
errInvalidBlock = errors.New("invalid block")
errUnsupportedAction = errors.New("the type of action is not supported")
errMsgBatchTooLarge = errors.New("batch too large")
errHTTPNotSupported = errors.New("http not supported")

_pendingBlockNumber = "pending"
_latestBlockNumber = "latest"
Expand Down Expand Up @@ -224,7 +225,11 @@ func (svr *web3Handler) handleWeb3Req(ctx context.Context, web3Req *gjson.Result
case "eth_newBlockFilter":
res, err = svr.newBlockFilter()
case "eth_subscribe":
res, err = svr.subscribe(web3Req, writer)
sc, ok := StreamFromContext(ctx)
if !ok {
return errHTTPNotSupported
}
res, err = svr.subscribe(sc, web3Req, writer)
case "eth_unsubscribe":
res, err = svr.unsubscribe(web3Req)
//TODO: enable debug api after archive mode is supported
Expand Down Expand Up @@ -924,35 +929,36 @@ func (svr *web3Handler) getFilterLogs(in *gjson.Result) (interface{}, error) {
return svr.getLogsWithFilter(from, to, filterObj.Address, filterObj.Topics)
}

func (svr *web3Handler) subscribe(in *gjson.Result, writer apitypes.Web3ResponseWriter) (interface{}, error) {
func (svr *web3Handler) subscribe(ctx *StreamContext, in *gjson.Result, writer apitypes.Web3ResponseWriter) (interface{}, error) {
subscription := in.Get("params.0")
if !subscription.Exists() {
return nil, errInvalidFormat
}
switch subscription.String() {
case "newHeads":
return svr.streamBlocks(writer)
return svr.streamBlocks(ctx, writer)
case "logs":
filter, err := parseLogRequest(in.Get("params.1"))
if err != nil {
return nil, err
}
return svr.streamLogs(filter, writer)
return svr.streamLogs(ctx, filter, writer)
default:
return nil, errInvalidFormat
}
}

func (svr *web3Handler) streamBlocks(writer apitypes.Web3ResponseWriter) (interface{}, error) {
func (svr *web3Handler) streamBlocks(ctx *StreamContext, writer apitypes.Web3ResponseWriter) (interface{}, error) {
chainListener := svr.coreService.ChainListener()
streamID, err := chainListener.AddResponder(NewWeb3BlockListener(writer.Write))
if err != nil {
return nil, err
}
ctx.AddListener(streamID)
return streamID, nil
}

func (svr *web3Handler) streamLogs(filterObj *filterObject, writer apitypes.Web3ResponseWriter) (interface{}, error) {
func (svr *web3Handler) streamLogs(ctx *StreamContext, filterObj *filterObject, writer apitypes.Web3ResponseWriter) (interface{}, error) {
filter, err := newLogFilterFrom(filterObj.Address, filterObj.Topics)
if err != nil {
return nil, err
Expand All @@ -962,6 +968,7 @@ func (svr *web3Handler) streamLogs(filterObj *filterObject, writer apitypes.Web3
if err != nil {
return nil, err
}
ctx.AddListener(streamID)
return streamID, nil
}

Expand Down
15 changes: 10 additions & 5 deletions api/web3server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1125,34 +1125,39 @@ func TestSubscribe(t *testing.T) {

t.Run("newHeads subscription", func(t *testing.T) {
in := gjson.Parse(`{"params":["newHeads"]}`)
ret, err := web3svr.subscribe(&in, writer)
sc, _ := StreamFromContext(WithStreamContext(context.Background()))
ret, err := web3svr.subscribe(sc, &in, writer)
require.NoError(err)
require.Equal("streamid_1", ret.(string))
})

t.Run("logs subscription", func(t *testing.T) {
in := gjson.Parse(`{"params":["logs",{"fromBlock":"1","fromBlock":"2","address":["0x0000000000000000000000000000000000000001"],"topics":[["0x5f746f70696331"]]}]}`)
ret, err := web3svr.subscribe(&in, writer)
sc, _ := StreamFromContext(WithStreamContext(context.Background()))
ret, err := web3svr.subscribe(sc, &in, writer)
require.NoError(err)
require.Equal("streamid_1", ret.(string))
})

t.Run("logs topic not array", func(t *testing.T) {
in := gjson.Parse(`{"params":["logs",{"fromBlock":"1","fromBlock":"2","address":["0x0000000000000000000000000000000000000001"],"topics":["0x5f746f70696331"]}]}`)
ret, err := web3svr.subscribe(&in, writer)
sc, _ := StreamFromContext(WithStreamContext(context.Background()))
ret, err := web3svr.subscribe(sc, &in, writer)
require.NoError(err)
require.Equal("streamid_1", ret.(string))
})

t.Run("nil params", func(t *testing.T) {
inNil := gjson.Parse(`{"params":[]}`)
_, err := web3svr.subscribe(&inNil, writer)
sc, _ := StreamFromContext(WithStreamContext(context.Background()))
_, err := web3svr.subscribe(sc, &inNil, writer)
require.EqualError(err, errInvalidFormat.Error())
})

t.Run("nil logs", func(t *testing.T) {
inNil := gjson.Parse(`{"params":["logs"]}`)
_, err := web3svr.subscribe(&inNil, writer)
sc, _ := StreamFromContext(WithStreamContext(context.Background()))
_, err := web3svr.subscribe(sc, &inNil, writer)
require.EqualError(err, errInvalidFormat.Error())
})
}
Expand Down
22 changes: 16 additions & 6 deletions api/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ const (

// WebsocketHandler handles requests from websocket protocol
type WebsocketHandler struct {
msgHandler Web3Handler
limiter *rate.Limiter
coreService CoreService
msgHandler Web3Handler
limiter *rate.Limiter
}

var upgrader = websocket.Upgrader{
Expand Down Expand Up @@ -75,14 +76,15 @@ func (c *safeWebsocketConn) SetWriteDeadline(t time.Time) error {
}

// NewWebsocketHandler creates a new websocket handler
func NewWebsocketHandler(web3Handler Web3Handler, limiter *rate.Limiter) *WebsocketHandler {
func NewWebsocketHandler(coreService CoreService, web3Handler Web3Handler, limiter *rate.Limiter) *WebsocketHandler {
if limiter == nil {
// set the limiter to the maximum possible rate
limiter = rate.NewLimiter(rate.Limit(math.MaxFloat64), 1)
}
return &WebsocketHandler{
msgHandler: web3Handler,
limiter: limiter,
msgHandler: web3Handler,
limiter: limiter,
coreService: coreService,
}
}

Expand Down Expand Up @@ -112,10 +114,18 @@ func (wsSvr *WebsocketHandler) handleConnection(ctx context.Context, ws *websock
return nil
})

ctx, cancel := context.WithCancel(ctx)
ctx, cancel := context.WithCancel(WithStreamContext(ctx))
safeWs := &safeWebsocketConn{ws: ws}
go ping(ctx, safeWs, cancel)

defer func() {
// clean up the stream context
sc, _ := StreamFromContext(ctx)
for _, id := range sc.ListenerIDs() {
wsSvr.coreService.ChainListener().RemoveResponder(id)
}
}()

for {
select {
case <-ctx.Done():
Expand Down

0 comments on commit 778b8ee

Please sign in to comment.