diff --git a/CHANGELOG.md b/CHANGELOG.md index bfb46ad..d1cde4f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,20 @@ to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +## [v0.6.0] - 2022-08-07 (Beta) + +## Changes + +- **[BREAKING]** The `server` now concurrently process incoming packets from connections by calling handler functions in a goroutine. + This is done to avoid blocking the main packet processing loop when the handler for an incoming packet is slow. +- The `UPDATE` Action has been completely removed from the `server` and the `client` - the context can no longer be + updated from a handler function. +- The `SetConcurrency` function has been added to the `server` to set the concurrency of the packet processing + goroutines. +- `io/ioutil.Discard` has been replaced with `io.Discard` because it was being deprecated +- The `README.md` file has been updated to reflect the `frisbee-go` package name, and better direct users to the `frpc.io` website. +- @jimmyaxod has been added as a maintainer for the `frisbee-go` package. + ## [v0.5.4] - 2022-07-28 (Beta) ## Features @@ -300,7 +314,8 @@ to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). Initial Release of Frisbee -[unreleased]: https://github.com/loopholelabs/frisbee/compare/v0.5.4...HEAD +[unreleased]: https://github.com/loopholelabs/frisbee/compare/v0.6.0...HEAD +[v0.6.0]: https://github.com/loopholelabs/frisbee/compare/v0.5.4...v0.6.0 [v0.5.4]: https://github.com/loopholelabs/frisbee/compare/v0.5.3...v0.5.4 [v0.5.3]: https://github.com/loopholelabs/frisbee/compare/v0.5.2...v0.5.3 [v0.5.2]: https://github.com/loopholelabs/frisbee/compare/v0.5.1...v0.5.2 diff --git a/MAINTAINERS.md b/MAINTAINERS.md index 64f52b5..831258a 100644 --- a/MAINTAINERS.md +++ b/MAINTAINERS.md @@ -1,3 +1,4 @@ - Shivansh Vij @shivanshvij - Alex Sørlie Glomsaas @supermanifolds - Felicitas Pojtinger @pojntfx +- Jimmy Moore @jimmyaxod diff --git a/README.md b/README.md index d2c063e..f87cc60 100644 --- a/README.md +++ b/README.md @@ -21,8 +21,8 @@ same is true for selected other new features explicitly marked as ## Usage and Documentation -Usage instructions and documentation for Frisbee is available -at [https://frpc.io/frisbee](https://frpc.io/frisbee). The Frisbee framework also has great +Usage instructions and documentation for `frisbee-go` are available +at [https://frpc.io/frisbee](https://frpc.io/frisbee). This library also has great documentation coverage using [GoDoc](https://godoc.org/github.com/loopholelabs/frisbee-go). ## Contributing diff --git a/async_test.go b/async_test.go index 7662b72..6437538 100644 --- a/async_test.go +++ b/async_test.go @@ -25,7 +25,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "io" - "io/ioutil" "net" "runtime" "sync" @@ -38,7 +37,7 @@ func TestNewAsync(t *testing.T) { const packetSize = 512 - emptyLogger := zerolog.New(ioutil.Discard) + emptyLogger := zerolog.New(io.Discard) reader, writer := net.Pipe() @@ -95,7 +94,7 @@ func TestAsyncLargeWrite(t *testing.T) { const testSize = 100000 const packetSize = 512 - emptyLogger := zerolog.New(ioutil.Discard) + emptyLogger := zerolog.New(io.Discard) reader, writer := net.Pipe() @@ -142,7 +141,7 @@ func TestAsyncRawConn(t *testing.T) { const testSize = 100000 const packetSize = 32 - emptyLogger := zerolog.New(ioutil.Discard) + emptyLogger := zerolog.New(io.Discard) reader, writer, err := pair.New() require.NoError(t, err) @@ -204,7 +203,7 @@ func TestAsyncReadClose(t *testing.T) { reader, writer := net.Pipe() - emptyLogger := zerolog.New(ioutil.Discard) + emptyLogger := zerolog.New(io.Discard) readerConn := NewAsync(reader, &emptyLogger) writerConn := NewAsync(writer, &emptyLogger) @@ -252,7 +251,7 @@ func TestAsyncReadAvailableClose(t *testing.T) { reader, writer := net.Pipe() - emptyLogger := zerolog.New(ioutil.Discard) + emptyLogger := zerolog.New(io.Discard) readerConn := NewAsync(reader, &emptyLogger) writerConn := NewAsync(writer, &emptyLogger) @@ -302,7 +301,7 @@ func TestAsyncWriteClose(t *testing.T) { reader, writer := net.Pipe() - emptyLogger := zerolog.New(ioutil.Discard) + emptyLogger := zerolog.New(io.Discard) readerConn := NewAsync(reader, &emptyLogger) writerConn := NewAsync(writer, &emptyLogger) @@ -352,7 +351,7 @@ func TestAsyncWriteClose(t *testing.T) { func TestAsyncTimeout(t *testing.T) { t.Parallel() - emptyLogger := zerolog.New(ioutil.Discard) + emptyLogger := zerolog.New(io.Discard) reader, writer, err := pair.New() require.NoError(t, err) @@ -380,7 +379,7 @@ func TestAsyncTimeout(t *testing.T) { assert.Equal(t, uint32(0), p.Metadata.ContentLength) assert.Equal(t, 0, len(*p.Content)) - time.Sleep(DefaultDeadline * 5) + time.Sleep(DefaultDeadline * 2) err = writerConn.Error() require.NoError(t, err) @@ -400,7 +399,7 @@ func TestAsyncTimeout(t *testing.T) { require.NoError(t, err) runtime.Gosched() - time.Sleep(DefaultDeadline * 5) + time.Sleep(DefaultDeadline * 2) runtime.Gosched() p, err = readerConn.ReadPacket() @@ -417,7 +416,7 @@ func TestAsyncTimeout(t *testing.T) { err = readerConn.Error() if err == nil { runtime.Gosched() - time.Sleep(DefaultDeadline * 10) + time.Sleep(DefaultDeadline * 3) runtime.Gosched() } require.Error(t, readerConn.Error()) @@ -431,7 +430,7 @@ func TestAsyncTimeout(t *testing.T) { func BenchmarkAsyncThroughputPipe(b *testing.B) { const testSize = 100 - emptyLogger := zerolog.New(ioutil.Discard) + emptyLogger := zerolog.New(io.Discard) reader, writer := net.Pipe() @@ -451,7 +450,7 @@ func BenchmarkAsyncThroughputPipe(b *testing.B) { func BenchmarkAsyncThroughputNetwork(b *testing.B) { const testSize = 100 - emptyLogger := zerolog.New(ioutil.Discard) + emptyLogger := zerolog.New(io.Discard) reader, writer, err := pair.New() if err != nil { @@ -530,7 +529,7 @@ func BenchmarkAsyncThroughputNetworkMultiple(b *testing.B) { b.ReportAllocs() for i := 0; i < numClients; i++ { go func() { - emptyLogger := zerolog.New(ioutil.Discard) + emptyLogger := zerolog.New(io.Discard) reader, writer, err := pair.New() if err != nil { diff --git a/client.go b/client.go index 9a014a2..ea63224 100644 --- a/client.go +++ b/client.go @@ -179,33 +179,44 @@ func (c *Client) handleConn() { var action Action var err error var handlerFunc Handler -LOOP: - if c.closed.Load() { - c.wg.Done() - return - } - p, err = c.conn.ReadPacket() - if err != nil { - c.Logger().Debug().Err(err).Msg("error while getting packet frisbee connection") - c.wg.Done() - _ = c.Close() - return - } - handlerFunc = c.handlerTable[p.Metadata.Operation] - if handlerFunc != nil { - packetCtx := c.ctx - if c.PacketContext != nil { - packetCtx = c.PacketContext(packetCtx, p) + for { + if c.closed.Load() { + c.wg.Done() + return } - outgoing, action = handlerFunc(packetCtx, p) - if outgoing != nil && outgoing.Metadata.ContentLength == uint32(len(*outgoing.Content)) { - err = c.conn.WritePacket(outgoing) - if outgoing != p { - packet.Put(outgoing) + p, err = c.conn.ReadPacket() + if err != nil { + c.Logger().Debug().Err(err).Msg("error while getting packet frisbee connection") + c.wg.Done() + _ = c.Close() + return + } + handlerFunc = c.handlerTable[p.Metadata.Operation] + if handlerFunc != nil { + packetCtx := c.ctx + if c.PacketContext != nil { + packetCtx = c.PacketContext(packetCtx, p) } - packet.Put(p) - if err != nil { - c.Logger().Error().Err(err).Msg("error while writing to frisbee conn") + outgoing, action = handlerFunc(packetCtx, p) + if outgoing != nil && outgoing.Metadata.ContentLength == uint32(len(*outgoing.Content)) { + err = c.conn.WritePacket(outgoing) + if outgoing != p { + packet.Put(outgoing) + } + packet.Put(p) + if err != nil { + c.Logger().Error().Err(err).Msg("error while writing to frisbee conn") + c.wg.Done() + _ = c.Close() + return + } + } else { + packet.Put(p) + } + switch action { + case NONE: + case CLOSE: + c.Logger().Debug().Msgf("Closing connection %s because of CLOSE action", c.conn.RemoteAddr()) c.wg.Done() _ = c.Close() return @@ -213,22 +224,7 @@ LOOP: } else { packet.Put(p) } - switch action { - case NONE: - case UPDATE: - if c.UpdateContext != nil { - c.ctx = c.UpdateContext(c.ctx, c.conn) - } - case CLOSE: - c.Logger().Debug().Msgf("Closing connection %s because of CLOSE action", c.conn.RemoteAddr()) - c.wg.Done() - _ = c.Close() - return - } - } else { - packet.Put(p) } - goto LOOP } func (c *Client) heartbeat() { diff --git a/client_test.go b/client_test.go index ec8e1d5..e07f7b9 100644 --- a/client_test.go +++ b/client_test.go @@ -25,7 +25,7 @@ import ( "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "io/ioutil" + "io" "net" "testing" ) @@ -63,10 +63,12 @@ func TestClientRaw(t *testing.T) { return } - emptyLogger := zerolog.New(ioutil.Discard) + emptyLogger := zerolog.New(io.Discard) s, err := NewServer(serverHandlerTable, WithLogger(&emptyLogger)) require.NoError(t, err) + s.SetConcurrency(1) + s.ConnContext = func(ctx context.Context, c *Async) context.Context { return context.WithValue(ctx, clientConnContextKey, c) } @@ -156,10 +158,12 @@ func TestClientStaleClose(t *testing.T) { return } - emptyLogger := zerolog.New(ioutil.Discard) + emptyLogger := zerolog.New(io.Discard) s, err := NewServer(serverHandlerTable, WithLogger(&emptyLogger)) require.NoError(t, err) + s.SetConcurrency(1) + serverConn, clientConn, err := pair.New() require.NoError(t, err) @@ -214,12 +218,14 @@ func BenchmarkThroughputClient(b *testing.B) { return } - emptyLogger := zerolog.New(ioutil.Discard) + emptyLogger := zerolog.New(io.Discard) s, err := NewServer(serverHandlerTable, WithLogger(&emptyLogger)) if err != nil { b.Fatal(err) } + s.SetConcurrency(1) + serverConn, clientConn, err := pair.New() if err != nil { b.Fatal(err) @@ -297,12 +303,14 @@ func BenchmarkThroughputResponseClient(b *testing.B) { return } - emptyLogger := zerolog.New(ioutil.Discard) + emptyLogger := zerolog.New(io.Discard) s, err := NewServer(serverHandlerTable, WithLogger(&emptyLogger)) if err != nil { b.Fatal(err) } + s.SetConcurrency(1) + serverConn, clientConn, err := pair.New() if err != nil { b.Fatal(err) diff --git a/conn.go b/conn.go index 20ee837..a3f4ca4 100644 --- a/conn.go +++ b/conn.go @@ -22,7 +22,7 @@ import ( "github.com/loopholelabs/frisbee-go/pkg/packet" "github.com/pkg/errors" "github.com/rs/zerolog" - "io/ioutil" + "io" "net" "time" ) @@ -31,7 +31,7 @@ import ( const DefaultBufferSize = 1 << 16 var ( - defaultLogger = zerolog.New(ioutil.Discard) + defaultLogger = zerolog.New(io.Discard) DefaultDeadline = time.Second * 5 diff --git a/doc.go b/doc.go index 5edb565..fceac11 100644 --- a/doc.go +++ b/doc.go @@ -24,60 +24,60 @@ // // In depth documentation and examples can be found at https://loopholelabs.io/docs/frisbee // -// -// An Echo Example +// # An Echo Example // // As a starting point, a very basic echo server: // -// package main +// package main // -// import ( -// "github.com/loopholelabs/frisbee" -// "github.com/loopholelabs/frisbee/pkg/packet" -// "github.com/rs/zerolog/log" -// "os" -// "os/signal" -// ) +// import ( +// "github.com/loopholelabs/frisbee-go" +// "github.com/loopholelabs/frisbee-go/pkg/packet" +// "github.com/rs/zerolog/log" +// "os" +// "os/signal" +// ) // -// const PING = uint16(10) -// const PONG = uint16(11) +// const PING = uint16(10) +// const PONG = uint16(11) // -// func handlePing(_ *frisbee.Async, incoming *packet.Packet) (outgoing *packet.Packet, action frisbee.Action) { -// if incoming.Metadata.ContentLength > 0 { -// log.Printf("Server Received Metadata: %s\n", incoming.Content) -// incoming.Metadata.Operation = PONG -// outgoing = incoming -// } +// func handlePing(_ *frisbee.Async, incoming *packet.Packet) (outgoing *packet.Packet, action frisbee.Action) { +// if incoming.Metadata.ContentLength > 0 { +// log.Printf("Server Received Metadata: %s\n", incoming.Content) +// incoming.Metadata.Operation = PONG +// outgoing = incoming +// } // -// return -// } +// return +// } // -// func main() { -// handlerTable := make(frisbee.ServerRouter) -// handlerTable[PING] = handlePing -// exit := make(chan os.Signal) -// signal.Notify(exit, os.Interrupt) +// func main() { +// handlerTable := make(frisbee.ServerRouter) +// handlerTable[PING] = handlePing +// exit := make(chan os.Signal) +// signal.Notify(exit, os.Interrupt) // -// s := frisbee.NewServer(":8192", handlerTable, 0) -// err := s.Start() -// if err != nil { -// panic(err) -// } +// s := frisbee.NewServer(":8192", handlerTable, 0) +// err := s.Start() +// if err != nil { +// panic(err) +// } // -// <-exit -// err = s.Shutdown() -// if err != nil { -// panic(err) +// <-exit +// err = s.Shutdown() +// if err != nil { +// panic(err) +// } // } -// } // // And an accompanying echo client: +// // package main // // import ( // "fmt" -// "github.com/loopholelabs/frisbee" -// "github.com/loopholelabs/frisbee/pkg/packet" +// "github.com/loopholelabs/frisbee-go" +// "github.com/loopholelabs/frisbee-go/pkg/packet" // "github.com/rs/zerolog/log" // "os" // "os/signal" diff --git a/frisbee.go b/frisbee.go index dde0ef4..4de6ba2 100644 --- a/frisbee.go +++ b/frisbee.go @@ -31,7 +31,7 @@ var ( ConnectionClosed = errors.New("connection closed") ConnectionNotInitialized = errors.New("connection not initialized") InvalidBufferLength = errors.New("invalid buffer length") - InvalidHandlerTable = errors.New("invalid handlePacket table configuration, a reserved value may have been used") + InvalidHandlerTable = errors.New("invalid handler table configuration, a reserved value may have been used") ) // Action is an ENUM used to modify the state of the client or server from a Handler function @@ -46,9 +46,6 @@ const ( // NONE is used to do nothing (default) NONE = Action(iota) - // UPDATE is used to trigger an UpdateContext call on the Server or Client - UPDATE - // CLOSE is used to close the frisbee connection CLOSE ) diff --git a/go.mod b/go.mod index 4a9658e..ac495dc 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/loopholelabs/frisbee-go go 1.18 require ( - github.com/loopholelabs/common v0.2.0 + github.com/loopholelabs/common v0.4.4 github.com/loopholelabs/polyglot-go v0.3.0 github.com/loopholelabs/testing v0.2.3 github.com/pkg/errors v0.9.1 diff --git a/go.sum b/go.sum index 2de7cce..a1f06a3 100644 --- a/go.sum +++ b/go.sum @@ -8,9 +8,8 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/loopholelabs/common v0.2.0 h1:irIV5qsMK2ghO7Bu1XSXnN/6rtP5mazSE1WSBDhi/eg= -github.com/loopholelabs/common v0.2.0/go.mod h1:hB8T0eBGDcZ4cvBfDoD3yUs6bUzeZed9VqhzN4lI7nU= -github.com/loopholelabs/frisbee v0.5.0 h1:wB/PggORrg3IxqppNsSSvRdKG5w3F1A8tRxQhPVvmCI= +github.com/loopholelabs/common v0.4.4 h1:Ge+1v1WiLYgR/4pziOQoJAwUqUm1c9j6nQvnkiFFBsk= +github.com/loopholelabs/common v0.4.4/go.mod h1:YKnljczr4jgxkHhhAwIHh3CJXaff89YBd8Vp3pwpG3k= github.com/loopholelabs/polyglot-go v0.3.0 h1:iOqPw5B3krCMYfgDaPgPoh1A87ACE8lKdbpbERM58pY= github.com/loopholelabs/polyglot-go v0.3.0/go.mod h1:9/Hr1nFO9Al46806vMP3DB2k8blQ3gazBPaoOsdgo34= github.com/loopholelabs/testing v0.2.3 h1:4nVuK5ctaE6ua5Z0dYk2l7xTFmcpCYLUeGjRBp8keOA= diff --git a/options.go b/options.go index a4394bf..57fc549 100644 --- a/options.go +++ b/options.go @@ -19,7 +19,7 @@ package frisbee import ( "crypto/tls" "github.com/rs/zerolog" - "io/ioutil" + "io" "time" ) @@ -27,11 +27,12 @@ import ( type Option func(opts *Options) // DefaultLogger is the default logger used within frisbee -var DefaultLogger = zerolog.New(ioutil.Discard) +var DefaultLogger = zerolog.New(io.Discard) // Options is used to provide the frisbee client and server with configuration options. // // Default Values: +// // options := Options { // KeepAlive: time.Minute * 3, // Logger: &DefaultLogger, diff --git a/options_test.go b/options_test.go index fcf4051..14f806f 100644 --- a/options_test.go +++ b/options_test.go @@ -20,7 +20,7 @@ import ( "crypto/tls" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" - "io/ioutil" + "io" "testing" "time" ) @@ -73,7 +73,7 @@ func TestDisableOptions(t *testing.T) { func TestIndividualOptions(t *testing.T) { t.Parallel() - logger := zerolog.New(ioutil.Discard) + logger := zerolog.New(io.Discard) tlsConfig := &tls.Config{ InsecureSkipVerify: true, } diff --git a/server.go b/server.go index 6d13b45..8bb5f2a 100644 --- a/server.go +++ b/server.go @@ -53,6 +53,8 @@ type Server struct { connections map[*Async]struct{} connectionsMu sync.Mutex startedCh chan struct{} + concurrency uint64 + limiter chan struct{} // baseContext is used to define the base context for this Server and all incoming connections baseContext func() context.Context @@ -148,6 +150,21 @@ func (s *Server) GetHandlerTable() HandlerTable { return s.handlerTable } +// SetConcurrency sets the maximum number of concurrent goroutines that will be created +// by the server to handle incoming packets. +// +// An important caveat of this is that handlers must always thread-safe if they share resources +// between connections. If the concurrency is set to a value != 1, then the handlers +// must also be thread-safe if they share resources per connection. +// +// This function should not be called once the server has started. +func (s *Server) SetConcurrency(concurrency uint64) { + s.concurrency = concurrency + if s.concurrency > 1 { + s.limiter = make(chan struct{}, s.concurrency) + } +} + // Start will start the frisbee server and its reactor goroutines // to receive and handle incoming connections. If the baseContext, ConnContext, // onClosed, OnShutdown, or preWrite functions have not been defined, it will @@ -205,70 +222,193 @@ func (s *Server) handleListener() error { } backoff = 0 - s.wg.Add(1) - go func() { - s.ServeConn(newConn) - s.wg.Done() - }() + s.ServeConn(newConn) } } -func (s *Server) handlePacket(frisbeeConn *Async, connCtx context.Context) (err error) { +func (s *Server) handleSinglePacket(frisbeeConn *Async, connCtx context.Context) { var p *packet.Packet var outgoing *packet.Packet var action Action var handlerFunc Handler + var err error p, err = frisbeeConn.ReadPacket() if err != nil { + _ = frisbeeConn.Close() + s.onClosed(frisbeeConn, err) return } if s.ConnContext != nil { connCtx = s.ConnContext(connCtx, frisbeeConn) } - goto HANDLE -LOOP: - p, err = frisbeeConn.ReadPacket() + for { + handlerFunc = s.handlerTable[p.Metadata.Operation] + if handlerFunc != nil { + packetCtx := connCtx + if s.PacketContext != nil { + packetCtx = s.PacketContext(packetCtx, p) + } + outgoing, action = handlerFunc(packetCtx, p) + if outgoing != nil && outgoing.Metadata.ContentLength == uint32(len(*outgoing.Content)) { + s.preWrite() + err = frisbeeConn.WritePacket(outgoing) + if outgoing != p { + packet.Put(outgoing) + } + packet.Put(p) + if err != nil { + _ = frisbeeConn.Close() + s.onClosed(frisbeeConn, err) + return + } + } else { + packet.Put(p) + } + switch action { + case NONE: + case CLOSE: + _ = frisbeeConn.Close() + s.onClosed(frisbeeConn, nil) + return + } + } else { + packet.Put(p) + } + p, err = frisbeeConn.ReadPacket() + if err != nil { + _ = frisbeeConn.Close() + s.onClosed(frisbeeConn, err) + return + } + } +} + +func (s *Server) handler(conn *Async, closed *atomic.Bool, wg *sync.WaitGroup, ctx context.Context, cancel context.CancelFunc) func(*packet.Packet) { + return func(p *packet.Packet) { + handlerFunc := s.handlerTable[p.Metadata.Operation] + if handlerFunc != nil { + packetCtx := ctx + if s.PacketContext != nil { + packetCtx = s.PacketContext(packetCtx, p) + } + outgoing, action := handlerFunc(packetCtx, p) + if outgoing != nil && outgoing.Metadata.ContentLength == uint32(len(*outgoing.Content)) { + s.preWrite() + err := conn.WritePacket(outgoing) + if outgoing != p { + packet.Put(outgoing) + } + packet.Put(p) + if err != nil { + _ = conn.Close() + if closed.CAS(false, true) { + s.onClosed(conn, err) + } + cancel() + wg.Done() + return + } + } else { + packet.Put(p) + } + switch action { + case NONE: + case CLOSE: + _ = conn.Close() + if closed.CAS(false, true) { + s.onClosed(conn, nil) + } + cancel() + } + } else { + packet.Put(p) + } + wg.Done() + } +} + +func (s *Server) handleUnlimitedPacket(frisbeeConn *Async, connCtx context.Context) { + p, err := frisbeeConn.ReadPacket() if err != nil { + _ = frisbeeConn.Close() + s.onClosed(frisbeeConn, err) return } -HANDLE: - handlerFunc = s.handlerTable[p.Metadata.Operation] - if handlerFunc != nil { - packetCtx := connCtx - if s.PacketContext != nil { - packetCtx = s.PacketContext(packetCtx, p) - } - outgoing, action = handlerFunc(packetCtx, p) - if outgoing != nil && outgoing.Metadata.ContentLength == uint32(len(*outgoing.Content)) { - s.preWrite() - err = frisbeeConn.WritePacket(outgoing) - if outgoing != p { - packet.Put(outgoing) + if s.ConnContext != nil { + connCtx = s.ConnContext(connCtx, frisbeeConn) + } + wg := new(sync.WaitGroup) + closed := atomic.NewBool(false) + connCtx, cancel := context.WithCancel(connCtx) + handler := s.handler(frisbeeConn, closed, wg, connCtx, cancel) + for { + wg.Add(1) + go handler(p) + p, err = frisbeeConn.ReadPacket() + if err != nil { + _ = frisbeeConn.Close() + if closed.CAS(false, true) { + s.onClosed(frisbeeConn, err) } - packet.Put(p) + cancel() + wg.Wait() + return + } + } +} + +func (s *Server) handleLimitedPacket(frisbeeConn *Async, connCtx context.Context) { + p, err := frisbeeConn.ReadPacket() + if err != nil { + _ = frisbeeConn.Close() + s.onClosed(frisbeeConn, err) + } + if s.ConnContext != nil { + connCtx = s.ConnContext(connCtx, frisbeeConn) + } + wg := new(sync.WaitGroup) + closed := atomic.NewBool(false) + connCtx, cancel := context.WithCancel(connCtx) + uHandler := s.handler(frisbeeConn, closed, wg, connCtx, cancel) + handler := func(p *packet.Packet) { + uHandler(p) + <-s.limiter + } + for { + select { + case s.limiter <- struct{}{}: + wg.Add(1) + go handler(p) + p, err = frisbeeConn.ReadPacket() if err != nil { + _ = frisbeeConn.Close() + if closed.CAS(false, true) { + s.onClosed(frisbeeConn, err) + } + cancel() + wg.Wait() return } - } else { - packet.Put(p) - } - switch action { - case NONE: - case UPDATE: - if s.UpdateContext != nil { - connCtx = s.UpdateContext(connCtx, frisbeeConn) + case <-connCtx.Done(): + _ = frisbeeConn.Close() + if closed.CAS(false, true) { + s.onClosed(frisbeeConn, err) } - case CLOSE: + wg.Wait() return } - } else { - packet.Put(p) } - goto LOOP } -// ServeConn takes a TCP net.Conn and serves it using the Server -func (s *Server) ServeConn(newConn net.Conn) { +// ServeConn takes a net.Conn and starts a goroutine to handle it using the Server. +func (s *Server) ServeConn(conn net.Conn) { + s.wg.Add(1) + go s.serveConn(conn) +} + +// serveConn takes a net.Conn and serves it using the Server +// and assumes that the server's wait group has been incremented by 1. +func (s *Server) serveConn(newConn net.Conn) { var err error switch v := newConn.(type) { case *net.TCPConn: @@ -276,12 +416,14 @@ func (s *Server) ServeConn(newConn net.Conn) { if err != nil { s.Logger().Error().Err(err).Msg("Error while setting TCP Keepalive") _ = v.Close() + s.wg.Done() return } err = v.SetKeepAlivePeriod(s.options.KeepAlive) if err != nil { s.Logger().Error().Err(err).Msg("Error while setting TCP Keepalive Period") _ = v.Close() + s.wg.Done() return } } @@ -291,19 +433,24 @@ func (s *Server) ServeConn(newConn net.Conn) { s.connectionsMu.Lock() if s.shutdown.Load() { + s.wg.Done() return } s.connections[frisbeeConn] = struct{}{} s.connectionsMu.Unlock() - - err = s.handlePacket(frisbeeConn, connCtx) - _ = frisbeeConn.Close() - s.onClosed(frisbeeConn, err) + if s.concurrency == 0 { + s.handleUnlimitedPacket(frisbeeConn, connCtx) + } else if s.concurrency == 1 { + s.handleSinglePacket(frisbeeConn, connCtx) + } else { + s.handleLimitedPacket(frisbeeConn, connCtx) + } s.connectionsMu.Lock() if !s.shutdown.Load() { delete(s.connections, frisbeeConn) } s.connectionsMu.Unlock() + s.wg.Done() } // Logger returns the server's logger (useful for ServerRouter functions) diff --git a/server_test.go b/server_test.go index d51f3ae..0ef88bc 100644 --- a/server_test.go +++ b/server_test.go @@ -27,10 +27,12 @@ import ( "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "io/ioutil" + "go.uber.org/atomic" + "io" "net" "sync" "testing" + "time" ) // trunk-ignore-all(golangci-lint/staticcheck) @@ -39,7 +41,7 @@ const ( serverConnContextKey = "conn" ) -func TestServerRaw(t *testing.T) { +func TestServerRawSingle(t *testing.T) { t.Parallel() const testSize = 100 @@ -65,10 +67,12 @@ func TestServerRaw(t *testing.T) { return } - emptyLogger := zerolog.New(ioutil.Discard) + emptyLogger := zerolog.New(io.Discard) s, err := NewServer(serverHandlerTable, WithLogger(&emptyLogger)) require.NoError(t, err) + s.SetConcurrency(1) + s.ConnContext = func(ctx context.Context, c *Async) context.Context { return context.WithValue(ctx, serverConnContextKey, c) } @@ -119,14 +123,14 @@ func TestServerRaw(t *testing.T) { write, err := rawServerConn.Write(serverBytes) assert.NoError(t, err) - assert.Equal(t, len(serverBytes), write) + assert.Equal(t, cap(serverBytes), write) - clientBuffer := make([]byte, len(serverBytes)) - read, err := rawClientConn.Read(clientBuffer) + clientBuffer := make([]byte, cap(serverBytes)) + read, err := rawClientConn.Read(clientBuffer[:]) assert.NoError(t, err) - assert.Equal(t, len(serverBytes), read) + assert.Equal(t, cap(serverBytes), read) - assert.Equal(t, serverBytes, clientBuffer) + assert.Equal(t, serverBytes, clientBuffer[:read]) err = c.Close() assert.NoError(t, err) @@ -139,7 +143,7 @@ func TestServerRaw(t *testing.T) { assert.NoError(t, err) } -func TestServerStaleClose(t *testing.T) { +func TestServerStaleCloseSingle(t *testing.T) { t.Parallel() const testSize = 100 @@ -162,10 +166,12 @@ func TestServerStaleClose(t *testing.T) { return } - emptyLogger := zerolog.New(ioutil.Discard) + emptyLogger := zerolog.New(io.Discard) s, err := NewServer(serverHandlerTable, WithLogger(&emptyLogger)) require.NoError(t, err) + s.SetConcurrency(1) + serverConn, clientConn, err := pair.New() require.NoError(t, err) @@ -205,7 +211,7 @@ func TestServerStaleClose(t *testing.T) { assert.NoError(t, err) } -func TestServerMultipleConnections(t *testing.T) { +func TestServerMultipleConnectionsSingle(t *testing.T) { t.Parallel() const testSize = 100 @@ -232,10 +238,12 @@ func TestServerMultipleConnections(t *testing.T) { return } - emptyLogger := zerolog.New(ioutil.Discard) + emptyLogger := zerolog.New(io.Discard) s, err := NewServer(serverHandlerTable, WithLogger(&emptyLogger)) require.NoError(t, err) + s.SetConcurrency(1) + var wg sync.WaitGroup wg.Add(1) @@ -302,131 +310,1079 @@ func TestServerMultipleConnections(t *testing.T) { t.Run("100", func(t *testing.T) { runner(t, 100) }) } -func BenchmarkThroughputServer(b *testing.B) { - const testSize = 1<<16 - 1 +func TestServerRawUnlimited(t *testing.T) { + t.Parallel() + + const testSize = 100 const packetSize = 512 + clientHandlerTable := make(HandlerTable) + serverHandlerTable := make(HandlerTable) - handlerTable := make(HandlerTable) + serverIsRaw := make(chan struct{}, 1) - handlerTable[metadata.PacketPing] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { + serverHandlerTable[metadata.PacketPing] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { return } - emptyLogger := zerolog.New(ioutil.Discard) - server, err := NewServer(handlerTable, WithLogger(&emptyLogger)) - if err != nil { - b.Fatal(err) + var rawServerConn, rawClientConn net.Conn + serverHandlerTable[metadata.PacketProbe] = func(ctx context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { + c := ctx.Value(serverConnContextKey).(*Async) + rawServerConn = c.Raw() + serverIsRaw <- struct{}{} + return } - serverConn, clientConn, err := pair.New() - if err != nil { - b.Fatal(err) + clientHandlerTable[metadata.PacketPing] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { + return } - go server.ServeConn(serverConn) + emptyLogger := zerolog.New(io.Discard) + s, err := NewServer(serverHandlerTable, WithLogger(&emptyLogger)) + require.NoError(t, err) - frisbeeConn := NewAsync(clientConn, &emptyLogger) + s.SetConcurrency(0) + + s.ConnContext = func(ctx context.Context, c *Async) context.Context { + return context.WithValue(ctx, serverConnContextKey, c) + } + + serverConn, clientConn, err := pair.New() + require.NoError(t, err) + + go s.ServeConn(serverConn) + + c, err := NewClient(clientHandlerTable, context.Background(), WithLogger(&emptyLogger)) + assert.NoError(t, err) + + _, err = c.Raw() + assert.ErrorIs(t, ConnectionNotInitialized, err) + + err = c.FromConn(clientConn) + assert.NoError(t, err) data := make([]byte, packetSize) _, _ = rand.Read(data) p := packet.Get() - p.Metadata.Operation = metadata.PacketPing - p.Content.Write(data) p.Metadata.ContentLength = packetSize + p.Metadata.Operation = metadata.PacketPing + assert.Equal(t, polyglot.Buffer(data), *p.Content) - b.Run("test", func(b *testing.B) { - b.SetBytes(testSize * packetSize) - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - for q := 0; q < testSize; q++ { - p.Metadata.Id = uint16(q) - err = frisbeeConn.WritePacket(p) - if err != nil { - b.Fatal(err) - } - } - } - }) - b.StopTimer() + for q := 0; q < testSize; q++ { + p.Metadata.Id = uint16(q) + err = c.WritePacket(p) + assert.NoError(t, err) + } + + p.Reset() + assert.Equal(t, 0, len(*p.Content)) + p.Metadata.Operation = metadata.PacketProbe + + err = c.WritePacket(p) + require.NoError(t, err) packet.Put(p) - err = frisbeeConn.Close() - if err != nil { - b.Fatal(err) - } - err = server.Shutdown() - if err != nil { - b.Fatal(err) - } + rawClientConn, err = c.Raw() + require.NoError(t, err) + + <-serverIsRaw + + serverBytes := []byte("SERVER WRITE") + + write, err := rawServerConn.Write(serverBytes) + assert.NoError(t, err) + assert.Equal(t, len(serverBytes), write) + + clientBuffer := make([]byte, len(serverBytes)) + read, err := rawClientConn.Read(clientBuffer) + assert.NoError(t, err) + assert.Equal(t, len(serverBytes), read) + + assert.Equal(t, serverBytes, clientBuffer) + + err = c.Close() + assert.NoError(t, err) + err = rawClientConn.Close() + assert.NoError(t, err) + + err = s.Shutdown() + assert.NoError(t, err) + err = rawServerConn.Close() + assert.NoError(t, err) } -func BenchmarkThroughputResponseServer(b *testing.B) { - const testSize = 1<<16 - 1 - const packetSize = 512 +func TestServerStaleCloseUnlimited(t *testing.T) { + t.Parallel() - serverConn, clientConn, err := pair.New() - if err != nil { - b.Fatal(err) - } + const testSize = 100 + const packetSize = 512 + clientHandlerTable := make(HandlerTable) + serverHandlerTable := make(HandlerTable) - handlerTable := make(HandlerTable) + finished := make(chan struct{}, 1) - handlerTable[metadata.PacketPing] = func(_ context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action Action) { - if incoming.Metadata.Id == testSize-1 { - incoming.Reset() - incoming.Metadata.Id = testSize - incoming.Metadata.Operation = metadata.PacketPong + count := atomic.NewUint32(0) + serverHandlerTable[metadata.PacketPing] = func(_ context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action Action) { + if count.Inc() == testSize-1 { outgoing = incoming + action = CLOSE + count.Store(0) } return } - emptyLogger := zerolog.New(ioutil.Discard) - server, err := NewServer(handlerTable, WithLogger(&emptyLogger)) - if err != nil { - b.Fatal(err) + clientHandlerTable[metadata.PacketPing] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { + finished <- struct{}{} + return } - go server.ServeConn(serverConn) + emptyLogger := zerolog.New(io.Discard) + s, err := NewServer(serverHandlerTable, WithLogger(&emptyLogger)) + require.NoError(t, err) - frisbeeConn := NewAsync(clientConn, &emptyLogger) + s.SetConcurrency(0) + + serverConn, clientConn, err := pair.New() + require.NoError(t, err) + + go s.ServeConn(serverConn) + + c, err := NewClient(clientHandlerTable, context.Background(), WithLogger(&emptyLogger)) + assert.NoError(t, err) + _, err = c.Raw() + assert.ErrorIs(t, ConnectionNotInitialized, err) + + err = c.FromConn(clientConn) + require.NoError(t, err) data := make([]byte, packetSize) _, _ = rand.Read(data) - p := packet.Get() - p.Metadata.Operation = metadata.PacketPing - p.Content.Write(data) p.Metadata.ContentLength = packetSize + p.Metadata.Operation = metadata.PacketPing + assert.Equal(t, polyglot.Buffer(data), *p.Content) - b.Run("test", func(b *testing.B) { - b.SetBytes(testSize * packetSize) - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - for q := 0; q < testSize; q++ { - p.Metadata.Id = uint16(q) - err = frisbeeConn.WritePacket(p) - if err != nil { - b.Fatal(err) - } - } - readPacket, err := frisbeeConn.ReadPacket() - if err != nil { - b.Fatal(err) - } + for q := 0; q < testSize; q++ { + p.Metadata.Id = uint16(q) + err = c.WritePacket(p) + assert.NoError(t, err) + } + packet.Put(p) + <-finished - if readPacket.Metadata.Id != testSize { - b.Fatal("invalid decoded metadata id", readPacket.Metadata.Id) + _, err = c.conn.ReadPacket() + assert.ErrorIs(t, err, ConnectionClosed) + + err = c.Close() + assert.NoError(t, err) + + err = s.Shutdown() + assert.NoError(t, err) +} + +func TestServerMultipleConnectionsUnlimited(t *testing.T) { + t.Parallel() + + const testSize = 100 + const packetSize = 512 + + runner := func(t *testing.T, num int) { + finished := make([]chan struct{}, num) + clientTables := make([]HandlerTable, num) + for i := 0; i < num; i++ { + idx := i + finished[idx] = make(chan struct{}, 1) + clientTables[i] = make(HandlerTable) + clientTables[i][metadata.PacketPing] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { + finished[idx] <- struct{}{} + return } + } + clientCounts := make([]*atomic.Uint32, num) + for i := 0; i < num; i++ { + clientCounts[i] = atomic.NewUint32(0) + } - if readPacket.Metadata.Operation != metadata.PacketPong { - b.Fatal("invalid decoded operation", readPacket.Metadata.Operation) + serverHandlerTable := make(HandlerTable) + serverHandlerTable[metadata.PacketPing] = func(_ context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action Action) { + if clientCounts[incoming.Metadata.Id].Inc() == testSize-1 { + outgoing = incoming + action = CLOSE + clientCounts[incoming.Metadata.Id].Store(0) } + return + } + + emptyLogger := zerolog.New(io.Discard) + s, err := NewServer(serverHandlerTable, WithLogger(&emptyLogger)) + require.NoError(t, err) + + s.SetConcurrency(0) + + var wg sync.WaitGroup + + wg.Add(1) + go func() { + err := s.Start(conn.Listen) + require.NoError(t, err) + wg.Done() + }() + + <-s.started() + listenAddr := s.listener.Addr().String() + + clients := make([]*Client, num) + for i := 0; i < num; i++ { + clients[i], err = NewClient(clientTables[i], context.Background(), WithLogger(&emptyLogger)) + assert.NoError(t, err) + _, err = clients[i].Raw() + assert.ErrorIs(t, ConnectionNotInitialized, err) + + err = clients[i].Connect(listenAddr) + require.NoError(t, err) + } + + data := make([]byte, packetSize) + _, err = rand.Read(data) + assert.NoError(t, err) + + var clientWg sync.WaitGroup + for i := 0; i < num; i++ { + idx := i + clientWg.Add(1) + go func() { + p := packet.Get() + p.Content.Write(data) + p.Metadata.ContentLength = packetSize + p.Metadata.Operation = metadata.PacketPing + p.Metadata.Id = uint16(idx) + assert.Equal(t, polyglot.Buffer(data), *p.Content) + for q := 0; q < testSize; q++ { + err := clients[idx].WritePacket(p) + assert.NoError(t, err) + } + <-finished[idx] + err := clients[idx].Close() + assert.NoError(t, err) + clientWg.Done() + packet.Put(p) + }() + } + + clientWg.Wait() + + err = s.Shutdown() + assert.NoError(t, err) + wg.Wait() + + } + + t.Run("1", func(t *testing.T) { runner(t, 1) }) + t.Run("2", func(t *testing.T) { runner(t, 2) }) + t.Run("3", func(t *testing.T) { runner(t, 3) }) + t.Run("5", func(t *testing.T) { runner(t, 5) }) + t.Run("10", func(t *testing.T) { runner(t, 10) }) + t.Run("100", func(t *testing.T) { runner(t, 100) }) +} + +func TestServerRawLimited(t *testing.T) { + t.Parallel() + + const testSize = 100 + const packetSize = 512 + clientHandlerTable := make(HandlerTable) + serverHandlerTable := make(HandlerTable) + + serverIsRaw := make(chan struct{}, 1) + + serverHandlerTable[metadata.PacketPing] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { + return + } + + var rawServerConn, rawClientConn net.Conn + serverHandlerTable[metadata.PacketProbe] = func(ctx context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { + c := ctx.Value(serverConnContextKey).(*Async) + rawServerConn = c.Raw() + serverIsRaw <- struct{}{} + return + } + + clientHandlerTable[metadata.PacketPing] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { + return + } + + emptyLogger := zerolog.New(io.Discard) + s, err := NewServer(serverHandlerTable, WithLogger(&emptyLogger)) + require.NoError(t, err) + + s.SetConcurrency(10) + + s.ConnContext = func(ctx context.Context, c *Async) context.Context { + return context.WithValue(ctx, serverConnContextKey, c) + } + + serverConn, clientConn, err := pair.New() + require.NoError(t, err) + + go s.ServeConn(serverConn) + + c, err := NewClient(clientHandlerTable, context.Background(), WithLogger(&emptyLogger)) + assert.NoError(t, err) + + _, err = c.Raw() + assert.ErrorIs(t, ConnectionNotInitialized, err) + + err = c.FromConn(clientConn) + assert.NoError(t, err) + + data := make([]byte, packetSize) + _, _ = rand.Read(data) + p := packet.Get() + p.Content.Write(data) + p.Metadata.ContentLength = packetSize + p.Metadata.Operation = metadata.PacketPing + assert.Equal(t, polyglot.Buffer(data), *p.Content) + + for q := 0; q < testSize; q++ { + p.Metadata.Id = uint16(q) + err = c.WritePacket(p) + assert.NoError(t, err) + } + + p.Reset() + assert.Equal(t, 0, len(*p.Content)) + p.Metadata.Operation = metadata.PacketProbe + + err = c.WritePacket(p) + require.NoError(t, err) + + packet.Put(p) + + rawClientConn, err = c.Raw() + require.NoError(t, err) + + <-serverIsRaw + + serverBytes := []byte("SERVER WRITE") + + write, err := rawServerConn.Write(serverBytes) + assert.NoError(t, err) + assert.Equal(t, len(serverBytes), write) + + clientBuffer := make([]byte, len(serverBytes)) + read, err := rawClientConn.Read(clientBuffer) + assert.NoError(t, err) + assert.Equal(t, len(serverBytes), read) + + assert.Equal(t, serverBytes, clientBuffer) + + err = c.Close() + assert.NoError(t, err) + err = rawClientConn.Close() + assert.NoError(t, err) + + err = s.Shutdown() + assert.NoError(t, err) + err = rawServerConn.Close() + assert.NoError(t, err) +} + +func TestServerStaleCloseLimited(t *testing.T) { + t.Parallel() + + const testSize = 100 + const packetSize = 512 + clientHandlerTable := make(HandlerTable) + serverHandlerTable := make(HandlerTable) + + finished := make(chan struct{}, 1) + + count := atomic.NewUint32(0) + serverHandlerTable[metadata.PacketPing] = func(_ context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action Action) { + if count.Inc() == testSize-1 { + outgoing = incoming + action = CLOSE + count.Store(0) + } + return + } + + clientHandlerTable[metadata.PacketPing] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { + finished <- struct{}{} + return + } + + emptyLogger := zerolog.New(io.Discard) + s, err := NewServer(serverHandlerTable, WithLogger(&emptyLogger)) + require.NoError(t, err) + + s.SetConcurrency(10) + + serverConn, clientConn, err := pair.New() + require.NoError(t, err) + + go s.ServeConn(serverConn) + + c, err := NewClient(clientHandlerTable, context.Background(), WithLogger(&emptyLogger)) + assert.NoError(t, err) + _, err = c.Raw() + assert.ErrorIs(t, ConnectionNotInitialized, err) + + err = c.FromConn(clientConn) + require.NoError(t, err) + + data := make([]byte, packetSize) + _, _ = rand.Read(data) + p := packet.Get() + p.Content.Write(data) + p.Metadata.ContentLength = packetSize + p.Metadata.Operation = metadata.PacketPing + assert.Equal(t, polyglot.Buffer(data), *p.Content) + + for q := 0; q < testSize; q++ { + p.Metadata.Id = uint16(q) + err = c.WritePacket(p) + assert.NoError(t, err) + } + packet.Put(p) + <-finished + + _, err = c.conn.ReadPacket() + assert.ErrorIs(t, err, ConnectionClosed) + + err = c.Close() + assert.NoError(t, err) + + err = s.Shutdown() + assert.NoError(t, err) +} + +func TestServerMultipleConnectionsLimited(t *testing.T) { + t.Parallel() + + const testSize = 100 + const packetSize = 512 + + runner := func(t *testing.T, num int) { + finished := make([]chan struct{}, num) + clientTables := make([]HandlerTable, num) + for i := 0; i < num; i++ { + idx := i + finished[idx] = make(chan struct{}, 1) + clientTables[i] = make(HandlerTable) + clientTables[i][metadata.PacketPing] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { + finished[idx] <- struct{}{} + return + } + } + + clientCounts := make([]*atomic.Uint32, num) + for i := 0; i < num; i++ { + clientCounts[i] = atomic.NewUint32(0) + } + + serverHandlerTable := make(HandlerTable) + serverHandlerTable[metadata.PacketPing] = func(_ context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action Action) { + if clientCounts[incoming.Metadata.Id].Inc() == testSize-1 { + outgoing = incoming + action = CLOSE + clientCounts[incoming.Metadata.Id].Store(0) + } + return + } + + emptyLogger := zerolog.New(io.Discard) + s, err := NewServer(serverHandlerTable, WithLogger(&emptyLogger)) + require.NoError(t, err) + + s.SetConcurrency(10) + + var wg sync.WaitGroup + + wg.Add(1) + go func() { + err := s.Start(conn.Listen) + require.NoError(t, err) + wg.Done() + }() + + <-s.started() + listenAddr := s.listener.Addr().String() + + clients := make([]*Client, num) + for i := 0; i < num; i++ { + clients[i], err = NewClient(clientTables[i], context.Background(), WithLogger(&emptyLogger)) + assert.NoError(t, err) + _, err = clients[i].Raw() + assert.ErrorIs(t, ConnectionNotInitialized, err) + + err = clients[i].Connect(listenAddr) + require.NoError(t, err) + } + + data := make([]byte, packetSize) + _, err = rand.Read(data) + assert.NoError(t, err) + + var clientWg sync.WaitGroup + for i := 0; i < num; i++ { + idx := i + clientWg.Add(1) + go func() { + p := packet.Get() + p.Content.Write(data) + p.Metadata.ContentLength = packetSize + p.Metadata.Operation = metadata.PacketPing + p.Metadata.Id = uint16(idx) + assert.Equal(t, polyglot.Buffer(data), *p.Content) + for q := 0; q < testSize; q++ { + err := clients[idx].WritePacket(p) + assert.NoError(t, err) + } + <-finished[idx] + err := clients[idx].Close() + assert.NoError(t, err) + clientWg.Done() + packet.Put(p) + }() + } + + clientWg.Wait() + + err = s.Shutdown() + assert.NoError(t, err) + wg.Wait() + + } + + t.Run("1", func(t *testing.T) { runner(t, 1) }) + t.Run("2", func(t *testing.T) { runner(t, 2) }) + t.Run("3", func(t *testing.T) { runner(t, 3) }) + t.Run("5", func(t *testing.T) { runner(t, 5) }) + t.Run("10", func(t *testing.T) { runner(t, 10) }) + t.Run("100", func(t *testing.T) { runner(t, 100) }) +} + +func BenchmarkThroughputServerSingle(b *testing.B) { + const testSize = 1<<16 - 1 + const packetSize = 512 + + handlerTable := make(HandlerTable) + + handlerTable[metadata.PacketPing] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { + return + } + + emptyLogger := zerolog.New(io.Discard) + server, err := NewServer(handlerTable, WithLogger(&emptyLogger)) + if err != nil { + b.Fatal(err) + } + + server.SetConcurrency(1) + + serverConn, clientConn, err := pair.New() + if err != nil { + b.Fatal(err) + } + + go server.ServeConn(serverConn) + + frisbeeConn := NewAsync(clientConn, &emptyLogger) + + data := make([]byte, packetSize) + _, _ = rand.Read(data) + p := packet.Get() + p.Metadata.Operation = metadata.PacketPing + + p.Content.Write(data) + p.Metadata.ContentLength = packetSize + + b.Run("test", func(b *testing.B) { + b.SetBytes(testSize * packetSize) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + for q := 0; q < testSize; q++ { + p.Metadata.Id = uint16(q) + err = frisbeeConn.WritePacket(p) + if err != nil { + b.Fatal(err) + } + } + } + }) + b.StopTimer() + + packet.Put(p) + + err = frisbeeConn.Close() + if err != nil { + b.Fatal(err) + } + err = server.Shutdown() + if err != nil { + b.Fatal(err) + } +} + +func BenchmarkThroughputServerUnlimited(b *testing.B) { + const testSize = 1<<16 - 1 + const packetSize = 512 + + handlerTable := make(HandlerTable) + + handlerTable[metadata.PacketPing] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { + time.Sleep(time.Millisecond * 50) + return + } + + emptyLogger := zerolog.New(io.Discard) + server, err := NewServer(handlerTable, WithLogger(&emptyLogger)) + if err != nil { + b.Fatal(err) + } + + server.SetConcurrency(0) + + serverConn, clientConn, err := pair.New() + if err != nil { + b.Fatal(err) + } + + go server.ServeConn(serverConn) + + frisbeeConn := NewAsync(clientConn, &emptyLogger) + + data := make([]byte, packetSize) + _, _ = rand.Read(data) + p := packet.Get() + p.Metadata.Operation = metadata.PacketPing + + p.Content.Write(data) + p.Metadata.ContentLength = packetSize + + b.Run("test", func(b *testing.B) { + b.SetBytes(testSize * packetSize) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + for q := 0; q < testSize; q++ { + p.Metadata.Id = uint16(q) + err = frisbeeConn.WritePacket(p) + if err != nil { + b.Fatal(err) + } + } + } + }) + b.StopTimer() + + packet.Put(p) + + err = frisbeeConn.Close() + if err != nil { + b.Fatal(err) + } + err = server.Shutdown() + if err != nil { + b.Fatal(err) + } +} + +func BenchmarkThroughputServerLimited(b *testing.B) { + const testSize = 1<<16 - 1 + const packetSize = 512 + + handlerTable := make(HandlerTable) + + handlerTable[metadata.PacketPing] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { + time.Sleep(time.Millisecond * 50) + return + } + + emptyLogger := zerolog.New(io.Discard) + server, err := NewServer(handlerTable, WithLogger(&emptyLogger)) + if err != nil { + b.Fatal(err) + } + + server.SetConcurrency(1 << 14) + + serverConn, clientConn, err := pair.New() + if err != nil { + b.Fatal(err) + } + + go server.ServeConn(serverConn) + + frisbeeConn := NewAsync(clientConn, &emptyLogger) + + data := make([]byte, packetSize) + _, _ = rand.Read(data) + p := packet.Get() + p.Metadata.Operation = metadata.PacketPing + + p.Content.Write(data) + p.Metadata.ContentLength = packetSize + + b.Run("test", func(b *testing.B) { + b.SetBytes(testSize * packetSize) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + for q := 0; q < testSize; q++ { + p.Metadata.Id = uint16(q) + err = frisbeeConn.WritePacket(p) + if err != nil { + b.Fatal(err) + } + } + } + }) + b.StopTimer() + + packet.Put(p) + + err = frisbeeConn.Close() + if err != nil { + b.Fatal(err) + } + err = server.Shutdown() + if err != nil { + b.Fatal(err) + } +} + +func BenchmarkThroughputResponseServerSingle(b *testing.B) { + const testSize = 1<<16 - 1 + const packetSize = 512 + + serverConn, clientConn, err := pair.New() + if err != nil { + b.Fatal(err) + } + + handlerTable := make(HandlerTable) + + handlerTable[metadata.PacketPing] = func(_ context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action Action) { + if incoming.Metadata.Id == testSize-1 { + incoming.Reset() + incoming.Metadata.Id = testSize + incoming.Metadata.Operation = metadata.PacketPong + outgoing = incoming + } + return + } + + emptyLogger := zerolog.New(io.Discard) + server, err := NewServer(handlerTable, WithLogger(&emptyLogger)) + if err != nil { + b.Fatal(err) + } + + server.SetConcurrency(1) + + go server.ServeConn(serverConn) + + frisbeeConn := NewAsync(clientConn, &emptyLogger) + + data := make([]byte, packetSize) + _, _ = rand.Read(data) + + p := packet.Get() + p.Metadata.Operation = metadata.PacketPing + + p.Content.Write(data) + p.Metadata.ContentLength = packetSize + + b.Run("test", func(b *testing.B) { + b.SetBytes(testSize * packetSize) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + for q := 0; q < testSize; q++ { + p.Metadata.Id = uint16(q) + err = frisbeeConn.WritePacket(p) + if err != nil { + b.Fatal(err) + } + } + readPacket, err := frisbeeConn.ReadPacket() + if err != nil { + b.Fatal(err) + } + + if readPacket.Metadata.Id != testSize { + b.Fatal("invalid decoded metadata id", readPacket.Metadata.Id) + } + + if readPacket.Metadata.Operation != metadata.PacketPong { + b.Fatal("invalid decoded operation", readPacket.Metadata.Operation) + } + packet.Put(readPacket) + } + + }) + b.StopTimer() + + packet.Put(p) + + err = frisbeeConn.Close() + if err != nil { + b.Fatal(err) + } + err = server.Shutdown() + if err != nil { + b.Fatal(err) + } +} + +func BenchmarkThroughputResponseServerSlowSingle(b *testing.B) { + const testSize = 1<<16 - 1 + const packetSize = 512 + + serverConn, clientConn, err := pair.New() + if err != nil { + b.Fatal(err) + } + + handlerTable := make(HandlerTable) + + handlerTable[metadata.PacketPing] = func(_ context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action Action) { + time.Sleep(time.Microsecond * 50) + if incoming.Metadata.Id == testSize-1 { + incoming.Reset() + incoming.Metadata.Id = testSize + incoming.Metadata.Operation = metadata.PacketPong + outgoing = incoming + } + return + } + + emptyLogger := zerolog.New(io.Discard) + server, err := NewServer(handlerTable, WithLogger(&emptyLogger)) + if err != nil { + b.Fatal(err) + } + + server.SetConcurrency(1) + + go server.ServeConn(serverConn) + + frisbeeConn := NewAsync(clientConn, &emptyLogger) + + data := make([]byte, packetSize) + _, _ = rand.Read(data) + + p := packet.Get() + p.Metadata.Operation = metadata.PacketPing + + p.Content.Write(data) + p.Metadata.ContentLength = packetSize + + b.Run("test", func(b *testing.B) { + b.SetBytes(testSize * packetSize) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + for q := 0; q < testSize; q++ { + p.Metadata.Id = uint16(q) + err = frisbeeConn.WritePacket(p) + if err != nil { + b.Fatal(err) + } + } + readPacket, err := frisbeeConn.ReadPacket() + if err != nil { + b.Fatal(err) + } + + if readPacket.Metadata.Id != testSize { + b.Fatal("invalid decoded metadata id", readPacket.Metadata.Id) + } + + if readPacket.Metadata.Operation != metadata.PacketPong { + b.Fatal("invalid decoded operation", readPacket.Metadata.Operation) + } + packet.Put(readPacket) + } + + }) + b.StopTimer() + + packet.Put(p) + + err = frisbeeConn.Close() + if err != nil { + b.Fatal(err) + } + err = server.Shutdown() + if err != nil { + b.Fatal(err) + } +} + +func BenchmarkThroughputResponseServerSlowUnlimited(b *testing.B) { + const testSize = 1<<16 - 1 + const packetSize = 512 + + serverConn, clientConn, err := pair.New() + if err != nil { + b.Fatal(err) + } + + handlerTable := make(HandlerTable) + + count := atomic.NewUint64(0) + handlerTable[metadata.PacketPing] = func(_ context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action Action) { + time.Sleep(time.Microsecond * 50) + if count.Inc() == testSize-1 { + incoming.Reset() + incoming.Metadata.Id = testSize + incoming.Metadata.Operation = metadata.PacketPong + outgoing = incoming + count.Store(0) + } + return + } + + emptyLogger := zerolog.New(io.Discard) + server, err := NewServer(handlerTable, WithLogger(&emptyLogger)) + if err != nil { + b.Fatal(err) + } + + server.SetConcurrency(0) + + go server.ServeConn(serverConn) + + frisbeeConn := NewAsync(clientConn, &emptyLogger) + + data := make([]byte, packetSize) + _, _ = rand.Read(data) + + p := packet.Get() + p.Metadata.Operation = metadata.PacketPing + + p.Content.Write(data) + p.Metadata.ContentLength = packetSize + + b.Run("test", func(b *testing.B) { + b.SetBytes(testSize * packetSize) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + for q := 0; q < testSize; q++ { + p.Metadata.Id = uint16(q) + err = frisbeeConn.WritePacket(p) + if err != nil { + b.Fatal(err) + } + } + readPacket, err := frisbeeConn.ReadPacket() + if err != nil { + b.Fatal(err) + } + + if readPacket.Metadata.Id != testSize { + b.Fatal("invalid decoded metadata id", readPacket.Metadata.Id) + } + + if readPacket.Metadata.Operation != metadata.PacketPong { + b.Fatal("invalid decoded operation", readPacket.Metadata.Operation) + } + + packet.Put(readPacket) + } + + }) + b.StopTimer() + + packet.Put(p) + + err = frisbeeConn.Close() + if err != nil { + b.Fatal(err) + } + err = server.Shutdown() + if err != nil { + b.Fatal(err) + } +} + +func BenchmarkThroughputResponseServerSlowLimited(b *testing.B) { + const testSize = 1<<16 - 1 + const packetSize = 512 + + serverConn, clientConn, err := pair.New() + if err != nil { + b.Fatal(err) + } + + handlerTable := make(HandlerTable) + + count := atomic.NewUint64(0) + handlerTable[metadata.PacketPing] = func(_ context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action Action) { + time.Sleep(time.Microsecond * 50) + if count.Inc() == testSize-1 { + incoming.Reset() + incoming.Metadata.Id = testSize + incoming.Metadata.Operation = metadata.PacketPong + outgoing = incoming + count.Store(0) + } + return + } + + emptyLogger := zerolog.New(io.Discard) + server, err := NewServer(handlerTable, WithLogger(&emptyLogger)) + if err != nil { + b.Fatal(err) + } + + server.SetConcurrency(100) + + go server.ServeConn(serverConn) + + frisbeeConn := NewAsync(clientConn, &emptyLogger) + + data := make([]byte, packetSize) + _, _ = rand.Read(data) + + p := packet.Get() + p.Metadata.Operation = metadata.PacketPing + + p.Content.Write(data) + p.Metadata.ContentLength = packetSize + + b.Run("test", func(b *testing.B) { + b.SetBytes(testSize * packetSize) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + for q := 0; q < testSize; q++ { + p.Metadata.Id = uint16(q) + err = frisbeeConn.WritePacket(p) + if err != nil { + b.Fatal(err) + } + } + readPacket, err := frisbeeConn.ReadPacket() + if err != nil { + b.Fatal(err) + } + + if readPacket.Metadata.Id != testSize { + b.Fatal("invalid decoded metadata id", readPacket.Metadata.Id) + } + + if readPacket.Metadata.Operation != metadata.PacketPong { + b.Fatal("invalid decoded operation", readPacket.Metadata.Operation) + } + packet.Put(readPacket) } diff --git a/sync_test.go b/sync_test.go index f434288..bcf1ca1 100644 --- a/sync_test.go +++ b/sync_test.go @@ -25,7 +25,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "io" - "io/ioutil" "net" "testing" ) @@ -34,7 +33,7 @@ func TestNewSync(t *testing.T) { t.Parallel() const packetSize = 512 - emptyLogger := zerolog.New(ioutil.Discard) + emptyLogger := zerolog.New(io.Discard) reader, writer := net.Pipe() @@ -105,7 +104,7 @@ func TestSyncLargeWrite(t *testing.T) { const testSize = 100000 const packetSize = 512 - emptyLogger := zerolog.New(ioutil.Discard) + emptyLogger := zerolog.New(io.Discard) reader, writer := net.Pipe() @@ -163,7 +162,7 @@ func TestSyncRawConn(t *testing.T) { const testSize = 100000 const packetSize = 32 - emptyLogger := zerolog.New(ioutil.Discard) + emptyLogger := zerolog.New(io.Discard) reader, writer, err := pair.New() require.NoError(t, err) @@ -236,7 +235,7 @@ func TestSyncReadClose(t *testing.T) { reader, writer := net.Pipe() - emptyLogger := zerolog.New(ioutil.Discard) + emptyLogger := zerolog.New(io.Discard) readerConn := NewSync(reader, &emptyLogger) writerConn := NewSync(writer, &emptyLogger) @@ -286,7 +285,7 @@ func TestSyncWriteClose(t *testing.T) { reader, writer := net.Pipe() - emptyLogger := zerolog.New(ioutil.Discard) + emptyLogger := zerolog.New(io.Discard) readerConn := NewSync(reader, &emptyLogger) writerConn := NewSync(writer, &emptyLogger) @@ -334,7 +333,7 @@ func TestSyncWriteClose(t *testing.T) { func BenchmarkSyncThroughputPipe(b *testing.B) { const testSize = 100 - emptyLogger := zerolog.New(ioutil.Discard) + emptyLogger := zerolog.New(io.Discard) reader, writer := net.Pipe() @@ -354,7 +353,7 @@ func BenchmarkSyncThroughputPipe(b *testing.B) { func BenchmarkSyncThroughputNetwork(b *testing.B) { const testSize = 100 - emptyLogger := zerolog.New(ioutil.Discard) + emptyLogger := zerolog.New(io.Discard) reader, writer, err := pair.New() if err != nil { diff --git a/throughput_test.go b/throughput_test.go index 7146237..b7ba120 100644 --- a/throughput_test.go +++ b/throughput_test.go @@ -24,7 +24,6 @@ import ( "github.com/loopholelabs/testing/conn/pair" "github.com/rs/zerolog" "io" - "io/ioutil" "net" "testing" "time" @@ -33,7 +32,7 @@ import ( func BenchmarkAsyncThroughputLarge(b *testing.B) { const testSize = 100 - emptyLogger := zerolog.New(ioutil.Discard) + emptyLogger := zerolog.New(io.Discard) reader, writer, err := pair.New() if err != nil { @@ -56,7 +55,7 @@ func BenchmarkAsyncThroughputLarge(b *testing.B) { func BenchmarkSyncThroughputLarge(b *testing.B) { const testSize = 100 - emptyLogger := zerolog.New(ioutil.Discard) + emptyLogger := zerolog.New(io.Discard) reader, writer, err := pair.New() if err != nil {