Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the subscriptions api to meet standard mount approach #926

Merged
merged 9 commits into from
Dec 23, 2024
3 changes: 1 addition & 2 deletions api/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,8 @@ func metricsMiddleware(next http.Handler) http.Handler {
if rt != nil && rt.GetName() != "" {
enabled = true
name = rt.GetName()
if strings.HasPrefix(name, "subscriptions") {
if strings.HasPrefix(name, "WS") {
subscription = true
name = "WS " + r.URL.Path
}
}

Expand Down
88 changes: 86 additions & 2 deletions api/subscriptions/pending_tx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,22 @@ package subscriptions

import (
"math/big"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vechain/thor/v2/api/utils"
"github.com/vechain/thor/v2/block"
"github.com/vechain/thor/v2/chain"
"github.com/vechain/thor/v2/genesis"
"github.com/vechain/thor/v2/muxdb"
"github.com/vechain/thor/v2/packer"
"github.com/vechain/thor/v2/state"
"github.com/vechain/thor/v2/test/datagen"
"github.com/vechain/thor/v2/thor"
"github.com/vechain/thor/v2/tx"
"github.com/vechain/thor/v2/txpool"
Expand Down Expand Up @@ -149,12 +155,90 @@ func createTx(repo *chain.Repository, addressNumber uint) *tx.Transaction {
new(tx.Builder).
ChainTag(repo.ChainTag()).
GasPriceCoef(1).
Expiration(10).
Expiration(1000).
Gas(21000).
Nonce(1).
Nonce(uint64(datagen.RandInt())).
Clause(cla).
BlockRef(tx.NewBlockRef(0)).
Build(),
genesis.DevAccounts()[addressNumber].PrivateKey,
)
}

func TestPendingTx_NoWriteAfterUnsubscribe(t *testing.T) {
// Arrange
thorChain := initChain(t)
txPool := txpool.New(thorChain.Repo(), thorChain.Stater(), txpool.Options{
Limit: 100,
LimitPerAccount: 16,
MaxLifetime: time.Hour,
})

p := newPendingTx(txPool)
txCh := make(chan *tx.Transaction, txQueueSize)

// Subscribe and then unsubscribe
p.Subscribe(txCh)
p.Unsubscribe(txCh)

done := make(chan struct{})
// Attempt to write a new transaction
trx := createTx(thorChain.Repo(), 0)
assert.NotPanics(t, func() {
p.dispatch(trx, done) // dispatch should not panic after unsubscribe
}, "Dispatching after unsubscribe should not panic")

select {
case <-txCh:
t.Fatal("Channel should not receive new transactions after unsubscribe")
default:
t.Log("No transactions sent to unsubscribed channel, as expected")
}
}

func TestPendingTx_UnsubscribeOnWebSocketClose(t *testing.T) {
// Arrange
thorChain := initChain(t)
txPool := txpool.New(thorChain.Repo(), thorChain.Stater(), txpool.Options{
Limit: 100,
LimitPerAccount: 16,
MaxLifetime: time.Hour,
})

// Subscriptions setup
sub := New(thorChain.Repo(), []string{"*"}, 100, txPool, false)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
utils.WrapHandlerFunc(sub.handlePendingTransactions)(w, r)
}))
defer server.Close()

require.Equal(t, len(sub.pendingTx.listeners), 0)

// Connect as WebSocket client
url := "ws" + server.URL[4:] + "/txpool"
ws, _, err := websocket.DefaultDialer.Dial(url, nil)
assert.NoError(t, err)
defer ws.Close()

// Add a transaction
trx := createTx(thorChain.Repo(), 0)
txPool.AddLocal(trx)

// Wait to receive transaction
time.Sleep(500 * time.Millisecond)
sub.pendingTx.mu.Lock()
require.Equal(t, len(sub.pendingTx.listeners), 1)
sub.pendingTx.mu.Unlock()

// Simulate WebSocket closure
ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
ws.Close()

// Wait for cleanup
time.Sleep(5 * time.Second)

// Assert cleanup
sub.pendingTx.mu.Lock()
require.Equal(t, len(sub.pendingTx.listeners), 0)
sub.pendingTx.mu.Unlock()
}
134 changes: 77 additions & 57 deletions api/subscriptions/subscriptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,15 @@ func New(repo *chain.Repository, allowedOrigins []string, backtraceLimit uint32,
return sub
}

