From 785d0a67c97d67e55d894a407c65d43af61edb55 Mon Sep 17 00:00:00 2001 From: D3v Date: Wed, 29 May 2024 21:46:47 +0200 Subject: [PATCH 1/2] Brotli fix, tests refact --- .golangci.yml | 1 - compression/compression.go | 17 ++++- compression/compression_test.go | 117 +++++++++++++++++--------------- 3 files changed, 79 insertions(+), 56 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index d631a67..5a0b0b6 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -283,4 +283,3 @@ linters: - unused - gomnd - forcetypeassert - - golint \ No newline at end of file diff --git a/compression/compression.go b/compression/compression.go index cf5d8e0..33401e9 100644 --- a/compression/compression.go +++ b/compression/compression.go @@ -2,6 +2,7 @@ package compression import ( "bytes" + "errors" "io" "github.com/andybalholm/brotli" @@ -47,6 +48,8 @@ const ( BrotliBestSpeed int = 0 ) +var ErrMissingCompressionLevel = errors.New("missing compression level parameter") + type Compressor interface { Compress([]byte) ([]byte, error) Decompress([]byte) ([]byte, error) @@ -156,6 +159,10 @@ func (z *Zstd) Compress(in []byte) ([]byte, error) { } func (z *Zstd) CompressStream(in io.Reader, out io.Writer) error { + if z.Level == 0 { + return ErrMissingCompressionLevel + } + enc, err := zstd.NewWriter(out, zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(z.Level))) if err != nil { return err @@ -282,7 +289,6 @@ func (f *Flate) GetModes() []int { DefaultCompression, BestCompression, HuffmanOnly, - StatelessCompression, } } @@ -360,7 +366,6 @@ func (zl *Zlib) GetModes() []int { DefaultCompression, BestCompression, HuffmanOnly, - StatelessCompression, } } @@ -385,6 +390,10 @@ func (b *Brotli) Compress(in []byte) ([]byte, error) { } func (b *Brotli) CompressStream(in io.Reader, out io.Writer) error { + if b.bw == nil { + b.bw = brotli.NewWriterLevel(nil, b.Level) + } + b.bw.Reset(out) _, err := io.Copy(b.bw, in) if err != nil { @@ -411,6 +420,10 @@ func (b *Brotli) Decompress(in []byte) ([]byte, error) { } func (b *Brotli) DecompressStream(in io.Reader, out io.Writer) error { + if b.br == nil { + b.br = brotli.NewReader(nil) + } + if err := b.br.Reset(in); err != nil { return err } diff --git a/compression/compression_test.go b/compression/compression_test.go index de604e7..ee3fc7b 100644 --- a/compression/compression_test.go +++ b/compression/compression_test.go @@ -11,45 +11,14 @@ import ( r "github.com/stretchr/testify/require" ) -type compressor struct { - name string - compressor compression.Compressor - modes []int -} - -func compressionArlgorithms() []compressor { - genericModes := []int{compression.BestCompression, compression.BestSpeed, compression.NoCompression, compression.DefaultCompression, compression.HuffmanOnly} - zstdModes := []int{compression.ZstdSpeedBestCompression, compression.ZstdSpeedBetterCompression, compression.ZstdSpeedDefault, compression.ZstdSpeedFastest} - brotliModes := []int{compression.BrotliBestCompression, compression.BrotliDefaultCompression, compression.BrotliBestSpeed} - - compressors := []compressor{ - { - name: "zlib", - compressor: &compression.Zlib{}, - modes: genericModes, - }, - { - name: "gzip", - compressor: &compression.Gzip{}, - modes: genericModes, - }, - { - name: "zstd", - compressor: &compression.Zstd{}, - modes: zstdModes, - }, - { - name: "generic", - compressor: &compression.Flate{}, - modes: genericModes, - }, - { - name: "brotli", - compressor: &compression.Brotli{}, - modes: brotliModes, - }, +func compressors() []compression.Compressor { + return []compression.Compressor{ + &compression.Zlib{}, + &compression.Gzip{}, + &compression.Zstd{}, + &compression.Flate{}, + &compression.Brotli{}, } - return compressors } type compressionSample struct { @@ -93,16 +62,16 @@ func compressionSamples() []compressionSample { } func BenchmarkRoundTrip(b *testing.B) { - compressors := compressionArlgorithms() + compressors := compressors() compressionSamples := compressionSamples() for _, compressor := range compressors { - for _, mode := range compressor.modes { - compressor.compressor.SetLevel(mode) + for _, mode := range compressor.GetModes() { + compressor.SetLevel(mode) for _, sample := range compressionSamples { - b.Run(compressor.name+"-"+strconv.Itoa(mode)+"-"+sample.name, func(b *testing.B) { + b.Run(compressor.GetName()+"/"+strconv.Itoa(mode)+"/"+sample.name, func(b *testing.B) { for i := 0; i < b.N; i++ { - benchmarkRoundTrip(b, compressor.compressor, sample.data) + benchmarkRoundTrip(b, compressor, sample.data) } }) } @@ -138,21 +107,38 @@ func benchmarkRoundTrip(b *testing.B, compressor compression.Compressor, data [] } func TestRoundTrips(t *testing.T) { - compressors := compressionArlgorithms() + compressors := compressors() compressionSamples := compressionSamples() for _, compressor := range compressors { - for _, mode := range compressor.modes { - compressor.compressor.SetLevel(mode) + for _, mode := range compressor.GetModes() { + compressor.SetLevel(mode) for _, sample := range compressionSamples { - t.Run(compressor.name+"-"+strconv.Itoa(mode)+"-"+sample.name, func(t *testing.T) { - testRoundTrip(t, compressor.compressor, sample.data) + t.Run(compressor.GetName()+"/"+strconv.Itoa(mode)+"/"+sample.name, func(t *testing.T) { + testRoundTrip(t, compressor, sample.data) }) } } } } +func TestInterfacelessRoundTrip(t *testing.T) { + compressors := []compression.Compressor{ + &compression.Zlib{Level: compression.DefaultCompression}, + &compression.Gzip{Level: compression.DefaultCompression}, + &compression.Zstd{Level: compression.ZstdSpeedDefault}, + &compression.Flate{Level: compression.DefaultCompression}, + &compression.Brotli{Level: compression.BrotliDefaultCompression}, + } + samples := compressionSamples() + + for _, compressor := range compressors { + t.Run(compressor.GetName(), func(t *testing.T) { + testRoundTrip(t, compressor, samples[0].data) + }) + } +} + func testRoundTrip(t *testing.T, compressor compression.Compressor, data []byte) { compressed, err := compressor.Compress(data) r.NoError(t, err) @@ -166,11 +152,11 @@ func testRoundTrip(t *testing.T, compressor compression.Compressor, data []byte) r.Equal(t, compressed, compressedBuff.Bytes()) - t.Log("Compressor name: ", compressor.GetName()) - t.Log("Data sample: ", data[:16]) - t.Log("Orignal size: ", len(data)) - t.Log("Compressed size: ", compressedBuff.Len()) - t.Log("Compression mode: ", compressor.GetLevel()) + t.Log("Compressor name:", compressor.GetName()) + t.Log("Data sample:", data[:16]) + t.Log("Orignal size:", len(data)) + t.Log("Compressed size:", compressedBuff.Len()) + t.Log("Compression mode:", compressor.GetLevel()) t.Log("---") compressedReader := bytes.NewReader(compressedBuff.Bytes()) @@ -326,3 +312,28 @@ func TestZstdWrongDecompressData(t *testing.T) { err = compressor.DecompressStream(reader, &compressedBuff) r.Error(t, err) } + +func TestMissingCompressLevels(t *testing.T) { + compressors := []compression.Compressor{ + &compression.Zstd{}, + } + samples := compressionSamples() + + for _, compressor := range compressors { + t.Run(compressor.GetName(), func(t *testing.T) { + + out, err := compressor.Compress(samples[0].data) + r.ErrorIs(t, err, compression.ErrMissingCompressionLevel) + r.Nil(t, out) + }) + + t.Run("streaming/"+compressor.GetName(), func(t *testing.T) { + var out bytes.Buffer + reader := bytes.NewReader(samples[0].data) + + err := compressor.CompressStream(reader, &out) + r.ErrorIs(t, err, compression.ErrMissingCompressionLevel) + r.Nil(t, out.Bytes()) + }) + } +} From 9789e4005c4bde2a7cedc940e08e274a389877dc Mon Sep 17 00:00:00 2001 From: D3v Date: Wed, 29 May 2024 22:33:24 +0200 Subject: [PATCH 2/2] Fix aged test --- aged/age_bind_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aged/age_bind_test.go b/aged/age_bind_test.go index 9bf231a..b95d80d 100644 --- a/aged/age_bind_test.go +++ b/aged/age_bind_test.go @@ -148,7 +148,7 @@ func TestRoundTrips(t *testing.T) { name: "Compress with Zstd, no obfuscate", parameter: aged.Parameters{ Data: config.plainData, - Compressor: &compression.Zstd{}, + Compressor: &compression.Zstd{Level: compression.ZstdSpeedDefault}, }, }, { @@ -156,7 +156,7 @@ func TestRoundTrips(t *testing.T) { parameter: aged.Parameters{ Data: config.plainData, Obfuscator: &aged.AgeV1Obf{}, - Compressor: &compression.Zstd{}, + Compressor: &compression.Zstd{Level: compression.ZstdSpeedDefault}, }, }, {