Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the subscriptions api to meet standard mount approach #926

Merged
merged 9 commits into from
Dec 23, 2024
12 changes: 7 additions & 5 deletions api/accounts/accounts.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,6 @@ func (a *Accounts) handleGetStorage(w http.ResponseWriter, req *http.Request) er
}

func (a *Accounts) handleCallContract(w http.ResponseWriter, req *http.Request) error {
if !a.enabledDeprecated {
return utils.HTTPError(nil, http.StatusGone)
}
callData := &CallData{}
if err := utils.ParseJSON(req.Body, &callData); err != nil {
return utils.BadRequest(errors.WithMessage(err, "body"))
Expand Down Expand Up @@ -378,13 +375,18 @@ func (a *Accounts) Mount(root *mux.Router, pathPrefix string) {
Methods("GET").
Name("GET /accounts/{address}/storage").
HandlerFunc(utils.WrapHandlerFunc(a.handleGetStorage))

// These two methods are currently deprecated
callContractHandler := utils.HandleGone
if a.enabledDeprecated {
callContractHandler = a.handleCallContract
}
sub.Path("").
Methods(http.MethodPost).
Name("POST /accounts").
HandlerFunc(utils.WrapHandlerFunc(a.handleCallContract))
HandlerFunc(utils.WrapHandlerFunc(callContractHandler))
sub.Path("/{address}").
Methods(http.MethodPost).
Name("POST /accounts/{address}").
HandlerFunc(utils.WrapHandlerFunc(a.handleCallContract))
HandlerFunc(utils.WrapHandlerFunc(callContractHandler))
}
3 changes: 1 addition & 2 deletions api/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,8 @@ func metricsMiddleware(next http.Handler) http.Handler {
if rt != nil && rt.GetName() != "" {
enabled = true
name = rt.GetName()
if strings.HasPrefix(name, "subscriptions") {
if strings.HasPrefix(name, "WS") {
subscription = true
name = "WS " + r.URL.Path
}
}

Expand Down
88 changes: 86 additions & 2 deletions api/subscriptions/pending_tx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,22 @@ package subscriptions

import (
"math/big"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vechain/thor/v2/api/utils"
"github.com/vechain/thor/v2/block"
"github.com/vechain/thor/v2/chain"
"github.com/vechain/thor/v2/genesis"
"github.com/vechain/thor/v2/muxdb"
"github.com/vechain/thor/v2/packer"
"github.com/vechain/thor/v2/state"
"github.com/vechain/thor/v2/test/datagen"
"github.com/vechain/thor/v2/thor"
"github.com/vechain/thor/v2/tx"
"github.com/vechain/thor/v2/txpool"
Expand Down Expand Up @@ -149,12 +155,90 @@ func createTx(repo *chain.Repository, addressNumber uint) *tx.Transaction {
new(tx.Builder).
ChainTag(repo.ChainTag()).
GasPriceCoef(1).
Expiration(10).
Expiration(1000).
Gas(21000).
Nonce(1).
Nonce(uint64(datagen.RandInt())).
Clause(cla).
BlockRef(tx.NewBlockRef(0)).
Build(),
genesis.DevAccounts()[addressNumber].PrivateKey,
)
}

func TestPendingTx_NoWriteAfterUnsubscribe(t *testing.T) {
// Arrange
thorChain := initChain(t)
txPool := txpool.New(thorChain.Repo(), thorChain.Stater(), txpool.Options{
Limit: 100,
LimitPerAccount: 16,
MaxLifetime: time.Hour,
})

p := newPendingTx(txPool)
txCh := make(chan *tx.Transaction, txQueueSize)

// Subscribe and then unsubscribe
p.Subscribe(txCh)
p.Unsubscribe(txCh)

done := make(chan struct{})
// Attempt to write a new transaction
trx := createTx(thorChain.Repo(), 0)
assert.NotPanics(t, func() {
p.dispatch(trx, done) // dispatch should not panic after unsubscribe
}, "Dispatching after unsubscribe should not panic")

select {
case <-txCh:
t.Fatal("Channel should not receive new transactions after unsubscribe")
default:
t.Log("No transactions sent to unsubscribed channel, as expected")
}
}

func TestPendingTx_UnsubscribeOnWebSocketClose(t *testing.T) {
// Arrange
thorChain := initChain(t)
txPool := txpool.New(thorChain.Repo(), thorChain.Stater(), txpool.Options{
Limit: 100,
LimitPerAccount: 16,
MaxLifetime: time.Hour,
})

// Subscriptions setup
sub := New(thorChain.Repo(), []string{"*"}, 100, txPool, false)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
utils.WrapHandlerFunc(sub.handlePendingTransactions)(w, r)
}))
defer server.Close()

require.Equal(t, len(sub.pendingTx.listeners), 0)

// Connect as WebSocket client
url := "ws" + server.URL[4:] + "/txpool"
ws, _, err := websocket.DefaultDialer.Dial(url, nil)
assert.NoError(t, err)
defer ws.Close()

// Add a transaction
trx := createTx(thorChain.Repo(), 0)
txPool.AddLocal(trx)

// Wait to receive transaction
time.Sleep(500 * time.Millisecond)
sub.pendingTx.mu.Lock()
require.Equal(t, len(sub.pendingTx.listeners), 1)
sub.pendingTx.mu.Unlock()

// Simulate WebSocket closure
ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
ws.Close()

// Wait for cleanup
time.Sleep(5 * time.Second)

// Assert cleanup
sub.pendingTx.mu.Lock()
require.Equal(t, len(sub.pendingTx.listeners), 0)
sub.pendingTx.mu.Unlock()
}
120 changes: 63 additions & 57 deletions api/subscriptions/subscriptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,15 @@ func New(repo *chain.Repository, allowedOrigins []string, backtraceLimit uint32,
return sub
}

func (s *Subscriptions) handleBlockReader(_ http.ResponseWriter, req *http.Request) (*blockReader, error) {
func (s *Subscriptions) handleBlockReader(_ http.ResponseWriter, req *http.Request) (msgReader, error) {
position, err := s.parsePosition(req.URL.Query().Get("pos"))
if err != nil {
return nil, err
}
return newBlockReader(s.repo, position), nil
}

func (s *Subscriptions) handleEventReader(w http.ResponseWriter, req *http.Request) (*eventReader, error) {
func (s *Subscriptions) handleEventReader(w http.ResponseWriter, req *http.Request) (msgReader, error) {
position, err := s.parsePosition(req.URL.Query().Get("pos"))
if err != nil {
return nil, err
Expand Down Expand Up @@ -134,7 +134,7 @@ func (s *Subscriptions) handleEventReader(w http.ResponseWriter, req *http.Reque
return newEventReader(s.repo, position, eventFilter), nil
}

func (s *Subscriptions) handleTransferReader(_ http.ResponseWriter, req *http.Request) (*transferReader, error) {
func (s *Subscriptions) handleTransferReader(_ http.ResponseWriter, req *http.Request) (msgReader, error) {
position, err := s.parsePosition(req.URL.Query().Get("pos"))
if err != nil {
return nil, err
Expand All @@ -159,70 +159,22 @@ func (s *Subscriptions) handleTransferReader(_ http.ResponseWriter, req *http.Re
return newTransferReader(s.repo, position, transferFilter), nil
}

func (s *Subscriptions) handleBeatReader(w http.ResponseWriter, req *http.Request) (*beatReader, error) {
func (s *Subscriptions) handleBeatReader(w http.ResponseWriter, req *http.Request) (msgReader, error) {
position, err := s.parsePosition(req.URL.Query().Get("pos"))
if err != nil {
return nil, err
}
return newBeatReader(s.repo, position, s.beatCache), nil
}

func (s *Subscriptions) handleBeat2Reader(w http.ResponseWriter, req *http.Request) (*beat2Reader, error) {
func (s *Subscriptions) handleBeat2Reader(_ http.ResponseWriter, req *http.Request) (msgReader, error) {
position, err := s.parsePosition(req.URL.Query().Get("pos"))
if err != nil {
return nil, err
}
return newBeat2Reader(s.repo, position, s.beat2Cache), nil
}

func (s *Subscriptions) handleSubject(w http.ResponseWriter, req *http.Request) error {
s.wg.Add(1)
defer s.wg.Done()

var (
reader msgReader
err error
)
switch mux.Vars(req)["subject"] {
case "block":
if reader, err = s.handleBlockReader(w, req); err != nil {
return err
}
case "event":
if reader, err = s.handleEventReader(w, req); err != nil {
return err
}
case "transfer":
if reader, err = s.handleTransferReader(w, req); err != nil {
return err
}
case "beat":
if !s.enabledDeprecated {
return utils.HTTPError(nil, http.StatusGone)
}
if reader, err = s.handleBeatReader(w, req); err != nil {
return err
}
case "beat2":
if reader, err = s.handleBeat2Reader(w, req); err != nil {
return err
}
default:
return utils.HTTPError(errors.New("not found"), http.StatusNotFound)
}

conn, closed, err := s.setupConn(w, req)
// since the conn is hijacked here, no error should be returned in lines below
if err != nil {
logger.Debug("upgrade to websocket", "err", err)
return nil
}

err = s.pipe(conn, reader, closed)
s.closeConn(conn, err)
return nil
}

func (s *Subscriptions) handlePendingTransactions(w http.ResponseWriter, req *http.Request) error {
s.wg.Add(1)
defer s.wg.Done()
Expand Down Expand Up @@ -387,15 +339,69 @@ func (s *Subscriptions) Close() {
s.wg.Wait()
}

func (s *Subscriptions) websocket(readerFunc func(http.ResponseWriter, *http.Request) (msgReader, error)) utils.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) error {
s.wg.Add(1)
defer s.wg.Done()

// Call the provided reader function
reader, err := readerFunc(w, req)
if err != nil {
return err
}

// Setup WebSocket connection
conn, closed, err := s.setupConn(w, req)
if err != nil {
logger.Debug("upgrade to websocket", "err", err)
return err
}
defer s.closeConn(conn, err)

// Stream messages
err = s.pipe(conn, reader, closed)
if err != nil {
logger.Debug("error in websocket pipe", "err", err)
}
return err
}
}

func (s *Subscriptions) Mount(root *mux.Router, pathPrefix string) {
sub := root.PathPrefix(pathPrefix).Subrouter()

sub.Path("/txpool").
Methods(http.MethodGet).
Name("subscriptions_pending_tx").
Name("WS /subscriptions/txpool"). // metrics middleware relies on this name
HandlerFunc(utils.WrapHandlerFunc(s.handlePendingTransactions))
sub.Path("/{subject:beat|beat2|block|event|transfer}").

sub.Path("/block").
Methods(http.MethodGet).
Name("WS /subscriptions/block"). // metrics middleware relies on this name
HandlerFunc(utils.WrapHandlerFunc(s.websocket(s.handleBlockReader)))

sub.Path("/event").
Methods(http.MethodGet).
Name("WS /subscriptions/event"). // metrics middleware relies on this name
HandlerFunc(utils.WrapHandlerFunc(s.websocket(s.handleEventReader)))

sub.Path("/transfer").
Methods(http.MethodGet).
Name("WS /subscriptions/transfer"). // metrics middleware relies on this name
HandlerFunc(utils.WrapHandlerFunc(s.websocket(s.handleTransferReader)))

sub.Path("/beat2").
Methods(http.MethodGet).
Name("WS /subscriptions/beat2"). // metrics middleware relies on this name
HandlerFunc(utils.WrapHandlerFunc(s.websocket(s.handleBeat2Reader)))

// This method is currently deprecated
beatHandler := utils.HandleGone
if s.enabledDeprecated {
beatHandler = s.websocket(s.handleBeatReader)
}
sub.Path("/beat").
Methods(http.MethodGet).
Name("subscriptions_subject").
HandlerFunc(utils.WrapHandlerFunc(s.handleSubject))
Name("WS /subscriptions/beat"). // metrics middleware relies on this name
HandlerFunc(utils.WrapHandlerFunc(beatHandler))
}
7 changes: 7 additions & 0 deletions api/utils/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,5 +98,12 @@ func WriteJSON(w http.ResponseWriter, obj interface{}) error {
return json.NewEncoder(w).Encode(obj)
}

// HandleGone is a handler for deprecated endpoints that returns HTTP 410 Gone.
func HandleGone(w http.ResponseWriter, _ *http.Request) error {
w.WriteHeader(http.StatusGone)
_, _ = w.Write([]byte("This endpoint is no longer supported."))
return nil
}

// M shortcut for type map[string]interface{}.
type M map[string]interface{}
2 changes: 1 addition & 1 deletion thorclient/common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ type EventWrapper[T any] struct {
// Subscription is used to handle the active subscription
type Subscription[T any] struct {
EventChan <-chan EventWrapper[T]
Unsubscribe func()
Unsubscribe func() error
}
13 changes: 10 additions & 3 deletions thorclient/wsclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,17 @@ func subscribe[T any](conn *websocket.Conn) *common.Subscription[*T] {

return &common.Subscription[*T]{
EventChan: eventChan,
Unsubscribe: func() {
Unsubscribe: func() error {
closed = true
conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
conn.Close()
err := conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil {
return fmt.Errorf("failed to issue close message: %w", err)
}
err = conn.Close()
if err != nil {
return fmt.Errorf("failed to close connections: %w", err)
}
return nil
},
}
}
Expand Down
Loading