diff --git a/sse.go b/sse.go index 12bf601..ab0aeb1 100644 --- a/sse.go +++ b/sse.go @@ -13,7 +13,6 @@ func UpgradeHTTP(r *http.Request, w http.ResponseWriter) (*Stream, error) { // LastEventID returns a last ID known by user. // If it's not presented - empty string will be returnes -// func LastEventID(r *http.Request) string { return r.Header.Get("Last-Event-ID") } @@ -29,7 +28,7 @@ func (u Upgrader) UpgradeHTTP(r *http.Request, w http.ResponseWriter) (*Stream, return nil, ErrNotHijacker } - _, bw, err := hj.Hijack() + nc, bw, err := hj.Hijack() if err != nil { http.Error(w, http.ErrHijacked.Error(), http.StatusInternalServerError) return nil, http.ErrHijacked @@ -41,6 +40,7 @@ func (u Upgrader) UpgradeHTTP(r *http.Request, w http.ResponseWriter) (*Stream, } s := &Stream{ + nc: nc, bw: bw, w: w, } diff --git a/stream.go b/stream.go index 57faac0..b7d2bd1 100644 --- a/stream.go +++ b/stream.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io" + "net" "strconv" "time" ) @@ -12,6 +13,7 @@ import ( type Stream struct { bw *bufio.ReadWriter w io.Writer + nc net.Conn } type BinaryMarshaler interface { @@ -26,9 +28,11 @@ func (s *Stream) Flush() error { return s.bw.Flush() } -// Close sends close event with empth data. +// Close sends close event with empty data and closes underlying connection. func (s *Stream) Close() error { - _, err := s.w.Write([]byte("event:close\ndata:\n\n")) + defer s.nc.Close() + + _, err := s.bw.Write([]byte("event:close\ndata:\n\n")) if err != nil { return err } diff --git a/stream_test.go b/stream_test.go index 9305b14..fc2b449 100644 --- a/stream_test.go +++ b/stream_test.go @@ -3,6 +3,8 @@ package sse import ( "bufio" "bytes" + "io" + "net/http" "testing" "time" ) @@ -38,3 +40,37 @@ func newStream() (*Stream, *bytes.Buffer) { } return s, buf } + +func TestStream_Close(t *testing.T) { + server := newServer(func(w http.ResponseWriter, r *http.Request) { + u := Upgrader{} + + stream, err := u.UpgradeHTTP(r, w) + if err != nil { + t.Fatal(err) + } + + if err := stream.Close(); err != nil { + t.Fatalf("stream.Close() = %v, want nil", err) + } + }) + + client := http.DefaultClient + + resp, err := client.Do(newStreamRequest(server.URL)) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + want := "event:close\ndata:\n\n" + + if got := string(body); got != want { + t.Fatalf("got %#v, want %#v", got, want) + } +}