Skip to content

Commit

Permalink
Refactor the subscriptions api to meet standard mount approach (#926)
Browse files Browse the repository at this point in the history
* Tweaking subscriptions

* return gone when endpoint deprecated

* update metrics naming structure

* fixing race condition on tests

* add utils.gone handler

* minor naming

* remove closer interface usage

* adding thorclient.Subscription.Unsubscribe errors

---------

Co-authored-by: tony <liboliqi@gmail.com>
  • Loading branch information
otherview and libotony authored Dec 23, 2024
1 parent 5867f22 commit 77ec5d0
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 70 deletions.
12 changes: 7 additions & 5 deletions api/accounts/accounts.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,6 @@ func (a *Accounts) handleGetStorage(w http.ResponseWriter, req *http.Request) er
}

func (a *Accounts) handleCallContract(w http.ResponseWriter, req *http.Request) error {
if !a.enabledDeprecated {
return utils.HTTPError(nil, http.StatusGone)
}
callData := &CallData{}
if err := utils.ParseJSON(req.Body, &callData); err != nil {
return utils.BadRequest(errors.WithMessage(err, "body"))
Expand Down Expand Up @@ -378,13 +375,18 @@ func (a *Accounts) Mount(root *mux.Router, pathPrefix string) {
Methods("GET").
Name("GET /accounts/{address}/storage").
HandlerFunc(utils.WrapHandlerFunc(a.handleGetStorage))

// These two methods are currently deprecated
callContractHandler := utils.HandleGone
if a.enabledDeprecated {
callContractHandler = a.handleCallContract
}
sub.Path("").
Methods(http.MethodPost).
Name("POST /accounts").
HandlerFunc(utils.WrapHandlerFunc(a.handleCallContract))
HandlerFunc(utils.WrapHandlerFunc(callContractHandler))
sub.Path("/{address}").
Methods(http.MethodPost).
Name("POST /accounts/{address}").
HandlerFunc(utils.WrapHandlerFunc(a.handleCallContract))
HandlerFunc(utils.WrapHandlerFunc(callContractHandler))
}
3 changes: 1 addition & 2 deletions api/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,8 @@ func metricsMiddleware(next http.Handler) http.Handler {
if rt != nil && rt.GetName() != "" {
enabled = true
name = rt.GetName()
if strings.HasPrefix(name, "subscriptions") {
if strings.HasPrefix(name, "WS") {
subscription = true
name = "WS " + r.URL.Path
}
}

Expand Down
88 changes: 86 additions & 2 deletions api/subscriptions/pending_tx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,22 @@ package subscriptions

import (
"math/big"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vechain/thor/v2/api/utils"
"github.com/vechain/thor/v2/block"
"github.com/vechain/thor/v2/chain"
"github.com/vechain/thor/v2/genesis"
"github.com/vechain/thor/v2/muxdb"
"github.com/vechain/thor/v2/packer"
"github.com/vechain/thor/v2/state"
"github.com/vechain/thor/v2/test/datagen"
"github.com/vechain/thor/v2/thor"
"github.com/vechain/thor/v2/tx"
"github.com/vechain/thor/v2/txpool"
Expand Down Expand Up @@ -149,12 +155,90 @@ func createTx(repo *chain.Repository, addressNumber uint) *tx.Transaction {
new(tx.Builder).
ChainTag(repo.ChainTag()).
GasPriceCoef(1).
Expiration(10).
Expiration(1000).
Gas(21000).
Nonce(1).
Nonce(uint64(datagen.RandInt())).
Clause(cla).
BlockRef(tx.NewBlockRef(0)).
Build(),
genesis.DevAccounts()[addressNumber].PrivateKey,
)
}

func TestPendingTx_NoWriteAfterUnsubscribe(t *testing.T) {
// Arrange
thorChain := initChain(t)
txPool := txpool.New(thorChain.Repo(), thorChain.Stater(), txpool.Options{
Limit: 100,
LimitPerAccount: 16,
MaxLifetime: time.Hour,
})

p := newPendingTx(txPool)
txCh := make(chan *tx.Transaction, txQueueSize)

// Subscribe and then unsubscribe
p.Subscribe(txCh)
p.Unsubscribe(txCh)

done := make(chan struct{})
// Attempt to write a new transaction
trx := createTx(thorChain.Repo(), 0)
assert.NotPanics(t, func() {
p.dispatch(trx, done) // dispatch should not panic after unsubscribe
}, "Dispatching after unsubscribe should not panic")

select {
case <-txCh:
t.Fatal("Channel should not receive new transactions after unsubscribe")
default:
t.Log("No transactions sent to unsubscribed channel, as expected")
}
}

func TestPendingTx_UnsubscribeOnWebSocketClose(t *testing.T) {
// Arrange
thorChain := initChain(t)
txPool := txpool.New(thorChain.Repo(), thorChain.Stater(), txpool.Options{
Limit: 100,
LimitPerAccount: 16,
MaxLifetime: time.Hour,
})

// Subscriptions setup
sub := New(thorChain.Repo(), []string{"*"}, 100, txPool, false)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
utils.WrapHandlerFunc(sub.handlePendingTransactions)(w, r)
}))
defer server.Close()

require.Equal(t, len(sub.pendingTx.listeners), 0)

// Connect as WebSocket client
url := "ws" + server.URL[4:] + "/txpool"
ws, _, err := websocket.DefaultDialer.Dial(url, nil)
assert.NoError(t, err)
defer ws.Close()

// Add a transaction
trx := createTx(thorChain.Repo(), 0)
txPool.AddLocal(trx)

// Wait to receive transaction
time.Sleep(500 * time.Millisecond)
sub.pendingTx.mu.Lock()
require.Equal(t, len(sub.pendingTx.listeners), 1)
sub.pendingTx.mu.Unlock()

// Simulate WebSocket closure
ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
ws.Close()

// Wait for cleanup
time.Sleep(5 * time.Second)

// Assert cleanup
sub.pendingTx.mu.Lock()
require.Equal(t, len(sub.pendingTx.listeners), 0)
sub.pendingTx.mu.Unlock()
}
120 changes: 63 additions & 57 deletions api/subscriptions/subscriptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,15 @@ func New(repo *chain.Repository, allowedOrigins []string, backtraceLimit uint32,
return sub
}

func (s *Subscriptions) handleBlockReader(_ http.ResponseWriter, req *http.Request) (*blockReader, error) {
func (s *Subscriptions) handleBlockReader(_ http.ResponseWriter, req *http.Request) (msgReader, error) {
position, err := s.parsePosition(req.URL.Query().Get("pos"))
if err != nil {
return nil, err
}
return newBlockReader(s.repo, position), nil
}

func (s *Subscriptions) handleEventReader(w http.ResponseWriter, req *http.Request) (*eventReader, error) {
func (s *Subscriptions) handleEventReader(w http.ResponseWriter, req *http.Request) (msgReader, error) {
position, err := s.parsePosition(req.URL.Query().Get("pos"))
if err != nil {
return nil, err
Expand Down Expand Up @@ -134,7 +134,7 @@ func (s *Subscriptions) handleEventReader(w http.ResponseWriter, req *http.Reque
return newEventReader(s.repo, position, eventFilter), nil
}

func (s *Subscriptions) handleTransferReader(_ http.ResponseWriter, req *http.Request) (*transferReader, error) {
func (s *Subscriptions) handleTransferReader(_ http.ResponseWriter, req *http.Request) (msgReader, error) {
position, err := s.parsePosition(req.URL.Query().Get("pos"))
if err != nil {
return nil, err
Expand All @@ -159,70 +159,22 @@ func (s *Subscriptions) handleTransferReader(_ http.ResponseWriter, req *http.Re
return newTransferReader(s.repo, position, transferFilter), nil
}

func (s *Subscriptions) handleBeatReader(w http.ResponseWriter, req *http.Request) (*beatReader, error) {
func (s *Subscriptions) handleBeatReader(w http.ResponseWriter, req *http.Request) (msgReader, error) {
position, err := s.parsePosition(req.URL.Query().Get("pos"))
if err != nil {
return nil, err
}
return newBeatReader(s.repo, position, s.beatCache), nil
}

func (s *Subscriptions) handleBeat2Reader(w http.ResponseWriter, req *http.Request) (*beat2Reader, error) {
func (s *Subscriptions) handleBeat2Reader(_ http.ResponseWriter, req *http.Request) (msgReader, error) {
position, err := s.parsePosition(req.URL.Query().Get("pos"))
if err != nil {
return nil, err
}
return newBeat2Reader(s.repo, position, s.beat2Cache), nil
}

func (s *Subscriptions) handleSubject(w http.ResponseWriter, req *http.Request) error {
s.wg.Add(1)
defer s.wg.Done()

var (
reader msgReader
err error
)
switch mux.Vars(req)["subject"] {
case "block":
if reader, err = s.handleBlockReader(w, req); err != nil {
return err
}
case "event":
if reader, err = s.handleEventReader(w, req); err != nil {
return err
}
case "transfer":
if reader, err = s.handleTransferReader(w, req); err != nil {
return err
}
case "beat":
if !s.enabledDeprecated {
return utils.HTTPError(nil, http.StatusGone)
}
if reader, err = s.handleBeatReader(w, req); err != nil {
return err
}
case "beat2":
if reader, err = s.handleBeat2Reader(w, req); err != nil {
return err
}
default:
return utils.HTTPError(errors.New("not found"), http.StatusNotFound)
}

conn, closed, err := s.setupConn(w, req)
// since the conn is hijacked here, no error should be returned in lines below
if err != nil {
logger.Debug("upgrade to websocket", "err", err)
return nil
}

err = s.pipe(conn, reader, closed)
s.closeConn(conn, err)
return nil
}

func (s *Subscriptions) handlePendingTransactions(w http.ResponseWriter, req *http.Request) error {
s.wg.Add(1)
defer s.wg.Done()
Expand Down Expand Up @@ -387,15 +339,69 @@ func (s *Subscriptions) Close() {
s.wg.Wait()
}

func (s *Subscriptions) websocket(readerFunc func(http.ResponseWriter, *http.Request) (msgReader, error)) utils.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) error {
s.wg.Add(1)
defer s.wg.Done()

// Call the provided reader function
reader, err := readerFunc(w, req)
if err != nil {
return err
}

// Setup WebSocket connection
conn, closed, err := s.setupConn(w, req)
if err != nil {
logger.Debug("upgrade to websocket", "err", err)
return err
}
defer s.closeConn(conn, err)

// Stream messages
err = s.pipe(conn, reader, closed)
if err != nil {
logger.Debug("error in websocket pipe", "err", err)
}
return err
}
}

func (s *Subscriptions) Mount(root *mux.Router, pathPrefix string) {
sub := root.PathPrefix(pathPrefix).Subrouter()

sub.Path("/txpool").
Methods(http.MethodGet).
Name("subscriptions_pending_tx").
Name("WS /subscriptions/txpool"). // metrics middleware relies on this name
HandlerFunc(utils.WrapHandlerFunc(s.handlePendingTransactions))
sub.Path("/{subject:beat|beat2|block|event|transfer}").

sub.Path("/block").
Methods(http.MethodGet).
Name("WS /subscriptions/block"). // metrics middleware relies on this name
HandlerFunc(utils.WrapHandlerFunc(s.websocket(s.handleBlockReader)))

sub.Path("/event").
Methods(http.MethodGet).
Name("WS /subscriptions/event"). // metrics middleware relies on this name
HandlerFunc(utils.WrapHandlerFunc(s.websocket(s.handleEventReader)))

sub.Path("/transfer").
Methods(http.MethodGet).
Name("WS /subscriptions/transfer"). // metrics middleware relies on this name
HandlerFunc(utils.WrapHandlerFunc(s.websocket(s.handleTransferReader)))

sub.Path("/beat2").
Methods(http.MethodGet).
Name("WS /subscriptions/beat2"). // metrics middleware relies on this name
HandlerFunc(utils.WrapHandlerFunc(s.websocket(s.handleBeat2Reader)))

// This method is currently deprecated
beatHandler := utils.HandleGone
if s.enabledDeprecated {
beatHandler = s.websocket(s.handleBeatReader)
}
sub.Path("/beat").
Methods(http.MethodGet).
Name("subscriptions_subject").
HandlerFunc(utils.WrapHandlerFunc(s.handleSubject))
Name("WS /subscriptions/beat"). // metrics middleware relies on this name
HandlerFunc(utils.WrapHandlerFunc(beatHandler))
}
7 changes: 7 additions & 0 deletions api/utils/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,5 +98,12 @@ func WriteJSON(w http.ResponseWriter, obj interface{}) error {
return json.NewEncoder(w).Encode(obj)
}

// HandleGone is a handler for deprecated endpoints that returns HTTP 410 Gone.
func HandleGone(w http.ResponseWriter, _ *http.Request) error {
w.WriteHeader(http.StatusGone)
_, _ = w.Write([]byte("This endpoint is no longer supported."))
return nil
}

// M shortcut for type map[string]interface{}.
type M map[string]interface{}
2 changes: 1 addition & 1 deletion thorclient/common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ type EventWrapper[T any] struct {
// Subscription is used to handle the active subscription
type Subscription[T any] struct {
EventChan <-chan EventWrapper[T]
Unsubscribe func()
Unsubscribe func() error
}
13 changes: 10 additions & 3 deletions thorclient/wsclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,17 @@ func subscribe[T any](conn *websocket.Conn) *common.Subscription[*T] {

return &common.Subscription[*T]{
EventChan: eventChan,
Unsubscribe: func() {
Unsubscribe: func() error {
closed = true
conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
conn.Close()
err := conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil {
return fmt.Errorf("failed to issue close message: %w", err)
}
err = conn.Close()
if err != nil {
return fmt.Errorf("failed to close connections: %w", err)
}
return nil
},
}
}
Expand Down

0 comments on commit 77ec5d0

Please sign in to comment.