func (s *Subscriptions) handleBlockReader(_ http.ResponseWriter, req *http.Request) (*blockReader, error) {
func (s *Subscriptions) handleBlockReader(_ http.ResponseWriter, req *http.Request) (msgReader, error) {
position, err := s.parsePosition(req.URL.Query().Get("pos"))
if err != nil {
return nil, err
}
return newBlockReader(s.repo, position), nil
}

func (s *Subscriptions) handleEventReader(w http.ResponseWriter, req *http.Request) (*eventReader, error) {
func (s *Subscriptions) handleEventReader(w http.ResponseWriter, req *http.Request) (msgReader, error) {
position, err := s.parsePosition(req.URL.Query().Get("pos"))
if err != nil {
return nil, err
Expand Down Expand Up @@ -134,7 +134,7 @@ func (s *Subscriptions) handleEventReader(w http.ResponseWriter, req *http.Reque
return newEventReader(s.repo, position, eventFilter), nil
}

func (s *Subscriptions) handleTransferReader(_ http.ResponseWriter, req *http.Request) (*transferReader, error) {
func (s *Subscriptions) handleTransferReader(_ http.ResponseWriter, req *http.Request) (msgReader, error) {
position, err := s.parsePosition(req.URL.Query().Get("pos"))
if err != nil {
return nil, err
Expand All @@ -159,70 +159,22 @@ func (s *Subscriptions) handleTransferReader(_ http.ResponseWriter, req *http.Re
return newTransferReader(s.repo, position, transferFilter), nil
}

func (s *Subscriptions) handleBeatReader(w http.ResponseWriter, req *http.Request) (*beatReader, error) {
func (s *Subscriptions) handleBeatReader(w http.ResponseWriter, req *http.Request) (msgReader, error) {
position, err := s.parsePosition(req.URL.Query().Get("pos"))
if err != nil {
return nil, err
}
return newBeatReader(s.repo, position, s.beatCache), nil
}

func (s *Subscriptions) handleBeat2Reader(w http.ResponseWriter, req *http.Request) (*beat2Reader, error) {
func (s *Subscriptions) handleBeat2Reader(_ http.ResponseWriter, req *http.Request) (msgReader, error) {
position, err := s.parsePosition(req.URL.Query().Get("pos"))
if err != nil {
return nil, err
}
return newBeat2Reader(s.repo, position, s.beat2Cache), nil
}

func (s *Subscriptions) handleSubject(w http.ResponseWriter, req *http.Request) error {
s.wg.Add(1)
defer s.wg.Done()

var (
reader msgReader
err error
)
switch mux.Vars(req)["subject"] {
case "block":
if reader, err = s.handleBlockReader(w, req); err != nil {
return err
}
case "event":
if reader, err = s.handleEventReader(w, req); err != nil {
return err
}
case "transfer":
if reader, err = s.handleTransferReader(w, req); err != nil {
return err
}
case "beat":
if !s.enabledDeprecated {
return utils.HTTPError(nil, http.StatusGone)
}
if reader, err = s.handleBeatReader(w, req); err != nil {
return err
}
case "beat2":
if reader, err = s.handleBeat2Reader(w, req); err != nil {
return err
}
default:
return utils.HTTPError(errors.New("not found"), http.StatusNotFound)
}

conn, closed, err := s.setupConn(w, req)
// since the conn is hijacked here, no error should be returned in lines below
if err != nil {
logger.Debug("upgrade to websocket", "err", err)
return nil
}

err = s.pipe(conn, reader, closed)
s.closeConn(conn, err)
return nil
}

func (s *Subscriptions) handlePendingTransactions(w http.ResponseWriter, req *http.Request) error {
s.wg.Add(1)
defer s.wg.Done()
Expand Down Expand Up @@ -387,15 +339,83 @@ func (s *Subscriptions) Close() {
s.wg.Wait()
}

func (s *Subscriptions) websocket(readerFunc func(http.ResponseWriter, *http.Request) (msgReader, error)) utils.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) error {
s.wg.Add(1)
defer s.wg.Done()

// Call the provided reader function
reader, err := readerFunc(w, req)
if err != nil {
return err
}

// Ensure cleanup when the connection closes
defer func() {
if closer, ok := reader.(interface{ Close() }); ok {
otherview marked this conversation as resolved.
Show resolved Hide resolved
closer.Close()
}
}()

// Setup WebSocket connection
conn, closed, err := s.setupConn(w, req)
if err != nil {
logger.Debug("upgrade to websocket", "err", err)
return err
}
defer s.closeConn(conn, err)

// Stream messages
err = s.pipe(conn, reader, closed)
if err != nil {
logger.Debug("error in websocket pipe", "err", err)
}
return err
}
}

// handleGone is a handler for deprecated endpoints that returns HTTP 410 Gone.
func handleGone(w http.ResponseWriter, _ *http.Request) error {
libotony marked this conversation as resolved.
Show resolved Hide resolved
w.WriteHeader(http.StatusGone)
_, _ = w.Write([]byte("This endpoint is no longer supported."))
return nil
}

func (s *Subscriptions) Mount(root *mux.Router, pathPrefix string) {
sub := root.PathPrefix(pathPrefix).Subrouter()

sub.Path("/txpool").
Methods(http.MethodGet).
Name("subscriptions_pending_tx").
Name("WS /subscriptions/txpool").
HandlerFunc(utils.WrapHandlerFunc(s.handlePendingTransactions))
sub.Path("/{subject:beat|beat2|block|event|transfer}").

sub.Path("/block").
Methods(http.MethodGet).
Name("WS /subscriptions/block").
HandlerFunc(utils.WrapHandlerFunc(s.websocket(s.handleBlockReader)))

sub.Path("/event").
Methods(http.MethodGet).
Name("WS /subscriptions/event").
HandlerFunc(utils.WrapHandlerFunc(s.websocket(s.handleEventReader)))

sub.Path("/transfer").
Methods(http.MethodGet).
Name("subscriptions_subject").
HandlerFunc(utils.WrapHandlerFunc(s.handleSubject))
Name("WS /subscriptions/transfer").
HandlerFunc(utils.WrapHandlerFunc(s.websocket(s.handleTransferReader)))

sub.Path("/beat2").
Methods(http.MethodGet).
Name("WS /subscriptions/beat2").
HandlerFunc(utils.WrapHandlerFunc(s.websocket(s.handleBeat2Reader)))

deprecatedBeat := sub.Path("/beat").
Methods(http.MethodGet).
Name("WS /subscriptions/beat")

if s.enabledDeprecated {
deprecatedBeat.HandlerFunc(utils.WrapHandlerFunc(s.websocket(s.handleBeatReader)))
} else {
deprecatedBeat.HandlerFunc(utils.WrapHandlerFunc(handleGone))
}
}
Loading