Skip to content

Commit

Permalink
zstd: Improve memory usage on small streaming encodes
Browse files Browse the repository at this point in the history
Very small streams will use EncodeAll internally when closing and no header has been written.

This will pull a new encoder from the async buffer.

Instead re-use the stream encoder.

Before:
```
BenchmarkMem/flush-32         	    1359	    837989 ns/op	 7376959 B/op	     109 allocs/op
BenchmarkMem/no-flush-32      	     129	   8884753 ns/op	112044489 B/op	     254 allocs/op
```

After:
```
BenchmarkMem/flush-32         	    1254	    922593 ns/op	 7376966 B/op	     109 allocs/op
BenchmarkMem/no-flush-32      	    1488	    841270 ns/op	 7374164 B/op	      29 allocs/op
```

Test is pretty much worst case, but shows the issue nicely.
  • Loading branch information
klauspost committed Sep 18, 2024
1 parent 51aa0ec commit c3e680a
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 13 deletions.
4 changes: 3 additions & 1 deletion huff0/_generate/go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
module github.com/klauspost/compress/s2/_generate

go 1.19
go 1.21

toolchain go1.22.4

require (
github.com/klauspost/compress v1.15.15
Expand Down
4 changes: 3 additions & 1 deletion s2/cmd/_s2sx/go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
module github.com/klauspost/compress/s2/cmd/s2sx

go 1.19
go 1.21

toolchain go1.22.4

require github.com/klauspost/compress v1.11.9

Expand Down
4 changes: 3 additions & 1 deletion zstd/_generate/go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
module github.com/klauspost/compress/s2/_generate

go 1.19
go 1.21

toolchain go1.22.4

require (
github.com/klauspost/compress v1.15.15
Expand Down
4 changes: 3 additions & 1 deletion zstd/blockdec.go
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,9 @@ func (b *blockDec) prepareSequences(in []byte, hist *history) (err error) {
printf("RLE set to 0x%x, code: %v", symb, v)
}
case compModeFSE:
println("Reading table for", tableIndex(i))
if debugDecoder {
println("Reading table for", tableIndex(i))
}
if seq.fse == nil || seq.fse.preDefined {
seq.fse = fseDecoderPool.Get().(*fseDecoder)
}
Expand Down
19 changes: 11 additions & 8 deletions zstd/encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ func (e *Encoder) nextBlock(final bool) error {
return nil
}
if final && len(s.filling) > 0 {
s.current = e.EncodeAll(s.filling, s.current[:0])
s.current = e.encodeAll(s.encoder, s.filling, s.current[:0])
var n2 int
n2, s.err = s.w.Write(s.current)
if s.err != nil {
Expand Down Expand Up @@ -469,6 +469,15 @@ func (e *Encoder) Close() error {
// Data compressed with EncodeAll can be decoded with the Decoder,
// using either a stream or DecodeAll.
func (e *Encoder) EncodeAll(src, dst []byte) []byte {
e.init.Do(e.initialize)
enc := <-e.encoders
defer func() {
e.encoders <- enc
}()
return e.encodeAll(enc, src, dst)
}

func (e *Encoder) encodeAll(enc encoder, src, dst []byte) []byte {
if len(src) == 0 {
if e.o.fullZero {
// Add frame header.
Expand All @@ -491,13 +500,7 @@ func (e *Encoder) EncodeAll(src, dst []byte) []byte {
}
return dst
}
e.init.Do(e.initialize)
enc := <-e.encoders
defer func() {
// Release encoder reference to last block.
// If a non-single block is needed the encoder will reset again.
e.encoders <- enc
}()

// Use single segments when above minimum window and below window size.
single := len(src) <= e.o.windowSize && len(src) > MinWindowSize
if e.o.single != nil {
Expand Down
4 changes: 3 additions & 1 deletion zstd/framedec.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ func (d *frameDec) reset(br byteBuffer) error {
}
return err
}
printf("raw: %x, mantissa: %d, exponent: %d\n", wd, wd&7, wd>>3)
if debugDecoder {
printf("raw: %x, mantissa: %d, exponent: %d\n", wd, wd&7, wd>>3)
}
windowLog := 10 + (wd >> 3)
windowBase := uint64(1) << windowLog
windowAdd := (windowBase / 8) * uint64(wd&0x7)
Expand Down
112 changes: 112 additions & 0 deletions zstd/zstd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ package zstd
import (
"flag"
"fmt"
"io"
"log"
"os"
"runtime"
"runtime/pprof"
"strings"
"testing"
"time"
)
Expand Down Expand Up @@ -59,3 +62,112 @@ func TestMatchLen(t *testing.T) {
a[l] = ^a[l]
}
}

func TestWriterMemUsage(t *testing.T) {
testMem := func(t *testing.T, fn func()) {
var before, after runtime.MemStats
var w io.Writer
if false {
f, err := os.Create(strings.ReplaceAll(fmt.Sprintf("%s.pprof", t.Name()), "/", "_"))
if err != nil {
log.Fatal(err)
}
defer f.Close()
w = f
t.Logf("opened memory profile %s", t.Name())
}
runtime.GC()
runtime.ReadMemStats(&before)
fn()
runtime.GC()
runtime.ReadMemStats(&after)
if w != nil {
pprof.WriteHeapProfile(w)
}
t.Log("wrote profile")
t.Logf("%s: Memory Used: %dMB, %d allocs", t.Name(), (after.HeapInuse-before.HeapInuse)/1024/1024, after.HeapObjects-before.HeapObjects)
}
data := make([]byte, 10<<20)

t.Run("enc-all-lower", func(t *testing.T) {
for level := SpeedFastest; level <= SpeedBestCompression; level++ {
t.Run(fmt.Sprint("level-", level), func(t *testing.T) {
var zr *Encoder
var err error
dst := make([]byte, 0, len(data)*2)
testMem(t, func() {
zr, err = NewWriter(io.Discard, WithEncoderConcurrency(32), WithEncoderLevel(level), WithLowerEncoderMem(false), WithWindowSize(1<<20))
if err != nil {
t.Fatal(err)
}
for i := 0; i < 100; i++ {
_ = zr.EncodeAll(data, dst[:0])
}
})
zr.Close()
})
}
})
}

var data = []byte{1, 2, 3}

func newZstdWriter() (*Encoder, error) {
return NewWriter(
io.Discard,
WithEncoderLevel(SpeedBetterCompression),
WithEncoderConcurrency(16), // we implicitly get this concurrency level if we run on 16 core CPU
WithLowerEncoderMem(false),
WithWindowSize(1<<20),
)
}

func BenchmarkMem(b *testing.B) {
b.Run("flush", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
w, err := newZstdWriter()
if err != nil {
b.Fatal(err)
}

for j := 0; j < 16; j++ {
w.Reset(io.Discard)

if _, err := w.Write(data); err != nil {
b.Fatal(err)
}

if err := w.Flush(); err != nil {
b.Fatal(err)
}

if err := w.Close(); err != nil {
b.Fatal(err)
}
}
}
})
b.Run("no-flush", func(b *testing.B) {
// Will use encodeAll for block.
b.ReportAllocs()
for i := 0; i < b.N; i++ {
w, err := newZstdWriter()
if err != nil {
b.Fatal(err)
}

for j := 0; j < 16; j++ {
w.Reset(io.Discard)

if _, err := w.Write(data); err != nil {
b.Fatal(err)
}

if err := w.Close(); err != nil {
b.Fatal(err)
}
}
}
})
}

0 comments on commit c3e680a

Please sign in to comment.