Skip to content

Commit

Permalink
add support for datagrams (#142)
Browse files Browse the repository at this point in the history
* add support for datagrams

* add a test case
  • Loading branch information
marten-seemann authored Apr 27, 2024
1 parent 375a5dc commit a08801e
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 11 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ require (
github.com/quic-go/quic-go v0.43.0
github.com/stretchr/testify v1.8.0
go.uber.org/mock v0.4.0
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63
)

require (
Expand All @@ -19,7 +20,6 @@ require (
github.com/quic-go/qpack v0.4.0 // indirect
github.com/rogpeppe/go-internal v1.10.0 // indirect
golang.org/x/crypto v0.14.0 // indirect
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 // indirect
golang.org/x/mod v0.12.0 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/sys v0.15.0 // indirect
Expand Down
33 changes: 31 additions & 2 deletions mock_stream_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 12 additions & 4 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func (q *acceptQueue[T]) Chan() <-chan struct{} { return q.c }
type Session struct {
sessionID sessionID
qconn http3.Connection
requestStr quic.Stream
requestStr http3.Stream

streamHdr []byte
uniStreamHdr []byte
Expand All @@ -82,7 +82,7 @@ type Session struct {
streams streamsMap
}

func newSession(sessionID sessionID, qconn http3.Connection, requestStr quic.Stream) *Session {
func newSession(sessionID sessionID, qconn http3.Connection, requestStr http3.Stream) *Session {
tracingID := qconn.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)
ctx, ctxCancel := context.WithCancel(context.WithValue(context.Background(), quic.ConnectionTracingKey, tracingID))
c := &Session{
Expand Down Expand Up @@ -390,6 +390,14 @@ func (s *Session) CloseWithError(code SessionErrorCode, msg string) error {
return err
}

func (s *Session) SendDatagram(b []byte) error {
return s.requestStr.SendDatagram(b)
}

func (s *Session) ReceiveDatagram(ctx context.Context) ([]byte, error) {
return s.requestStr.ReceiveDatagram(ctx)
}

func (s *Session) closeWithError(code SessionErrorCode, msg string) (bool /* first call to close session */, error) {
s.closeMx.Lock()
defer s.closeMx.Unlock()
Expand All @@ -413,6 +421,6 @@ func (s *Session) closeWithError(code SessionErrorCode, msg string) (bool /* fir
)
}

func (c *Session) ConnectionState() quic.ConnectionState {
return c.qconn.ConnectionState()
func (s *Session) ConnectionState() quic.ConnectionState {
return s.qconn.ConnectionState()
}
2 changes: 1 addition & 1 deletion session_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ func (m *sessionManager) handleUniStream(str quic.ReceiveStream, sess *session)
}

// AddSession adds a new WebTransport session.
func (m *sessionManager) AddSession(qconn http3.Connection, id sessionID, requestStr quic.Stream) *Session {
func (m *sessionManager) AddSession(qconn http3.Connection, id sessionID, requestStr http3.Stream) *Session {
conn := newSession(id, qconn, requestStr)
connTracingID := qconn.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)

Expand Down
5 changes: 3 additions & 2 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,21 @@ import (
"time"

"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"

"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
)

//go:generate sh -c "go run go.uber.org/mock/mockgen -package webtransport -destination mock_connection_test.go github.com/quic-go/quic-go/http3 Connection && cat mock_connection_test.go | sed s@qerr\\.ApplicationErrorCode@quic.ApplicationErrorCode@g > tmp.go && mv tmp.go mock_connection_test.go && goimports -w mock_connection_test.go"
//go:generate sh -c "go run go.uber.org/mock/mockgen -package webtransport -destination mock_stream_test.go github.com/quic-go/quic-go Stream && cat mock_stream_test.go | sed s@protocol\\.StreamID@quic.StreamID@g | sed s@qerr\\.StreamErrorCode@quic.StreamErrorCode@g > tmp.go && mv tmp.go mock_stream_test.go && goimports -w mock_stream_test.go"
//go:generate sh -c "go run go.uber.org/mock/mockgen -package webtransport -destination mock_stream_test.go github.com/quic-go/quic-go/http3 Stream && cat mock_stream_test.go | sed s@protocol\\.StreamID@quic.StreamID@g | sed s@qerr\\.StreamErrorCode@quic.StreamErrorCode@g > tmp.go && mv tmp.go mock_stream_test.go && goimports -w mock_stream_test.go"

type mockRequestStream struct {
*MockStream
c chan struct{}
}

func newMockRequestStream(ctrl *gomock.Controller) quic.Stream {
func newMockRequestStream(ctrl *gomock.Controller) http3.Stream {
str := NewMockStream(ctrl)
str.EXPECT().Close()
str.EXPECT().CancelRead(gomock.Any())
Expand Down
57 changes: 56 additions & 1 deletion webtransport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package webtransport_test

import (
"context"
"crypto/rand"
"crypto/tls"
"errors"
"fmt"
Expand All @@ -15,6 +14,8 @@ import (
"testing"
"time"

"golang.org/x/exp/rand"

"github.com/quic-go/webtransport-go"

"github.com/quic-go/quic-go"
Expand Down Expand Up @@ -595,3 +596,57 @@ func TestWriteCloseRace(t *testing.T) {
<-ready
close(ch)
}

func TestDatagrams(t *testing.T) {
const num = 100
var mx sync.Mutex
m := make(map[string]bool, num)

var counter int
done := make(chan struct{})
serverErrChan := make(chan error, 1)
sess, closeServer := establishSession(t, func(sess *webtransport.Session) {
defer close(done)
for {
b, err := sess.ReceiveDatagram(context.Background())
if err != nil {
return
}
mx.Lock()
if _, ok := m[string(b)]; !ok {
serverErrChan <- errors.New("received unexpected datagram")
return
}
m[string(b)] = true
mx.Unlock()
counter++
}
})
defer closeServer()

errChan := make(chan error, 1)

for i := 0; i < num; i++ {
b := make([]byte, 800)
rand.Read(b)
mx.Lock()
m[string(b)] = false
mx.Unlock()
if err := sess.SendDatagram(b); err != nil {
break
}
}
time.Sleep(scaleDuration(10 * time.Millisecond))
sess.CloseWithError(0, "")
select {
case err := <-serverErrChan:
t.Fatal(err)
case err := <-errChan:
t.Fatal(err)
case <-done:
t.Logf("sent: %d, received: %d", num, counter)
require.Greater(t, counter, num*4/5)
case <-time.After(5 * time.Second):
t.Fatal("timeout")
}
}

0 comments on commit a08801e

Please sign in to comment.