From 77ec5d03cc3cac30a647be4fe92cf99281f11cdd Mon Sep 17 00:00:00 2001 From: Pedro Gomes Date: Mon, 23 Dec 2024 12:02:29 +0000 Subject: [PATCH] Refactor the subscriptions api to meet standard mount approach (#926) * Tweaking subscriptions * return gone when endpoint deprecated * update metrics naming structure * fixing race condition on tests * add utils.gone handler * minor naming * remove closer interface usage * adding thorclient.Subscription.Unsubscribe errors --------- Co-authored-by: tony --- api/accounts/accounts.go | 12 +-- api/metrics.go | 3 +- api/subscriptions/pending_tx_test.go | 88 +++++++++++++++++++- api/subscriptions/subscriptions.go | 120 ++++++++++++++------------- api/utils/http.go | 7 ++ thorclient/common/common.go | 2 +- thorclient/wsclient/client.go | 13 ++- 7 files changed, 175 insertions(+), 70 deletions(-) diff --git a/api/accounts/accounts.go b/api/accounts/accounts.go index 22698bdbd..4e78e8711 100644 --- a/api/accounts/accounts.go +++ b/api/accounts/accounts.go @@ -171,9 +171,6 @@ func (a *Accounts) handleGetStorage(w http.ResponseWriter, req *http.Request) er } func (a *Accounts) handleCallContract(w http.ResponseWriter, req *http.Request) error { - if !a.enabledDeprecated { - return utils.HTTPError(nil, http.StatusGone) - } callData := &CallData{} if err := utils.ParseJSON(req.Body, &callData); err != nil { return utils.BadRequest(errors.WithMessage(err, "body")) @@ -378,13 +375,18 @@ func (a *Accounts) Mount(root *mux.Router, pathPrefix string) { Methods("GET"). Name("GET /accounts/{address}/storage"). HandlerFunc(utils.WrapHandlerFunc(a.handleGetStorage)) + // These two methods are currently deprecated + callContractHandler := utils.HandleGone + if a.enabledDeprecated { + callContractHandler = a.handleCallContract + } sub.Path(""). Methods(http.MethodPost). Name("POST /accounts"). - HandlerFunc(utils.WrapHandlerFunc(a.handleCallContract)) + HandlerFunc(utils.WrapHandlerFunc(callContractHandler)) sub.Path("/{address}"). Methods(http.MethodPost). Name("POST /accounts/{address}"). - HandlerFunc(utils.WrapHandlerFunc(a.handleCallContract)) + HandlerFunc(utils.WrapHandlerFunc(callContractHandler)) } diff --git a/api/metrics.go b/api/metrics.go index 57631be71..dd6f893f2 100644 --- a/api/metrics.go +++ b/api/metrics.go @@ -75,9 +75,8 @@ func metricsMiddleware(next http.Handler) http.Handler { if rt != nil && rt.GetName() != "" { enabled = true name = rt.GetName() - if strings.HasPrefix(name, "subscriptions") { + if strings.HasPrefix(name, "WS") { subscription = true - name = "WS " + r.URL.Path } } diff --git a/api/subscriptions/pending_tx_test.go b/api/subscriptions/pending_tx_test.go index 00e6a0140..f7fdbf2bd 100644 --- a/api/subscriptions/pending_tx_test.go +++ b/api/subscriptions/pending_tx_test.go @@ -7,16 +7,22 @@ package subscriptions import ( "math/big" + "net/http" + "net/http/httptest" "testing" "time" + "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/vechain/thor/v2/api/utils" "github.com/vechain/thor/v2/block" "github.com/vechain/thor/v2/chain" "github.com/vechain/thor/v2/genesis" "github.com/vechain/thor/v2/muxdb" "github.com/vechain/thor/v2/packer" "github.com/vechain/thor/v2/state" + "github.com/vechain/thor/v2/test/datagen" "github.com/vechain/thor/v2/thor" "github.com/vechain/thor/v2/tx" "github.com/vechain/thor/v2/txpool" @@ -149,12 +155,90 @@ func createTx(repo *chain.Repository, addressNumber uint) *tx.Transaction { new(tx.Builder). ChainTag(repo.ChainTag()). GasPriceCoef(1). - Expiration(10). + Expiration(1000). Gas(21000). - Nonce(1). + Nonce(uint64(datagen.RandInt())). Clause(cla). BlockRef(tx.NewBlockRef(0)). Build(), genesis.DevAccounts()[addressNumber].PrivateKey, ) } + +func TestPendingTx_NoWriteAfterUnsubscribe(t *testing.T) { + // Arrange + thorChain := initChain(t) + txPool := txpool.New(thorChain.Repo(), thorChain.Stater(), txpool.Options{ + Limit: 100, + LimitPerAccount: 16, + MaxLifetime: time.Hour, + }) + + p := newPendingTx(txPool) + txCh := make(chan *tx.Transaction, txQueueSize) + + // Subscribe and then unsubscribe + p.Subscribe(txCh) + p.Unsubscribe(txCh) + + done := make(chan struct{}) + // Attempt to write a new transaction + trx := createTx(thorChain.Repo(), 0) + assert.NotPanics(t, func() { + p.dispatch(trx, done) // dispatch should not panic after unsubscribe + }, "Dispatching after unsubscribe should not panic") + + select { + case <-txCh: + t.Fatal("Channel should not receive new transactions after unsubscribe") + default: + t.Log("No transactions sent to unsubscribed channel, as expected") + } +} + +func TestPendingTx_UnsubscribeOnWebSocketClose(t *testing.T) { + // Arrange + thorChain := initChain(t) + txPool := txpool.New(thorChain.Repo(), thorChain.Stater(), txpool.Options{ + Limit: 100, + LimitPerAccount: 16, + MaxLifetime: time.Hour, + }) + + // Subscriptions setup + sub := New(thorChain.Repo(), []string{"*"}, 100, txPool, false) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + utils.WrapHandlerFunc(sub.handlePendingTransactions)(w, r) + })) + defer server.Close() + + require.Equal(t, len(sub.pendingTx.listeners), 0) + + // Connect as WebSocket client + url := "ws" + server.URL[4:] + "/txpool" + ws, _, err := websocket.DefaultDialer.Dial(url, nil) + assert.NoError(t, err) + defer ws.Close() + + // Add a transaction + trx := createTx(thorChain.Repo(), 0) + txPool.AddLocal(trx) + + // Wait to receive transaction + time.Sleep(500 * time.Millisecond) + sub.pendingTx.mu.Lock() + require.Equal(t, len(sub.pendingTx.listeners), 1) + sub.pendingTx.mu.Unlock() + + // Simulate WebSocket closure + ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + ws.Close() + + // Wait for cleanup + time.Sleep(5 * time.Second) + + // Assert cleanup + sub.pendingTx.mu.Lock() + require.Equal(t, len(sub.pendingTx.listeners), 0) + sub.pendingTx.mu.Unlock() +} diff --git a/api/subscriptions/subscriptions.go b/api/subscriptions/subscriptions.go index 715a71308..3d85f139b 100644 --- a/api/subscriptions/subscriptions.go +++ b/api/subscriptions/subscriptions.go @@ -86,7 +86,7 @@ func New(repo *chain.Repository, allowedOrigins []string, backtraceLimit uint32, return sub } -func (s *Subscriptions) handleBlockReader(_ http.ResponseWriter, req *http.Request) (*blockReader, error) { +func (s *Subscriptions) handleBlockReader(_ http.ResponseWriter, req *http.Request) (msgReader, error) { position, err := s.parsePosition(req.URL.Query().Get("pos")) if err != nil { return nil, err @@ -94,7 +94,7 @@ func (s *Subscriptions) handleBlockReader(_ http.ResponseWriter, req *http.Reque return newBlockReader(s.repo, position), nil } -func (s *Subscriptions) handleEventReader(w http.ResponseWriter, req *http.Request) (*eventReader, error) { +func (s *Subscriptions) handleEventReader(w http.ResponseWriter, req *http.Request) (msgReader, error) { position, err := s.parsePosition(req.URL.Query().Get("pos")) if err != nil { return nil, err @@ -134,7 +134,7 @@ func (s *Subscriptions) handleEventReader(w http.ResponseWriter, req *http.Reque return newEventReader(s.repo, position, eventFilter), nil } -func (s *Subscriptions) handleTransferReader(_ http.ResponseWriter, req *http.Request) (*transferReader, error) { +func (s *Subscriptions) handleTransferReader(_ http.ResponseWriter, req *http.Request) (msgReader, error) { position, err := s.parsePosition(req.URL.Query().Get("pos")) if err != nil { return nil, err @@ -159,7 +159,7 @@ func (s *Subscriptions) handleTransferReader(_ http.ResponseWriter, req *http.Re return newTransferReader(s.repo, position, transferFilter), nil } -func (s *Subscriptions) handleBeatReader(w http.ResponseWriter, req *http.Request) (*beatReader, error) { +func (s *Subscriptions) handleBeatReader(w http.ResponseWriter, req *http.Request) (msgReader, error) { position, err := s.parsePosition(req.URL.Query().Get("pos")) if err != nil { return nil, err @@ -167,7 +167,7 @@ func (s *Subscriptions) handleBeatReader(w http.ResponseWriter, req *http.Reques return newBeatReader(s.repo, position, s.beatCache), nil } -func (s *Subscriptions) handleBeat2Reader(w http.ResponseWriter, req *http.Request) (*beat2Reader, error) { +func (s *Subscriptions) handleBeat2Reader(_ http.ResponseWriter, req *http.Request) (msgReader, error) { position, err := s.parsePosition(req.URL.Query().Get("pos")) if err != nil { return nil, err @@ -175,54 +175,6 @@ func (s *Subscriptions) handleBeat2Reader(w http.ResponseWriter, req *http.Reque return newBeat2Reader(s.repo, position, s.beat2Cache), nil } -func (s *Subscriptions) handleSubject(w http.ResponseWriter, req *http.Request) error { - s.wg.Add(1) - defer s.wg.Done() - - var ( - reader msgReader - err error - ) - switch mux.Vars(req)["subject"] { - case "block": - if reader, err = s.handleBlockReader(w, req); err != nil { - return err - } - case "event": - if reader, err = s.handleEventReader(w, req); err != nil { - return err - } - case "transfer": - if reader, err = s.handleTransferReader(w, req); err != nil { - return err - } - case "beat": - if !s.enabledDeprecated { - return utils.HTTPError(nil, http.StatusGone) - } - if reader, err = s.handleBeatReader(w, req); err != nil { - return err - } - case "beat2": - if reader, err = s.handleBeat2Reader(w, req); err != nil { - return err - } - default: - return utils.HTTPError(errors.New("not found"), http.StatusNotFound) - } - - conn, closed, err := s.setupConn(w, req) - // since the conn is hijacked here, no error should be returned in lines below - if err != nil { - logger.Debug("upgrade to websocket", "err", err) - return nil - } - - err = s.pipe(conn, reader, closed) - s.closeConn(conn, err) - return nil -} - func (s *Subscriptions) handlePendingTransactions(w http.ResponseWriter, req *http.Request) error { s.wg.Add(1) defer s.wg.Done() @@ -387,15 +339,69 @@ func (s *Subscriptions) Close() { s.wg.Wait() } +func (s *Subscriptions) websocket(readerFunc func(http.ResponseWriter, *http.Request) (msgReader, error)) utils.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) error { + s.wg.Add(1) + defer s.wg.Done() + + // Call the provided reader function + reader, err := readerFunc(w, req) + if err != nil { + return err + } + + // Setup WebSocket connection + conn, closed, err := s.setupConn(w, req) + if err != nil { + logger.Debug("upgrade to websocket", "err", err) + return err + } + defer s.closeConn(conn, err) + + // Stream messages + err = s.pipe(conn, reader, closed) + if err != nil { + logger.Debug("error in websocket pipe", "err", err) + } + return err + } +} + func (s *Subscriptions) Mount(root *mux.Router, pathPrefix string) { sub := root.PathPrefix(pathPrefix).Subrouter() sub.Path("/txpool"). Methods(http.MethodGet). - Name("subscriptions_pending_tx"). + Name("WS /subscriptions/txpool"). // metrics middleware relies on this name HandlerFunc(utils.WrapHandlerFunc(s.handlePendingTransactions)) - sub.Path("/{subject:beat|beat2|block|event|transfer}"). + + sub.Path("/block"). + Methods(http.MethodGet). + Name("WS /subscriptions/block"). // metrics middleware relies on this name + HandlerFunc(utils.WrapHandlerFunc(s.websocket(s.handleBlockReader))) + + sub.Path("/event"). + Methods(http.MethodGet). + Name("WS /subscriptions/event"). // metrics middleware relies on this name + HandlerFunc(utils.WrapHandlerFunc(s.websocket(s.handleEventReader))) + + sub.Path("/transfer"). + Methods(http.MethodGet). + Name("WS /subscriptions/transfer"). // metrics middleware relies on this name + HandlerFunc(utils.WrapHandlerFunc(s.websocket(s.handleTransferReader))) + + sub.Path("/beat2"). + Methods(http.MethodGet). + Name("WS /subscriptions/beat2"). // metrics middleware relies on this name + HandlerFunc(utils.WrapHandlerFunc(s.websocket(s.handleBeat2Reader))) + + // This method is currently deprecated + beatHandler := utils.HandleGone + if s.enabledDeprecated { + beatHandler = s.websocket(s.handleBeatReader) + } + sub.Path("/beat"). Methods(http.MethodGet). - Name("subscriptions_subject"). - HandlerFunc(utils.WrapHandlerFunc(s.handleSubject)) + Name("WS /subscriptions/beat"). // metrics middleware relies on this name + HandlerFunc(utils.WrapHandlerFunc(beatHandler)) } diff --git a/api/utils/http.go b/api/utils/http.go index 2235797de..6379b54f9 100644 --- a/api/utils/http.go +++ b/api/utils/http.go @@ -98,5 +98,12 @@ func WriteJSON(w http.ResponseWriter, obj interface{}) error { return json.NewEncoder(w).Encode(obj) } +// HandleGone is a handler for deprecated endpoints that returns HTTP 410 Gone. +func HandleGone(w http.ResponseWriter, _ *http.Request) error { + w.WriteHeader(http.StatusGone) + _, _ = w.Write([]byte("This endpoint is no longer supported.")) + return nil +} + // M shortcut for type map[string]interface{}. type M map[string]interface{} diff --git a/thorclient/common/common.go b/thorclient/common/common.go index 3bb9b0992..92500ef87 100644 --- a/thorclient/common/common.go +++ b/thorclient/common/common.go @@ -29,5 +29,5 @@ type EventWrapper[T any] struct { // Subscription is used to handle the active subscription type Subscription[T any] struct { EventChan <-chan EventWrapper[T] - Unsubscribe func() + Unsubscribe func() error } diff --git a/thorclient/wsclient/client.go b/thorclient/wsclient/client.go index 9eb1519ab..c987704e3 100644 --- a/thorclient/wsclient/client.go +++ b/thorclient/wsclient/client.go @@ -183,10 +183,17 @@ func subscribe[T any](conn *websocket.Conn) *common.Subscription[*T] { return &common.Subscription[*T]{ EventChan: eventChan, - Unsubscribe: func() { + Unsubscribe: func() error { closed = true - conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) - conn.Close() + err := conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + if err != nil { + return fmt.Errorf("failed to issue close message: %w", err) + } + err = conn.Close() + if err != nil { + return fmt.Errorf("failed to close connections: %w", err) + } + return nil }, } }