Skip to content

Commit

Permalink
chore: remove extra code
Browse files Browse the repository at this point in the history
  • Loading branch information
li-jin-gou committed Nov 10, 2023
1 parent 5c0d69e commit 0324b98
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 138 deletions.
121 changes: 50 additions & 71 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,13 @@ import (
"context"
"encoding/base64"
"encoding/binary"
"errors"
"fmt"
"io"
"net/http"
"sync/atomic"
"time"

"github.com/cloudwego/hertz/pkg/network/standard"

"github.com/cloudwego/hertz/pkg/app/client"
"github.com/cloudwego/hertz/pkg/protocol"
"github.com/cloudwego/hertz/pkg/protocol/consts"

"gopkg.in/cenkalti/backoff.v1"
)

var (
Expand All @@ -45,34 +38,31 @@ var (
)

// ConnCallback defines a function to be called on a particular connection event
type ConnCallback func(c *SSEClient)
type ConnCallback func(ctx context.Context, client *Client)

// ResponseValidator validates a response
type ResponseValidator func(c *SSEClient, resp *protocol.Response) error
type ResponseValidator func(ctx context.Context, req *protocol.Request, resp *protocol.Response) error

// SSEClient handles an incoming server stream
type SSEClient struct {
Retry time.Time
ReconnectStrategy backoff.BackOff
// Client handles an incoming server stream
type Client struct {
HertzClient *client.Client
disconnectCallback ConnCallback
connectedCallback ConnCallback
ReconnectNotify backoff.Notify
ResponseValidator ResponseValidator
HertzClient *client.Client
URL string
Headers map[string]string
URL string
Method string
LastEventID atomic.Value // []byte
maxBufferSize int
connected bool
EncodingBase64 bool
Connected bool
LastEventID atomic.Value // []byte
}

var defaultClient, _ = client.NewClient(client.WithResponseBodyStream(true), client.WithDialer(standard.NewDialer()))
var defaultClient, _ = client.NewClient(client.WithResponseBodyStream(true))

// NewClient creates a new client
func NewClient(url string) *SSEClient {
c := &SSEClient{
func NewClient(url string) *Client {
c := &Client{
URL: url,
HertzClient: defaultClient,
Headers: make(map[string]string),
Expand All @@ -84,62 +74,51 @@ func NewClient(url string) *SSEClient {
}

// Subscribe to a data stream
func (c *SSEClient) Subscribe(stream string, handler func(msg *Event)) error {
func (c *Client) Subscribe(stream string, handler func(msg *Event)) error {
return c.SubscribeWithContext(context.Background(), stream, handler)
}

// SubscribeWithContext to a data stream with context
func (c *SSEClient) SubscribeWithContext(ctx context.Context, stream string, handler func(msg *Event)) error {
operation := func() error {
req, resp := protocol.AcquireRequest(), protocol.AcquireResponse()
err := c.request(ctx, req, resp, stream)
func (c *Client) SubscribeWithContext(ctx context.Context, stream string, handler func(msg *Event)) error {
req, resp := protocol.AcquireRequest(), protocol.AcquireResponse()
err := c.request(ctx, req, resp, stream)
if err != nil {
return err
}
defer func() {
protocol.ReleaseRequest(req)
protocol.ReleaseResponse(resp)
}()
if validator := c.ResponseValidator; validator != nil {
err = validator(ctx, req, resp)
if err != nil {
return err
}
defer func() {
protocol.ReleaseRequest(req)
protocol.ReleaseResponse(resp)
}()
if validator := c.ResponseValidator; validator != nil {
err = validator(c, resp)
if err != nil {
return err
}
} else if resp.StatusCode() != 200 {
return fmt.Errorf("could not connect to stream: %s", http.StatusText(resp.StatusCode()))
}
} else if resp.StatusCode() != consts.StatusOK {
return fmt.Errorf("could not connect to stream code: %d", resp.StatusCode())
}

reader := NewEventStreamReader(resp.BodyStream(), c.maxBufferSize)
eventChan, errorChan := c.startReadLoop(ctx, reader)
reader := NewEventStreamReader(resp.BodyStream(), c.maxBufferSize)
eventChan, errorChan := c.startReadLoop(ctx, reader)

for {
select {
case err = <-errorChan:
return err
case msg := <-eventChan:
handler(msg)
}
for {
select {
case err = <-errorChan:
return err
case msg := <-eventChan:
handler(msg)
}
}

// Apply user specified reconnection strategy or default to standard NewExponentialBackOff() reconnection method
var err error
if c.ReconnectStrategy != nil {
err = backoff.RetryNotify(operation, c.ReconnectStrategy, c.ReconnectNotify)
} else {
err = backoff.RetryNotify(operation, backoff.NewExponentialBackOff(), c.ReconnectNotify)
}
return err
}

func (c *SSEClient) startReadLoop(ctx context.Context, reader *EventStreamReader) (chan *Event, chan error) {
func (c *Client) startReadLoop(ctx context.Context, reader *EventStreamReader) (chan *Event, chan error) {
outCh := make(chan *Event)
erChan := make(chan error)
go c.readLoop(ctx, reader, outCh, erChan)
return outCh, erChan
}

func (c *SSEClient) readLoop(ctx context.Context, reader *EventStreamReader, outCh chan *Event, erChan chan error) {
func (c *Client) readLoop(ctx context.Context, reader *EventStreamReader, outCh chan *Event, erChan chan error) {
for {
// Read each new line and process the type of event
event, err := reader.ReadEvent()
Expand All @@ -150,16 +129,16 @@ func (c *SSEClient) readLoop(ctx context.Context, reader *EventStreamReader, out
}
// run user specified disconnect function
if c.disconnectCallback != nil {
c.Connected = false
c.disconnectCallback(c)
c.connected = false
c.disconnectCallback(ctx, c)
}
erChan <- err
return
}

if !c.Connected && c.connectedCallback != nil {
c.Connected = true
c.connectedCallback(c)
if !c.connected && c.connectedCallback != nil {
c.connected = true
c.connectedCallback(ctx, c)
}

// If we get an error, ignore it.
Expand All @@ -180,31 +159,31 @@ func (c *SSEClient) readLoop(ctx context.Context, reader *EventStreamReader, out
}

// SubscribeRaw to an sse endpoint
func (c *SSEClient) SubscribeRaw(handler func(msg *Event)) error {
func (c *Client) SubscribeRaw(handler func(msg *Event)) error {
return c.Subscribe("", handler)
}

// SubscribeRawWithContext to an sse endpoint with context
func (c *SSEClient) SubscribeRawWithContext(ctx context.Context, handler func(msg *Event)) error {
func (c *Client) SubscribeRawWithContext(ctx context.Context, handler func(msg *Event)) error {
return c.SubscribeWithContext(ctx, "", handler)
}

// OnDisconnect specifies the function to run when the connection disconnects
func (c *SSEClient) OnDisconnect(fn ConnCallback) {
func (c *Client) OnDisconnect(fn ConnCallback) {
c.disconnectCallback = fn
}

// OnConnect specifies the function to run when the connection is successful
func (c *SSEClient) OnConnect(fn ConnCallback) {
func (c *Client) OnConnect(fn ConnCallback) {
c.connectedCallback = fn
}

// SetMaxBufferSize set sse client MaxBufferSize
func (c *SSEClient) SetMaxBufferSize(size int) {
func (c *Client) SetMaxBufferSize(size int) {
c.maxBufferSize = size
}

func (c *SSEClient) request(ctx context.Context, req *protocol.Request, resp *protocol.Response, stream string) error {
func (c *Client) request(ctx context.Context, req *protocol.Request, resp *protocol.Response, stream string) error {
req.SetMethod(c.Method)
req.SetRequestURI(c.URL)
// Setup request, specify stream to connect to
Expand All @@ -229,11 +208,11 @@ func (c *SSEClient) request(ctx context.Context, req *protocol.Request, resp *pr
return err
}

func (c *SSEClient) processEvent(msg []byte) (event *Event, err error) {
func (c *Client) processEvent(msg []byte) (event *Event, err error) {
var e Event

if len(msg) < 1 {
return nil, errors.New("event message was empty")
return nil, fmt.Errorf("event message was empty")
}

// Normalize the crlf to lf to make it easier to split the lines.
Expand Down
73 changes: 11 additions & 62 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,10 @@ import (
"testing"
"time"

"github.com/stretchr/testify/assert"
"gopkg.in/cenkalti/backoff.v1"

"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/app/server"
"github.com/cloudwego/hertz/pkg/common/hlog"
"github.com/stretchr/testify/require"
"github.com/cloudwego/hertz/pkg/common/test/assert"
)

var mldata = `{
Expand Down Expand Up @@ -112,22 +109,6 @@ func newMultilineServer(port string) {
h.Spin()
}

//func newServerDisconnect(empty bool, port string) {
// h := server.Default(server.WithHostPorts("0.0.0.0:" + port))
//
// h.GET("/sse", func(ctx context.Context, c *app.RequestContext) {
// // client can tell server last event it received with Last-Event-ID header
// lastEventID := GetLastEventID(c)
// hlog.CtxInfof(ctx, "last event ID: %s", lastEventID)
//
// // you must set status code and response headers before first render call
// c.SetStatusCode(http.StatusOK)
// s := NewStream(c)
// publishMsgs(s, empty, 1000)
// })
// h.Run()
//}

func newServer401(port string) {
h := server.Default(server.WithHostPorts("0.0.0.0:" + port))

Expand Down Expand Up @@ -185,8 +166,8 @@ func TestClientSubscribe(t *testing.T) {

for i := 0; i < 5; i++ {
msg, err := wait(events, time.Second*1)
require.Nil(t, err)
assert.Equal(t, []byte(`ping`), msg)
assert.Nil(t, err)
assert.DeepEqual(t, []byte(`ping`), msg)
}

assert.Nil(t, cErr)
Expand All @@ -211,68 +192,40 @@ func TestClientSubscribeMultiline(t *testing.T) {

for i := 0; i < 5; i++ {
msg, err := wait(events, time.Second*1)
require.Nil(t, err)
assert.Equal(t, []byte(mldata), msg)
assert.Nil(t, err)
assert.DeepEqual(t, []byte(mldata), msg)
}

assert.Nil(t, cErr)
}

//func TestClientOnDisconnect(t *testing.T) {
// go newServerDisconnect(false, "9008")
//
// c := NewClient("http://127.0.0.1:9008/sse")
//
// called := make(chan struct{})
// c.OnDisconnect(func(client *SSEClient) {
// called <- struct{}{}
// })
// file, _ := os.Open("data.txt") // 打开要扫描的文件
// defer file.Close()
//
// scanner := bufio.NewScanner(file) // 创建 Scanner 实例
// go c.startReadLoop(context.Background(), &EventStreamReader{scanner})
//
// time.Sleep(time.Second)
// // c.HertzClient.CloseIdleConnections()
// // server.CloseClientConnections()
//
// assert.Equal(t, struct{}{}, <-called)
//}

func TestClientOnConnect(t *testing.T) {
go newServerOnConnect(false, "9000")
time.Sleep(time.Second)
c := NewClient("http://127.0.0.1:9000/sse")

called := make(chan struct{})
c.OnConnect(func(client *SSEClient) {
c.OnConnect(func(ctx context.Context, client *Client) {
called <- struct{}{}
})

go c.Subscribe("test", func(msg *Event) {})

time.Sleep(time.Second)
assert.Equal(t, struct{}{}, <-called)
assert.DeepEqual(t, struct{}{}, <-called)
}

func TestClientUnsubscribe401(t *testing.T) {
go newServer401("9009")
time.Sleep(time.Second)
c := NewClient("http://127.0.0.1:9009/sse")

// limit retries to 3
c.ReconnectStrategy = backoff.WithMaxTries(
backoff.NewExponentialBackOff(),
3,
)

err := c.SubscribeRaw(func(ev *Event) {
// this shouldn't run
assert.False(t, true)
})

require.NotNil(t, err)
assert.NotNil(t, err)
}

func TestClientLargeData(t *testing.T) {
Expand All @@ -284,10 +237,6 @@ func TestClientLargeData(t *testing.T) {
c := NewClient("http://127.0.0.1:9005/sse")

// limit retries to 3
c.ReconnectStrategy = backoff.WithMaxTries(
backoff.NewExponentialBackOff(),
3,
)

ec := make(chan *Event, 1)

Expand All @@ -298,8 +247,8 @@ func TestClientLargeData(t *testing.T) {
}()

d, err := wait(ec, time.Second)
require.Nil(t, err)
require.Equal(t, data, d)
assert.Nil(t, err)
assert.DeepEqual(t, data, d)
}

func TestTrimHeader(t *testing.T) {
Expand All @@ -323,6 +272,6 @@ func TestTrimHeader(t *testing.T) {

for _, tc := range tests {
got := trimHeader(len(headerData), tc.input)
require.Equal(t, tc.want, got)
assert.DeepEqual(t, tc.want, got)
}
}
Loading

0 comments on commit 0324b98

Please sign in to comment.