diff --git a/.deepsource.toml b/.deepsource.toml deleted file mode 100644 index 42006a6..0000000 --- a/.deepsource.toml +++ /dev/null @@ -1,22 +0,0 @@ -version = 1 - -test_patterns = ["**/*_test.go"] - -[[analyzers]] -name = "secrets" -enabled = true - -[[analyzers]] -name = "test-coverage" -enabled = true - -[[analyzers]] -name = "go" -enabled = true - -[analyzers.meta] -import_root = "github.com/loopholelabs/frisbee" - -[[transformers]] -name = "gofmt" -enabled = true \ No newline at end of file diff --git a/.github/workflows/benchmarks.yaml b/.github/workflows/benchmarks.yaml index c70a901..d845573 100644 --- a/.github/workflows/benchmarks.yaml +++ b/.github/workflows/benchmarks.yaml @@ -20,7 +20,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v2 with: - go-version: 1.17 + go-version: 1.18 - name: Go Build Cache uses: actions/cache@v2 @@ -50,7 +50,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v2 with: - go-version: 1.17 + go-version: 1.18 - name: Go Build Cache uses: actions/cache@v2 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yaml similarity index 89% rename from .github/workflows/lint.yml rename to .github/workflows/lint.yaml index a38f10e..1a33ea6 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yaml @@ -19,4 +19,4 @@ jobs: key: trunk-${{ runner.os }} - name: Trunk Check - uses: trunk-io/trunk-action@v1.0.0 + uses: trunk-io/trunk-action@v1 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yaml similarity index 96% rename from .github/workflows/tests.yml rename to .github/workflows/tests.yaml index cbb017b..11885a3 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yaml @@ -20,7 +20,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v2 with: - go-version: 1.17 + go-version: 1.18 - name: Go Build Cache uses: actions/cache@v2 @@ -50,7 +50,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v2 with: - go-version: 1.17 + go-version: 1.18 - name: Go Build Cache uses: actions/cache@v2 diff --git a/.trunk/.gitignore b/.trunk/.gitignore index 7feb17f..12e4785 100644 --- a/.trunk/.gitignore +++ b/.trunk/.gitignore @@ -1 +1,2 @@ *out +*log \ No newline at end of file diff --git a/.trunk/trunk.yaml b/.trunk/trunk.yaml index b087158..a9805aa 100644 --- a/.trunk/trunk.yaml +++ b/.trunk/trunk.yaml @@ -1,10 +1,15 @@ version: 0.1 cli: - version: 0.11.0-beta + version: 0.15.0-beta lint: enabled: - - gitleaks@8.7.1 - - gofmt@1.18.1 - - golangci-lint@1.45.2 + - actionlint@1.6.13 + - gitleaks@8.8.7 + - gofmt@1.18.3 + - golangci-lint@1.46.2 - markdownlint@0.31.1 - prettier@2.6.2 + ignore: + - linters: [ALL] + paths: + - dist/** diff --git a/CHANGELOG.md b/CHANGELOG.md index 7bccfd8..c386588 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,19 @@ to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +## [v0.5.1] - 2022-07-20 (Beta) + +## Fixes + +- Fixed an issue where new connections in the server would be overwritten sometimes due to a pointer error + +## Changes + +- FRPC is now called fRPC +- fRPC has been moved into its own [repository](https://github.com/loopholelabs/frpc-go) +- We're using the [Common](https://github.com/loopholelabs/common) library for our queues and packets +- Packets now use the [Polyglot-Go](https://github.com/loopholelabs/polylgot-go) library for serialization + ## [v0.5.0] - 2022-05-18 (Beta) ## Changes @@ -265,7 +278,9 @@ to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). Initial Release of Frisbee -[unreleased]: https://github.com/loopholelabs/frisbee/compare/v0.4.6...HEAD +[unreleased]: https://github.com/loopholelabs/frisbee/compare/v0.5.1...HEAD +[v0.5.1]: https://github.com/loopholelabs/frisbee/compare/v0.5.0...v0.5.1 +[v0.5.0]: https://github.com/loopholelabs/frisbee/compare/v0.4.6...v0.5.0 [v0.4.6]: https://github.com/loopholelabs/frisbee/compare/v0.4.5...v0.4.6 [v0.4.5]: https://github.com/loopholelabs/frisbee/compare/v0.4.4...v0.4.5 [v0.4.4]: https://github.com/loopholelabs/frisbee/compare/v0.4.3...v0.4.4 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b9ec2b6..3f31b30 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -7,7 +7,7 @@ Frisbee uses GitHub to manage reviews of pull requests. [MAINTAINERS.md](MAINTAINERS.md)) in the description of the pull request. - If you plan to do something more involved, first discuss your ideas - on our [slack](https://join.slack.com/t/loopholelabs/shared_invite/zt-pntffh2t-l6mQJdBDafG3x1JJabMAFA). + on our [discord](https://loopholelabs.io/discord). This will avoid unnecessary work and surely give you and us a good deal of inspiration. diff --git a/MAINTAINERS.md b/MAINTAINERS.md index e2210ce..64f52b5 100644 --- a/MAINTAINERS.md +++ b/MAINTAINERS.md @@ -1,2 +1,3 @@ - Shivansh Vij @shivanshvij - Alex Sørlie Glomsaas @supermanifolds +- Felicitas Pojtinger @pojntfx diff --git a/README.md b/README.md index 3a0f92f..d2c063e 100644 --- a/README.md +++ b/README.md @@ -1,20 +1,16 @@ -# Frisbee +# Frisbee-Go [![License: Apache 2.0](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://www.apache.org/licenses/LICENSE-2.0) -[![Tests](https://github.com/loopholelabs/frisbee/actions/workflows/tests.yml/badge.svg)](https://github.com/loopholelabs/frisbee/actions/workflows/tests.yml) -[![Benchmarks](https://github.com/loopholelabs/frisbee/actions/workflows/benchmarks.yaml/badge.svg)](https://github.com/loopholelabs/frisbee/actions/workflows/benchmarks.yaml) -[![Go Report Card](https://goreportcard.com/badge/github.com/loopholelabs/frisbee)](https://goreportcard.com/report/github.com/loopholelabs/frisbee) -[![go-doc](https://godoc.org/github.com/loopholelabs/frisbee?status.svg)](https://godoc.org/github.com/loopholelabs/frisbee) +[![Tests](https://github.com/loopholelabs/frisbee-go/actions/workflows/tests.yaml/badge.svg)](https://github.com/loopholelabs/frisbee-go/actions/workflows/tests.yaml) +[![Benchmarks](https://github.com/loopholelabs/frisbee-go/actions/workflows/benchmarks.yaml/badge.svg)](https://github.com/loopholelabs/frisbee-go/actions/workflows/benchmarks.yaml) +[![Go Report Card](https://goreportcard.com/badge/github.com/loopholelabs/frisbee-go)](https://goreportcard.com/report/github.com/loopholelabs/frisbee-go) +[![go-doc](https://godoc.org/github.com/loopholelabs/frisbee-go?status.svg)](https://godoc.org/github.com/loopholelabs/frisbee-go) -This is the [Go](http://golang.org) library for -[Frisbee](https://frpc.io/concepts/frisbee), a bring-your-own protocol messaging framework designed for performance and +This is the [Go](http://golang.org) implementation of [Frisbee](https://frpc.io/frisbee), a bring-your-own +protocol messaging framework designed for performance and stability. -[FRPC](https://frpc.io) is a lightweight, fast, and secure RPC framework for Go that uses Frisbee under the hood. This -repository houses both projects, with **FRPC** being contained in the -[protoc-gen-frpc]("/protoc-gen-frpc") folder. - -**This library requires Go1.16 or later.** +**This library requires Go1.18 or later.** ## Important note about releases and stability @@ -26,37 +22,14 @@ same is true for selected other new features explicitly marked as ## Usage and Documentation Usage instructions and documentation for Frisbee is available -at [https://frpc.io/concepts/frisbee](https://frpc.io/concepts/frisbee). The Frisbee framework also has great -documentation coverage using [GoDoc](https://godoc.org/github.com/loopholelabs/frisbee). - -## FRPC - -The FRPC Generator is still in very early **Alpha**. While it is functional and being used within other products -we're building at [Loophole Labs][loophomepage], the `proto3` spec has a myriad of edge-cases that make it difficult to -guarantee validity of generated RPC frameworks without extensive real-world use. - -That being said, as the library matures and usage of FRPC grows we'll be able to increase our testing -coverage and fix any edge case bugs. One of the major benefits to the RPC framework is that reading the generated code -is extremely straight forward, making it easy to debug potential issues down the line. - -### Usage and Documentation - -Usage instructions and documentations for FRPC are available at [https://frpc.io/](https://frpc.io). - -### Unsupported Features - -The Frisbee RPC Generator currently does not support the following features, though they are actively being worked on: - -- `OneOf` Message Types -- Streaming Messages between the client and server - -Example `Proto3` files can be found [here](/protoc-gen-frpc/examples). +at [https://frpc.io/frisbee](https://frpc.io/frisbee). The Frisbee framework also has great +documentation coverage using [GoDoc](https://godoc.org/github.com/loopholelabs/frisbee-go). ## Contributing -Bug reports and pull requests are welcome on GitHub at [https://github.com/loopholelabs/frisbee][gitrepo]. For more +Bug reports and pull requests are welcome on GitHub at [https://github.com/loopholelabs/frisbee-go][gitrepo]. For more contribution information check -out [the contribution guide](https://github.com/loopholelabs/frisbee/blob/master/CONTRIBUTING.md). +out [the contribution guide](https://github.com/loopholelabs/frisbee-go/blob/master/CONTRIBUTING.md). ## License @@ -71,7 +44,6 @@ Everyone interacting in the Frisbee project’s codebases, issue trackers, chat [![https://loopholelabs.io][loopholelabs]](https://loopholelabs.io) -[gitrepo]: https://github.com/loopholelabs/frisbee +[gitrepo]: https://github.com/loopholelabs/frisbee-go [loopholelabs]: https://cdn.loopholelabs.io/loopholelabs/LoopholeLabsLogo.svg -[homepage]: https://loopholelabs.io/docs/frisbee [loophomepage]: https://loopholelabs.io diff --git a/async.go b/async.go index 8ffd6c4..17c1417 100644 --- a/async.go +++ b/async.go @@ -21,10 +21,10 @@ import ( "context" "crypto/tls" "encoding/binary" - "github.com/loopholelabs/frisbee/internal/dialer" - "github.com/loopholelabs/frisbee/internal/queue" - "github.com/loopholelabs/frisbee/pkg/metadata" - "github.com/loopholelabs/frisbee/pkg/packet" + "github.com/loopholelabs/common/pkg/queue" + "github.com/loopholelabs/frisbee-go/internal/dialer" + "github.com/loopholelabs/frisbee-go/pkg/metadata" + "github.com/loopholelabs/frisbee-go/pkg/packet" "github.com/pkg/errors" "github.com/rs/zerolog" "go.uber.org/atomic" @@ -43,7 +43,7 @@ type Async struct { closed *atomic.Bool writer *bufio.Writer flusher chan struct{} - incoming *queue.Circular + incoming *queue.Circular[packet.Packet, *packet.Packet] logger *zerolog.Logger wg sync.WaitGroup error *atomic.Error @@ -85,7 +85,7 @@ func NewAsync(c net.Conn, logger *zerolog.Logger) (conn *Async) { conn: c, closed: atomic.NewBool(false), writer: bufio.NewWriterSize(c, DefaultBufferSize), - incoming: queue.NewCircular(DefaultBufferSize), + incoming: queue.NewCircular[packet.Packet, *packet.Packet](DefaultBufferSize), flusher: make(chan struct{}, 3), logger: logger, error: atomic.NewError(nil), @@ -174,11 +174,11 @@ func (c *Async) CloseChannel() <-chan struct{} { // // If packet.Metadata.ContentLength == 0, then the content array must be nil. Otherwise, it is required that packet.Metadata.ContentLength == len(content). func (c *Async) WritePacket(p *packet.Packet) error { - if int(p.Metadata.ContentLength) != len(p.Content.B) { + if int(p.Metadata.ContentLength) != len(*p.Content) { return InvalidContentLength } - encodedMetadata := metadata.Get() + encodedMetadata := metadata.GetBuffer() binary.BigEndian.PutUint16(encodedMetadata[metadata.IdOffset:metadata.IdOffset+metadata.IdSize], p.Metadata.Id) binary.BigEndian.PutUint16(encodedMetadata[metadata.OperationOffset:metadata.OperationOffset+metadata.OperationSize], p.Metadata.Operation) binary.BigEndian.PutUint32(encodedMetadata[metadata.ContentLengthOffset:metadata.ContentLengthOffset+metadata.ContentLengthSize], p.Metadata.ContentLength) @@ -190,7 +190,7 @@ func (c *Async) WritePacket(p *packet.Packet) error { } _, err := c.writer.Write(encodedMetadata[:]) - metadata.Put(encodedMetadata) + metadata.PutBuffer(encodedMetadata) if err != nil { c.Unlock() if c.closed.Load() { @@ -213,7 +213,7 @@ func (c *Async) WritePacket(p *packet.Packet) error { return c.closeWithError(err) } } - _, err = c.writer.Write(p.Content.B[:p.Metadata.ContentLength]) + _, err = c.writer.Write((*p.Content)[:p.Metadata.ContentLength]) if err != nil { c.Unlock() if c.closed.Load() { diff --git a/async_test.go b/async_test.go index 682f23c..a7e27cb 100644 --- a/async_test.go +++ b/async_test.go @@ -18,7 +18,8 @@ package frisbee import ( "crypto/rand" - "github.com/loopholelabs/frisbee/pkg/packet" + "github.com/loopholelabs/frisbee-go/pkg/packet" + "github.com/loopholelabs/polyglot-go" "github.com/loopholelabs/testing/conn/pair" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" @@ -27,6 +28,7 @@ import ( "io/ioutil" "net" "runtime" + "sync" "testing" "time" ) @@ -57,7 +59,7 @@ func TestNewAsync(t *testing.T) { assert.Equal(t, uint16(64), p.Metadata.Id) assert.Equal(t, uint16(32), p.Metadata.Operation) assert.Equal(t, uint32(0), p.Metadata.ContentLength) - assert.Equal(t, 0, len(p.Content.B)) + assert.Equal(t, 0, len(*p.Content)) data := make([]byte, packetSize) _, _ = rand.Read(data) @@ -76,8 +78,8 @@ func TestNewAsync(t *testing.T) { assert.Equal(t, uint16(64), p.Metadata.Id) assert.Equal(t, uint16(32), p.Metadata.Operation) assert.Equal(t, uint32(packetSize), p.Metadata.ContentLength) - assert.Equal(t, len(data), len(p.Content.B)) - assert.Equal(t, data, p.Content.B) + assert.Equal(t, len(data), len(*p.Content)) + assert.Equal(t, polyglot.Buffer(data), *p.Content) packet.Put(p) @@ -123,8 +125,8 @@ func TestAsyncLargeWrite(t *testing.T) { assert.Equal(t, uint16(64), p.Metadata.Id) assert.Equal(t, uint16(32), p.Metadata.Operation) assert.Equal(t, uint32(packetSize), p.Metadata.ContentLength) - assert.Equal(t, len(randomData[i]), len(p.Content.B)) - assert.Equal(t, randomData[i], p.Content.B) + assert.Equal(t, len(randomData[i]), len(*p.Content)) + assert.Equal(t, polyglot.Buffer(randomData[i]), *p.Content) packet.Put(p) } @@ -171,8 +173,8 @@ func TestAsyncRawConn(t *testing.T) { assert.Equal(t, uint16(64), p.Metadata.Id) assert.Equal(t, uint16(32), p.Metadata.Operation) assert.Equal(t, uint32(packetSize), p.Metadata.ContentLength) - assert.Equal(t, packetSize, len(p.Content.B)) - assert.Equal(t, randomData, p.Content.B) + assert.Equal(t, packetSize, len(*p.Content)) + assert.Equal(t, polyglot.Buffer(randomData), *p.Content) } rawReaderConn := readerConn.Raw() @@ -225,7 +227,7 @@ func TestAsyncReadClose(t *testing.T) { assert.Equal(t, uint16(64), p.Metadata.Id) assert.Equal(t, uint16(32), p.Metadata.Operation) assert.Equal(t, uint32(0), p.Metadata.ContentLength) - assert.Equal(t, 0, len(p.Content.B)) + assert.Equal(t, 0, len(*p.Content)) err = readerConn.conn.Close() assert.NoError(t, err) @@ -276,7 +278,7 @@ func TestAsyncReadAvailableClose(t *testing.T) { assert.Equal(t, uint16(64), p.Metadata.Id) assert.Equal(t, uint16(32), p.Metadata.Operation) assert.Equal(t, uint32(0), p.Metadata.ContentLength) - assert.Equal(t, 0, len(p.Content.B)) + assert.Equal(t, 0, len(*p.Content)) p, err = readerConn.ReadPacket() require.NoError(t, err) @@ -284,7 +286,7 @@ func TestAsyncReadAvailableClose(t *testing.T) { assert.Equal(t, uint16(64), p.Metadata.Id) assert.Equal(t, uint16(32), p.Metadata.Operation) assert.Equal(t, uint32(0), p.Metadata.ContentLength) - assert.Equal(t, 0, len(p.Content.B)) + assert.Equal(t, 0, len(*p.Content)) p, err = readerConn.ReadPacket() require.Error(t, err) @@ -323,7 +325,7 @@ func TestAsyncWriteClose(t *testing.T) { assert.Equal(t, uint16(64), p.Metadata.Id) assert.Equal(t, uint16(32), p.Metadata.Operation) assert.Equal(t, uint32(0), p.Metadata.ContentLength) - assert.Equal(t, 0, len(p.Content.B)) + assert.Equal(t, 0, len(*p.Content)) err = writerConn.WritePacket(p) assert.NoError(t, err) @@ -376,7 +378,7 @@ func TestAsyncTimeout(t *testing.T) { assert.Equal(t, uint16(64), p.Metadata.Id) assert.Equal(t, uint16(32), p.Metadata.Operation) assert.Equal(t, uint32(0), p.Metadata.ContentLength) - assert.Equal(t, 0, len(p.Content.B)) + assert.Equal(t, 0, len(*p.Content)) time.Sleep(defaultDeadline * 5) @@ -407,7 +409,7 @@ func TestAsyncTimeout(t *testing.T) { assert.Equal(t, uint16(64), p.Metadata.Id) assert.Equal(t, uint16(32), p.Metadata.Operation) assert.Equal(t, uint32(0), p.Metadata.ContentLength) - assert.Equal(t, 0, len(p.Content.B)) + assert.Equal(t, 0, len(*p.Content)) _, err = readerConn.ReadPacket() require.ErrorIs(t, err, ConnectionClosed) @@ -468,3 +470,107 @@ func BenchmarkAsyncThroughputNetwork(b *testing.B) { _ = readerConn.Close() _ = writerConn.Close() } + +func BenchmarkAsyncThroughputNetworkMultiple(b *testing.B) { + const testSize = 100 + + throughputRunner := func(testSize uint32, packetSize uint32, readerConn Conn, writerConn Conn) func(b *testing.B) { + return func(b *testing.B) { + var err error + + randomData := make([]byte, packetSize) + + p := packet.Get() + p.Metadata.Id = 64 + p.Metadata.Operation = 32 + p.Content.Write(randomData) + p.Metadata.ContentLength = packetSize + for i := 0; i < b.N; i++ { + done := make(chan struct{}, 1) + errCh := make(chan error, 1) + go func() { + for i := uint32(0); i < testSize; i++ { + p, err := readerConn.ReadPacket() + if err != nil { + errCh <- err + return + } + packet.Put(p) + } + done <- struct{}{} + }() + for i := uint32(0); i < testSize; i++ { + select { + case err = <-errCh: + b.Fatal(err) + default: + err = writerConn.WritePacket(p) + if err != nil { + b.Fatal(err) + } + } + } + select { + case <-done: + continue + case err = <-errCh: + b.Fatal(err) + } + } + + packet.Put(p) + } + } + + runner := func(numClients int, packetSize uint32) func(b *testing.B) { + return func(b *testing.B) { + var wg sync.WaitGroup + wg.Add(numClients) + b.SetBytes(int64(testSize * packetSize)) + b.ReportAllocs() + for i := 0; i < numClients; i++ { + go func() { + emptyLogger := zerolog.New(ioutil.Discard) + + reader, writer, err := pair.New() + if err != nil { + b.Error(err) + } + + readerConn := NewAsync(reader, &emptyLogger) + writerConn := NewAsync(writer, &emptyLogger) + throughputRunner(testSize, packetSize, readerConn, writerConn)(b) + + _ = readerConn.Close() + _ = writerConn.Close() + wg.Done() + }() + } + wg.Wait() + } + } + + b.Run("1 Pair, 32 Bytes", runner(1, 32)) + b.Run("2 Pair, 32 Bytes", runner(2, 32)) + b.Run("5 Pair, 32 Bytes", runner(5, 32)) + b.Run("10 Pair, 32 Bytes", runner(10, 32)) + b.Run("Half CPU Pair, 32 Bytes", runner(runtime.NumCPU()/2, 32)) + b.Run("CPU Pair, 32 Bytes", runner(runtime.NumCPU(), 32)) + b.Run("Double CPU Pair, 32 Bytes", runner(runtime.NumCPU()*2, 32)) + + b.Run("1 Pair, 512 Bytes", runner(1, 512)) + b.Run("2 Pair, 512 Bytes", runner(2, 512)) + b.Run("5 Pair, 512 Bytes", runner(5, 512)) + b.Run("10 Pair, 512 Bytes", runner(10, 512)) + b.Run("Half CPU Pair, 512 Bytes", runner(runtime.NumCPU()/2, 512)) + b.Run("CPU Pair, 512 Bytes", runner(runtime.NumCPU(), 512)) + b.Run("Double CPU Pair, 512 Bytes", runner(runtime.NumCPU()*2, 512)) + + b.Run("1 Pair, 4096 Bytes", runner(1, 4096)) + b.Run("2 Pair, 4096 Bytes", runner(2, 4096)) + b.Run("5 Pair, 4096 Bytes", runner(5, 4096)) + b.Run("10 Pair, 4096 Bytes", runner(10, 4096)) + b.Run("Half CPU Pair, 4096 Bytes", runner(runtime.NumCPU()/2, 4096)) + b.Run("CPU Pair, 4096 Bytes", runner(runtime.NumCPU(), 4096)) + b.Run("Double CPU Pair, 4096 Bytes", runner(runtime.NumCPU()*2, 4096)) +} diff --git a/client.go b/client.go index 595e76e..a17cdb3 100644 --- a/client.go +++ b/client.go @@ -18,7 +18,7 @@ package frisbee import ( "context" - "github.com/loopholelabs/frisbee/pkg/packet" + "github.com/loopholelabs/frisbee-go/pkg/packet" "github.com/rs/zerolog" "go.uber.org/atomic" "net" @@ -198,7 +198,7 @@ LOOP: packetCtx = c.PacketContext(packetCtx, p) } outgoing, action = handlerFunc(packetCtx, p) - if outgoing != nil && outgoing.Metadata.ContentLength == uint32(len(outgoing.Content.B)) { + if outgoing != nil && outgoing.Metadata.ContentLength == uint32(len(*outgoing.Content)) { err = c.conn.WritePacket(outgoing) if outgoing != p { packet.Put(outgoing) diff --git a/client_test.go b/client_test.go index cd355ae..ec8e1d5 100644 --- a/client_test.go +++ b/client_test.go @@ -19,8 +19,8 @@ package frisbee import ( "context" "crypto/rand" - "github.com/loopholelabs/frisbee/pkg/metadata" - "github.com/loopholelabs/frisbee/pkg/packet" + "github.com/loopholelabs/frisbee-go/pkg/metadata" + "github.com/loopholelabs/frisbee-go/pkg/packet" "github.com/loopholelabs/testing/conn/pair" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" diff --git a/conn.go b/conn.go index 1f7c664..840e9be 100644 --- a/conn.go +++ b/conn.go @@ -19,7 +19,7 @@ package frisbee import ( "context" "crypto/tls" - "github.com/loopholelabs/frisbee/pkg/packet" + "github.com/loopholelabs/frisbee-go/pkg/packet" "github.com/pkg/errors" "github.com/rs/zerolog" "io/ioutil" diff --git a/frisbee.go b/frisbee.go index e7130b8..dde0ef4 100644 --- a/frisbee.go +++ b/frisbee.go @@ -18,9 +18,9 @@ package frisbee import ( "context" - "github.com/loopholelabs/frisbee/pkg/content" - "github.com/loopholelabs/frisbee/pkg/metadata" - "github.com/loopholelabs/frisbee/pkg/packet" + "github.com/loopholelabs/frisbee-go/pkg/metadata" + "github.com/loopholelabs/frisbee-go/pkg/packet" + "github.com/loopholelabs/polyglot-go" "github.com/pkg/errors" "time" ) @@ -86,7 +86,7 @@ var ( Metadata: &metadata.Metadata{ Operation: HEARTBEAT, }, - Content: content.New(), + Content: polyglot.NewBuffer(), } // PINGPacket is a pre-allocated Frisbee Packet for PING Packets @@ -94,7 +94,7 @@ var ( Metadata: &metadata.Metadata{ Operation: PING, }, - Content: content.New(), + Content: polyglot.NewBuffer(), } // PONGPacket is a pre-allocated Frisbee Packet for PONG Packets @@ -102,7 +102,7 @@ var ( Metadata: &metadata.Metadata{ Operation: PONG, }, - Content: content.New(), + Content: polyglot.NewBuffer(), } ) diff --git a/frisbee_test.go b/frisbee_test.go index a209a0d..6d8ada8 100644 --- a/frisbee_test.go +++ b/frisbee_test.go @@ -18,8 +18,8 @@ package frisbee_test import ( "context" - "github.com/loopholelabs/frisbee" - "github.com/loopholelabs/frisbee/pkg/packet" + "github.com/loopholelabs/frisbee-go" + "github.com/loopholelabs/frisbee-go/pkg/packet" "github.com/rs/zerolog" "os" ) diff --git a/go.mod b/go.mod index 30316ea..4a9658e 100644 --- a/go.mod +++ b/go.mod @@ -1,14 +1,23 @@ -module github.com/loopholelabs/frisbee +module github.com/loopholelabs/frisbee-go -go 1.16 +go 1.18 require ( - github.com/google/go-cmp v0.5.6 // indirect + github.com/loopholelabs/common v0.2.0 + github.com/loopholelabs/polyglot-go v0.3.0 github.com/loopholelabs/testing v0.2.3 github.com/pkg/errors v0.9.1 - github.com/rs/zerolog v1.20.0 - github.com/stretchr/testify v1.7.0 - go.uber.org/atomic v1.7.0 + github.com/rs/zerolog v1.27.0 + github.com/stretchr/testify v1.8.0 + go.uber.org/atomic v1.9.0 go.uber.org/goleak v1.1.12 - google.golang.org/protobuf v1.27.1 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/mattn/go-colorable v0.1.12 // indirect + github.com/mattn/go-isatty v0.0.14 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 1cf1025..2de7cce 100644 --- a/go.sum +++ b/go.sum @@ -1,33 +1,41 @@ -github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= -github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/loopholelabs/common v0.2.0 h1:irIV5qsMK2ghO7Bu1XSXnN/6rtP5mazSE1WSBDhi/eg= +github.com/loopholelabs/common v0.2.0/go.mod h1:hB8T0eBGDcZ4cvBfDoD3yUs6bUzeZed9VqhzN4lI7nU= +github.com/loopholelabs/frisbee v0.5.0 h1:wB/PggORrg3IxqppNsSSvRdKG5w3F1A8tRxQhPVvmCI= +github.com/loopholelabs/polyglot-go v0.3.0 h1:iOqPw5B3krCMYfgDaPgPoh1A87ACE8lKdbpbERM58pY= +github.com/loopholelabs/polyglot-go v0.3.0/go.mod h1:9/Hr1nFO9Al46806vMP3DB2k8blQ3gazBPaoOsdgo34= github.com/loopholelabs/testing v0.2.3 h1:4nVuK5ctaE6ua5Z0dYk2l7xTFmcpCYLUeGjRBp8keOA= github.com/loopholelabs/testing v0.2.3/go.mod h1:gqtGY91soYD1fQoKQt/6kP14OYpS7gcbcIgq5mc9m8Q= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40= +github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= +github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= -github.com/rs/zerolog v1.20.0 h1:38k9hgtUBdxFwE34yS8rTHmHBa4eN16E4DJlv177LNs= -github.com/rs/zerolog v1.20.0/go.mod h1:IzD0RJ65iWH0w97OQQebJEvTZYvsCUm9WVLWBQrJRjo= +github.com/rs/xid v1.3.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/rs/zerolog v1.27.0 h1:1T7qCieN22GVc8S4Q2yuexzBb1EqjbgjSH9RohbMjKs= +github.com/rs/zerolog v1.27.0/go.mod h1:7frBqO0oezxmnO7GF86FY++uy8I0Tk/If5ni1G9Qc0U= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= -go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= -go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= +go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/goleak v1.1.12 h1:gZAh5/EyT/HQwlpkCy6wTpqfH9H8Lz8zbm3dZh+OyzA= go.uber.org/goleak v1.1.12/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= @@ -46,25 +54,23 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6 h1:foEbQz/B0Oz6YIqu/69kfXPYeFQAuuMYFkjaqXzl5Wo= +golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190828213141-aed303cbaa74/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.5 h1:ouewzE6p+/VEB31YYnTbEJdi8pFqKp4P4n85vwo3DHA= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.27.1 h1:SnqbnDw1V7RiZcXPx5MEeqPv2s79L9i7BJUlG/+RurQ= -google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/helpers_test.go b/helpers_test.go index 9d25ef6..82ad431 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -19,7 +19,7 @@ package frisbee import ( "testing" - "github.com/loopholelabs/frisbee/pkg/packet" + "github.com/loopholelabs/frisbee-go/pkg/packet" "go.uber.org/goleak" ) diff --git a/internal/queue/circular.go b/internal/queue/circular.go deleted file mode 100644 index 6c04242..0000000 --- a/internal/queue/circular.go +++ /dev/null @@ -1,170 +0,0 @@ -/* - Copyright 2022 Loophole Labs - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package queue - -import ( - "github.com/loopholelabs/frisbee/pkg/packet" - "sync" - "unsafe" -) - -type Circular struct { - _padding0 [8]uint64 //nolint:structcheck,unused - head uint64 - _padding1 [8]uint64 //nolint:structcheck,unused - tail uint64 - _padding2 [8]uint64 //nolint:structcheck,unused - maxSize uint64 - _padding3 [8]uint64 //nolint:structcheck,unused - closed bool - _padding4 [8]uint64 //nolint:structcheck,unused - lock *sync.Mutex - _padding5 [8]uint64 //nolint:structcheck,unused - notEmpty *sync.Cond - _padding6 [8]uint64 //nolint:structcheck,unused - notFull *sync.Cond - _padding7 [8]uint64 //nolint:structcheck,unused - nodes []unsafe.Pointer -} - -func NewCircular(maxSize uint64) *Circular { - q := &Circular{} - q.lock = &sync.Mutex{} - q.notFull = sync.NewCond(q.lock) - q.notEmpty = sync.NewCond(q.lock) - - q.head = 0 - q.tail = 0 - maxSize++ - if maxSize < 2 { - q.maxSize = 2 - } else { - q.maxSize = round(maxSize) - } - - q.nodes = make([]unsafe.Pointer, q.maxSize) - return q -} - -func (q *Circular) IsEmpty() (empty bool) { - q.lock.Lock() - empty = q.isEmpty() - q.lock.Unlock() - return -} - -func (q *Circular) isEmpty() bool { - return q.head == q.tail -} - -func (q *Circular) IsFull() (full bool) { - q.lock.Lock() - full = q.isFull() - q.lock.Unlock() - return -} - -func (q *Circular) isFull() bool { - return q.head == (q.tail+1)%q.maxSize -} - -func (q *Circular) IsClosed() (closed bool) { - q.lock.Lock() - closed = q.isClosed() - q.lock.Unlock() - return -} - -func (q *Circular) isClosed() bool { - return q.closed -} - -func (q *Circular) Length() (size int) { - q.lock.Lock() - size = q.length() - q.lock.Unlock() - return -} - -func (q *Circular) length() int { - if q.tail < q.head { - return int(q.maxSize - q.head + q.tail) - } - return int(q.tail - q.head) -} - -func (q *Circular) Close() { - q.lock.Lock() - q.closed = true - q.notFull.Broadcast() - q.notEmpty.Broadcast() - q.lock.Unlock() -} - -func (q *Circular) Push(p *packet.Packet) error { - q.lock.Lock() -LOOP: - if q.isClosed() { - q.lock.Unlock() - return Closed - } - if q.isFull() { - q.notFull.Wait() - goto LOOP - } - - q.nodes[q.tail] = unsafe.Pointer(p) - q.tail = (q.tail + 1) % q.maxSize - q.notEmpty.Signal() - q.lock.Unlock() - return nil -} - -func (q *Circular) Pop() (p *packet.Packet, err error) { - q.lock.Lock() -LOOP: - if q.isClosed() { - q.lock.Unlock() - return nil, Closed - } - if q.isEmpty() { - q.notEmpty.Wait() - goto LOOP - } - - p = (*packet.Packet)(q.nodes[q.head]) - q.head = (q.head + 1) % q.maxSize - q.notFull.Signal() - q.lock.Unlock() - return -} - -func (q *Circular) Drain() (packets []*packet.Packet) { - q.lock.Lock() - if q.isEmpty() { - q.lock.Unlock() - return nil - } - if size := int(q.head) - int(q.tail); size > 0 { - packets = make([]*packet.Packet, 0, size) - } else { - packets = make([]*packet.Packet, 0, -1*size) - } - for i := 0; i < cap(packets); i++ { - packets = append(packets, (*packet.Packet)(q.nodes[q.head])) - q.head = (q.head + 1) % q.maxSize - } - q.lock.Unlock() - return packets -} diff --git a/internal/queue/circular_test.go b/internal/queue/circular_test.go deleted file mode 100644 index 74c58ed..0000000 --- a/internal/queue/circular_test.go +++ /dev/null @@ -1,318 +0,0 @@ -/* - Copyright 2022 Loophole Labs - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package queue - -import ( - "testing" - "time" - - "github.com/loopholelabs/frisbee/pkg/packet" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestCircular(t *testing.T) { - t.Parallel() - - testPacket := packet.Get - testPacket2 := func() *packet.Packet { - p := packet.Get() - p.Content.Write([]byte{1}) - return p - } - - t.Run("success", func(t *testing.T) { - rb := NewCircular(1) - p := testPacket() - err := rb.Push(p) - assert.NoError(t, err) - actual, err := rb.Pop() - assert.NoError(t, err) - assert.Equal(t, p, actual) - }) - t.Run("out of capacity", func(t *testing.T) { - rb := NewCircular(0) - err := rb.Push(testPacket()) - assert.NoError(t, err) - }) - t.Run("out of capacity with non zero capacity, blocking", func(t *testing.T) { - rb := NewCircular(1) - p1 := testPacket() - err := rb.Push(p1) - assert.NoError(t, err) - doneCh := make(chan struct{}, 1) - p2 := testPacket2() - go func() { - err = rb.Push(p2) - assert.NoError(t, err) - doneCh <- struct{}{} - }() - select { - case <-doneCh: - t.Fatal("LockFree did not block on full write") - case <-time.After(time.Millisecond * 10): - actual, err := rb.Pop() - require.NoError(t, err) - assert.Equal(t, p1, actual) - select { - case <-doneCh: - actual, err := rb.Pop() - require.NoError(t, err) - assert.Equal(t, p2, actual) - case <-time.After(time.Millisecond * 10): - t.Fatal("Circular did not unblock on read from full write") - } - } - }) - t.Run("length calculations", func(t *testing.T) { - rb := NewCircular(1) - p1 := testPacket() - - err := rb.Push(p1) - assert.NoError(t, err) - assert.Equal(t, 1, rb.Length()) - assert.Equal(t, uint64(0), rb.head) - assert.Equal(t, uint64(1), rb.tail) - - actual, err := rb.Pop() - require.NoError(t, err) - assert.Equal(t, p1, actual) - assert.Equal(t, 0, rb.Length()) - assert.Equal(t, uint64(1), rb.head) - assert.Equal(t, uint64(1), rb.tail) - - err = rb.Push(p1) - assert.NoError(t, err) - assert.Equal(t, 1, rb.Length()) - assert.Equal(t, uint64(1), rb.head) - assert.Equal(t, uint64(0), rb.tail) - - rb = NewCircular(4) - - err = rb.Push(p1) - assert.NoError(t, err) - assert.Equal(t, 1, rb.Length()) - assert.Equal(t, uint64(0), rb.head) - assert.Equal(t, uint64(1), rb.tail) - - p2 := testPacket2() - err = rb.Push(p2) - assert.NoError(t, err) - assert.Equal(t, 2, rb.Length()) - assert.Equal(t, uint64(0), rb.head) - assert.Equal(t, uint64(2), rb.tail) - - err = rb.Push(p2) - assert.NoError(t, err) - assert.Equal(t, 3, rb.Length()) - assert.Equal(t, uint64(0), rb.head) - assert.Equal(t, uint64(3), rb.tail) - - actual, err = rb.Pop() - require.NoError(t, err) - assert.Equal(t, p1, actual) - assert.Equal(t, 2, rb.Length()) - assert.Equal(t, uint64(1), rb.head) - assert.Equal(t, uint64(3), rb.tail) - - actual, err = rb.Pop() - require.NoError(t, err) - assert.Equal(t, p2, actual) - assert.Equal(t, 1, rb.Length()) - assert.Equal(t, uint64(2), rb.head) - assert.Equal(t, uint64(3), rb.tail) - - err = rb.Push(p2) - assert.NoError(t, err) - assert.Equal(t, 2, rb.Length()) - assert.Equal(t, uint64(2), rb.head) - assert.Equal(t, uint64(4), rb.tail) - - err = rb.Push(p2) - assert.NoError(t, err) - assert.Equal(t, 3, rb.Length()) - assert.Equal(t, uint64(2), rb.head) - assert.Equal(t, uint64(5), rb.tail) - - actual, err = rb.Pop() - require.NoError(t, err) - assert.Equal(t, p2, actual) - assert.Equal(t, 2, rb.Length()) - assert.Equal(t, uint64(3), rb.head) - assert.Equal(t, uint64(5), rb.tail) - - actual, err = rb.Pop() - require.NoError(t, err) - assert.Equal(t, p2, actual) - assert.Equal(t, 1, rb.Length()) - assert.Equal(t, uint64(4), rb.head) - assert.Equal(t, uint64(5), rb.tail) - - actual, err = rb.Pop() - require.NoError(t, err) - assert.Equal(t, p2, actual) - assert.Equal(t, 0, rb.Length()) - assert.Equal(t, uint64(5), rb.head) - assert.Equal(t, uint64(5), rb.tail) - }) - t.Run("buffer closed", func(t *testing.T) { - rb := NewCircular(1) - assert.False(t, rb.IsClosed()) - rb.Close() - assert.True(t, rb.IsClosed()) - err := rb.Push(testPacket()) - assert.ErrorIs(t, Closed, err) - _, err = rb.Pop() - assert.ErrorIs(t, Closed, err) - }) - t.Run("pop empty", func(t *testing.T) { - done := make(chan struct{}, 1) - rb := NewCircular(1) - go func() { - _, _ = rb.Pop() - done <- struct{}{} - }() - assert.Equal(t, 0, len(done)) - _ = rb.Push(testPacket()) - <-done - assert.Equal(t, 0, rb.Length()) - }) - t.Run("partial overflow, blocking", func(t *testing.T) { - rb := NewCircular(4) - p1 := testPacket() - p1.Metadata.Id = 1 - - p2 := testPacket() - p2.Metadata.Id = 2 - - p3 := testPacket() - p3.Metadata.Id = 3 - - p4 := testPacket() - p4.Metadata.Id = 4 - - p5 := testPacket() - p5.Metadata.Id = 5 - - err := rb.Push(p1) - assert.NoError(t, err) - err = rb.Push(p2) - assert.NoError(t, err) - err = rb.Push(p3) - assert.NoError(t, err) - - assert.Equal(t, 3, rb.Length()) - - actual, err := rb.Pop() - assert.NoError(t, err) - assert.Equal(t, p1, actual) - assert.Equal(t, 2, rb.Length()) - - err = rb.Push(p4) - assert.NoError(t, err) - err = rb.Push(p5) - assert.NoError(t, err) - - assert.Equal(t, 4, rb.Length()) - - actual, err = rb.Pop() - assert.NoError(t, err) - assert.Equal(t, p2, actual) - - assert.Equal(t, 3, rb.Length()) - - actual, err = rb.Pop() - assert.NoError(t, err) - assert.Equal(t, p3, actual) - - assert.Equal(t, 2, rb.Length()) - - actual, err = rb.Pop() - assert.NoError(t, err) - assert.Equal(t, p4, actual) - - assert.Equal(t, 1, rb.Length()) - - actual, err = rb.Pop() - assert.NoError(t, err) - assert.Equal(t, p5, actual) - assert.NotEqual(t, p1, p5) - assert.Equal(t, 0, rb.Length()) - }) - t.Run("partial overflow, non-blocking", func(t *testing.T) { - rb := NewCircular(4) - p1 := testPacket() - p1.Metadata.Id = 1 - - p2 := testPacket() - p2.Metadata.Id = 2 - - p3 := testPacket() - p3.Metadata.Id = 3 - - p4 := testPacket() - p4.Metadata.Id = 4 - - p5 := testPacket() - p5.Metadata.Id = 5 - - p6 := testPacket() - p6.Metadata.Id = 6 - - err := rb.Push(p1) - assert.NoError(t, err) - err = rb.Push(p2) - assert.NoError(t, err) - err = rb.Push(p3) - assert.NoError(t, err) - err = rb.Push(p4) - assert.NoError(t, err) - - assert.Equal(t, 4, rb.Length()) - - err = rb.Push(p5) - assert.NoError(t, err) - - assert.Equal(t, 5, rb.Length()) - - err = rb.Push(p6) - assert.NoError(t, err) - - assert.Equal(t, 6, rb.Length()) - - actual, err := rb.Pop() - assert.NoError(t, err) - assert.Equal(t, p1, actual) - - assert.Equal(t, 5, rb.Length()) - - actual, err = rb.Pop() - assert.NoError(t, err) - assert.Equal(t, p2, actual) - - assert.Equal(t, 4, rb.Length()) - - actual, err = rb.Pop() - assert.NoError(t, err) - assert.Equal(t, p3, actual) - - assert.Equal(t, 3, rb.Length()) - - actual, err = rb.Pop() - assert.NoError(t, err) - assert.Equal(t, p4, actual) - assert.NotEqual(t, p1, p4) - assert.Equal(t, 2, rb.Length()) - }) -} diff --git a/internal/queue/lockfree.go b/internal/queue/lockfree.go deleted file mode 100644 index 271fdef..0000000 --- a/internal/queue/lockfree.go +++ /dev/null @@ -1,259 +0,0 @@ -/* - Copyright 2022 Loophole Labs - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package queue - -import ( - "github.com/loopholelabs/frisbee/pkg/packet" - "runtime" - "sync/atomic" - "unsafe" -) - -// node is a struct that keeps track of its own position as well as a piece of data -// stored as an unsafe.Pointer. Normally we would store the pointer to a packet.Packet -// directly, however benchmarking shows performance improvements with unsafe.Pointer instead -type node struct { - _padding0 [8]uint64 //nolint:structcheck,unused - position uint64 - _padding1 [8]uint64 //nolint:structcheck,unused - data unsafe.Pointer -} - -// nodes is a struct type containing a slice of node pointers -type nodes []*node - -// LockFree is the struct used to store a blocking or non-blocking FIFO queue of type *packet.Packet -// -// In it's non-blocking form it acts as a ringbuffer, overwriting old data when new data arrives. In its blocking -// form it waits for a space in the queue to open up before it adds the item to the LockFree. -type LockFree struct { - _padding0 [8]uint64 //nolint:structcheck,unused - head uint64 - _padding1 [8]uint64 //nolint:structcheck,unused - tail uint64 - _padding2 [8]uint64 //nolint:structcheck,unused - mask uint64 - _padding3 [8]uint64 //nolint:structcheck,unused - closed uint64 - _padding4 [8]uint64 //nolint:structcheck,unused - nodes nodes - _padding5 [8]uint64 //nolint:structcheck,unused - overflow func() (uint64, error) -} - -// NewLockFree creates a new LockFree with blocking or non-blocking behavior -func NewLockFree(size uint64, blocking bool) *LockFree { - q := new(LockFree) - if size < 1 { - size = 1 - } - if blocking { - q.overflow = q.blocker - } else { - q.overflow = q.unblocker - } - q.init(size) - return q -} - -// init actually initializes a queue and can be used in the future to reuse LockFree structs -// with their own pool -func (q *LockFree) init(size uint64) { - size = round(size) - q.nodes = make(nodes, size) - for i := uint64(0); i < size; i++ { - q.nodes[i] = &node{position: i} - } - q.mask = size - 1 -} - -// blocker is a LockFree.overflow function that blocks a Push operation from -// proceeding if the LockFree is ever full of data. -// -// If two Push operations happen simultaneously, blocker will block both of them until -// a Pop takes place, and unblock both of them at the same time. This can cause problems, -// however in our use case it won't because there shouldn't ever be more than one producer -// operating on the LockFree at any given time. There may be multiple consumers in the future, -// but that won't cause any problems. -// -// If we decide to use this as an MPMC LockFree instead of a SPMC LockFree (which is how we currently use it) -// then we can solve this bug by replacing the existing `default` switch case in the Push function with the -// following snippet: -// ``` -// default: -// head, err = q.overflow() -// if err != nil { -// return err -// } -// ``` -func (q *LockFree) blocker() (head uint64, err error) { -LOOP: - head = atomic.LoadUint64(&q.head) - if uint64(len(q.nodes)) == head-atomic.LoadUint64(&q.tail) { - if atomic.LoadUint64(&q.closed) == 1 { - err = Closed - return - } - runtime.Gosched() - goto LOOP - } - return -} - -// unblocker is a LockFree.overflow function that unblocks a Push operation from -// proceeding if the LockFree is full of data. It does this by adding its own Pop() -// operation before proceeding with the Push attempt. -// -// If two Push operations happen simultaneously, unblocker will unblock them both -// by running two Pop() operations. This function will also be called whenever there -// is a Push conflict (when two Push operations attempt to modify the queue concurrently). -// -// In highly concurrent situations we may lose more data than we should, however since we will -// be using this as a SPMC LockFree, this conflict will never arise. -func (q *LockFree) unblocker() (head uint64, err error) { - head = atomic.LoadUint64(&q.head) - if uint64(len(q.nodes)) == head-atomic.LoadUint64(&q.tail) { - var p *packet.Packet - p, err = q.Pop() - packet.Put(p) - if err != nil { - return - } - } - return -} - -// Push appends an item of type *packet.Packet to the LockFree, and will block -// until the item is pushed successfully (with the blocking function depending -// on whether this is a blocking LockFree). -// -// This method is not meant to be used concurrently, and the LockFree is meant to operate -// as an SPMC LockFree with one producer operating at a time. If we want to use this as an MPMC LockFree -// we can modify this Push function by replacing the existing `default` switch case with the -// following snippet: -// ``` -// default: -// head, err = q.overflow() -// if err != nil { -// return err -// } -// ``` -func (q *LockFree) Push(item *packet.Packet) error { - var newNode *node - head, err := q.overflow() - if err != nil { - return err - } -RETRY: - for { - if atomic.LoadUint64(&q.closed) == 1 { - return Closed - } - - newNode = q.nodes[head&q.mask] - switch dif := atomic.LoadUint64(&newNode.position) - head; { - case dif == 0: - if atomic.CompareAndSwapUint64(&q.head, head, head+1) { - break RETRY - } - default: - head = atomic.LoadUint64(&q.head) - } - runtime.Gosched() - } - newNode.data = unsafe.Pointer(item) - atomic.StoreUint64(&newNode.position, head+1) - return nil -} - -// Pop removes an item from the start of the LockFree and returns it to the caller. -// This method blocks until an item is available, but unblocks when the LockFree is closed. -// This allows for long-term listeners to wait on the LockFree until either an item is available -// or the LockFree is closed. -// -// This method is safe to be used concurrently and is even optimized for the SPMC use case. -func (q *LockFree) Pop() (*packet.Packet, error) { - var oldNode *node - var oldPosition = atomic.LoadUint64(&q.tail) -RETRY: - if atomic.LoadUint64(&q.closed) == 1 { - return nil, Closed - } - - oldNode = q.nodes[oldPosition&q.mask] - switch dif := atomic.LoadUint64(&oldNode.position) - (oldPosition + 1); { - case dif == 0: - if atomic.CompareAndSwapUint64(&q.tail, oldPosition, oldPosition+1) { - goto DONE - } - default: - oldPosition = atomic.LoadUint64(&q.tail) - } - runtime.Gosched() - goto RETRY -DONE: - data := oldNode.data - oldNode.data = nil - atomic.StoreUint64(&oldNode.position, oldPosition+q.mask+1) - return (*packet.Packet)(data), nil -} - -// Close marks the LockFree as closed, returns any waiting Pop() calls, -// and blocks all future Push calls from occurring. -func (q *LockFree) Close() { - atomic.CompareAndSwapUint64(&q.closed, 0, 1) -} - -// IsClosed returns whether the LockFree has been closed -func (q *LockFree) IsClosed() bool { - return atomic.LoadUint64(&q.closed) == 1 -} - -// Length is the current number of items in the LockFree -func (q *LockFree) Length() int { - return int(atomic.LoadUint64(&q.head) - atomic.LoadUint64(&q.tail)) -} - -// Drain drains all the current packets in the queue and returns them to the caller. -// -// It is an unsafe function that should only be used once, only after the queue has been closed, -// and only while there are no producers writing to it. If used incorrectly it has the potential -// to infinitely block the caller. If used correctly, it allows a single caller to drain any remaining -// packets in the queue after the queue has been closed. -func (q *LockFree) Drain() []*packet.Packet { - length := q.Length() - packets := make([]*packet.Packet, 0, length) - for i := 0; i < length; i++ { - var oldNode *node - var oldPosition = atomic.LoadUint64(&q.tail) - RETRY: - oldNode = q.nodes[oldPosition&q.mask] - switch dif := atomic.LoadUint64(&oldNode.position) - (oldPosition + 1); { - case dif == 0: - if atomic.CompareAndSwapUint64(&q.tail, oldPosition, oldPosition+1) { - goto DONE - } - default: - oldPosition = atomic.LoadUint64(&q.tail) - } - runtime.Gosched() - goto RETRY - DONE: - data := oldNode.data - oldNode.data = nil - atomic.StoreUint64(&oldNode.position, oldPosition+q.mask+1) - packets = append(packets, (*packet.Packet)(data)) - } - return packets -} diff --git a/internal/queue/lockfree_test.go b/internal/queue/lockfree_test.go deleted file mode 100644 index 41e4ff5..0000000 --- a/internal/queue/lockfree_test.go +++ /dev/null @@ -1,242 +0,0 @@ -/* - Copyright 2022 Loophole Labs - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package queue - -import ( - "testing" - "time" - - "github.com/loopholelabs/frisbee/pkg/packet" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestLockFree(t *testing.T) { - t.Parallel() - - testPacket := packet.Get - testPacket2 := func() *packet.Packet { - p := packet.Get() - p.Content.Write([]byte{1}) - return p - } - - t.Run("success", func(t *testing.T) { - rb := NewLockFree(1, false) - p := testPacket() - err := rb.Push(p) - assert.NoError(t, err) - actual, err := rb.Pop() - assert.NoError(t, err) - assert.Equal(t, p, actual) - }) - t.Run("out of capacity", func(t *testing.T) { - rb := NewLockFree(0, false) - err := rb.Push(testPacket()) - assert.NoError(t, err) - }) - t.Run("out of capacity with non zero capacity, blocking", func(t *testing.T) { - rb := NewLockFree(1, true) - p1 := testPacket() - err := rb.Push(p1) - assert.NoError(t, err) - doneCh := make(chan struct{}, 1) - p2 := testPacket2() - go func() { - err = rb.Push(p2) - assert.NoError(t, err) - doneCh <- struct{}{} - }() - select { - case <-doneCh: - t.Fatal("LockFree did not block on full write") - case <-time.After(time.Millisecond * 10): - actual, err := rb.Pop() - require.NoError(t, err) - assert.Equal(t, p1, actual) - select { - case <-doneCh: - actual, err := rb.Pop() - require.NoError(t, err) - assert.Equal(t, p2, actual) - case <-time.After(time.Millisecond * 10): - t.Fatal("LockFree did not unblock on read from full write") - } - } - }) - t.Run("out of capacity with non zero capacity, non-blocking", func(t *testing.T) { - rb := NewLockFree(1, false) - p1 := testPacket() - err := rb.Push(p1) - assert.NoError(t, err) - assert.Equal(t, 1, rb.Length()) - p2 := testPacket2() - err = rb.Push(p2) - assert.NoError(t, err) - assert.Equal(t, 1, rb.Length()) - actual, err := rb.Pop() - require.NoError(t, err) - assert.Equal(t, p2, actual) - assert.Equal(t, 0, rb.Length()) - }) - t.Run("buffer closed", func(t *testing.T) { - rb := NewLockFree(1, false) - assert.False(t, rb.IsClosed()) - rb.Close() - assert.True(t, rb.IsClosed()) - err := rb.Push(testPacket()) - assert.ErrorIs(t, Closed, err) - _, err = rb.Pop() - assert.ErrorIs(t, Closed, err) - }) - t.Run("pop empty", func(t *testing.T) { - done := make(chan struct{}, 1) - rb := NewLockFree(1, false) - go func() { - _, _ = rb.Pop() - done <- struct{}{} - }() - assert.Equal(t, 0, len(done)) - _ = rb.Push(testPacket()) - <-done - assert.Equal(t, 0, rb.Length()) - }) - t.Run("partial overflow, blocking", func(t *testing.T) { - rb := NewLockFree(4, true) - p1 := testPacket() - p1.Metadata.Id = 1 - - p2 := testPacket() - p2.Metadata.Id = 2 - - p3 := testPacket() - p3.Metadata.Id = 3 - - p4 := testPacket() - p4.Metadata.Id = 4 - - p5 := testPacket() - p5.Metadata.Id = 5 - - err := rb.Push(p1) - assert.NoError(t, err) - err = rb.Push(p2) - assert.NoError(t, err) - err = rb.Push(p3) - assert.NoError(t, err) - - assert.Equal(t, 3, rb.Length()) - - actual, err := rb.Pop() - assert.NoError(t, err) - assert.Equal(t, p1, actual) - assert.Equal(t, 2, rb.Length()) - - err = rb.Push(p4) - assert.NoError(t, err) - err = rb.Push(p5) - assert.NoError(t, err) - - assert.Equal(t, 4, rb.Length()) - - actual, err = rb.Pop() - assert.NoError(t, err) - assert.Equal(t, p2, actual) - - assert.Equal(t, 3, rb.Length()) - - actual, err = rb.Pop() - assert.NoError(t, err) - assert.Equal(t, p3, actual) - - assert.Equal(t, 2, rb.Length()) - - actual, err = rb.Pop() - assert.NoError(t, err) - assert.Equal(t, p4, actual) - - assert.Equal(t, 1, rb.Length()) - - actual, err = rb.Pop() - assert.NoError(t, err) - assert.Equal(t, p5, actual) - assert.NotEqual(t, p1, p5) - assert.Equal(t, 0, rb.Length()) - }) - t.Run("partial overflow, non-blocking", func(t *testing.T) { - rb := NewLockFree(4, false) - p1 := testPacket() - p1.Metadata.Id = 1 - - p2 := testPacket() - p2.Metadata.Id = 2 - - p3 := testPacket() - p3.Metadata.Id = 3 - - p4 := testPacket() - p4.Metadata.Id = 4 - - p5 := testPacket() - p5.Metadata.Id = 5 - - p6 := testPacket() - p6.Metadata.Id = 6 - - err := rb.Push(p1) - assert.NoError(t, err) - err = rb.Push(p2) - assert.NoError(t, err) - err = rb.Push(p3) - assert.NoError(t, err) - err = rb.Push(p4) - assert.NoError(t, err) - - assert.Equal(t, 4, rb.Length()) - - err = rb.Push(p5) - assert.NoError(t, err) - - assert.Equal(t, 4, rb.Length()) - - err = rb.Push(p6) - assert.NoError(t, err) - - assert.Equal(t, 4, rb.Length()) - - actual, err := rb.Pop() - assert.NoError(t, err) - assert.Equal(t, p3, actual) - - assert.Equal(t, 3, rb.Length()) - - actual, err = rb.Pop() - assert.NoError(t, err) - assert.Equal(t, p4, actual) - - assert.Equal(t, 2, rb.Length()) - - actual, err = rb.Pop() - assert.NoError(t, err) - assert.Equal(t, p5, actual) - - assert.Equal(t, 1, rb.Length()) - - actual, err = rb.Pop() - assert.NoError(t, err) - assert.Equal(t, p6, actual) - assert.NotEqual(t, p1, p6) - assert.Equal(t, 0, rb.Length()) - }) -} diff --git a/internal/queue/queue.go b/internal/queue/queue.go deleted file mode 100644 index 51e2d64..0000000 --- a/internal/queue/queue.go +++ /dev/null @@ -1,35 +0,0 @@ -/* - Copyright 2022 Loophole Labs - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package queue - -import ( - "github.com/pkg/errors" -) - -var ( - Closed = errors.New("queue is closed") -) - -// round takes an uint64 value and rounds up to the nearest power of 2 -func round(value uint64) uint64 { - value-- - value |= value >> 1 - value |= value >> 2 - value |= value >> 4 - value |= value >> 8 - value |= value >> 16 - value |= value >> 32 - value++ - return value -} diff --git a/internal/queue/queue_test.go b/internal/queue/queue_test.go deleted file mode 100644 index 0bba933..0000000 --- a/internal/queue/queue_test.go +++ /dev/null @@ -1,44 +0,0 @@ -/* - Copyright 2022 Loophole Labs - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package queue - -import ( - "github.com/stretchr/testify/assert" - "testing" -) - -func TestRound(t *testing.T) { - t.Parallel() - tcs := []struct { - in uint64 - expected uint64 - }{ - {in: 0, expected: 0x0}, - {in: 1, expected: 0x1}, - {in: 2, expected: 0x2}, - {in: 3, expected: 0x4}, - {in: 4, expected: 0x4}, - {in: 5, expected: 0x8}, - {in: 7, expected: 0x8}, - {in: 8, expected: 0x8}, - {in: 9, expected: 0x10}, - {in: 16, expected: 0x10}, - {in: 32, expected: 0x20}, - {in: 0xFFFFFFF0, expected: 0x100000000}, - {in: 0xFFFFFFFF, expected: 0x100000000}, - } - for _, tc := range tcs { - assert.Equalf(t, tc.expected, round(tc.in), "in: %d", tc.in) - } -} diff --git a/pkg/content/content.go b/pkg/content/content.go deleted file mode 100644 index 067fa0a..0000000 --- a/pkg/content/content.go +++ /dev/null @@ -1,49 +0,0 @@ -/* - Copyright 2022 Loophole Labs - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package content - -const ( - defaultSize = 512 -) - -// Content is a packet's byte buffer that grows and is recycled as required -type Content struct { - B []byte -} - -// New returns a new Content struct with an allocated array of size defaultSize -func New() *Content { - return &Content{ - B: make([]byte, 0, defaultSize), - } -} - -// Reset resets the underlying byte slice for future use -func (c *Content) Reset() { - c.B = c.B[:0] -} - -// Write efficiently copies the byte slice b into the content buffer, however it -// does *not* update the content length. -func (c *Content) Write(b []byte) int { - if cap(c.B)-len(c.B) < len(b) { - c.B = append(c.B[:len(c.B)], b...) - } else { - c.B = c.B[:len(c.B)+copy(c.B[len(c.B):cap(c.B)], b)] - } - return len(b) -} diff --git a/pkg/metadata/metadata.go b/pkg/metadata/metadata.go index 2bf81bf..e1455b0 100644 --- a/pkg/metadata/metadata.go +++ b/pkg/metadata/metadata.go @@ -55,41 +55,24 @@ type Metadata struct { ContentLength uint32 // 4 Bytes } -type Handler struct{} - -func NewDefaultHandler() Handler { - return NewHandler() -} - -func NewHandler() Handler { - return Handler{} -} - -func (*Handler) Encode(id, operation uint16, contentLength uint32) ([Size]byte, error) { - return Encode(id, operation, contentLength) -} - -func (*Handler) Decode(buf []byte) (Metadata, error) { - return Decode(buf) -} - // Encode Metadata -func (fm *Metadata) Encode() (result [Size]byte, err error) { +func (fm *Metadata) Encode() (b *Buffer, err error) { defer func() { if recoveredErr := recover(); recoveredErr != nil { err = errors.Wrap(recoveredErr.(error), Encoding.Error()) } }() - binary.BigEndian.PutUint16(result[IdOffset:IdOffset+IdSize], fm.Id) - binary.BigEndian.PutUint16(result[OperationOffset:OperationOffset+OperationSize], fm.Operation) - binary.BigEndian.PutUint32(result[ContentLengthOffset:ContentLengthOffset+ContentLengthSize], fm.ContentLength) + b = NewBuffer() + binary.BigEndian.PutUint16(b[IdOffset:IdOffset+IdSize], fm.Id) + binary.BigEndian.PutUint16(b[OperationOffset:OperationOffset+OperationSize], fm.Operation) + binary.BigEndian.PutUint32(b[ContentLengthOffset:ContentLengthOffset+ContentLengthSize], fm.ContentLength) return } // Decode Metadata -func (fm *Metadata) Decode(buf [Size]byte) (err error) { +func (fm *Metadata) Decode(buf *Buffer) (err error) { defer func() { if recoveredErr := recover(); recoveredErr != nil { err = errors.Wrap(recoveredErr.(error), Decoding.Error()) @@ -103,8 +86,7 @@ func (fm *Metadata) Decode(buf [Size]byte) (err error) { return nil } -// Encode without a Handler -func Encode(id, operation uint16, contentLength uint32) ([Size]byte, error) { +func Encode(id, operation uint16, contentLength uint32) (*Buffer, error) { metadata := Metadata{ Id: id, Operation: operation, @@ -114,13 +96,11 @@ func Encode(id, operation uint16, contentLength uint32) ([Size]byte, error) { return metadata.Encode() } -// Decode without a Handler -func Decode(buf []byte) (metadata Metadata, err error) { +func Decode(buf []byte) (*Metadata, error) { if len(buf) < Size { - return Metadata{}, InvalidBufferLength + return nil, InvalidBufferLength } - err = metadata.Decode(*(*[Size]byte)(unsafe.Pointer(&buf[0]))) - - return + m := new(Metadata) + return m, m.Decode((*Buffer)(unsafe.Pointer(&buf[0]))) } diff --git a/pkg/metadata/metadata_test.go b/pkg/metadata/metadata_test.go index ee86678..786bea2 100644 --- a/pkg/metadata/metadata_test.go +++ b/pkg/metadata/metadata_test.go @@ -32,7 +32,7 @@ func TestMessageEncodeDecode(t *testing.T) { ContentLength: uint32(0), } - correct := [Size]byte{} + correct := NewBuffer() binary.BigEndian.PutUint16(correct[IdOffset:IdOffset+IdSize], uint16(64)) binary.BigEndian.PutUint16(correct[OperationOffset:OperationOffset+OperationSize], PacketProbe) @@ -49,43 +49,6 @@ func TestMessageEncodeDecode(t *testing.T) { assert.Equal(t, message, decoderMessage) } -func TestDefaultHandler(t *testing.T) { - t.Parallel() - - defaultHandler := NewDefaultHandler() - assert.Equal(t, defaultHandler, NewHandler()) -} - -func TestEncodeDecodeHandler(t *testing.T) { - t.Parallel() - - handler := NewHandler() - - encodedBytes, err := handler.Encode(64, PacketProbe, 512) - assert.NoError(t, err) - - message, err := handler.Decode(encodedBytes[:]) - require.NoError(t, err) - assert.Equal(t, uint32(512), message.ContentLength) - assert.Equal(t, uint16(64), message.Id) - assert.Equal(t, PacketProbe, message.Operation) - - emptyEncodedBytes, err := handler.Encode(64, PacketPing, 0) - assert.Equal(t, nil, err) - - emptyMessage, err := handler.Decode(emptyEncodedBytes[:]) - require.NoError(t, err) - assert.Equal(t, uint32(0), emptyMessage.ContentLength) - assert.Equal(t, uint16(64), emptyMessage.Id) - assert.Equal(t, PacketPing, emptyMessage.Operation) - - invalidMessage, err := handler.Decode(emptyEncodedBytes[8:]) - require.Error(t, err) - assert.ErrorIs(t, InvalidBufferLength, err) - assert.Equal(t, uint32(0), invalidMessage.ContentLength) - assert.Equal(t, uint16(0), invalidMessage.Id) -} - func TestEncodeDecode(t *testing.T) { t.Parallel() @@ -110,36 +73,7 @@ func TestEncodeDecode(t *testing.T) { invalidMessage, err := Decode(emptyEncodedBytes[1:]) require.Error(t, err) assert.ErrorIs(t, InvalidBufferLength, err) - assert.Equal(t, uint32(0), invalidMessage.ContentLength) - assert.Equal(t, uint16(0), invalidMessage.Id) -} - -func BenchmarkEncodeHandler(b *testing.B) { - handler := NewHandler() - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = handler.Encode(uint16(i), PacketProbe, 512) - } -} - -func BenchmarkDecodeHandler(b *testing.B) { - handler := NewHandler() - encodedMessage, _ := handler.Encode(0, PacketProbe, 512) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = handler.Decode(encodedMessage[:]) - } -} - -func BenchmarkEncodeDecodeHandler(b *testing.B) { - handler := NewHandler() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - encodedMessage, _ := handler.Encode(uint16(i), PacketProbe, 512) - _, _ = handler.Decode(encodedMessage[:]) - } + assert.Nil(t, invalidMessage) } func BenchmarkEncode(b *testing.B) { diff --git a/pkg/metadata/pool.go b/pkg/metadata/pool.go index f74c3f9..c5d112e 100644 --- a/pkg/metadata/pool.go +++ b/pkg/metadata/pool.go @@ -16,36 +16,26 @@ package metadata -import "sync" +import ( + "github.com/loopholelabs/common/pkg/pool" +) -type Pool struct { - pool sync.Pool -} +type Buffer [Size]byte -func NewPool() *Pool { - return new(Pool) +func NewBuffer() *Buffer { + return new(Buffer) } -func (p *Pool) Get() *[Size]byte { - v := p.pool.Get() - if v == nil { - v = &[Size]byte{} - } - return v.(*[Size]byte) -} - -func (p *Pool) Put(b *[Size]byte) { - p.pool.Put(b) -} +func (b *Buffer) Reset() {} var ( - pool = NewPool() + bufferPool = pool.NewPool[Buffer, *Buffer](NewBuffer) ) -func Get() *[Size]byte { - return pool.Get() +func GetBuffer() *Buffer { + return bufferPool.Get() } -func Put(b *[Size]byte) { - pool.Put(b) +func PutBuffer(b *Buffer) { + bufferPool.Put(b) } diff --git a/pkg/packet/decode.go b/pkg/packet/decode.go deleted file mode 100644 index b87d673..0000000 --- a/pkg/packet/decode.go +++ /dev/null @@ -1,230 +0,0 @@ -/* - Copyright 2022 Loophole Labs - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package packet - -import ( - "math" - - "github.com/pkg/errors" -) - -const ( - emptyString = "" -) - -var ( - InvalidSlice = errors.New("invalid slice encoding") - InvalidMap = errors.New("invalid map encoding") - InvalidBytes = errors.New("invalid bytes encoding") - InvalidString = errors.New("invalid string encoding") - InvalidError = errors.New("invalid error encoding") - InvalidBool = errors.New("invalid bool encoding") - InvalidUint8 = errors.New("invalid uint8 encoding") - InvalidUint16 = errors.New("invalid uint16 encoding") - InvalidUint32 = errors.New("invalid uint32 encoding") - InvalidUint64 = errors.New("invalid uint64 encoding") - InvalidInt32 = errors.New("invalid int32 encoding") - InvalidInt64 = errors.New("invalid int64 encoding") - InvalidFloat32 = errors.New("invalid float32 encoding") - InvalidFloat64 = errors.New("invalid float64 encoding") -) - -func decodeNil(b []byte) ([]byte, bool) { - if len(b) > 0 { - if b[0] == NilKind[0] { - return b[1:], true - } - } - return b, false -} - -func decodeMap(b []byte, keyKind, valueKind Kind) ([]byte, uint32, error) { - if len(b) > 2 { - if b[0] == MapKind[0] && b[1] == keyKind[0] && b[2] == valueKind[0] { - var size uint32 - var err error - b, size, err = decodeUint32(b[3:]) - if err != nil { - return b, 0, InvalidMap - } - return b, size, nil - } - } - return b, 0, InvalidMap -} - -func decodeSlice(b []byte, kind Kind) ([]byte, uint32, error) { - if len(b) > 1 { - if b[0] == SliceKind[0] && b[1] == kind[0] { - var size uint32 - var err error - b, size, err = decodeUint32(b[2:]) - if err != nil { - return b, 0, InvalidSlice - } - return b, size, nil - } - } - return b, 0, InvalidSlice -} - -func decodeBytes(b, ret []byte) ([]byte, []byte, error) { - if len(b) > 0 { - if b[0] == BytesKind[0] { - var size uint32 - var err error - b, size, err = decodeUint32(b[1:]) - if err != nil { - return b, nil, InvalidBytes - } - if len(b) > int(size)-1 { - if len(ret) < int(size) { - if ret == nil { - ret = make([]byte, size) - copy(ret, b[:size]) - } else { - ret = append(ret[:0], b[:size]...) - } - } else { - copy(ret[0:], b[:size]) - } - return b[size:], ret, nil - } - } - } - return b, nil, InvalidBytes -} - -func decodeString(b []byte) ([]byte, string, error) { - if len(b) > 0 { - if b[0] == StringKind[0] { - var size uint32 - var err error - b, size, err = decodeUint32(b[1:]) - if err != nil { - return b, emptyString, InvalidString - } - if len(b) > int(size)-1 { - return b[size:], string(b[:size]), nil - } - } - } - return b, emptyString, InvalidString -} - -func decodeError(b []byte) ([]byte, error, error) { - if len(b) > 0 { - if b[0] == ErrorKind[0] { - var val string - var err error - b, val, err = decodeString(b[1:]) - if err != nil { - return b, nil, InvalidError - } - return b, Error(val), nil - } - } - return b, nil, InvalidError -} - -func decodeBool(b []byte) ([]byte, bool, error) { - if len(b) > 1 { - if b[0] == BoolKind[0] { - if b[1] == trueBool { - return b[2:], true, nil - } else { - return b[2:], false, nil - } - } - } - return b, false, InvalidBool -} - -func decodeUint8(b []byte) ([]byte, uint8, error) { - if len(b) > 1 { - if b[0] == Uint8Kind[0] { - return b[2:], b[1], nil - } - } - return b, 0, InvalidUint8 -} - -func decodeUint16(b []byte) ([]byte, uint16, error) { - if len(b) > 2 { - if b[0] == Uint16Kind[0] { - return b[3:], uint16(b[2]) | uint16(b[1])<<8, nil - } - } - return b, 0, InvalidUint16 -} - -func decodeUint32(b []byte) ([]byte, uint32, error) { - if len(b) > 4 { - if b[0] == Uint32Kind[0] { - return b[5:], uint32(b[4]) | uint32(b[3])<<8 | uint32(b[2])<<16 | uint32(b[1])<<24, nil - } - } - return b, 0, InvalidUint32 -} - -func decodeUint64(b []byte) ([]byte, uint64, error) { - if len(b) > 8 { - if b[0] == Uint64Kind[0] { - return b[9:], uint64(b[8]) | uint64(b[7])<<8 | uint64(b[6])<<16 | uint64(b[5])<<24 | - uint64(b[4])<<32 | uint64(b[3])<<40 | uint64(b[2])<<48 | uint64(b[1])<<56, nil - } - } - return b, 0, InvalidUint64 -} - -func decodeInt32(b []byte) ([]byte, int32, error) { - if len(b) > 4 { - if b[0] == Int32Kind[0] { - return b[5:], int32(uint32(b[4]) | uint32(b[3])<<8 | uint32(b[2])<<16 | uint32(b[1])<<24), nil - } - } - return b, 0, InvalidInt32 -} - -func decodeInt64(b []byte) ([]byte, int64, error) { - if len(b) > 8 { - if b[0] == Int64Kind[0] { - return b[9:], int64(uint64(b[8]) | uint64(b[7])<<8 | uint64(b[6])<<16 | uint64(b[5])<<24 | - uint64(b[4])<<32 | uint64(b[3])<<40 | uint64(b[2])<<48 | uint64(b[1])<<56), nil - } - } - return b, 0, InvalidInt64 -} - -func decodeFloat32(b []byte) ([]byte, float32, error) { - if len(b) > 4 { - if b[0] == Float32Kind[0] { - return b[5:], math.Float32frombits(uint32(b[4]) | uint32(b[3])<<8 | uint32(b[2])<<16 | uint32(b[1])<<24), nil - } - } - return b, 0, InvalidFloat32 -} - -func decodeFloat64(b []byte) ([]byte, float64, error) { - if len(b) > 8 { - if b[0] == Float64Kind[0] { - return b[9:], math.Float64frombits(uint64(b[8]) | uint64(b[7])<<8 | uint64(b[6])<<16 | uint64(b[5])<<24 | - uint64(b[4])<<32 | uint64(b[3])<<40 | uint64(b[2])<<48 | uint64(b[1])<<56), nil - } - } - return b, 0, InvalidFloat64 -} diff --git a/pkg/packet/decode_test.go b/pkg/packet/decode_test.go deleted file mode 100644 index 75dad61..0000000 --- a/pkg/packet/decode_test.go +++ /dev/null @@ -1,606 +0,0 @@ -/* - Copyright 2022 Loophole Labs - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package packet - -import ( - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestDecodeNil(t *testing.T) { - t.Parallel() - - p := Get() - encodeNil(p) - - var value bool - - remaining, value := decodeNil(p.Content.B) - assert.True(t, value) - assert.Equal(t, 0, len(remaining)) - - _, value = decodeNil(p.Content.B[1:]) - assert.False(t, value) - - remaining, value = decodeNil(p.Content.B) - assert.True(t, value) - assert.Equal(t, 0, len(remaining)) - - p.Content.B[len(p.Content.B)-1] = 'S' - assert.True(t, value) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - encodeNil(p) - _, _ = decodeNil(p.Content.B) - p.Content.Reset() - }) - assert.Zero(t, n) - Put(p) -} - -func TestDecodeMap(t *testing.T) { - t.Parallel() - - p := Get() - encodeMap(p, 32, StringKind, Uint32Kind) - - remaining, size, err := decodeMap(p.Content.B, StringKind, Uint32Kind) - assert.NoError(t, err) - assert.Equal(t, uint32(32), size) - assert.Equal(t, 0, len(remaining)) - - _, _, err = decodeMap(p.Content.B[1:], StringKind, Uint32Kind) - assert.ErrorIs(t, err, InvalidMap) - - _, _, err = decodeMap(p.Content.B, StringKind, Float64Kind) - assert.ErrorIs(t, err, InvalidMap) - - remaining, size, err = decodeMap(p.Content.B, StringKind, Uint32Kind) - assert.NoError(t, err) - assert.Equal(t, uint32(32), size) - assert.Equal(t, 0, len(remaining)) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - encodeNil(p) - remaining, size, err = decodeMap(p.Content.B, StringKind, Uint32Kind) - p.Content.Reset() - }) - assert.Zero(t, n) - Put(p) -} - -func TestDecodeBytes(t *testing.T) { - t.Parallel() - - p := Get() - v := []byte("Test Bytes") - encodeBytes(p, v) - - var value []byte - - remaining, value, err := decodeBytes(p.Content.B, value) - assert.NoError(t, err) - assert.Equal(t, v, value) - assert.Equal(t, 0, len(remaining)) - - _, value, err = decodeBytes(p.Content.B[1:], value) - assert.ErrorIs(t, err, InvalidBytes) - - remaining, value, err = decodeBytes(p.Content.B, value) - assert.NoError(t, err) - assert.Equal(t, v, value) - assert.Equal(t, 0, len(remaining)) - - p.Content.B[len(p.Content.B)-1] = 'S' - assert.Equal(t, v, value) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - encodeBytes(p, v) - remaining, value, err = decodeBytes(p.Content.B, value) - p.Content.Reset() - }) - assert.Zero(t, n) - - n = testing.AllocsPerRun(100, func() { - encodeBytes(p, v) - remaining, value, err = decodeBytes(p.Content.B, nil) - p.Content.Reset() - }) - assert.Equal(t, float64(1), n) - - s := [][]byte{v, v, v, v, v} - encodeSlice(p, uint32(len(s)), BytesKind) - for _, sb := range s { - encodeBytes(p, sb) - } - var size uint32 - - remaining, size, err = decodeSlice(p.Content.B, BytesKind) - assert.NoError(t, err) - assert.Equal(t, uint32(len(s)), size) - - sValue := make([][]byte, size) - for i := uint32(0); i < size; i++ { - remaining, sValue[i], err = decodeBytes(remaining, nil) - assert.NoError(t, err) - assert.Equal(t, s[i], sValue[i]) - } - - assert.Equal(t, s, sValue) - assert.Equal(t, 0, len(remaining)) - - Put(p) -} - -func TestDecodeString(t *testing.T) { - t.Parallel() - - p := Get() - v := "Test String" - encodeString(p, v) - - var value string - - remaining, value, err := decodeString(p.Content.B) - assert.NoError(t, err) - assert.Equal(t, v, value) - assert.Equal(t, 0, len(remaining)) - - _, _, err = decodeString(p.Content.B[1:]) - assert.ErrorIs(t, err, InvalidString) - - remaining, value, err = decodeString(p.Content.B) - assert.NoError(t, err) - assert.Equal(t, v, value) - assert.Equal(t, 0, len(remaining)) - - p.Content.B[len(p.Content.B)-1] = 'S' - assert.Equal(t, v, value) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - encodeString(p, v) - remaining, value, err = decodeString(p.Content.B) - p.Content.Reset() - }) - assert.Equal(t, float64(1), n) - - s := []string{v, v, v, v, v} - encodeSlice(p, uint32(len(s)), StringKind) - for _, sb := range s { - encodeString(p, sb) - } - var size uint32 - - remaining, size, err = decodeSlice(p.Content.B, StringKind) - assert.NoError(t, err) - assert.Equal(t, uint32(len(s)), size) - - sValue := make([]string, size) - for i := uint32(0); i < size; i++ { - remaining, sValue[i], err = decodeString(remaining) - assert.NoError(t, err) - assert.Equal(t, s[i], sValue[i]) - } - - assert.Equal(t, s, sValue) - assert.Equal(t, 0, len(remaining)) - - Put(p) -} - -func TestDecodeError(t *testing.T) { - t.Parallel() - - p := Get() - v := errors.New("Test Error") - encodeError(p, v) - - var value error - - remaining, value, err := decodeError(p.Content.B) - assert.NoError(t, err) - assert.ErrorIs(t, value, v) - assert.Equal(t, 0, len(remaining)) - - _, _, err = decodeError(p.Content.B[1:]) - assert.ErrorIs(t, err, InvalidError) - - remaining, value, err = decodeError(p.Content.B) - assert.NoError(t, err) - assert.ErrorIs(t, value, v) - assert.Equal(t, 0, len(remaining)) - - p.Content.B[len(p.Content.B)-1] = 'S' - assert.ErrorIs(t, value, v) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - encodeError(p, v) - remaining, value, err = decodeError(p.Content.B) - p.Content.Reset() - }) - assert.Equal(t, float64(2), n) - - s := []error{v, v, v, v, v} - encodeSlice(p, uint32(len(s)), ErrorKind) - for _, sb := range s { - encodeError(p, sb) - } - var size uint32 - - remaining, size, err = decodeSlice(p.Content.B, ErrorKind) - assert.NoError(t, err) - assert.Equal(t, uint32(len(s)), size) - - sValue := make([]error, size) - for i := uint32(0); i < size; i++ { - remaining, sValue[i], err = decodeError(remaining) - assert.NoError(t, err) - assert.ErrorIs(t, sValue[i], s[i]) - } - - assert.Equal(t, 0, len(remaining)) - - Put(p) -} - -func TestDecodeBool(t *testing.T) { - t.Parallel() - - p := Get() - encodeBool(p, true) - - var value bool - - remaining, value, err := decodeBool(p.Content.B) - assert.NoError(t, err) - assert.True(t, value) - assert.Equal(t, 0, len(remaining)) - - _, _, err = decodeBool(p.Content.B[1:]) - assert.ErrorIs(t, err, InvalidBool) - - remaining, value, err = decodeBool(p.Content.B) - assert.NoError(t, err) - assert.True(t, value) - assert.Equal(t, 0, len(remaining)) - - p.Content.B[len(p.Content.B)-1] = 'S' - assert.True(t, value) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - encodeBool(p, true) - remaining, value, err = decodeBool(p.Content.B) - p.Content.Reset() - }) - assert.Zero(t, n) - - s := []bool{true, true, false, true, true} - encodeSlice(p, uint32(len(s)), BoolKind) - for _, sb := range s { - encodeBool(p, sb) - } - var size uint32 - - remaining, size, err = decodeSlice(p.Content.B, BoolKind) - assert.NoError(t, err) - assert.Equal(t, uint32(len(s)), size) - - sValue := make([]bool, size) - for i := uint32(0); i < size; i++ { - remaining, sValue[i], err = decodeBool(remaining) - assert.NoError(t, err) - assert.Equal(t, s[i], sValue[i]) - } - - assert.Equal(t, s, sValue) - assert.Equal(t, 0, len(remaining)) - - Put(p) -} - -func TestDecodeUint8(t *testing.T) { - t.Parallel() - - p := Get() - v := uint8(32) - encodeUint8(p, v) - - var value uint8 - - remaining, value, err := decodeUint8(p.Content.B) - assert.NoError(t, err) - assert.Equal(t, v, value) - assert.Equal(t, 0, len(remaining)) - - _, _, err = decodeUint8(p.Content.B[1:]) - assert.ErrorIs(t, err, InvalidUint8) - - remaining, value, err = decodeUint8(p.Content.B) - assert.NoError(t, err) - assert.Equal(t, v, value) - assert.Equal(t, 0, len(remaining)) - - p.Content.B[len(p.Content.B)-1] = 'S' - assert.Equal(t, v, value) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - encodeUint8(p, v) - remaining, value, err = decodeUint8(p.Content.B) - p.Content.Reset() - }) - assert.Zero(t, n) - Put(p) -} - -func TestDecodeUint16(t *testing.T) { - t.Parallel() - - p := Get() - v := uint16(1024) - encodeUint16(p, v) - - var value uint16 - - remaining, value, err := decodeUint16(p.Content.B) - assert.NoError(t, err) - assert.Equal(t, v, value) - assert.Equal(t, 0, len(remaining)) - - _, _, err = decodeUint16(p.Content.B[1:]) - assert.ErrorIs(t, err, InvalidUint16) - - remaining, value, err = decodeUint16(p.Content.B) - assert.NoError(t, err) - assert.Equal(t, v, value) - assert.Equal(t, 0, len(remaining)) - - p.Content.B[len(p.Content.B)-1] = 'S' - assert.Equal(t, v, value) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - encodeUint16(p, v) - remaining, value, err = decodeUint16(p.Content.B) - p.Content.Reset() - }) - assert.Zero(t, n) - Put(p) -} - -func TestDecodeUint32(t *testing.T) { - t.Parallel() - - p := Get() - v := uint32(4294967290) - encodeUint32(p, v) - - var value uint32 - - remaining, value, err := decodeUint32(p.Content.B) - assert.NoError(t, err) - assert.Equal(t, v, value) - assert.Equal(t, 0, len(remaining)) - - _, _, err = decodeUint32(p.Content.B[1:]) - assert.ErrorIs(t, err, InvalidUint32) - - remaining, value, err = decodeUint32(p.Content.B) - assert.NoError(t, err) - assert.Equal(t, v, value) - assert.Equal(t, 0, len(remaining)) - - p.Content.B[len(p.Content.B)-1] = 'S' - assert.Equal(t, v, value) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - encodeUint32(p, v) - remaining, value, err = decodeUint32(p.Content.B) - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestDecodeUint64(t *testing.T) { - t.Parallel() - - p := Get() - v := uint64(18446744073709551610) - encodeUint64(p, v) - - var value uint64 - - remaining, value, err := decodeUint64(p.Content.B) - assert.NoError(t, err) - assert.Equal(t, v, value) - assert.Equal(t, 0, len(remaining)) - - _, _, err = decodeUint64(p.Content.B[1:]) - assert.ErrorIs(t, err, InvalidUint64) - - remaining, value, err = decodeUint64(p.Content.B) - assert.NoError(t, err) - assert.Equal(t, v, value) - assert.Equal(t, 0, len(remaining)) - - p.Content.B[len(p.Content.B)-1] = 'S' - assert.Equal(t, v, value) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - encodeUint64(p, v) - remaining, value, err = decodeUint64(p.Content.B) - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestDecodeInt32(t *testing.T) { - t.Parallel() - - p := Get() - v := int32(-2147483648) - encodeInt32(p, v) - - var value int32 - - remaining, value, err := decodeInt32(p.Content.B) - assert.NoError(t, err) - assert.Equal(t, v, value) - assert.Equal(t, 0, len(remaining)) - - _, _, err = decodeInt32(p.Content.B[1:]) - assert.ErrorIs(t, err, InvalidInt32) - - remaining, value, err = decodeInt32(p.Content.B) - assert.NoError(t, err) - assert.Equal(t, v, value) - assert.Equal(t, 0, len(remaining)) - - p.Content.B[len(p.Content.B)-1] = 'S' - assert.Equal(t, v, value) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - encodeInt32(p, v) - remaining, value, err = decodeInt32(p.Content.B) - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestDecodeInt64(t *testing.T) { - t.Parallel() - - p := Get() - v := int64(-9223372036854775808) - encodeInt64(p, v) - - var value int64 - - remaining, value, err := decodeInt64(p.Content.B) - assert.NoError(t, err) - assert.Equal(t, v, value) - assert.Equal(t, 0, len(remaining)) - - _, _, err = decodeInt64(p.Content.B[1:]) - assert.ErrorIs(t, err, InvalidInt64) - - remaining, value, err = decodeInt64(p.Content.B) - assert.NoError(t, err) - assert.Equal(t, v, value) - assert.Equal(t, 0, len(remaining)) - - p.Content.B[len(p.Content.B)-1] = 'S' - assert.Equal(t, v, value) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - encodeInt64(p, v) - remaining, value, err = decodeInt64(p.Content.B) - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestDecodeFloat32(t *testing.T) { - t.Parallel() - - p := Get() - v := float32(-12311.12429) - encodeFloat32(p, v) - - var value float32 - - remaining, value, err := decodeFloat32(p.Content.B) - assert.NoError(t, err) - assert.Equal(t, v, value) - assert.Equal(t, 0, len(remaining)) - - _, _, err = decodeFloat32(p.Content.B[1:]) - assert.ErrorIs(t, err, InvalidFloat32) - - remaining, value, err = decodeFloat32(p.Content.B) - assert.NoError(t, err) - assert.Equal(t, v, value) - assert.Equal(t, 0, len(remaining)) - - p.Content.B[len(p.Content.B)-1] = 'S' - assert.Equal(t, v, value) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - encodeFloat32(p, v) - remaining, value, err = decodeFloat32(p.Content.B) - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestDecodeFloat64(t *testing.T) { - t.Parallel() - - p := Get() - v := -12311241.1242009 - encodeFloat64(p, v) - - var value float64 - - remaining, value, err := decodeFloat64(p.Content.B) - assert.NoError(t, err) - assert.Equal(t, v, value) - assert.Equal(t, 0, len(remaining)) - - _, _, err = decodeFloat64(p.Content.B[1:]) - assert.ErrorIs(t, err, InvalidFloat64) - - remaining, value, err = decodeFloat64(p.Content.B) - assert.NoError(t, err) - assert.Equal(t, v, value) - assert.Equal(t, 0, len(remaining)) - - p.Content.B[len(p.Content.B)-1] = 'S' - assert.Equal(t, v, value) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - encodeFloat64(p, v) - remaining, value, err = decodeFloat64(p.Content.B) - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} diff --git a/pkg/packet/decoder.go b/pkg/packet/decoder.go deleted file mode 100644 index d92b9a8..0000000 --- a/pkg/packet/decoder.go +++ /dev/null @@ -1,126 +0,0 @@ -/* - Copyright 2022 Loophole Labs - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package packet - -import ( - "sync" -) - -var decoderPool sync.Pool - -type Decoder struct { - b []byte -} - -func GetDecoder(b []byte) (d *Decoder) { - v := decoderPool.Get() - if v == nil { - d = &Decoder{ - b: b, - } - } else { - d = v.(*Decoder) - d.b = b - } - return -} - -func ReturnDecoder(d *Decoder) { - if d != nil { - d.b = nil - decoderPool.Put(d) - } -} - -func (d *Decoder) Return() { - ReturnDecoder(d) -} - -func (d *Decoder) Nil() (value bool) { - d.b, value = decodeNil(d.b) - return -} - -func (d *Decoder) Map(keyKind, valueKind Kind) (size uint32, err error) { - d.b, size, err = decodeMap(d.b, keyKind, valueKind) - return -} - -func (d *Decoder) Slice(kind Kind) (size uint32, err error) { - d.b, size, err = decodeSlice(d.b, kind) - return -} - -func (d *Decoder) Bytes(b []byte) (value []byte, err error) { - d.b, value, err = decodeBytes(d.b, b) - return -} - -func (d *Decoder) String() (value string, err error) { - d.b, value, err = decodeString(d.b) - return -} - -func (d *Decoder) Error() (value, err error) { - d.b, value, err = decodeError(d.b) - return -} - -func (d *Decoder) Bool() (value bool, err error) { - d.b, value, err = decodeBool(d.b) - return -} - -func (d *Decoder) Uint8() (value uint8, err error) { - d.b, value, err = decodeUint8(d.b) - return -} - -func (d *Decoder) Uint16() (value uint16, err error) { - d.b, value, err = decodeUint16(d.b) - return -} - -func (d *Decoder) Uint32() (value uint32, err error) { - d.b, value, err = decodeUint32(d.b) - return -} - -func (d *Decoder) Uint64() (value uint64, err error) { - d.b, value, err = decodeUint64(d.b) - return -} - -func (d *Decoder) Int32() (value int32, err error) { - d.b, value, err = decodeInt32(d.b) - return -} - -func (d *Decoder) Int64() (value int64, err error) { - d.b, value, err = decodeInt64(d.b) - return -} - -func (d *Decoder) Float32() (value float32, err error) { - d.b, value, err = decodeFloat32(d.b) - return -} - -func (d *Decoder) Float64() (value float64, err error) { - d.b, value, err = decodeFloat64(d.b) - return -} diff --git a/pkg/packet/decoder_test.go b/pkg/packet/decoder_test.go deleted file mode 100644 index 98d2818..0000000 --- a/pkg/packet/decoder_test.go +++ /dev/null @@ -1,527 +0,0 @@ -/* - Copyright 2022 Loophole Labs - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package packet - -import ( - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestDecoderNil(t *testing.T) { - t.Parallel() - - p := Get() - Encoder(p).Nil() - - d := GetDecoder(p.Content.B) - value := d.Nil() - assert.True(t, value) - - value = d.Nil() - assert.False(t, value) - - d.Return() - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - Encoder(p).Nil() - d = GetDecoder(p.Content.B) - value = d.Nil() - d.Return() - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestDecoderMap(t *testing.T) { - t.Parallel() - - p := Get() - m := make(map[string]uint32) - m["1"] = 1 - m["2"] = 2 - m["3"] = 3 - - e := Encoder(p).Map(uint32(len(m)), StringKind, Uint32Kind) - for k, v := range m { - e.String(k).Uint32(v) - } - - d := GetDecoder(p.Content.B) - size, err := d.Map(StringKind, Uint32Kind) - assert.NoError(t, err) - assert.Equal(t, uint32(len(m)), size) - - mv := make(map[string]uint32, size) - var k string - var v uint32 - for i := uint32(0); i < size; i++ { - k, err = d.String() - assert.NoError(t, err) - v, err = d.Uint32() - assert.NoError(t, err) - mv[k] = v - } - assert.Equal(t, m, mv) - - size, err = d.Map(StringKind, Uint32Kind) - assert.ErrorIs(t, err, InvalidMap) - assert.Equal(t, uint32(0), size) - - d.Return() - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - e = Encoder(p).Map(uint32(len(m)), StringKind, Uint32Kind) - for k, v = range m { - e.String(k).Uint32(v) - } - d = GetDecoder(p.Content.B) - size, err = d.Map(StringKind, Uint32Kind) - for i := uint32(0); i < size; i++ { - _, _ = d.String() - _, _ = d.Uint32() - } - d.Return() - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestDecoderSlice(t *testing.T) { - t.Parallel() - - p := Get() - m := []string{"1", "2", "3"} - - e := Encoder(p).Slice(uint32(len(m)), StringKind) - for _, v := range m { - e.String(v) - } - - d := GetDecoder(p.Content.B) - size, err := d.Slice(StringKind) - assert.NoError(t, err) - assert.Equal(t, uint32(len(m)), size) - - mv := make([]string, size) - for i := range mv { - mv[i], err = d.String() - assert.NoError(t, err) - assert.Equal(t, m[i], mv[i]) - } - assert.Equal(t, m, mv) - - size, err = d.Slice(StringKind) - assert.ErrorIs(t, err, InvalidSlice) - assert.Equal(t, uint32(0), size) - - d.Return() - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - e = Encoder(p).Slice(uint32(len(m)), StringKind) - for _, v := range m { - e.String(v) - } - d = GetDecoder(p.Content.B) - size, err = d.Slice(StringKind) - for i := uint32(0); i < size; i++ { - _, _ = d.String() - } - d.Return() - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestDecoderBytes(t *testing.T) { - t.Parallel() - - p := Get() - v := []byte("Test String") - - Encoder(p).Bytes(v) - - d := GetDecoder(p.Content.B) - value, err := d.Bytes(nil) - assert.NoError(t, err) - assert.Equal(t, v, value) - - value, err = d.Bytes(value) - assert.ErrorIs(t, err, InvalidBytes) - - d.Return() - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - Encoder(p).Bytes(v) - d = GetDecoder(p.Content.B) - value, err = d.Bytes(value) - d.Return() - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestDecoderString(t *testing.T) { - t.Parallel() - - p := Get() - v := "Test String" - - Encoder(p).String(v) - - d := GetDecoder(p.Content.B) - value, err := d.String() - assert.NoError(t, err) - assert.Equal(t, v, value) - - value, err = d.String() - assert.ErrorIs(t, err, InvalidString) - - d.Return() - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - Encoder(p).String(v) - d = GetDecoder(p.Content.B) - value, err = d.String() - d.Return() - p.Content.Reset() - }) - assert.Equal(t, float64(1), n) - - Put(p) -} - -func TestDecoderError(t *testing.T) { - t.Parallel() - - p := Get() - v := errors.New("Test Error") - - Encoder(p).Error(v) - - d := GetDecoder(p.Content.B) - value, err := d.Error() - assert.NoError(t, err) - assert.ErrorIs(t, value, v) - - value, err = d.Error() - assert.ErrorIs(t, err, InvalidError) - - d.Return() - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - Encoder(p).Error(v) - d = GetDecoder(p.Content.B) - value, err = d.Error() - d.Return() - p.Content.Reset() - }) - assert.Equal(t, float64(2), n) - - Put(p) -} - -func TestDecoderBool(t *testing.T) { - t.Parallel() - - p := Get() - Encoder(p).Bool(true) - - d := GetDecoder(p.Content.B) - value, err := d.Bool() - assert.NoError(t, err) - assert.True(t, value) - - value, err = d.Bool() - assert.ErrorIs(t, err, InvalidBool) - - d.Return() - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - Encoder(p).Bool(true) - d = GetDecoder(p.Content.B) - value, err = d.Bool() - d.Return() - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestDecoderUint8(t *testing.T) { - t.Parallel() - - p := Get() - v := uint8(32) - - Encoder(p).Uint8(v) - - d := GetDecoder(p.Content.B) - value, err := d.Uint8() - assert.NoError(t, err) - assert.Equal(t, v, value) - - value, err = d.Uint8() - assert.ErrorIs(t, err, InvalidUint8) - - d.Return() - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - Encoder(p).Uint8(v) - d = GetDecoder(p.Content.B) - value, err = d.Uint8() - d.Return() - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestDecoderUint16(t *testing.T) { - t.Parallel() - - p := Get() - v := uint16(1024) - - Encoder(p).Uint16(v) - - d := GetDecoder(p.Content.B) - value, err := d.Uint16() - assert.NoError(t, err) - assert.Equal(t, v, value) - - value, err = d.Uint16() - assert.ErrorIs(t, err, InvalidUint16) - - d.Return() - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - Encoder(p).Uint16(v) - d = GetDecoder(p.Content.B) - value, err = d.Uint16() - d.Return() - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestDecoderUint32(t *testing.T) { - t.Parallel() - - p := Get() - v := uint32(4294967290) - - Encoder(p).Uint32(v) - - d := GetDecoder(p.Content.B) - value, err := d.Uint32() - assert.NoError(t, err) - assert.Equal(t, v, value) - - value, err = d.Uint32() - assert.ErrorIs(t, err, InvalidUint32) - - d.Return() - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - Encoder(p).Uint32(v) - d = GetDecoder(p.Content.B) - value, err = d.Uint32() - d.Return() - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestDecoderUint64(t *testing.T) { - t.Parallel() - - p := Get() - v := uint64(18446744073709551610) - - Encoder(p).Uint64(v) - - d := GetDecoder(p.Content.B) - value, err := d.Uint64() - assert.NoError(t, err) - assert.Equal(t, v, value) - - value, err = d.Uint64() - assert.ErrorIs(t, err, InvalidUint64) - - d.Return() - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - Encoder(p).Uint64(v) - d = GetDecoder(p.Content.B) - value, err = d.Uint64() - d.Return() - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestDecoderInt32(t *testing.T) { - t.Parallel() - - p := Get() - v := int32(-2147483648) - - Encoder(p).Int32(v) - - d := GetDecoder(p.Content.B) - value, err := d.Int32() - assert.NoError(t, err) - assert.Equal(t, v, value) - - value, err = d.Int32() - assert.ErrorIs(t, err, InvalidInt32) - - d.Return() - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - Encoder(p).Int32(v) - d = GetDecoder(p.Content.B) - value, err = d.Int32() - d.Return() - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestDecoderInt64(t *testing.T) { - t.Parallel() - - p := Get() - v := int64(-9223372036854775808) - - Encoder(p).Int64(v) - - d := GetDecoder(p.Content.B) - value, err := d.Int64() - assert.NoError(t, err) - assert.Equal(t, v, value) - - value, err = d.Int64() - assert.ErrorIs(t, err, InvalidInt64) - - d.Return() - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - Encoder(p).Int64(v) - d = GetDecoder(p.Content.B) - value, err = d.Int64() - d.Return() - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestDecoderFloat32(t *testing.T) { - t.Parallel() - - p := Get() - v := float32(-2147483.648) - - Encoder(p).Float32(v) - - d := GetDecoder(p.Content.B) - value, err := d.Float32() - assert.NoError(t, err) - assert.Equal(t, v, value) - - value, err = d.Float32() - assert.ErrorIs(t, err, InvalidFloat32) - - d.Return() - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - Encoder(p).Float32(v) - d = GetDecoder(p.Content.B) - value, err = d.Float32() - d.Return() - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestDecoderFloat64(t *testing.T) { - t.Parallel() - - p := Get() - v := -922337203.477580 - - Encoder(p).Float64(v) - - d := GetDecoder(p.Content.B) - value, err := d.Float64() - assert.NoError(t, err) - assert.Equal(t, v, value) - - value, err = d.Float64() - assert.ErrorIs(t, err, InvalidFloat64) - - d.Return() - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - Encoder(p).Float64(v) - d = GetDecoder(p.Content.B) - value, err = d.Float64() - d.Return() - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} diff --git a/pkg/packet/encode.go b/pkg/packet/encode.go deleted file mode 100644 index c43a8cc..0000000 --- a/pkg/packet/encode.go +++ /dev/null @@ -1,123 +0,0 @@ -/* - Copyright 2022 Loophole Labs - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package packet - -import ( - "math" - "reflect" - "unsafe" -) - -var ( - falseBool = byte(0) - trueBool = byte(1) -) - -func encodeNil(p *Packet) { - p.Content.Write(NilKind) -} - -func encodeMap(p *Packet, size uint32, keyKind, valueKind Kind) { - p.Content.Write(MapKind) - p.Content.Write(keyKind) - p.Content.Write(valueKind) - encodeUint32(p, size) -} - -func encodeSlice(p *Packet, size uint32, kind Kind) { - p.Content.Write(SliceKind) - p.Content.Write(kind) - encodeUint32(p, size) -} - -func encodeBytes(p *Packet, value []byte) { - p.Content.Write(BytesKind) - encodeUint32(p, uint32(len(value))) - p.Content.Write(value) -} - -func encodeString(p *Packet, value string) { - var b []byte - /* #nosec G103 */ - bh := (*reflect.SliceHeader)(unsafe.Pointer(&b)) - /* #nosec G103 */ - sh := (*reflect.StringHeader)(unsafe.Pointer(&value)) - bh.Data = sh.Data - bh.Cap = sh.Len - bh.Len = sh.Len - p.Content.Write(StringKind) - encodeUint32(p, uint32(len(b))) - p.Content.Write(b) -} - -func encodeError(p *Packet, err error) { - p.Content.Write(ErrorKind) - encodeString(p, err.Error()) -} - -func encodeBool(p *Packet, value bool) { - p.Content.Write(BoolKind) - if value { - p.Content.B = append(p.Content.B, trueBool) - } else { - p.Content.B = append(p.Content.B, falseBool) - } -} - -func encodeUint8(p *Packet, value uint8) { - p.Content.Write(Uint8Kind) - p.Content.B = append(p.Content.B, value) -} - -func encodeUint16(p *Packet, value uint16) { - p.Content.Write(Uint16Kind) - p.Content.B = append(p.Content.B, byte(value>>8), byte(value)) -} - -func encodeUint32(p *Packet, value uint32) { - p.Content.Write(Uint32Kind) - p.Content.B = append(p.Content.B, byte(value>>24), byte(value>>16), byte(value>>8), byte(value)) -} - -func encodeUint64(p *Packet, value uint64) { - p.Content.Write(Uint64Kind) - p.Content.B = append(p.Content.B, byte(value>>56), byte(value>>48), byte(value>>40), byte(value>>32), byte(value>>24), byte(value>>16), byte(value>>8), byte(value)) -} - -func encodeInt32(p *Packet, value int32) { - p.Content.Write(Int32Kind) - castValue := uint32(value) - p.Content.B = append(p.Content.B, byte(castValue>>24), byte(castValue>>16), byte(castValue>>8), byte(castValue)) -} - -func encodeInt64(p *Packet, value int64) { - p.Content.Write(Int64Kind) - castValue := uint64(value) - p.Content.B = append(p.Content.B, byte(castValue>>56), byte(castValue>>48), byte(castValue>>40), byte(castValue>>32), byte(castValue>>24), byte(castValue>>16), byte(castValue>>8), byte(castValue)) -} - -func encodeFloat32(p *Packet, value float32) { - p.Content.Write(Float32Kind) - castValue := math.Float32bits(value) - p.Content.B = append(p.Content.B, byte(castValue>>24), byte(castValue>>16), byte(castValue>>8), byte(castValue)) -} - -func encodeFloat64(p *Packet, value float64) { - p.Content.Write(Float64Kind) - castValue := math.Float64bits(value) - p.Content.B = append(p.Content.B, byte(castValue>>56), byte(castValue>>48), byte(castValue>>40), byte(castValue>>32), byte(castValue>>24), byte(castValue>>16), byte(castValue>>8), byte(castValue)) -} diff --git a/pkg/packet/encode_test.go b/pkg/packet/encode_test.go deleted file mode 100644 index 0098bf5..0000000 --- a/pkg/packet/encode_test.go +++ /dev/null @@ -1,344 +0,0 @@ -/* - Copyright 2022 Loophole Labs - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package packet - -import ( - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" - "math" - "testing" -) - -func TestEncodeNil(t *testing.T) { - t.Parallel() - - p := Get() - encodeNil(p) - - assert.Equal(t, 1, len(p.Content.B)) - assert.Equal(t, NilKind, Kind(p.Content.B)) - - n := testing.AllocsPerRun(100, func() { - encodeNil(p) - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestEncodeMap(t *testing.T) { - t.Parallel() - - p := Get() - encodeMap(p, 32, StringKind, Uint32Kind) - - assert.Equal(t, 1+1+1+1+4, len(p.Content.B)) - assert.Equal(t, MapKind, Kind(p.Content.B[0:1])) - assert.Equal(t, StringKind, Kind(p.Content.B[1:2])) - assert.Equal(t, Uint32Kind, Kind(p.Content.B[2:3])) - assert.Equal(t, Uint32Kind, Kind(p.Content.B[3:4])) - - n := testing.AllocsPerRun(100, func() { - encodeMap(p, 32, StringKind, Uint32Kind) - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestEncodeSlice(t *testing.T) { - t.Parallel() - - p := Get() - encodeSlice(p, 32, StringKind) - - assert.Equal(t, 1+1+1+4, len(p.Content.B)) - assert.Equal(t, SliceKind, Kind(p.Content.B[0:1])) - assert.Equal(t, StringKind, Kind(p.Content.B[1:2])) - assert.Equal(t, Uint32Kind, Kind(p.Content.B[2:3])) - - n := testing.AllocsPerRun(100, func() { - encodeSlice(p, 32, StringKind) - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestEncodeBytes(t *testing.T) { - t.Parallel() - - p := Get() - v := []byte("Test String") - - encodeBytes(p, v) - - assert.Equal(t, 1+1+4+len(v), len(p.Content.B)) - assert.Equal(t, v, p.Content.B[1+1+4:]) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - encodeBytes(p, v) - p.Content.Reset() - }) - assert.Zero(t, n) - Put(p) -} - -func TestEncodeString(t *testing.T) { - t.Parallel() - - p := Get() - v := "Test String" - e := []byte(v) - - encodeString(p, v) - - assert.Equal(t, 1+1+4+len(e), len(p.Content.B)) - assert.Equal(t, e, p.Content.B[1+1+4:]) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - encodeString(p, v) - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestEncodeError(t *testing.T) { - t.Parallel() - - p := Get() - v := errors.New("Test Error") - e := []byte(v.Error()) - - encodeError(p, v) - - assert.Equal(t, 1+1+1+4+len(e), len(p.Content.B)) - assert.Equal(t, e, p.Content.B[1+1+1+4:]) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - encodeError(p, v) - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestEncodeBool(t *testing.T) { - t.Parallel() - - p := Get() - e := []byte{trueBool} - - encodeBool(p, true) - - assert.Equal(t, 1+len(e), len(p.Content.B)) - assert.Equal(t, e, p.Content.B[1:]) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - encodeBool(p, true) - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestEncodeUint8(t *testing.T) { - t.Parallel() - - p := Get() - v := uint8(32) - e := []byte{v} - - encodeUint8(p, v) - - assert.Equal(t, 1+len(e), len(p.Content.B)) - assert.Equal(t, e, p.Content.B[1:]) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - encodeUint8(p, v) - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestEncodeUint16(t *testing.T) { - t.Parallel() - - p := Get() - v := uint16(1024) - e := []byte{byte(v >> 8), byte(v)} - - encodeUint16(p, v) - - assert.Equal(t, 1+len(e), len(p.Content.B)) - assert.Equal(t, e, p.Content.B[1:]) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - encodeUint16(p, v) - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestEncodeUint32(t *testing.T) { - t.Parallel() - - p := Get() - v := uint32(4294967290) - e := []byte{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)} - - encodeUint32(p, v) - - assert.Equal(t, 1+len(e), len(p.Content.B)) - assert.Equal(t, e, p.Content.B[1:]) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - encodeUint32(p, v) - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestEncodeUint64(t *testing.T) { - t.Parallel() - - p := Get() - v := uint64(18446744073709551610) - e := []byte{byte(v >> 56), byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)} - - encodeUint64(p, v) - - assert.Equal(t, 1+len(e), len(p.Content.B)) - assert.Equal(t, e, p.Content.B[1:]) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - encodeUint64(p, v) - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestEncodeInt32(t *testing.T) { - t.Parallel() - - p := Get() - v := int32(-2147483648) - e := []byte{byte(uint32(v) >> 24), byte(uint32(v) >> 16), byte(uint32(v) >> 8), byte(uint32(v))} - - encodeInt32(p, v) - - assert.Equal(t, 1+len(e), len(p.Content.B)) - assert.Equal(t, e, p.Content.B[1:]) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - encodeInt32(p, v) - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestEncodeInt64(t *testing.T) { - t.Parallel() - - p := Get() - v := int64(-9223372036854775808) - e := []byte{byte(uint64(v) >> 56), byte(uint64(v) >> 48), byte(uint64(v) >> 40), byte(uint64(v) >> 32), byte(uint64(v) >> 24), byte(uint64(v) >> 16), byte(uint64(v) >> 8), byte(uint64(v))} - - encodeInt64(p, v) - - assert.Equal(t, 1+len(e), len(p.Content.B)) - assert.Equal(t, e, p.Content.B[1:]) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - encodeInt64(p, v) - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestEncodeFloat32(t *testing.T) { - t.Parallel() - - p := Get() - v := float32(-214648.34432) - e := []byte{byte(math.Float32bits(v) >> 24), byte(math.Float32bits(v) >> 16), byte(math.Float32bits(v) >> 8), byte(math.Float32bits(v))} - - encodeFloat32(p, v) - - assert.Equal(t, 1+len(e), len(p.Content.B)) - assert.Equal(t, e, p.Content.B[1:]) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - encodeFloat32(p, v) - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestEncodeFloat64(t *testing.T) { - t.Parallel() - - p := Get() - v := -922337203685.2345 - e := []byte{byte(math.Float64bits(v) >> 56), byte(math.Float64bits(v) >> 48), byte(math.Float64bits(v) >> 40), byte(math.Float64bits(v) >> 32), byte(math.Float64bits(v) >> 24), byte(math.Float64bits(v) >> 16), byte(math.Float64bits(v) >> 8), byte(math.Float64bits(v))} - - encodeFloat64(p, v) - - assert.Equal(t, 1+len(e), len(p.Content.B)) - assert.Equal(t, e, p.Content.B[1:]) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - encodeFloat64(p, v) - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} diff --git a/pkg/packet/encoder.go b/pkg/packet/encoder.go deleted file mode 100644 index 8881504..0000000 --- a/pkg/packet/encoder.go +++ /dev/null @@ -1,98 +0,0 @@ -/* - Copyright 2022 Loophole Labs - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package packet - -type encoder Packet - -func Encoder(p *Packet) *encoder { - return (*encoder)(p) -} - -func (e *encoder) Nil() *encoder { - encodeNil((*Packet)(e)) - return e -} - -func (e *encoder) Map(size uint32, keyKind, valueKind Kind) *encoder { - encodeMap((*Packet)(e), size, keyKind, valueKind) - return e -} - -func (e *encoder) Slice(size uint32, kind Kind) *encoder { - encodeSlice((*Packet)(e), size, kind) - return e -} - -func (e *encoder) Bytes(value []byte) *encoder { - encodeBytes((*Packet)(e), value) - return e -} - -func (e *encoder) String(value string) *encoder { - encodeString((*Packet)(e), value) - return e -} - -func (e *encoder) Error(value error) *encoder { - encodeError((*Packet)(e), value) - return e -} - -func (e *encoder) Bool(value bool) *encoder { - encodeBool((*Packet)(e), value) - return e -} - -func (e *encoder) Uint8(value uint8) *encoder { - encodeUint8((*Packet)(e), value) - return e -} - -func (e *encoder) Uint16(value uint16) *encoder { - encodeUint16((*Packet)(e), value) - return e -} - -func (e *encoder) Uint32(value uint32) *encoder { - encodeUint32((*Packet)(e), value) - return e -} - -func (e *encoder) Uint64(value uint64) *encoder { - encodeUint64((*Packet)(e), value) - return e -} - -func (e *encoder) Int32(value int32) *encoder { - encodeInt32((*Packet)(e), value) - return e -} - -func (e *encoder) Int64(value int64) *encoder { - encodeInt64((*Packet)(e), value) - return e -} - -func (e *encoder) Float32(value float32) *encoder { - encodeFloat32((*Packet)(e), value) - return e -} - -func (e *encoder) Float64(value float64) *encoder { - encodeFloat64((*Packet)(e), value) - return e -} diff --git a/pkg/packet/encoder_test.go b/pkg/packet/encoder_test.go deleted file mode 100644 index 3882c40..0000000 --- a/pkg/packet/encoder_test.go +++ /dev/null @@ -1,364 +0,0 @@ -/* - Copyright 2022 Loophole Labs - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package packet - -import ( - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" - "math" - "testing" -) - -func TestEncoderNil(t *testing.T) { - t.Parallel() - - p := Get() - - Encoder(p).Nil() - - assert.Equal(t, 1, len(p.Content.B)) - assert.Equal(t, NilKind, Kind(p.Content.B)) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - Encoder(p).Nil() - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestEncoderMap(t *testing.T) { - t.Parallel() - - p := Get() - m := make(map[string]uint32) - m["1"] = 1 - m["2"] = 2 - m["3"] = 3 - - e := Encoder(p).Map(uint32(len(m)), StringKind, Uint32Kind) - for k, v := range m { - e.String(k).Uint32(v) - } - - assert.Equal(t, 1+1+1+1+4+len(m)*(1+1+4+1+1+4), len(p.Content.B)) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - e = Encoder(p).Map(uint32(len(m)), StringKind, Uint32Kind) - for k, v := range m { - e.String(k).Uint32(v) - } - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestEncoderSlice(t *testing.T) { - t.Parallel() - - p := Get() - m := make(map[string]uint32) - m["1"] = 1 - m["2"] = 2 - m["3"] = 3 - - e := Encoder(p).Map(uint32(len(m)), StringKind, Uint32Kind) - for k, v := range m { - e.String(k).Uint32(v) - } - - assert.Equal(t, 1+1+1+1+4+len(m)*(1+1+4+1+1+4), len(p.Content.B)) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - e = Encoder(p).Map(uint32(len(m)), StringKind, Uint32Kind) - for k, v := range m { - e.String(k).Uint32(v) - } - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestEncoderBytes(t *testing.T) { - t.Parallel() - - p := Get() - v := []byte("Test String") - - Encoder(p).Bytes(v) - - assert.Equal(t, 1+1+4+len(v), len(p.Content.B)) - assert.Equal(t, v, p.Content.B[1+1+4:]) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - Encoder(p).Bytes(v) - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestEncoderString(t *testing.T) { - t.Parallel() - - p := Get() - v := "Test String" - e := []byte(v) - - Encoder(p).String(v) - - assert.Equal(t, 1+1+4+len(e), len(p.Content.B)) - assert.Equal(t, e, p.Content.B[1+1+4:]) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - Encoder(p).String(v) - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestEncoderError(t *testing.T) { - t.Parallel() - - p := Get() - v := errors.New("Test String") - e := []byte(v.Error()) - - Encoder(p).Error(v) - - assert.Equal(t, 1+1+1+4+len(e), len(p.Content.B)) - assert.Equal(t, e, p.Content.B[1+1+1+4:]) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - Encoder(p).Error(v) - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestEncoderBool(t *testing.T) { - t.Parallel() - - p := Get() - e := []byte{trueBool} - - Encoder(p).Bool(true) - - assert.Equal(t, 1+len(e), len(p.Content.B)) - assert.Equal(t, e, p.Content.B[1:]) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - Encoder(p).Bool(true) - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestEncoderUint8(t *testing.T) { - t.Parallel() - - p := Get() - v := uint8(32) - e := []byte{v} - - Encoder(p).Uint8(v) - - assert.Equal(t, 1+len(e), len(p.Content.B)) - assert.Equal(t, e, p.Content.B[1:]) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - Encoder(p).Uint8(v) - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestEncoderUint16(t *testing.T) { - t.Parallel() - - p := Get() - v := uint16(1024) - e := []byte{byte(v >> 8), byte(v)} - - Encoder(p).Uint16(v) - - assert.Equal(t, 1+len(e), len(p.Content.B)) - assert.Equal(t, e, p.Content.B[1:]) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - Encoder(p).Uint16(v) - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestEncoderUint32(t *testing.T) { - t.Parallel() - - p := Get() - v := uint32(4294967290) - e := []byte{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)} - - Encoder(p).Uint32(v) - - assert.Equal(t, 1+len(e), len(p.Content.B)) - assert.Equal(t, e, p.Content.B[1:]) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - Encoder(p).Uint32(v) - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestEncoderUint64(t *testing.T) { - t.Parallel() - - p := Get() - v := uint64(18446744073709551610) - e := []byte{byte(v >> 56), byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)} - - Encoder(p).Uint64(v) - - assert.Equal(t, 1+len(e), len(p.Content.B)) - assert.Equal(t, e, p.Content.B[1:]) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - Encoder(p).Uint64(v) - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestEncoderInt32(t *testing.T) { - t.Parallel() - - p := Get() - v := int32(-2147483648) - e := []byte{byte(uint32(v) >> 24), byte(uint32(v) >> 16), byte(uint32(v) >> 8), byte(uint32(v))} - - Encoder(p).Int32(v) - - assert.Equal(t, 1+len(e), len(p.Content.B)) - assert.Equal(t, e, p.Content.B[1:]) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - Encoder(p).Int32(v) - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestEncoderInt64(t *testing.T) { - t.Parallel() - - p := Get() - v := int64(-9223372036854775808) - e := []byte{byte(uint64(v) >> 56), byte(uint64(v) >> 48), byte(uint64(v) >> 40), byte(uint64(v) >> 32), byte(uint64(v) >> 24), byte(uint64(v) >> 16), byte(uint64(v) >> 8), byte(uint64(v))} - - Encoder(p).Int64(v) - - assert.Equal(t, 1+len(e), len(p.Content.B)) - assert.Equal(t, e, p.Content.B[1:]) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - Encoder(p).Int64(v) - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestEncoderFloat32(t *testing.T) { - t.Parallel() - - p := Get() - v := float32(-214648.34432) - e := []byte{byte(math.Float32bits(v) >> 24), byte(math.Float32bits(v) >> 16), byte(math.Float32bits(v) >> 8), byte(math.Float32bits(v))} - - Encoder(p).Float32(v) - - assert.Equal(t, 1+len(e), len(p.Content.B)) - assert.Equal(t, e, p.Content.B[1:]) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - Encoder(p).Float32(v) - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} - -func TestEncoderFloat64(t *testing.T) { - t.Parallel() - - p := Get() - v := -922337203685.2345 - e := []byte{byte(math.Float64bits(v) >> 56), byte(math.Float64bits(v) >> 48), byte(math.Float64bits(v) >> 40), byte(math.Float64bits(v) >> 32), byte(math.Float64bits(v) >> 24), byte(math.Float64bits(v) >> 16), byte(math.Float64bits(v) >> 8), byte(math.Float64bits(v))} - - Encoder(p).Float64(v) - - assert.Equal(t, 1+len(e), len(p.Content.B)) - assert.Equal(t, e, p.Content.B[1:]) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - Encoder(p).Float64(v) - p.Content.Reset() - }) - assert.Zero(t, n) - - Put(p) -} diff --git a/pkg/packet/packet.go b/pkg/packet/packet.go index 14cff87..0331c3b 100644 --- a/pkg/packet/packet.go +++ b/pkg/packet/packet.go @@ -17,8 +17,8 @@ package packet import ( - "github.com/loopholelabs/frisbee/pkg/content" - "github.com/loopholelabs/frisbee/pkg/metadata" + "github.com/loopholelabs/frisbee-go/pkg/metadata" + "github.com/loopholelabs/polyglot-go" ) // Packet is the structured frisbee data packet, and contains the following: @@ -36,7 +36,7 @@ import ( // delivered with the frisbee packet (see the Async.WritePacket function for more details), and the Operation field must be greater than uint16(9). type Packet struct { Metadata *metadata.Metadata - Content *content.Content + Content *polyglot.Buffer } func (p *Packet) Reset() { @@ -46,33 +46,9 @@ func (p *Packet) Reset() { p.Content.Reset() } -type Kind []byte - -var ( - NilKind = Kind([]byte{0}) - SliceKind = Kind([]byte{1}) - MapKind = Kind([]byte{2}) - AnyKind = Kind([]byte{3}) - BytesKind = Kind([]byte{4}) - StringKind = Kind([]byte{5}) - ErrorKind = Kind([]byte{6}) - BoolKind = Kind([]byte{7}) - Uint8Kind = Kind([]byte{8}) - Uint16Kind = Kind([]byte{9}) - Uint32Kind = Kind([]byte{10}) - Uint64Kind = Kind([]byte{11}) - Int32Kind = Kind([]byte{11}) - Int64Kind = Kind([]byte{12}) - Float32Kind = Kind([]byte{13}) - Float64Kind = Kind([]byte{14}) -) - -type Error string - -func (e Error) Error() string { - return string(e) -} - -func (e Error) Is(err error) bool { - return e.Error() == err.Error() +func New() *Packet { + return &Packet{ + Metadata: new(metadata.Metadata), + Content: polyglot.NewBuffer(), + } } diff --git a/pkg/packet/packet_test.go b/pkg/packet/packet_test.go index ce55b86..8fd2399 100644 --- a/pkg/packet/packet_test.go +++ b/pkg/packet/packet_test.go @@ -18,7 +18,7 @@ package packet import ( "crypto/rand" - "github.com/pkg/errors" + "github.com/loopholelabs/polyglot-go" "github.com/stretchr/testify/assert" "testing" ) @@ -33,7 +33,7 @@ func TestNew(t *testing.T) { assert.Equal(t, uint16(0), p.Metadata.Id) assert.Equal(t, uint16(0), p.Metadata.Operation) assert.Equal(t, uint32(0), p.Metadata.ContentLength) - assert.Equal(t, []byte{}, p.Content.B) + assert.Equal(t, polyglot.Buffer{}, *p.Content) Put(p) } @@ -48,12 +48,12 @@ func TestWrite(t *testing.T) { assert.NoError(t, err) p.Content.Write(b) - assert.Equal(t, b, p.Content.B) + assert.Equal(t, polyglot.Buffer(b), *p.Content) p.Reset() - assert.NotEqual(t, b, p.Content.B) - assert.Equal(t, 0, len(p.Content.B)) - assert.Equal(t, 512, cap(p.Content.B)) + assert.NotEqual(t, b, *p.Content) + assert.Equal(t, 0, len(*p.Content)) + assert.Equal(t, 512, cap(*p.Content)) b = make([]byte, 1024) _, err = rand.Read(b) @@ -61,384 +61,8 @@ func TestWrite(t *testing.T) { p.Content.Write(b) - assert.Equal(t, b, p.Content.B) - assert.Equal(t, 1024, len(p.Content.B)) - assert.GreaterOrEqual(t, cap(p.Content.B), 1024) + assert.Equal(t, polyglot.Buffer(b), *p.Content) + assert.Equal(t, 1024, len(*p.Content)) + assert.GreaterOrEqual(t, cap(*p.Content), 1024) } - -type embedStruct struct { - test string - b []byte -} - -type testStruct struct { - err error - test string - b []byte - num1 uint8 - num2 uint16 - num3 uint32 - num4 uint64 - num5 int32 - num6 int64 - num7 float32 - num8 float64 - truth bool - slice []string - m map[uint32]*embedStruct -} - -func TestChain(t *testing.T) { - t.Parallel() - - test := &testStruct{ - err: errors.New("Test Error"), - test: "Test String", - b: []byte("Test Bytes"), - num1: 32, - num2: 1024, - num3: 4294967290, - num4: 18446744073709551610, - num5: -123531252, - num6: -123514361905132059, - num7: -21239.343, - num8: -129403505932.823, - truth: true, - slice: []string{"te", "s", "t"}, - m: make(map[uint32]*embedStruct), - } - - test.m[0] = &embedStruct{ - test: "Embed String", - b: []byte("embed Bytes"), - } - - test.m[1] = &embedStruct{ - test: "Other Embed String", - b: []byte("other embed Bytes"), - } - - p := Get() - encodeError(p, test.err) - encodeString(p, test.test) - encodeBytes(p, test.b) - encodeUint8(p, test.num1) - encodeUint16(p, test.num2) - encodeUint32(p, test.num3) - encodeUint64(p, test.num4) - encodeInt32(p, test.num5) - encodeInt64(p, test.num6) - encodeFloat32(p, test.num7) - encodeFloat64(p, test.num8) - encodeBool(p, test.truth) - encodeSlice(p, uint32(len(test.slice)), StringKind) - for _, s := range test.slice { - encodeString(p, s) - } - encodeMap(p, uint32(len(test.m)), Uint32Kind, AnyKind) - for k, v := range test.m { - encodeUint32(p, k) - encodeString(p, v.test) - encodeBytes(p, v.b) - } - encodeNil(p) - - val := new(testStruct) - var err error - var remaining []byte - - remaining, val.err, err = decodeError(p.Content.B) - assert.NoError(t, err) - assert.ErrorIs(t, val.err, test.err) - - remaining, val.test, err = decodeString(remaining) - assert.NoError(t, err) - assert.Equal(t, test.test, val.test) - - remaining, val.b, err = decodeBytes(remaining, nil) - assert.NoError(t, err) - assert.Equal(t, test.b, val.b) - - remaining, val.num1, err = decodeUint8(remaining) - assert.NoError(t, err) - assert.Equal(t, val.num1, val.num1) - - remaining, val.num2, err = decodeUint16(remaining) - assert.NoError(t, err) - assert.Equal(t, test.num2, val.num2) - - remaining, val.num3, err = decodeUint32(remaining) - assert.NoError(t, err) - assert.Equal(t, test.num3, val.num3) - - remaining, val.num4, err = decodeUint64(remaining) - assert.NoError(t, err) - assert.Equal(t, test.num4, val.num4) - - remaining, val.num5, err = decodeInt32(remaining) - assert.NoError(t, err) - assert.Equal(t, test.num5, val.num5) - - remaining, val.num6, err = decodeInt64(remaining) - assert.NoError(t, err) - assert.Equal(t, test.num6, val.num6) - - remaining, val.num7, err = decodeFloat32(remaining) - assert.NoError(t, err) - assert.Equal(t, test.num7, val.num7) - - remaining, val.num8, err = decodeFloat64(remaining) - assert.NoError(t, err) - assert.Equal(t, test.num8, val.num8) - - remaining, val.truth, err = decodeBool(remaining) - assert.NoError(t, err) - assert.Equal(t, test.truth, val.truth) - - var size uint32 - remaining, size, err = decodeSlice(remaining, StringKind) - assert.NoError(t, err) - assert.Equal(t, uint32(len(test.slice)), size) - - val.slice = make([]string, size) - for i := range val.slice { - remaining, val.slice[i], err = decodeString(remaining) - assert.NoError(t, err) - assert.Equal(t, test.slice[i], val.slice[i]) - } - assert.Equal(t, test.slice, val.slice) - - remaining, size, err = decodeMap(remaining, Uint32Kind, AnyKind) - assert.NoError(t, err) - assert.Equal(t, uint32(len(test.m)), size) - val.m = make(map[uint32]*embedStruct, size) - k := uint32(0) - v := new(embedStruct) - for i := uint32(0); i < size; i++ { - remaining, k, err = decodeUint32(remaining) - assert.NoError(t, err) - remaining, v.test, err = decodeString(remaining) - assert.NoError(t, err) - remaining, v.b, err = decodeBytes(remaining, v.b) - assert.NoError(t, err) - val.m[k] = v - v = new(embedStruct) - } - assert.Equal(t, test.m, val.m) - - var isNil bool - remaining, isNil = decodeNil(remaining) - assert.True(t, isNil) - - assert.Equal(t, 0, len(remaining)) - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - encodeError(p, test.err) - encodeString(p, test.test) - encodeBytes(p, test.b) - encodeUint8(p, test.num1) - encodeUint16(p, test.num2) - encodeUint32(p, test.num3) - encodeUint64(p, test.num4) - encodeBool(p, test.truth) - encodeNil(p) - remaining, val.err, err = decodeError(p.Content.B) - remaining, val.test, err = decodeString(remaining) - remaining, val.b, err = decodeBytes(remaining, val.b) - remaining, val.num1, err = decodeUint8(remaining) - remaining, val.num2, err = decodeUint16(remaining) - remaining, val.num3, err = decodeUint32(remaining) - remaining, val.num4, err = decodeUint64(remaining) - remaining, val.truth, err = decodeBool(remaining) - remaining, isNil = decodeNil(remaining) - p.Content.Reset() - }) - assert.Equal(t, float64(3), n) - Put(p) -} - -func TestCompleteChain(t *testing.T) { - t.Parallel() - - test := &testStruct{ - err: errors.New("Test Error"), - test: "Test String", - b: []byte("Test Bytes"), - num1: 32, - num2: 1024, - num3: 4294967290, - num4: 18446744073709551610, - num5: -123531252, - num6: -123514361905132059, - num7: -21239.343, - num8: -129403505932.823, - truth: true, - slice: []string{"test1", "test2"}, - m: make(map[uint32]*embedStruct), - } - - test.m[0] = &embedStruct{ - test: "Embed String", - b: []byte("embed Bytes"), - } - - test.m[1] = &embedStruct{ - test: "Other Embed String", - b: []byte("other embed Bytes"), - } - - p := Get() - e := Encoder(p).Error(test.err).String(test.test).Bytes(test.b).Uint8(test.num1).Uint16(test.num2).Uint32(test.num3).Uint64(test.num4).Int32(test.num5).Int64(test.num6).Float32(test.num7).Float64(test.num8).Bool(test.truth).Nil().Slice(uint32(len(test.slice)), StringKind) - for _, s := range test.slice { - e.String(s) - } - e.Map(uint32(len(test.m)), Uint32Kind, AnyKind) - for k, v := range test.m { - e.Uint32(k).String(v.test).Bytes(v.b) - } - - val := new(testStruct) - var err error - - d := GetDecoder(p.Content.B) - - val.err, err = d.Error() - assert.NoError(t, err) - assert.ErrorIs(t, val.err, test.err) - - val.test, err = d.String() - assert.NoError(t, err) - assert.Equal(t, test.test, val.test) - - val.b, err = d.Bytes(nil) - assert.NoError(t, err) - assert.Equal(t, test.b, val.b) - - val.num1, err = d.Uint8() - assert.NoError(t, err) - assert.Equal(t, val.num1, val.num1) - - val.num2, err = d.Uint16() - assert.NoError(t, err) - assert.Equal(t, test.num2, val.num2) - - val.num3, err = d.Uint32() - assert.NoError(t, err) - assert.Equal(t, test.num3, val.num3) - - val.num4, err = d.Uint64() - assert.NoError(t, err) - assert.Equal(t, test.num4, val.num4) - - val.num5, err = d.Int32() - assert.NoError(t, err) - assert.Equal(t, test.num5, val.num5) - - val.num6, err = d.Int64() - assert.NoError(t, err) - assert.Equal(t, test.num6, val.num6) - - val.num7, err = d.Float32() - assert.NoError(t, err) - assert.Equal(t, test.num7, val.num7) - - val.num8, err = d.Float64() - assert.NoError(t, err) - assert.Equal(t, test.num8, val.num8) - - val.truth, err = d.Bool() - assert.NoError(t, err) - assert.Equal(t, test.truth, val.truth) - - isNil := d.Nil() - assert.True(t, isNil) - - var size uint32 - size, err = d.Slice(StringKind) - assert.NoError(t, err) - assert.Equal(t, uint32(len(test.slice)), size) - val.slice = make([]string, size) - for i := range val.slice { - val.slice[i], err = d.String() - assert.NoError(t, err) - assert.Equal(t, test.slice[i], val.slice[i]) - } - assert.Equal(t, test.slice, val.slice) - - size, err = d.Map(Uint32Kind, AnyKind) - assert.NoError(t, err) - assert.Equal(t, uint32(len(test.m)), size) - val.m = make(map[uint32]*embedStruct, size) - var k uint32 - var v *embedStruct - for i := uint32(0); i < size; i++ { - v = new(embedStruct) - k, err = d.Uint32() - assert.NoError(t, err) - v.test, err = d.String() - assert.NoError(t, err) - v.b, err = d.Bytes(v.b) - assert.NoError(t, err) - val.m[k] = v - } - assert.Equal(t, test.m, val.m) - - assert.Equal(t, 0, len(d.b)) - d.Return() - - p.Content.Reset() - n := testing.AllocsPerRun(100, func() { - Encoder(p).Error(test.err).String(test.test).Bytes(test.b).Uint8(test.num1).Uint16(test.num2).Uint32(test.num3).Uint64(test.num4).Bool(test.truth).Nil() - d = GetDecoder(p.Content.B) - val.err, err = d.Error() - val.test, err = d.String() - val.b, err = d.Bytes(val.b) - val.num1, err = d.Uint8() - val.num2, err = d.Uint16() - val.num3, err = d.Uint32() - val.num4, err = d.Uint64() - val.truth, err = d.Bool() - isNil = d.Nil() - d.Return() - p.Content.Reset() - }) - assert.Equal(t, float64(3), n) - Put(p) -} - -func TestNilSlice(t *testing.T) { - s := make([]string, 0) - p := Get() - Encoder(p).Slice(uint32(len(s)), StringKind) - - d := GetDecoder(p.Content.B) - j, err := d.Slice(StringKind) - assert.NoError(t, err) - assert.Equal(t, uint32(len(s)), j) - - j, err = d.Slice(StringKind) - assert.ErrorIs(t, err, InvalidSlice) - assert.Zero(t, j) -} - -func TestError(t *testing.T) { - t.Parallel() - - v := errors.New("Test Error") - - p := Get() - Encoder(p).Error(v) - - d := GetDecoder(p.Content.B) - _, err := d.String() - assert.ErrorIs(t, err, InvalidString) - - val, err := d.Error() - assert.NoError(t, err) - assert.ErrorIs(t, val, v) - - d.Return() - Put(p) -} diff --git a/pkg/packet/pool.go b/pkg/packet/pool.go index 7e5d1c3..761afbf 100644 --- a/pkg/packet/pool.go +++ b/pkg/packet/pool.go @@ -17,46 +17,21 @@ package packet import ( - "github.com/loopholelabs/frisbee/pkg/content" - "github.com/loopholelabs/frisbee/pkg/metadata" - "sync" + "github.com/loopholelabs/common/pkg/pool" ) var ( - pool = NewPool() + packetPool = NewPool() ) -type Pool struct { - pool sync.Pool -} - -func NewPool() *Pool { - return new(Pool) -} - -func (p *Pool) Get() (s *Packet) { - v := p.pool.Get() - if v == nil { - s = &Packet{ - Metadata: new(metadata.Metadata), - Content: content.New(), - } - return - } - return v.(*Packet) -} - -func (p *Pool) Put(packet *Packet) { - if packet != nil { - packet.Reset() - p.pool.Put(packet) - } +func NewPool() *pool.Pool[Packet, *Packet] { + return pool.NewPool(New) } func Get() (s *Packet) { - return pool.Get() + return packetPool.Get() } func Put(p *Packet) { - pool.Put(p) + packetPool.Put(p) } diff --git a/pkg/packet/pool_test.go b/pkg/packet/pool_test.go index 6769e94..dcb0c38 100644 --- a/pkg/packet/pool_test.go +++ b/pkg/packet/pool_test.go @@ -17,6 +17,7 @@ package packet import ( + "github.com/loopholelabs/polyglot-go" "math/rand" "testing" @@ -35,7 +36,7 @@ func TestRecycle(t *testing.T) { pool.Put(p) p = pool.Get() - testData := make([]byte, cap(p.Content.B)*2) + testData := make([]byte, cap(*p.Content)*2) _, err := rand.Read(testData) assert.NoError(t, err) for { @@ -43,11 +44,11 @@ func TestRecycle(t *testing.T) { assert.Equal(t, uint16(0), p.Metadata.Id) assert.Equal(t, uint16(0), p.Metadata.Operation) assert.Equal(t, uint32(0), p.Metadata.ContentLength) - assert.Equal(t, []byte{}, p.Content.B) + assert.Equal(t, polyglot.Buffer{}, *p.Content) p.Content.Write(testData) - assert.Equal(t, len(testData), len(p.Content.B)) - assert.GreaterOrEqual(t, cap(p.Content.B), len(testData)) + assert.Equal(t, len(testData), len(*p.Content)) + assert.GreaterOrEqual(t, cap(*p.Content), len(testData)) pool.Put(p) p = pool.Get() @@ -57,11 +58,11 @@ func TestRecycle(t *testing.T) { assert.Equal(t, uint16(0), p.Metadata.Operation) assert.Equal(t, uint32(0), p.Metadata.ContentLength) - if cap(p.Content.B) < len(testData) { + if cap(*p.Content) < len(testData) { continue } - assert.Equal(t, 0, len(p.Content.B)) - assert.GreaterOrEqual(t, cap(p.Content.B), len(testData)) + assert.Equal(t, 0, len(*p.Content)) + assert.GreaterOrEqual(t, cap(*p.Content), len(testData)) break } diff --git a/protoc-gen-frpc/dockerfile b/protoc-gen-frpc/dockerfile deleted file mode 100644 index 53cbff3..0000000 --- a/protoc-gen-frpc/dockerfile +++ /dev/null @@ -1,19 +0,0 @@ -FROM golang as builder - -ENV GOOS=linux GOARCH=amd64 CGO_ENABLED=0 - -RUN go install github.com/loopholelabs/frisbee/protoc-gen-frpc@v0.5.0 - -# Note, the Docker images must be built for amd64. If the host machine architecture is not amd64 -# you need to cross-compile the binary and move it into /go/bin. -RUN bash -c 'find /go/bin/${GOOS}_${GOARCH}/ -mindepth 1 -maxdepth 1 -exec mv {} /go/bin \;' - -FROM scratch - -# Runtime dependencies -LABEL "build.buf.plugins.runtime_library_versions.0.name"="github.com/loopholelabs/frisbee" -LABEL "build.buf.plugins.runtime_library_versions.0.version"="v0.5.0" - -COPY --from=builder /go/bin / - -ENTRYPOINT ["/protoc-gen-frpc"] diff --git a/protoc-gen-frpc/examples/echo/echo.proto b/protoc-gen-frpc/examples/echo/echo.proto deleted file mode 100644 index 13ea515..0000000 --- a/protoc-gen-frpc/examples/echo/echo.proto +++ /dev/null @@ -1,15 +0,0 @@ -syntax = "proto3"; - -option go_package = "/echo"; - -service EchoService { - rpc Echo(Request) returns (Response); -} - -message Request { - string Message = 1; -} - -message Response{ - string Message = 1; -} \ No newline at end of file diff --git a/protoc-gen-frpc/examples/pubsub/pubsub.proto b/protoc-gen-frpc/examples/pubsub/pubsub.proto deleted file mode 100644 index d968bba..0000000 --- a/protoc-gen-frpc/examples/pubsub/pubsub.proto +++ /dev/null @@ -1,39 +0,0 @@ -syntax = "proto3"; - -package pubsub; - -option go_package = "/pubsub"; - -message empty {} - -service PubSub { - rpc Pub(empty) returns (empty); - rpc Sub(empty) returns (empty); -} - -service ExampleService { - rpc ExampleCall1(ExampleMessage1) returns(ReturnType) {} - rpc ExampleCall2(ExampleMessage2) returns(ReturnType) {} -} - -// ExampleMessage1 - Example Leading Comment for ExampleMessage1 -message ExampleMessage1 { - string MyString = 1; -} - -/* -ExampleMessage2 - Example Leading Comment for ExampleMessage2 -*/ -message ExampleMessage2 { - int32 MyInt = 1; - // MyInt - Example trailing Comment - message ExampleNested { - bytes data = 1; - } - ExampleNested nested = 2; -} - -/* -ReturnType - Empty Structure Placeholder -*/ -message ReturnType {} \ No newline at end of file diff --git a/protoc-gen-frpc/examples/simple/simple.proto b/protoc-gen-frpc/examples/simple/simple.proto deleted file mode 100644 index c5e404a..0000000 --- a/protoc-gen-frpc/examples/simple/simple.proto +++ /dev/null @@ -1,15 +0,0 @@ -syntax = "proto3"; - -option go_package = "/simple"; - -service SimpleService { - rpc Echo(stream Request) returns (stream Response); -} - -message Request { - string Message = 1; -} - -message Response{ - string Message = 1; -} \ No newline at end of file diff --git a/protoc-gen-frpc/examples/test/test.proto b/protoc-gen-frpc/examples/test/test.proto deleted file mode 100644 index 2918d00..0000000 --- a/protoc-gen-frpc/examples/test/test.proto +++ /dev/null @@ -1,117 +0,0 @@ -syntax = "proto3"; - -option go_package = "/test"; - -service EchoService { - rpc Echo(Request) returns (Response); - rpc Testy(SearchResponse) returns (StockPricesWrapper); -} - -message Request { - string Message = 1; - enum Corpus { - UNIVERSAL = 0; - WEB = 1; - IMAGES = 2; - LOCAL = 3; - NEWS = 4; - PRODUCTS = 5; - VIDEO = 6; - } - Corpus corpus = 4; -} - -message Response{ - string Message = 1; - Data Test = 2; -} - -enum Test { - Potato = 0; - Monkey = 1; -} - -message Data{ - string Message = 1; - Test Checker = 2; -} - -message MyMessage1 { - enum EnumAllowingAlias { - option allow_alias = true; - UNKNOWN = 0; - STARTED = 1; - RUNNING = 1; - } -} -message MyMessage2 { - enum EnumNotAllowingAlias { - UNKNOWN = 0; - STARTED = 1; - } -} - -message SearchResponse { - message Result { - string url = 1; - string title = 2; - repeated string snippets = 3; - } - repeated Result results = 1; - repeated Result results2 = 2; - repeated string snippets = 3; - repeated string snippets2 = 4; -} - -message Resulting { - string url = 1; - string title = 2; - repeated string snippets = 3; -} - -message SomeOtherMessage { - SearchResponse.Result result = 1; -} - -message Outer {// Level 0 - message MiddleAA {// Level 1 - message Inner {// Level 2 - int64 ival = 1; - bool booly = 2; - } - Inner inner = 1; - } - message MiddleBB {// Level 1 - message Inner {// Level 2 - int32 ival = 1; - bool booly = 2; - } - Inner inner = 1; - } - MiddleAA a = 1; - MiddleBB b = 2; -} - -message SampleMessage { - oneof test_oneof { - string name = 4; - string potato = 9; - } -} - -message TestPotato { - map prices = 1; -} - - -message StockPrices { - map prices = 1; -} - -message StockPricesWrapper { - repeated StockPrices sPrices = 1; -} - -message StockPricesSuperWrap { - map prices = 1; -} diff --git a/protoc-gen-frpc/internal/utils/utils.go b/protoc-gen-frpc/internal/utils/utils.go deleted file mode 100644 index 3e5bfa1..0000000 --- a/protoc-gen-frpc/internal/utils/utils.go +++ /dev/null @@ -1,113 +0,0 @@ -/* - Copyright 2022 Loophole Labs - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package utils - -import ( - "google.golang.org/protobuf/reflect/protoreflect" - "strings" - "unicode" - "unicode/utf8" -) - -// CamelCase returns the CamelCased name. -// If there is an interior underscore followed by a lower case letter, -// drop the underscore and convert the letter to upper case. -// There is a remote possibility of this rewrite causing a name collision, -// but it's so remote we're prepared to pretend it's nonexistent - since the -// C++ generator lowercases names, it's extremely unlikely to have two fields -// with different capitalizations. -// In short, _my_field_name_2 becomes XMyFieldName_2. -func CamelCase(s string) string { - if s == "" { - return "" - } - t := make([]byte, 0, 32) - i := 0 - if s[0] == '_' { // Keep the initial _ if it exists - i++ - } - // Invariant: if the next letter is lower case, it must be converted - // to upper case. - // That is, we process a word at a time, where words are marked by _ or - // upper case letter. Digits are treated as words. - for ; i < len(s); i++ { - c := s[i] - if c == '_' && i+1 < len(s) && 'a' <= s[i+1] && s[i+1] <= 'z' { - continue // Skip the underscore in s. - } - if c == '.' { - continue - } - if '0' <= c && c <= '9' { - t = append(t, c) - continue - } - // Assume we have a letter now - if not, it's a bogus identifier. - // The next word is a sequence of characters that must start upper case. - if 'a' <= c && c <= 'z' { - c ^= ' ' // Make it a capital letter. - } - t = append(t, c) // Guaranteed not lower case. - // Accept lower case sequence that follows. - for i+1 < len(s) && 'a' <= s[i+1] && s[i+1] <= 'z' { - i++ - t = append(t, s[i]) - } - } - return string(t) -} - -func CamelCaseFullName(name protoreflect.FullName) string { - return CamelCase(string(name)) -} - -func CamelCaseName(name protoreflect.Name) string { - return CamelCase(string(name)) -} - -func AppendString(inputs ...string) string { - builder := new(strings.Builder) - for _, s := range inputs { - builder.WriteString(s) - } - - return builder.String() -} - -func FirstLowerCase(s string) string { - if s == "" { - return "" - } - r, n := utf8.DecodeRuneInString(s) - return string(unicode.ToLower(r)) + s[n:] -} - -func FirstLowerCaseName(name protoreflect.Name) string { - return FirstLowerCase(string(name)) -} - -func MakeIterable(length int) []struct{} { - return make([]struct{}, length) -} - -func Counter(initial int) func() int { - i := initial - return func() int { - i++ - return i - } -} diff --git a/protoc-gen-frpc/internal/version/version.go b/protoc-gen-frpc/internal/version/version.go deleted file mode 100644 index dfec33d..0000000 --- a/protoc-gen-frpc/internal/version/version.go +++ /dev/null @@ -1,21 +0,0 @@ -/* - Copyright 2022 Loophole Labs - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package version - -const ( - Version = "v0.1.0" -) diff --git a/protoc-gen-frpc/main.go b/protoc-gen-frpc/main.go deleted file mode 100644 index 1471029..0000000 --- a/protoc-gen-frpc/main.go +++ /dev/null @@ -1,52 +0,0 @@ -/* - Copyright 2022 Loophole Labs - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package main - -import ( - "github.com/loopholelabs/frisbee/protoc-gen-frpc/pkg/generator" - "io/ioutil" - "os" -) - -func main() { - gen := generator.New() - - data, err := ioutil.ReadAll(os.Stdin) - if err != nil { - panic(err) - } - - req, err := gen.UnmarshalRequest(data) - if err != nil { - panic(err) - } - - res, err := gen.Generate(req) - if err != nil { - panic(err) - } - - data, err = gen.MarshalResponse(res) - if err != nil { - panic(err) - } - - _, err = os.Stdout.Write(data) - if err != nil { - panic(err) - } -} diff --git a/protoc-gen-frpc/pkg/generator/defaults.go b/protoc-gen-frpc/pkg/generator/defaults.go deleted file mode 100644 index 4d3fcf0..0000000 --- a/protoc-gen-frpc/pkg/generator/defaults.go +++ /dev/null @@ -1,27 +0,0 @@ -/* - Copyright 2022 Loophole Labs - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package generator - -const ( - extension = ".frpc.go" - pointer = "*" - space = " " - comma = "," - mapSuffix = "Map" - slice = "[]" - packetAnyKind = "packet.AnyKind" -) diff --git a/protoc-gen-frpc/pkg/generator/file.go b/protoc-gen-frpc/pkg/generator/file.go deleted file mode 100644 index cccbeb4..0000000 --- a/protoc-gen-frpc/pkg/generator/file.go +++ /dev/null @@ -1,21 +0,0 @@ -/* - Copyright 2022 Loophole Labs - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package generator - -type File interface { - P(v ...interface{}) -} diff --git a/protoc-gen-frpc/pkg/generator/generator.go b/protoc-gen-frpc/pkg/generator/generator.go deleted file mode 100644 index 7738e33..0000000 --- a/protoc-gen-frpc/pkg/generator/generator.go +++ /dev/null @@ -1,105 +0,0 @@ -/* - Copyright 2022 Loophole Labs - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package generator - -import ( - "text/template" - - "github.com/loopholelabs/frisbee/protoc-gen-frpc/internal/utils" - "github.com/loopholelabs/frisbee/protoc-gen-frpc/internal/version" - "github.com/loopholelabs/frisbee/protoc-gen-frpc/templates" - "google.golang.org/protobuf/compiler/protogen" - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/types/pluginpb" -) - -type Generator struct { - options *protogen.Options -} - -var templ *template.Template - -func init() { - templ = template.Must(template.New("main").Funcs(template.FuncMap{ - "CamelCase": utils.CamelCaseFullName, - "CamelCaseName": utils.CamelCaseName, - "MakeIterable": utils.MakeIterable, - "Counter": utils.Counter, - "FirstLowerCase": utils.FirstLowerCase, - "FirstLowerCaseName": utils.FirstLowerCaseName, - "FindValue": findValue, - "GetKind": getKind, - "GetLUTEncoder": getLUTEncoder, - "GetLUTDecoder": getLUTDecoder, - "GetEncodingFields": getEncodingFields, - "GetDecodingFields": getDecodingFields, - "GetKindLUT": getKindLUT, - "GetServerFields": getServerFields, - }).ParseFS(templates.FS, "*")) -} - -func New() *Generator { - return &Generator{ - options: &protogen.Options{ - ParamFunc: func(name string, value string) error { return nil }, - ImportRewriteFunc: func(path protogen.GoImportPath) protogen.GoImportPath { return path }, - }, - } -} - -func (*Generator) UnmarshalRequest(buf []byte) (*pluginpb.CodeGeneratorRequest, error) { - req := new(pluginpb.CodeGeneratorRequest) - return req, proto.Unmarshal(buf, req) -} - -func (*Generator) MarshalResponse(res *pluginpb.CodeGeneratorResponse) ([]byte, error) { - return proto.Marshal(res) -} - -func (g *Generator) Generate(req *pluginpb.CodeGeneratorRequest) (res *pluginpb.CodeGeneratorResponse, err error) { - plugin, err := g.options.New(req) - if err != nil { - return nil, err - } - - for _, f := range plugin.Files { - if !f.Generate { - continue - } - genFile := plugin.NewGeneratedFile(fileName(f.GeneratedFilenamePrefix), f.GoImportPath) - - packageName := string(f.Desc.Package().Name()) - if packageName == "" { - packageName = string(f.GoPackageName) - } - - err = templ.ExecuteTemplate(genFile, "base.templ", map[string]interface{}{ - "pluginVersion": version.Version, - "sourcePath": f.Desc.Path(), - "package": packageName, - "imports": requiredImports, - "enums": f.Desc.Enums(), - "messages": f.Desc.Messages(), - "services": f.Desc.Services(), - }) - if err != nil { - return nil, err - } - } - - return plugin.Response(), nil -} diff --git a/protoc-gen-frpc/pkg/generator/headers.go b/protoc-gen-frpc/pkg/generator/headers.go deleted file mode 100644 index e652052..0000000 --- a/protoc-gen-frpc/pkg/generator/headers.go +++ /dev/null @@ -1,25 +0,0 @@ -/* - Copyright 2022 Loophole Labs - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package generator - -import ( - "github.com/loopholelabs/frisbee/protoc-gen-frpc/internal/utils" -) - -func fileName(name string) string { - return utils.AppendString(name, extension) -} diff --git a/protoc-gen-frpc/pkg/generator/imports.go b/protoc-gen-frpc/pkg/generator/imports.go deleted file mode 100644 index 5d584ee..0000000 --- a/protoc-gen-frpc/pkg/generator/imports.go +++ /dev/null @@ -1,30 +0,0 @@ -/* - Copyright 2022 Loophole Labs - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package generator - -var ( - requiredImports = []string{ - "github.com/loopholelabs/frisbee", - "github.com/loopholelabs/frisbee/pkg/packet", - "github.com/rs/zerolog", - "crypto/tls", - "github.com/pkg/errors", - "context", - "sync", - "sync/atomic", - } -) diff --git a/protoc-gen-frpc/pkg/generator/server.go b/protoc-gen-frpc/pkg/generator/server.go deleted file mode 100644 index 59f3e4c..0000000 --- a/protoc-gen-frpc/pkg/generator/server.go +++ /dev/null @@ -1,39 +0,0 @@ -/* - Copyright 2022 Loophole Labs - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package generator - -import ( - "github.com/loopholelabs/frisbee/protoc-gen-frpc/internal/utils" - "google.golang.org/protobuf/reflect/protoreflect" - "strings" -) - -func getServerFields(services protoreflect.ServiceDescriptors) string { - builder := new(strings.Builder) - for i := 0; i < services.Len(); i++ { - service := services.Get(i) - serviceName := utils.CamelCase(string(service.Name())) - builder.WriteString(utils.FirstLowerCase(serviceName)) - builder.WriteString(space) - builder.WriteString(serviceName) - builder.WriteString(comma) - builder.WriteString(space) - } - serverFields := builder.String() - serverFields = serverFields[:len(serverFields)-2] - return serverFields -} diff --git a/protoc-gen-frpc/pkg/generator/structs.go b/protoc-gen-frpc/pkg/generator/structs.go deleted file mode 100644 index a06d92f..0000000 --- a/protoc-gen-frpc/pkg/generator/structs.go +++ /dev/null @@ -1,243 +0,0 @@ -/* - Copyright 2022 Loophole Labs - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package generator - -import ( - "errors" - "fmt" - "github.com/loopholelabs/frisbee/protoc-gen-frpc/internal/utils" - "google.golang.org/protobuf/reflect/protoreflect" -) - -var ( - errUnknownKind = errors.New("unknown or unsupported protoreflect.Kind") - errUnknownCardinality = errors.New("unknown or unsupported protoreflect.Cardinality") -) - -var ( - typeLUT = map[protoreflect.Kind]string{ - protoreflect.BoolKind: "bool", - protoreflect.Int32Kind: "int32", - protoreflect.Sint32Kind: "int32", - protoreflect.Uint32Kind: "uint32", - protoreflect.Int64Kind: "int64", - protoreflect.Sint64Kind: "int64", - protoreflect.Uint64Kind: "uint64", - protoreflect.Sfixed32Kind: "int32", - protoreflect.Sfixed64Kind: "int64", - protoreflect.Fixed32Kind: "uint32", - protoreflect.Fixed64Kind: "uint64", - protoreflect.FloatKind: "float32", - protoreflect.DoubleKind: "float64", - protoreflect.StringKind: "string", - protoreflect.BytesKind: "[]byte", - } - - encodeLUT = map[protoreflect.Kind]string{ - protoreflect.BoolKind: ".Bool", - protoreflect.Int32Kind: ".Int32", - protoreflect.Sint32Kind: ".Int32", - protoreflect.Uint32Kind: ".Uint32", - protoreflect.Int64Kind: ".Int64", - protoreflect.Sint64Kind: ".Int64", - protoreflect.Uint64Kind: ".Uint64", - protoreflect.Sfixed32Kind: ".Int32", - protoreflect.Sfixed64Kind: ".Int64", - protoreflect.Fixed32Kind: ".Uint32", - protoreflect.Fixed64Kind: ".Uint64", - protoreflect.StringKind: ".String", - protoreflect.FloatKind: ".Float32", - protoreflect.DoubleKind: ".Float64", - protoreflect.BytesKind: ".Bytes", - protoreflect.EnumKind: ".Uint32", - } - - decodeLUT = map[protoreflect.Kind]string{ - protoreflect.BoolKind: ".Bool", - protoreflect.Int32Kind: ".Int32", - protoreflect.Sint32Kind: ".Int32", - protoreflect.Uint32Kind: ".Uint32", - protoreflect.Int64Kind: ".Int64", - protoreflect.Sint64Kind: ".Int64", - protoreflect.Uint64Kind: ".Uint64", - protoreflect.Sfixed32Kind: ".Int32", - protoreflect.Sfixed64Kind: ".Int64", - protoreflect.Fixed32Kind: ".Uint32", - protoreflect.Fixed64Kind: ".Uint64", - protoreflect.StringKind: ".String", - protoreflect.FloatKind: ".Float32", - protoreflect.DoubleKind: ".Float64", - protoreflect.BytesKind: ".Bytes", - protoreflect.EnumKind: ".Uint32", - } - - kindLUT = map[protoreflect.Kind]string{ - protoreflect.BoolKind: "packet.BoolKind", - protoreflect.Int32Kind: "packet.Int32Kind", - protoreflect.Sint32Kind: "packet.Int32Kind", - protoreflect.Uint32Kind: "packet.Uint32Kind", - protoreflect.Int64Kind: "packet.Int64Kind", - protoreflect.Sint64Kind: "packet.Int64Kind", - protoreflect.Uint64Kind: "packet.Uint64Kind", - protoreflect.Sfixed32Kind: "packet.Int32Kind", - protoreflect.Sfixed64Kind: "packet.Int64Kind", - protoreflect.Fixed32Kind: "packet.Uint32Kind", - protoreflect.Fixed64Kind: "packet.Uint64Kind", - protoreflect.StringKind: "packet.StringKind", - protoreflect.FloatKind: "packet.Float32Kind", - protoreflect.DoubleKind: "packet.Float64Kind", - protoreflect.BytesKind: "packet.BytesKind", - protoreflect.EnumKind: "packet.Uint32Kind", - } -) - -func findValue(field protoreflect.FieldDescriptor) string { - if kind, ok := typeLUT[field.Kind()]; !ok { - switch field.Kind() { - case protoreflect.EnumKind: - switch field.Cardinality() { - case protoreflect.Optional, protoreflect.Required: - return utils.CamelCase(string(field.Enum().FullName())) - case protoreflect.Repeated: - return utils.CamelCase(utils.AppendString(slice, string(field.Enum().FullName()))) - default: - panic(errUnknownCardinality) - } - case protoreflect.MessageKind: - if field.IsMap() { - return utils.CamelCase(utils.AppendString(string(field.FullName()), mapSuffix)) - } else { - switch field.Cardinality() { - case protoreflect.Optional, protoreflect.Required: - return utils.AppendString(pointer, utils.CamelCase(string(field.Message().FullName()))) - case protoreflect.Repeated: - return utils.AppendString(slice, pointer, utils.CamelCase(string(field.Message().FullName()))) - default: - panic(errUnknownCardinality) - } - } - default: - panic(errUnknownKind) - } - } else { - if field.Cardinality() == protoreflect.Repeated { - kind = slice + kind - } - return kind - } -} - -type encodingFields struct { - MessageFields []protoreflect.FieldDescriptor - SliceFields []protoreflect.FieldDescriptor - Values []string -} - -func getEncodingFields(fields protoreflect.FieldDescriptors) encodingFields { - var messageFields []protoreflect.FieldDescriptor - var sliceFields []protoreflect.FieldDescriptor - var values []string - - for i := 0; i < fields.Len(); i++ { - field := fields.Get(i) - if field.Cardinality() == protoreflect.Repeated && !field.IsMap() { - sliceFields = append(sliceFields, field) - } else { - if encoder, ok := encodeLUT[field.Kind()]; !ok { - switch field.Kind() { - case protoreflect.MessageKind: - messageFields = append(messageFields, field) - default: - panic(errUnknownKind) - } - } else { - if field.Kind() == protoreflect.EnumKind { - values = append(values, fmt.Sprintf("%s(uint32(x.%s))", encoder, utils.CamelCase(string(field.Name())))) - } else { - values = append(values, fmt.Sprintf("%s(x.%s)", encoder, utils.CamelCase(string(field.Name())))) - } - } - } - } - return encodingFields{ - MessageFields: messageFields, - SliceFields: sliceFields, - Values: values, - } -} - -type decodingFields struct { - MessageFields []protoreflect.FieldDescriptor - SliceFields []protoreflect.FieldDescriptor - Other []protoreflect.FieldDescriptor -} - -func getDecodingFields(fields protoreflect.FieldDescriptors) decodingFields { - var messageFields []protoreflect.FieldDescriptor - var sliceFields []protoreflect.FieldDescriptor - var other []protoreflect.FieldDescriptor - - for i := 0; i < fields.Len(); i++ { - field := fields.Get(i) - if field.Cardinality() == protoreflect.Repeated && !field.IsMap() { - sliceFields = append(sliceFields, field) - } else { - if _, ok := decodeLUT[field.Kind()]; !ok { - switch field.Kind() { - case protoreflect.MessageKind: - messageFields = append(messageFields, field) - default: - panic(errUnknownKind) - } - } else { - other = append(other, field) - } - } - } - - return decodingFields{ - MessageFields: messageFields, - SliceFields: sliceFields, - Other: other, - } -} - -func getKind(kind protoreflect.Kind) string { - var outKind string - var ok bool - if outKind, ok = kindLUT[kind]; !ok { - switch kind { - case protoreflect.MessageKind: - outKind = packetAnyKind - default: - panic(errUnknownKind) - } - } - return outKind -} - -func getLUTEncoder(kind protoreflect.Kind) string { - return encodeLUT[kind] -} - -func getLUTDecoder(kind protoreflect.Kind) string { - return decodeLUT[kind] -} - -func getKindLUT(kind protoreflect.Kind) string { - return kindLUT[kind] -} diff --git a/protoc-gen-frpc/templates/base.templ b/protoc-gen-frpc/templates/base.templ deleted file mode 100644 index 7b67c00..0000000 --- a/protoc-gen-frpc/templates/base.templ +++ /dev/null @@ -1,17 +0,0 @@ -{{template "headers" .}} - -{{template "imports" .}} - -var ( - NilDecode = errors.New("cannot decode into a nil root struct") -) - -{{template "enums" .}} - -{{template "messages" .}} - -{{template "interfaces" .}} - -{{template "server" .}} - -{{template "client" .}} \ No newline at end of file diff --git a/protoc-gen-frpc/templates/client.templ b/protoc-gen-frpc/templates/client.templ deleted file mode 100644 index 96c42a8..0000000 --- a/protoc-gen-frpc/templates/client.templ +++ /dev/null @@ -1,126 +0,0 @@ -{{define "client"}} -type Client struct { - *frisbee.Client - {{ range $i, $v := (MakeIterable .services.Len) -}} - {{ $service := $.services.Get $i -}} - {{ range $i, $v := (MakeIterable $service.Methods.Len) -}} - {{ $method := $service.Methods.Get $i -}} - next{{ CamelCaseName $method.Name }} atomic.Value - inflight{{ CamelCaseName $method.Name }}Mu sync.RWMutex - inflight{{ CamelCaseName $method.Name }} map[uint16]chan *{{ CamelCase $method.Output.FullName }} - {{end -}} - {{end -}} -} - -func NewClient (tlsConfig *tls.Config, logger *zerolog.Logger) (*Client, error) { - c := new(Client) - table := make(frisbee.HandlerTable) - {{template "clienthandlers" .services -}} - - var err error - if tlsConfig != nil { - c.Client, err = frisbee.NewClient(table, context.Background(), frisbee.WithTLS(tlsConfig), frisbee.WithLogger(logger)) - if err != nil { - return nil, err - } - } else { - c.Client, err = frisbee.NewClient(table, context.Background(), frisbee.WithLogger(logger)) - if err != nil { - return nil, err - } - } - - {{ range $i, $v := (MakeIterable .services.Len) -}} - {{ $service := $.services.Get $i -}} - {{ range $i, $v := (MakeIterable $service.Methods.Len) -}} - {{ $method := $service.Methods.Get $i -}} - c.next{{ CamelCaseName $method.Name }}.Store(uint16(0)) - c.inflight{{ CamelCaseName $method.Name }} = make(map[uint16]chan *{{ CamelCase $method.Output.FullName }}) - {{end -}} - {{end -}} - return c, nil -} - -{{template "clientmethods" .services }} -{{end}} - -{{define "clienthandlers"}} -{{ $counter := Counter 9 -}} -{{ range $i, $v := (MakeIterable .Len) }} - {{ $service := $.Get $i -}} - {{ range $i, $v := (MakeIterable $service.Methods.Len) -}} - {{ $method := $service.Methods.Get $i -}} - {{ $count := call $counter -}} - table[{{ $count }}] = func(ctx context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action frisbee.Action) { - c.inflight{{ CamelCaseName $method.Name }}Mu.RLock() - if ch, ok := c.inflight{{ CamelCaseName $method.Name }}[incoming.Metadata.Id]; ok { - c.inflight{{ CamelCaseName $method.Name }}Mu.RUnlock() - res := New{{ CamelCase $method.Output.FullName }}() - res.Decode(incoming.Content.B) - ch <- res - } else { - c.inflight{{ CamelCaseName $method.Name }}Mu.RUnlock() - } - return - } - {{end -}} -{{end -}} -{{end}} - -{{define "clientmethods"}} -{{ $counter := Counter 9 -}} -{{ range $i, $v := (MakeIterable .Len) -}} - {{ $service := $.Get $i -}} - {{ range $i, $v := (MakeIterable $service.Methods.Len) }} - {{ $method := $service.Methods.Get $i -}} - {{ $opIndex := call $counter -}} - func (c *Client) {{ CamelCaseName $method.Name }}(ctx context.Context, req *{{ CamelCase $method.Input.FullName }}) (res *{{ CamelCase $method.Output.FullName }}, err error) { - ch := make(chan *{{ CamelCase $method.Output.FullName }}, 1) - p := packet.Get() - p.Metadata.Operation = {{ $opIndex }} - LOOP: - p.Metadata.Id = c.next{{ CamelCaseName $method.Name }}.Load().(uint16) - if !c.next{{ CamelCaseName $method.Name }}.CompareAndSwap(p.Metadata.Id, p.Metadata.Id+1) { - goto LOOP - } - req.Encode(p) - p.Metadata.ContentLength = uint32(len(p.Content.B)) - c.inflight{{ CamelCaseName $method.Name }}Mu.Lock() - c.inflight{{ CamelCaseName $method.Name }}[p.Metadata.Id] = ch - c.inflight{{ CamelCaseName $method.Name }}Mu.Unlock() - err = c.Client.WritePacket(p) - if err != nil { - packet.Put(p) - return - } - select { - case res = <- ch: - err = res.error - case <- ctx.Done(): - err = ctx.Err() - } - c.inflight{{ CamelCaseName $method.Name }}Mu.Lock() - delete(c.inflight{{ CamelCaseName $method.Name }}, p.Metadata.Id) - c.inflight{{ CamelCaseName $method.Name }}Mu.Unlock() - packet.Put(p) - return - } - - func (c *Client) {{ CamelCaseName $method.Name }}Ignore(ctx context.Context, req *{{ CamelCase $method.Input.FullName }}) (err error) { - p := packet.Get() - p.Metadata.Operation = {{ $opIndex }} - LOOP: - p.Metadata.Id = c.next{{ CamelCaseName $method.Name }}.Load().(uint16) - if !c.next{{ CamelCaseName $method.Name }}.CompareAndSwap(p.Metadata.Id, p.Metadata.Id+1) { - goto LOOP - } - req.ignore = true - req.Encode(p) - p.Metadata.ContentLength = uint32(len(p.Content.B)) - err = c.Client.WritePacket(p) - packet.Put(p) - return - } - {{end -}} -{{end -}} -{{end}} diff --git a/protoc-gen-frpc/templates/decode.templ b/protoc-gen-frpc/templates/decode.templ deleted file mode 100644 index 8dc6e7d..0000000 --- a/protoc-gen-frpc/templates/decode.templ +++ /dev/null @@ -1,97 +0,0 @@ -{{define "decode"}} -func (x *{{ CamelCase .FullName }}) Decode (b []byte) error { - if x == nil { - return NilDecode - } - d := packet.GetDecoder(b) - defer d.Return() - return x.decode(d) -} -{{end}} - -{{define "internalDecode"}} -func (x *{{CamelCase .FullName}}) decode(d *packet.Decoder) error { - if d.Nil() { - return nil - } - var err error - x.error, err = d.Error() - if err != nil { - x.ignore, err = d.Bool() - if err != nil { - return err - } - {{ $decoding := GetDecodingFields .Fields -}} - {{ range $field := $decoding.Other -}} - {{ $decoder := GetLUTDecoder $field.Kind -}} - {{ if eq $field.Kind 12 -}} {{/* protoreflect.BytesKind */ -}} - x.{{ CamelCaseName $field.Name }}, err = d{{ $decoder }}(nil) - {{ else if eq $field.Kind 14 -}} {{/* protoreflect.EnumKind */ -}} - var {{ CamelCaseName $field.Name }}Temp uint32 - {{ CamelCaseName $field.Name }}Temp, err = d{{ $decoder }}() - x.{{ CamelCaseName $field.Name }} = {{ FindValue $field }}({{ CamelCaseName $field.Name }}Temp) - {{ else -}} - x.{{ CamelCaseName $field.Name }}, err = d{{ $decoder }}() - {{end -}} - if err != nil { - return err - } - {{end -}} - - {{ if $decoding.SliceFields -}} - var sliceSize uint32 - {{end -}} - {{ range $field := $decoding.SliceFields -}} - {{ $kind := GetKind $field.Kind -}} - sliceSize, err = d.Slice({{ $kind }}) - if err != nil { - return err - } - if uint32(len(x.{{ CamelCaseName $field.Name }})) != sliceSize { - x.{{ CamelCaseName $field.Name }} = make({{ FindValue $field }}, sliceSize) - } - for i := uint32(0); i < sliceSize; i++ { - {{ $decoder := GetLUTDecoder $field.Kind -}} - {{ if eq $field.Kind 11 -}} {{/* protoreflect.MessageKind */ -}} - if x.{{ CamelCaseName $field.Name }}[i] == nil { - x.{{ CamelCaseName $field.Name }}[i] = New{{ CamelCase $field.Message.FullName }}() - } - err = x.{{ CamelCaseName $field.Name }}[i].decode(d) - {{ else -}} - x.{{ CamelCaseName $field.Name }}[i], err = d{{ $decoder }}() - {{end -}} - if err != nil { - return err - } - } - {{end -}} - {{ range $field := $decoding.MessageFields -}} - {{ if $field.IsMap -}} - if !d.Nil() { - {{ $keyKind := GetKind $field.MapKey.Kind -}} - {{ $valKind := GetKind $field.MapValue.Kind -}} - - {{ CamelCaseName $field.Name }}Size, err := d.Map({{ $keyKind }}, {{ $valKind }}) - if err != nil { - return err - } - x.{{ CamelCaseName $field.Name }} = New{{ CamelCase $field.FullName }}Map({{ CamelCaseName $field.Name }}Size) - err = x.{{ CamelCaseName $field.Name }}.decode(d, {{ CamelCaseName $field.Name }}Size) - if err != nil { - return err - } - } - {{ else -}} - if x.{{ CamelCaseName $field.Name }} == nil { - x.{{ CamelCaseName $field.Name }} = New{{ CamelCase $field.Message.FullName }}() - } - err = x.{{ CamelCaseName $field.Name }}.decode(d) - if err != nil { - return err - } - {{end -}} - {{end -}} - } - return nil -} -{{end}} \ No newline at end of file diff --git a/protoc-gen-frpc/templates/decodeMap.templ b/protoc-gen-frpc/templates/decodeMap.templ deleted file mode 100644 index 2cb0dfe..0000000 --- a/protoc-gen-frpc/templates/decodeMap.templ +++ /dev/null @@ -1,51 +0,0 @@ -{{define "decodeMap"}} -func (x {{CamelCase .FullName}}Map) decode(d *packet.Decoder, size uint32) error { - if size == 0 { - return nil - } - var k {{ FindValue .MapKey }} - {{ if eq .MapKey.Kind 14 -}} {{/* protoreflect.EnumKind */ -}} - var {{ CamelCase .MapKey.Name }}Temp uint32 - {{end -}} - var v {{ FindValue .MapValue }} - {{ if eq .MapValue.Kind 14 -}} {{/* protoreflect.EnumKind */ -}} - var {{ CamelCaseName .MapValue.Name }}Temp uint32 - {{end -}} - var err error - for i := uint32(0); i < size; i++ { - {{ $keyDecoder := GetLUTDecoder .MapKey.Kind -}} - {{ if and (eq $keyDecoder "") (eq .MapKey.Kind 11) -}} {{/* protoreflect.MessageKind */ -}} - k = New{{ CamelCase .MapKey.Message.FullName }}() - err = k.decode(d) - {{else -}} - {{ if eq .MapKey.Kind 14 -}} {{/* protoreflect.EnumKind */ -}} - {{ CamelCase .MapKey.Name }}Temp, err = d{{$keyDecoder}}() - k = {{ FindValue .MapKey }}({{ CamelCase .MapKey.Name }}Temp) - {{else -}} - k, err = d{{$keyDecoder}}() - {{end -}} - {{end -}} - if err != nil { - return err - } - {{ $valDecoder := GetLUTDecoder .MapValue.Kind -}} - {{ if and (eq $valDecoder "") (eq .MapValue.Kind 11) -}} {{/* protoreflect.MessageKind */ -}} - v = New{{ CamelCase .MapValue.Message.FullName }}() - err = v.decode(d) - {{else -}} - {{ if eq .MapValue.Kind 14 -}} {{/* protoreflect.EnumKind */ -}} - {{CamelCaseName .MapValue.Name}}Temp, err = d{{$valDecoder}}() - v = {{ FindValue .MapValue }}({{ CamelCaseName .MapValue.Name }}Temp) - {{else -}} - v, err = d{{$valDecoder}}() - {{end -}} - {{end -}} - - if err != nil { - return err - } - x[k] = v - } - return nil -} -{{end}} \ No newline at end of file diff --git a/protoc-gen-frpc/templates/encode.templ b/protoc-gen-frpc/templates/encode.templ deleted file mode 100644 index 4021ed6..0000000 --- a/protoc-gen-frpc/templates/encode.templ +++ /dev/null @@ -1,41 +0,0 @@ -{{define "encode"}} -func (x *{{ CamelCase .FullName }}) Encode (p *packet.Packet) { - if x == nil { - packet.Encoder(p).Nil() - } else if x.error != nil { - packet.Encoder(p).Error(x.error) - } else { - {{ $encoding := GetEncodingFields .Fields -}} - packet.Encoder(p).Bool(x.ignore){{ range $val := $encoding.Values -}}{{ $val -}}{{end -}} - {{ if $encoding.SliceFields -}} - {{template "encodeSlices" $encoding -}} - {{end -}} - {{ if $encoding.MessageFields -}} - {{template "encodeMessages" $encoding -}} - {{end -}} - } -} -{{end}} - -{{define "encodeSlices"}} - {{ range $field := .SliceFields -}} - {{ $encoder := GetLUTEncoder $field.Kind -}} - {{ if and (eq $encoder "") (eq $field.Kind 11) -}} {{/* protoreflect.MessageKind */ -}} - packet.Encoder(p).Slice(uint32(len(x.{{ CamelCaseName $field.Name }})), packet.AnyKind) - for _, v := range x.{{CamelCaseName $field.Name}} { - v.Encode(p) - } - {{else -}} - packet.Encoder(p).Slice(uint32(len(x.{{ CamelCaseName $field.Name }})), {{ GetKindLUT $field.Kind }}) - for _, v := range x.{{ CamelCaseName $field.Name }} { - packet.Encoder(p){{$encoder}}(v) - } - {{end -}} - {{end -}} -{{end}} - -{{define "encodeMessages"}} - {{ range $field := .MessageFields -}} - x.{{ CamelCaseName $field.Name }}.Encode(p) - {{end -}} -{{end}} \ No newline at end of file diff --git a/protoc-gen-frpc/templates/encodeMap.templ b/protoc-gen-frpc/templates/encodeMap.templ deleted file mode 100644 index c25ffe9..0000000 --- a/protoc-gen-frpc/templates/encodeMap.templ +++ /dev/null @@ -1,33 +0,0 @@ -{{define "encodeMap"}} - func (x {{ CamelCase .FullName }}Map) Encode (p *packet.Packet) { - if x == nil { - packet.Encoder(p).Nil() - } else { - {{ $keyKind := GetKind .MapKey.Kind -}} - {{ $valKind := GetKind .MapValue.Kind -}} - packet.Encoder(p).Map(uint32(len(x)), {{$keyKind}}, {{$valKind}}) - for k, v := range x { - {{ $keyEncoder := GetLUTEncoder .MapKey.Kind -}} - {{ if and (eq $keyEncoder "") (eq .MapKey.Kind 11) -}} {{/* protoreflect.MessageKind */ -}} - k.Encode(p) - {{else -}} - {{ if eq .MapKey.Kind 14 -}} {{/* protoreflect.EnumKind */ -}} - packet.Encoder(p) {{$keyEncoder}} (uint32(k)) - {{else -}} - packet.Encoder(p) {{$keyEncoder}} (k) - {{end -}} - {{end -}} - {{ $valEncoder := GetLUTEncoder .MapValue.Kind -}} - {{ if and (eq $valEncoder "") (eq .MapValue.Kind 11) -}} {{/* protoreflect.MessageKind */ -}} - v.Encode(p) - {{else -}} - {{ if eq .MapValue.Kind 14 -}} {{/* protoreflect.EnumKind */ -}} - packet.Encoder(p) {{$valEncoder}} (uint32(v)) - {{else -}} - packet.Encoder(p) {{$valEncoder}} (v) - {{end -}} - {{end -}} - } - } - } -{{end}} diff --git a/protoc-gen-frpc/templates/enums.templ b/protoc-gen-frpc/templates/enums.templ deleted file mode 100644 index 18cd31a..0000000 --- a/protoc-gen-frpc/templates/enums.templ +++ /dev/null @@ -1,19 +0,0 @@ -{{define "enums"}} -{{range $i, $e := (MakeIterable .enums.Len) -}} -{{ $enum := ($.enums.Get $i) }} -{{template "enum" $enum}} -{{end -}} -{{end}} - -{{define "enum"}} -{{ $enumName := (CamelCase $.FullName) }} -type {{ $enumName }} uint32 - -const ( -{{range $i, $v := (MakeIterable $.Values.Len) -}} - {{ $val := ($.Values.Get $i) -}} - {{CamelCase $val.FullName}} = {{ $enumName }}({{ $i }}) -{{end -}} -) -{{end}} - diff --git a/protoc-gen-frpc/templates/headers.templ b/protoc-gen-frpc/templates/headers.templ deleted file mode 100644 index 2467a79..0000000 --- a/protoc-gen-frpc/templates/headers.templ +++ /dev/null @@ -1,6 +0,0 @@ -{{define "headers"}} -// Code generated by FRPC {{ .pluginVersion }}, DO NOT EDIT. -// source: {{ .sourcePath }} - -package {{ .package }} -{{end}} \ No newline at end of file diff --git a/protoc-gen-frpc/templates/imports.templ b/protoc-gen-frpc/templates/imports.templ deleted file mode 100644 index 6179937..0000000 --- a/protoc-gen-frpc/templates/imports.templ +++ /dev/null @@ -1,7 +0,0 @@ -{{define "imports"}} -import ( -{{range $im := .imports -}} - "{{$im}}" -{{end -}} -) -{{end}} \ No newline at end of file diff --git a/protoc-gen-frpc/templates/interfaces.templ b/protoc-gen-frpc/templates/interfaces.templ deleted file mode 100644 index 2194c80..0000000 --- a/protoc-gen-frpc/templates/interfaces.templ +++ /dev/null @@ -1,16 +0,0 @@ -{{define "interfaces"}} -{{range $i, $v := (MakeIterable .services.Len) -}} -{{ $service := ($.services.Get $i) -}} -{{template "interface" $service}} -{{end -}} -{{end}} - - -{{define "interface"}} -type {{ CamelCaseName .Name }} interface { - {{ range $i, $v := MakeIterable .Methods.Len -}} - {{ $method := $.Methods.Get $i -}} - {{ CamelCaseName $method.Name }} (context.Context, *{{ CamelCase $method.Input.FullName }}) (*{{ CamelCase $method.Output.FullName }}, error) - {{ end -}} -} -{{end}} \ No newline at end of file diff --git a/protoc-gen-frpc/templates/messages.templ b/protoc-gen-frpc/templates/messages.templ deleted file mode 100644 index 1c8c74b..0000000 --- a/protoc-gen-frpc/templates/messages.templ +++ /dev/null @@ -1,10 +0,0 @@ -{{define "messages"}} -{{range $i, $e := (MakeIterable .messages.Len) -}} - {{ $message := $.messages.Get $i }} - {{range $i, $e := (MakeIterable $message.Enums.Len) -}} - {{ $enum := ($message.Enums.Get $i) }} - {{template "enum" $enum}} - {{end}} - {{template "structs" $message}} -{{end}} -{{end}} \ No newline at end of file diff --git a/protoc-gen-frpc/templates/server.templ b/protoc-gen-frpc/templates/server.templ deleted file mode 100644 index 17baaa8..0000000 --- a/protoc-gen-frpc/templates/server.templ +++ /dev/null @@ -1,57 +0,0 @@ -{{define "server"}} -type Server struct { - *frisbee.Server -} - func NewServer({{ GetServerFields .services }}, tlsConfig *tls.Config, logger *zerolog.Logger) (*Server, error) { - table := make(frisbee.HandlerTable) - {{template "serverhandlers" .services -}} - var s *frisbee.Server - var err error - if tlsConfig != nil { - s, err = frisbee.NewServer(table, frisbee.WithTLS(tlsConfig), frisbee.WithLogger(logger)) - if err != nil { - return nil, err - } - } else { - s, err = frisbee.NewServer(table, frisbee.WithLogger(logger)) - if err != nil { - return nil, err - } - } - return &Server{ - Server: s, - }, nil - } -{{end}} - -{{define "serverhandlers"}} - {{ $counter := Counter 9 -}} - {{ range $i, $v := (MakeIterable .Len) -}} - {{ $service := $.Get $i -}} - {{ range $i, $v := (MakeIterable $service.Methods.Len) -}} - {{ $method := $service.Methods.Get $i -}} - {{ $count := call $counter -}} - table[{{ $count }}] = func(ctx context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action frisbee.Action) { - req := New{{ CamelCase $method.Input.FullName }}() - err := req.Decode(incoming.Content.B) - if err == nil { - if req.ignore { - {{ FirstLowerCaseName $service.Name }}.{{ CamelCaseName $method.Name }}(ctx, req) - } else { - var res *{{ CamelCase $method.Output.FullName }} - outgoing = incoming - outgoing.Content.Reset() - res, err = {{ FirstLowerCase (CamelCaseName $service.Name) }}.{{ CamelCaseName $method.Name }}(ctx, req) - if err != nil { - res.Error(outgoing, err) - } else { - res.Encode(outgoing) - } - outgoing.Metadata.ContentLength = uint32(len(outgoing.Content.B)) - } - } - return - } - {{end -}} - {{end -}} -{{end}} \ No newline at end of file diff --git a/protoc-gen-frpc/templates/structs.templ b/protoc-gen-frpc/templates/structs.templ deleted file mode 100644 index d758d61..0000000 --- a/protoc-gen-frpc/templates/structs.templ +++ /dev/null @@ -1,57 +0,0 @@ -{{define "structs"}} - {{ range $i, $v := (MakeIterable $.Messages.Len) }} - {{ $message := $.Messages.Get $i }} - {{ if not $message.IsMapEntry }} - {{template "structs" $message}} - {{end}} - {{end}} - {{ range $i, $v := (MakeIterable $.Fields.Len) -}} - {{ $field := $.Fields.Get $i }} - {{ if $field.IsMap }} - {{ $mapKeyValue := FindValue $field.MapKey }} - {{ $mapValueValue := FindValue $field.MapValue }} - type {{ CamelCase $field.FullName }}Map map[{{ $mapKeyValue }}]{{ $mapValueValue }} - func New{{ CamelCase $field.FullName }}Map (size uint32) map[{{ $mapKeyValue }}]{{$mapValueValue}} { - return make(map[{{ $mapKeyValue }}]{{ $mapValueValue }}, size) - } - - {{template "encodeMap" $field}} - {{template "decodeMap" $field}} - {{end}} - {{end -}} - type {{ CamelCase .FullName }} struct { - error error - ignore bool - - {{ range $i, $v := (MakeIterable $.Fields.Len) -}} - {{ $field := $.Fields.Get $i -}} - {{ $value := FindValue $field -}} - {{ CamelCaseName $field.Name }} {{ $value }} - {{end -}} - } - - {{template "getFunc" .}} - {{template "error" .}} - {{template "encode" .}} - {{template "decode" .}} - {{template "internalDecode" .}} -{{end}} - -{{define "getFunc"}} -func New{{ CamelCase .FullName }}() *{{ CamelCase .FullName }} { - return &{{ CamelCase .FullName }}{ - {{ range $i, $v := (MakeIterable .Fields.Len) -}} - {{ $field := $.Fields.Get $i -}} - {{ if and (eq $field.Kind 11) (ne $field.Cardinality 3) -}} {{/* protoreflect.MessageKind protoreflect.Repeated */ -}} - {{ CamelCaseName $field.Name }}: New{{ CamelCase $field.Message.FullName }}(), - {{end -}} - {{end -}} - } -} -{{end}} - -{{define "error"}} -func (x *{{CamelCase .FullName}}) Error(p *packet.Packet, err error) { - packet.Encoder(p).Error(err) -} -{{end}} \ No newline at end of file diff --git a/protoc-gen-frpc/templates/templates.go b/protoc-gen-frpc/templates/templates.go deleted file mode 100644 index fa88068..0000000 --- a/protoc-gen-frpc/templates/templates.go +++ /dev/null @@ -1,22 +0,0 @@ -/* - Copyright 2022 Loophole Labs - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package templates - -import "embed" - -//go:embed * -var FS embed.FS diff --git a/server.go b/server.go index 293d575..41e9cf9 100644 --- a/server.go +++ b/server.go @@ -23,7 +23,7 @@ import ( "sync" "time" - "github.com/loopholelabs/frisbee/pkg/packet" + "github.com/loopholelabs/frisbee-go/pkg/packet" "github.com/pkg/errors" "github.com/rs/zerolog" "go.uber.org/atomic" @@ -52,6 +52,7 @@ type Server struct { wg sync.WaitGroup connections map[*Async]struct{} connectionsMu sync.Mutex + startedCh chan struct{} // baseContext is used to define the base context for this Server and all incoming connections baseContext func() context.Context @@ -97,6 +98,7 @@ func NewServer(handlerTable HandlerTable, opts ...Option) (*Server, error) { options: options, shutdown: atomic.NewBool(false), connections: make(map[*Async]struct{}), + startedCh: make(chan struct{}), baseContext: defaultBaseContext, onClosed: defaultOnClosed, preWrite: defaultPreWrite, @@ -144,18 +146,25 @@ func (s *Server) Start(addr string) error { if err != nil { return err } - + s.wg.Add(1) + close(s.startedCh) return s.handleListener() } +// started returns a channel that will be closed when the server has successfully started +// +// This is meant to only be used for testing purposes. +func (s *Server) started() <-chan struct{} { + return s.startedCh +} + func (s *Server) handleListener() error { var backoff time.Duration - var newConn net.Conn - var err error for { - newConn, err = s.listener.Accept() + newConn, err := s.listener.Accept() if err != nil { if s.shutdown.Load() { + s.wg.Done() return nil } if ne, ok := err.(temporary); ok && ne.Temporary() { @@ -170,10 +179,12 @@ func (s *Server) handleListener() error { s.Logger().Warn().Err(err).Msgf("Temporary Accept Error, retrying in %s", backoff) time.Sleep(backoff) if s.shutdown.Load() { + s.wg.Done() return nil } continue } + s.wg.Done() return err } backoff = 0 @@ -212,7 +223,7 @@ HANDLE: packetCtx = s.PacketContext(packetCtx, p) } outgoing, action = handlerFunc(packetCtx, p) - if outgoing != nil && outgoing.Metadata.ContentLength == uint32(len(outgoing.Content.B)) { + if outgoing != nil && outgoing.Metadata.ContentLength == uint32(len(*outgoing.Content)) { s.preWrite() err = frisbeeConn.WritePacket(outgoing) if outgoing != p { diff --git a/server_test.go b/server_test.go index 4197bb3..d51f3ae 100644 --- a/server_test.go +++ b/server_test.go @@ -19,15 +19,16 @@ package frisbee import ( "context" "crypto/rand" - "github.com/loopholelabs/frisbee/pkg/metadata" - "github.com/loopholelabs/frisbee/pkg/packet" + "github.com/loopholelabs/frisbee-go/pkg/metadata" + "github.com/loopholelabs/frisbee-go/pkg/packet" + "github.com/loopholelabs/polyglot-go" + "github.com/loopholelabs/testing/conn" "github.com/loopholelabs/testing/conn/pair" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "io/ioutil" "net" - "runtime" "sync" "testing" ) @@ -54,8 +55,8 @@ func TestServerRaw(t *testing.T) { var rawServerConn, rawClientConn net.Conn serverHandlerTable[metadata.PacketProbe] = func(ctx context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { - conn := ctx.Value(serverConnContextKey).(*Async) - rawServerConn = conn.Raw() + c := ctx.Value(serverConnContextKey).(*Async) + rawServerConn = c.Raw() serverIsRaw <- struct{}{} return } @@ -92,7 +93,7 @@ func TestServerRaw(t *testing.T) { p.Content.Write(data) p.Metadata.ContentLength = packetSize p.Metadata.Operation = metadata.PacketPing - assert.Equal(t, data, p.Content.B) + assert.Equal(t, polyglot.Buffer(data), *p.Content) for q := 0; q < testSize; q++ { p.Metadata.Id = uint16(q) @@ -101,7 +102,7 @@ func TestServerRaw(t *testing.T) { } p.Reset() - assert.Equal(t, 0, len(p.Content.B)) + assert.Equal(t, 0, len(*p.Content)) p.Metadata.Operation = metadata.PacketProbe err = c.WritePacket(p) @@ -184,7 +185,7 @@ func TestServerStaleClose(t *testing.T) { p.Content.Write(data) p.Metadata.ContentLength = packetSize p.Metadata.Operation = metadata.PacketPing - assert.Equal(t, data, p.Content.B) + assert.Equal(t, polyglot.Buffer(data), *p.Content) for q := 0; q < testSize; q++ { p.Metadata.Id = uint16(q) @@ -204,6 +205,103 @@ func TestServerStaleClose(t *testing.T) { assert.NoError(t, err) } +func TestServerMultipleConnections(t *testing.T) { + t.Parallel() + + const testSize = 100 + const packetSize = 512 + + runner := func(t *testing.T, num int) { + finished := make([]chan struct{}, num) + clientTables := make([]HandlerTable, num) + for i := 0; i < num; i++ { + idx := i + finished[idx] = make(chan struct{}, 1) + clientTables[i] = make(HandlerTable) + clientTables[i][metadata.PacketPing] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { + finished[idx] <- struct{}{} + return + } + } + serverHandlerTable := make(HandlerTable) + serverHandlerTable[metadata.PacketPing] = func(_ context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action Action) { + if incoming.Metadata.Id == testSize-1 { + outgoing = incoming + action = CLOSE + } + return + } + + emptyLogger := zerolog.New(ioutil.Discard) + s, err := NewServer(serverHandlerTable, WithLogger(&emptyLogger)) + require.NoError(t, err) + + var wg sync.WaitGroup + + wg.Add(1) + go func() { + err := s.Start(conn.Listen) + require.NoError(t, err) + wg.Done() + }() + + <-s.started() + listenAddr := s.listener.Addr().String() + + clients := make([]*Client, num) + for i := 0; i < num; i++ { + clients[i], err = NewClient(clientTables[i], context.Background(), WithLogger(&emptyLogger)) + assert.NoError(t, err) + _, err = clients[i].Raw() + assert.ErrorIs(t, ConnectionNotInitialized, err) + + err = clients[i].Connect(listenAddr) + require.NoError(t, err) + } + + data := make([]byte, packetSize) + _, err = rand.Read(data) + assert.NoError(t, err) + + var clientWg sync.WaitGroup + for i := 0; i < num; i++ { + idx := i + clientWg.Add(1) + go func() { + p := packet.Get() + p.Content.Write(data) + p.Metadata.ContentLength = packetSize + p.Metadata.Operation = metadata.PacketPing + assert.Equal(t, polyglot.Buffer(data), *p.Content) + for q := 0; q < testSize; q++ { + p.Metadata.Id = uint16(q) + err := clients[idx].WritePacket(p) + assert.NoError(t, err) + } + <-finished[idx] + err := clients[idx].Close() + assert.NoError(t, err) + clientWg.Done() + packet.Put(p) + }() + } + + clientWg.Wait() + + err = s.Shutdown() + assert.NoError(t, err) + wg.Wait() + + } + + t.Run("1", func(t *testing.T) { runner(t, 1) }) + t.Run("2", func(t *testing.T) { runner(t, 2) }) + t.Run("3", func(t *testing.T) { runner(t, 3) }) + t.Run("5", func(t *testing.T) { runner(t, 5) }) + t.Run("10", func(t *testing.T) { runner(t, 10) }) + t.Run("100", func(t *testing.T) { runner(t, 100) }) +} + func BenchmarkThroughputServer(b *testing.B) { const testSize = 1<<16 - 1 const packetSize = 512 @@ -346,104 +444,3 @@ func BenchmarkThroughputResponseServer(b *testing.B) { b.Fatal(err) } } - -func BenchmarkAsyncThroughputNetworkMultiple(b *testing.B) { - const testSize = 100 - - throughputRunner := func(testSize uint32, packetSize uint32, readerConn Conn, writerConn Conn) func(b *testing.B) { - return func(b *testing.B) { - var err error - - randomData := make([]byte, packetSize) - - p := packet.Get() - p.Metadata.Id = 64 - p.Metadata.Operation = 32 - p.Content.Write(randomData) - p.Metadata.ContentLength = packetSize - for i := 0; i < b.N; i++ { - done := make(chan struct{}, 1) - errCh := make(chan error, 1) - go func() { - for i := uint32(0); i < testSize; i++ { - p, err := readerConn.ReadPacket() - if err != nil { - errCh <- err - return - } - packet.Put(p) - } - done <- struct{}{} - }() - for i := uint32(0); i < testSize; i++ { - select { - case err = <-errCh: - b.Fatal(err) - default: - err = writerConn.WritePacket(p) - if err != nil { - b.Fatal(err) - } - } - } - select { - case <-done: - continue - case err = <-errCh: - b.Fatal(err) - } - } - - packet.Put(p) - } - } - - runner := func(numClients int, packetSize uint32) func(b *testing.B) { - return func(b *testing.B) { - var wg sync.WaitGroup - wg.Add(numClients) - b.SetBytes(int64(testSize * packetSize)) - b.ReportAllocs() - for i := 0; i < numClients; i++ { - go func() { - emptyLogger := zerolog.New(ioutil.Discard) - - reader, writer, err := pair.New() - if err != nil { - b.Error(err) - } - - readerConn := NewAsync(reader, &emptyLogger) - writerConn := NewAsync(writer, &emptyLogger) - throughputRunner(testSize, packetSize, readerConn, writerConn)(b) - - _ = readerConn.Close() - _ = writerConn.Close() - wg.Done() - }() - } - wg.Wait() - } - } - - b.Run("1 Pair, 32 Bytes", runner(1, 32)) - b.Run("2 Pair, 32 Bytes", runner(2, 32)) - b.Run("5 Pair, 32 Bytes", runner(5, 32)) - b.Run("10 Pair, 32 Bytes", runner(10, 32)) - b.Run("Half CPU Pair, 32 Bytes", runner(runtime.NumCPU()/2, 32)) - b.Run("CPU Pair, 32 Bytes", runner(runtime.NumCPU(), 32)) - - b.Run("1 Pair, 512 Bytes", runner(1, 512)) - b.Run("2 Pair, 512 Bytes", runner(2, 512)) - b.Run("5 Pair, 512 Bytes", runner(5, 512)) - b.Run("10 Pair, 512 Bytes", runner(10, 512)) - b.Run("Half CPU Pair, 512 Bytes", runner(runtime.NumCPU()/2, 512)) - b.Run("CPU Pair, 512 Bytes", runner(runtime.NumCPU(), 512)) - - b.Run("1 Pair, 4096 Bytes", runner(1, 4096)) - b.Run("2 Pair, 4096 Bytes", runner(2, 4096)) - b.Run("5 Pair, 4096 Bytes", runner(5, 4096)) - b.Run("10 Pair, 4096 Bytes", runner(10, 4096)) - b.Run("Half CPU Pair, 4096 Bytes", runner(runtime.NumCPU()/2, 4096)) - b.Run("CPU Pair, 4096 Bytes", runner(runtime.NumCPU(), 4096)) -} diff --git a/sync.go b/sync.go index 7d28835..f7c4cb4 100644 --- a/sync.go +++ b/sync.go @@ -25,9 +25,9 @@ import ( "sync" "time" - "github.com/loopholelabs/frisbee/internal/dialer" - "github.com/loopholelabs/frisbee/pkg/metadata" - "github.com/loopholelabs/frisbee/pkg/packet" + "github.com/loopholelabs/frisbee-go/internal/dialer" + "github.com/loopholelabs/frisbee-go/pkg/metadata" + "github.com/loopholelabs/frisbee-go/pkg/packet" "github.com/pkg/errors" "github.com/rs/zerolog" "go.uber.org/atomic" @@ -141,7 +141,7 @@ func (c *Sync) RemoteAddr() net.Addr { // // If packet.Metadata.ContentLength == 0, then the content array must be nil. Otherwise, it is required that packet.Metadata.ContentLength == len(content). func (c *Sync) WritePacket(p *packet.Packet) error { - if int(p.Metadata.ContentLength) != len(p.Content.B) { + if int(p.Metadata.ContentLength) != len(*p.Content) { return InvalidContentLength } @@ -168,7 +168,7 @@ func (c *Sync) WritePacket(p *packet.Packet) error { return c.closeWithError(err) } if p.Metadata.ContentLength != 0 { - _, err = c.conn.Write(p.Content.B[:p.Metadata.ContentLength]) + _, err = c.conn.Write((*p.Content)[:p.Metadata.ContentLength]) if err != nil { c.Unlock() if c.closed.Load() { @@ -208,11 +208,11 @@ func (c *Sync) ReadPacket() (*packet.Packet, error) { p.Metadata.ContentLength = binary.BigEndian.Uint32(encodedPacket[metadata.ContentLengthOffset : metadata.ContentLengthOffset+metadata.ContentLengthSize]) if p.Metadata.ContentLength > 0 { - for cap(p.Content.B) < int(p.Metadata.ContentLength) { - p.Content.B = append(p.Content.B[:cap(p.Content.B)], 0) + for cap(*p.Content) < int(p.Metadata.ContentLength) { + *p.Content = append((*p.Content)[:cap(*p.Content)], 0) } - p.Content.B = p.Content.B[:p.Metadata.ContentLength] - _, err = io.ReadAtLeast(c.conn, p.Content.B, int(p.Metadata.ContentLength)) + *p.Content = (*p.Content)[:p.Metadata.ContentLength] + _, err = io.ReadAtLeast(c.conn, *p.Content, int(p.Metadata.ContentLength)) if err != nil { if c.closed.Load() { c.Logger().Debug().Err(ConnectionClosed).Msg("error while reading from underlying net.Conn") diff --git a/sync_test.go b/sync_test.go index 5659144..f434288 100644 --- a/sync_test.go +++ b/sync_test.go @@ -18,7 +18,8 @@ package frisbee import ( "crypto/rand" - "github.com/loopholelabs/frisbee/pkg/packet" + "github.com/loopholelabs/frisbee-go/pkg/packet" + "github.com/loopholelabs/polyglot-go" "github.com/loopholelabs/testing/conn/pair" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" @@ -55,7 +56,7 @@ func TestNewSync(t *testing.T) { assert.Equal(t, uint16(64), p.Metadata.Id) assert.Equal(t, uint16(32), p.Metadata.Operation) assert.Equal(t, uint32(0), p.Metadata.ContentLength) - assert.Equal(t, 0, len(p.Content.B)) + assert.Equal(t, 0, len(*p.Content)) end <- struct{}{} packet.Put(p) }() @@ -79,8 +80,8 @@ func TestNewSync(t *testing.T) { assert.Equal(t, uint16(64), p.Metadata.Id) assert.Equal(t, uint16(32), p.Metadata.Operation) assert.Equal(t, uint32(packetSize), p.Metadata.ContentLength) - assert.Equal(t, packetSize, len(p.Content.B)) - assert.Equal(t, data, p.Content.B) + assert.Equal(t, packetSize, len(*p.Content)) + assert.Equal(t, polyglot.Buffer(data), *p.Content) end <- struct{}{} packet.Put(p) }() @@ -130,8 +131,8 @@ func TestSyncLargeWrite(t *testing.T) { assert.Equal(t, uint16(64), p.Metadata.Id) assert.Equal(t, uint16(32), p.Metadata.Operation) assert.Equal(t, uint32(packetSize), p.Metadata.ContentLength) - assert.Equal(t, packetSize, len(p.Content.B)) - assert.Equal(t, randomData[i], p.Content.B) + assert.Equal(t, packetSize, len(*p.Content)) + assert.Equal(t, polyglot.Buffer(randomData[i]), *p.Content) packet.Put(p) } end <- struct{}{} @@ -191,8 +192,8 @@ func TestSyncRawConn(t *testing.T) { assert.Equal(t, uint16(64), p.Metadata.Id) assert.Equal(t, uint16(32), p.Metadata.Operation) assert.Equal(t, uint32(packetSize), p.Metadata.ContentLength) - assert.Equal(t, packetSize, len(p.Content.B)) - assert.Equal(t, randomData, p.Content.B) + assert.Equal(t, packetSize, len(*p.Content)) + assert.Equal(t, polyglot.Buffer(randomData), *p.Content) packet.Put(p) } end <- struct{}{} @@ -255,7 +256,7 @@ func TestSyncReadClose(t *testing.T) { assert.Equal(t, uint16(64), p.Metadata.Id) assert.Equal(t, uint16(32), p.Metadata.Operation) assert.Equal(t, uint32(0), p.Metadata.ContentLength) - assert.Equal(t, 0, len(p.Content.B)) + assert.Equal(t, 0, len(*p.Content)) end <- struct{}{} packet.Put(p) }() @@ -305,7 +306,7 @@ func TestSyncWriteClose(t *testing.T) { assert.Equal(t, uint16(64), p.Metadata.Id) assert.Equal(t, uint16(32), p.Metadata.Operation) assert.Equal(t, uint32(0), p.Metadata.ContentLength) - assert.Equal(t, 0, len(p.Content.B)) + assert.Equal(t, 0, len(*p.Content)) packet.Put(p) end <- struct{}{} }()