Skip to content

Commit

Permalink
Merge branch 'master' into dependabot_06-24
Browse files Browse the repository at this point in the history
  • Loading branch information
jmwample authored Jun 18, 2024
2 parents 8ccf21e + 7e7e7d8 commit fd29b5c
Show file tree
Hide file tree
Showing 23 changed files with 703 additions and 112 deletions.
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ app:
[ -d $(EXE_DIR) ] || mkdir -p $(EXE_DIR)
go build -o ${EXE_DIR}/application ./cmd/application

app-dbg:
[ -d $(EXE_DIR) ] || mkdir -p $(EXE_DIR)
go build -tags debug -o ${EXE_DIR}/application ./cmd/application

libtd:
cd ./libtapdance/ && make libtapdance.a

Expand Down
2 changes: 2 additions & 0 deletions cmd/application/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ func main() {
flag.StringVar(&zmqAddress, "zmq-address", "ipc://@zmq-proxy", "Address of ZMQ proxy")
flag.Parse()

startPProf()

// Init stats
cj.Stat()

Expand Down
6 changes: 6 additions & 0 deletions cmd/application/pprof.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
//go:build !debug

package main

func startPProf() {
}
15 changes: 15 additions & 0 deletions cmd/application/pprof_dbg.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
//go:build debug

package main

import (
"log"
"net/http"
_ "net/http/pprof"
)

func startPProf() {
go func() {
log.Println(http.ListenAndServe("localhost:6060", nil))
}()
}
47 changes: 36 additions & 11 deletions pkg/dtls/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"crypto/x509"
"fmt"
"net"
"time"

"github.com/pion/dtls/v2"
"github.com/pion/dtls/v2/pkg/protocol/handshake"
Expand Down Expand Up @@ -33,6 +34,40 @@ func Client(conn net.Conn, config *Config) (net.Conn, error) {

// DialWithContext creates a DTLS connection to the given network address using the given shared secret
func ClientWithContext(ctx context.Context, conn net.Conn, config *Config) (net.Conn, error) {

dtlsConn, err := dtlsCtx(ctx, conn, config)
if err != nil {
return nil, fmt.Errorf("error creating dtls connection: %w", err)
}

ddl, ok := ctx.Deadline()
if ok {
err := conn.SetDeadline(ddl)
if err != nil {
return nil, fmt.Errorf("error setting deadline: %v", err)
}
}

wrappedConn, err := wrapSCTP(dtlsConn, config)
if err != nil {
dtlsConn.Close()
return nil, err
}

err = conn.SetDeadline(time.Time{})
if err != nil {
return nil, fmt.Errorf("error setting deadline: %v", err)
}

err = wrappedConn.SetDeadline(time.Time{})
if err != nil {
return nil, fmt.Errorf("error setting deadline: %v", err)
}

return wrappedConn, nil
}

func dtlsCtx(ctx context.Context, conn net.Conn, config *Config) (net.Conn, error) {
clientCert, serverCert, err := certsFromSeed(config.PSK)

if err != nil {
Expand Down Expand Up @@ -68,16 +103,6 @@ func ClientWithContext(ctx context.Context, conn net.Conn, config *Config) (net.
VerifyPeerCertificate: verifyServerCertificate,
}

dtlsConn, err := dtls.ClientWithContext(ctx, conn, dtlsConf)
return dtls.ClientWithContext(ctx, conn, dtlsConf)

if err != nil {
return nil, fmt.Errorf("error creating dtls connection: %v", err)
}

wrappedConn, err := wrapSCTP(dtlsConn, config)
if err != nil {
return nil, err
}

return wrappedConn, nil
}
28 changes: 28 additions & 0 deletions pkg/dtls/goroutine_leak_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package dtls

import (
"runtime"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func passGoroutineLeak(testFunc func(*testing.T), t *testing.T) bool {
initialGoroutines := runtime.NumGoroutine()

testFunc(t)

time.Sleep(2 * time.Second)

return runtime.NumGoroutine() <= initialGoroutines
}

func TestGoroutineLeak(t *testing.T) {
testFuncs := []func(*testing.T){TestSend, TestServerFail, TestClientFail, TestListenSuccess, TestListenFail, TestFailSCTP}

for _, test := range testFuncs {
require.True(t, passGoroutineLeak(test, t))
}

}
106 changes: 82 additions & 24 deletions pkg/dtls/heartbeat.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package dtls
import (
"bytes"
"errors"
"net"
"sync"
"sync/atomic"
"time"
)
Expand All @@ -14,6 +16,9 @@ const recvChBufSize = 64
type hbConn struct {
stream msgStream

closeOnce sync.Once
closed chan struct{}

recvCh chan errBytes
waiting uint32
hb []byte
Expand All @@ -34,6 +39,7 @@ func heartbeatServer(stream msgStream, config *heartbeatConfig, maxMessageSize i
recvCh: make(chan errBytes, recvChBufSize),
timeout: conf.Interval,
hb: conf.Heartbeat,
closed: make(chan struct{}),
maxMessageSize: maxMessageSize,
}

Expand All @@ -48,12 +54,19 @@ func heartbeatServer(stream msgStream, config *heartbeatConfig, maxMessageSize i
func (c *hbConn) hbLoop() {
for {
if atomic.LoadUint32(&c.waiting) == 0 {
c.stream.Close()
c.Close()
return
}

atomic.StoreUint32(&c.waiting, 0)
time.Sleep(c.timeout)
timer := time.NewTimer(c.timeout)
select {
case <-c.closed:
timer.Stop()
return
case <-timer.C:
continue
}
}

}
Expand All @@ -62,6 +75,12 @@ func (c *hbConn) recvLoop() {
for {
buffer := make([]byte, c.maxMessageSize)

err := c.stream.SetReadDeadline(time.Now().Add(c.timeout))
if err != nil {
c.Close()
return
}

n, err := c.stream.Read(buffer)

if bytes.Equal(c.hb, buffer[:n]) {
Expand All @@ -70,15 +89,25 @@ func (c *hbConn) recvLoop() {
}

if err != nil {
c.recvCh <- errBytes{nil, err}
c.Close()
return
}

c.recvCh <- errBytes{buffer[:n], err}
timer := time.NewTimer(c.timeout)
select {
case c.recvCh <- errBytes{buffer[:n], err}:
timer.Stop()
continue
case <-timer.C:
c.Close()
return
}
}

}

func (c *hbConn) Close() error {
c.closeOnce.Do(func() { close(c.closed) })
return c.stream.Close()
}

Expand All @@ -87,18 +116,22 @@ func (c *hbConn) Write(b []byte) (n int, err error) {
}

func (c *hbConn) Read(b []byte) (int, error) {
readBytes := <-c.recvCh
if readBytes.err != nil {
return 0, readBytes.err
}
select {
case <-c.closed:
return 0, net.ErrClosed
case readBytes := <-c.recvCh:
if readBytes.err != nil {
return 0, readBytes.err
}

if len(b) < len(readBytes.b) {
return 0, ErrInsufficientBuffer
}
if len(b) < len(readBytes.b) {
return 0, ErrInsufficientBuffer
}

n := copy(b, readBytes.b)
n := copy(b, readBytes.b)

return n, nil
return n, nil
}
}

func (c *hbConn) BufferedAmount() uint64 {
Expand All @@ -117,19 +150,44 @@ func (c *hbConn) OnBufferedAmountLow(f func()) {
c.stream.OnBufferedAmountLow(f)
}

type hbClient struct {
msgStream
conf heartbeatConfig

closeOnce sync.Once
closed chan struct{}
}

// heartbeatClient sends heartbeats over conn with config
func heartbeatClient(conn msgStream, config *heartbeatConfig) error {
func heartbeatClient(conn msgStream, config *heartbeatConfig) (msgStream, error) {
conf := validate(config)
go func() {
for {
_, err := conn.Write(conf.Heartbeat)
if err != nil {
return
}

time.Sleep(conf.Interval / 2)
client := &hbClient{msgStream: conn,
conf: conf,
closed: make(chan struct{}),
}
go client.sendLoop()
return client, nil
}

func (c *hbClient) sendLoop() {
for {
_, err := c.Write(c.conf.Heartbeat)
if err != nil {
return
}

timer := time.NewTimer(c.conf.Interval / 2)
select {
case <-c.closed:
timer.Stop()
return
case <-timer.C:
continue
}
}
}

}()
return nil
func (c *hbClient) Close() error {
c.closeOnce.Do(func() { close(c.closed) })
return c.msgStream.Close()
}
14 changes: 7 additions & 7 deletions pkg/dtls/heartbeat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func TestHeartbeatReadWrite(t *testing.T) {
s, err := heartbeatServer(server, conf, maxMsgSize)
require.Nil(t, err)

err = heartbeatClient(client, conf)
c, err := heartbeatClient(client, conf)
require.Nil(t, err)

recvd := 0
Expand All @@ -140,7 +140,7 @@ func TestHeartbeatReadWrite(t *testing.T) {
wg.Add(1)
go func(ctx1 context.Context) {
defer wg.Done()
defer client.Close()
defer c.Close()
defer server.Close()
for i := 0; i < sendTimes; i++ {
select {
Expand Down Expand Up @@ -174,12 +174,12 @@ func TestHeartbeatReadWrite(t *testing.T) {
for i := 0; i < sendTimes; i++ {
select {
case <-ctx2.Done():
client.Close()
c.Close()
return
default:
err := server.SetWriteDeadline(time.Now().Add(sleepInterval * 2))
require.Nil(t, err)
_, err = client.Write(toSend)
_, err = c.Write(toSend)
if err != nil {
if !errors.Is(err, net.ErrClosed) {
t.Log("encountered error writing", err)
Expand Down Expand Up @@ -214,7 +214,7 @@ func TestHeartbeatSend(t *testing.T) {
}
}()

err := heartbeatClient(client, conf)
_, err := heartbeatClient(client, conf)
require.Nil(t, err)

duration := 2
Expand Down Expand Up @@ -268,7 +268,7 @@ func TestHeartbeatInsufficientBuf(t *testing.T) {
s, err := heartbeatServer(server, conf, maxMsgSize)
require.Nil(t, err)

err = heartbeatClient(client, conf)
c, err := heartbeatClient(client, conf)
require.Nil(t, err)

toSend := []byte("testtt")
Expand All @@ -282,7 +282,7 @@ func TestHeartbeatInsufficientBuf(t *testing.T) {
require.ErrorIs(t, err, ErrInsufficientBuffer)
}()

_, err = client.Write(toSend)
_, err = c.Write(toSend)
require.Nil(t, err)

wg.Wait()
Expand Down
Loading

0 comments on commit fd29b5c

Please sign in to comment.