From 1bbcf490ef20f08cae12b429b8bfbec2330dc29b Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Wed, 15 Feb 2023 11:49:44 +0100 Subject: [PATCH] Fix matrix buffer sizes (#241) Fixes #240 - retract v1.11.6 Add re-entrant tests. --- galois.go | 4 +- go.mod | 1 + reedsolomon.go | 40 +++++++++----- reedsolomon_test.go | 127 +++++++++++++++++++++++++++++++++++++++++--- 4 files changed, 153 insertions(+), 19 deletions(-) diff --git a/galois.go b/galois.go index fbb16e16..479fa447 100644 --- a/galois.go +++ b/galois.go @@ -6,7 +6,9 @@ package reedsolomon -import "encoding/binary" +import ( + "encoding/binary" +) const ( // The number of elements in the field. diff --git a/go.mod b/go.mod index 0a42e9a1..98f3ca41 100644 --- a/go.mod +++ b/go.mod @@ -10,4 +10,5 @@ require golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e // indirect retract ( v1.11.2 // https://github.com/klauspost/reedsolomon/pull/229 [v1.11.3, v1.11.5] // https://github.com/klauspost/reedsolomon/pull/238 + v1.11.6 // https://github.com/klauspost/reedsolomon/issues/240 ) diff --git a/reedsolomon.go b/reedsolomon.go index 6b11a62e..acf6da33 100644 --- a/reedsolomon.go +++ b/reedsolomon.go @@ -13,6 +13,7 @@ package reedsolomon import ( "bytes" "errors" + "fmt" "io" "runtime" "sync" @@ -171,7 +172,8 @@ type reedSolomon struct { tree *inversionTree parity [][]byte o options - mPool sync.Pool + mPoolSz int + mPool sync.Pool // Pool for temp matrices, etc } var _ = Extensions(&reedSolomon{}) @@ -571,12 +573,28 @@ func New(dataShards, parityShards int, opts ...Option) (Encoder, error) { if avx2CodeGen && r.o.useAVX2 { sz := r.dataShards * r.parityShards * 2 * 32 r.mPool.New = func() interface{} { - return make([]byte, sz) + return AllocAligned(1, sz)[0] } + r.mPoolSz = sz } return &r, err } +func (r *reedSolomon) getTmpSlice() []byte { + return r.mPool.Get().([]byte) +} + +func (r *reedSolomon) putTmpSlice(b []byte) { + if b != nil && cap(b) >= r.mPoolSz { + r.mPool.Put(b[:r.mPoolSz]) + return + } + if false { + // Sanity check + panic(fmt.Sprintf("got short tmp returned, want %d, got %d", r.mPoolSz, cap(b))) + } +} + // ErrTooFewShards is returned if too few shards where given to // Encode/Verify/Reconstruct/Update. It will also be returned from Reconstruct // if there were too few shards to reconstruct the missing data. @@ -806,16 +824,16 @@ func (r *reedSolomon) codeSomeShards(matrixRows, inputs, outputs [][]byte, byteC start += galMulSlicesGFNI(m, inputs, outputs, 0, byteCount) end = len(inputs[0]) } else if r.canAVX2C(byteCount, len(inputs), len(outputs)) { - m := genAvx2Matrix(matrixRows, len(inputs), 0, len(outputs), r.mPool.Get().([]byte)) + m := genAvx2Matrix(matrixRows, len(inputs), 0, len(outputs), r.getTmpSlice()) start += galMulSlicesAvx2(m, inputs, outputs, 0, byteCount) - r.mPool.Put(m) + r.putTmpSlice(m) end = len(inputs[0]) } else if len(inputs)+len(outputs) > avx2CodeGenMinShards && r.canAVX2C(byteCount, maxAvx2Inputs, maxAvx2Outputs) { var gfni [maxAvx2Inputs * maxAvx2Outputs]uint64 end = len(inputs[0]) inIdx := 0 - m := r.mPool.Get().([]byte) - defer r.mPool.Put(m) + m := r.getTmpSlice() + defer r.putTmpSlice(m) ins := inputs for len(ins) > 0 { inPer := ins @@ -888,8 +906,8 @@ func (r *reedSolomon) codeSomeShardsP(matrixRows, inputs, outputs [][]byte, byte var tmp [maxAvx2Inputs * maxAvx2Outputs]uint64 gfniMatrix = genGFNIMatrix(matrixRows, len(inputs), 0, len(outputs), tmp[:]) } else if useAvx2 { - avx2Matrix = genAvx2Matrix(matrixRows, len(inputs), 0, len(outputs), r.mPool.Get().([]byte)) - defer r.mPool.Put(avx2Matrix) + avx2Matrix = genAvx2Matrix(matrixRows, len(inputs), 0, len(outputs), r.getTmpSlice()) + defer r.putTmpSlice(avx2Matrix) } else if r.o.useGFNI && byteCount < 10<<20 && len(inputs)+len(outputs) > avx2CodeGenMinShards && r.canAVX2C(byteCount/4, maxAvx2Inputs, maxAvx2Outputs) { // It appears there is a switchover point at around 10MB where @@ -977,10 +995,8 @@ func (r *reedSolomon) codeSomeShardsAVXP(matrixRows, inputs, outputs [][]byte, b // Make a plan... plan := make([]state, 0, ((len(inputs)+maxAvx2Inputs-1)/maxAvx2Inputs)*((len(outputs)+maxAvx2Outputs-1)/maxAvx2Outputs)) - tmp := r.mPool.Get().([]byte) - defer func(b []byte) { - r.mPool.Put(b) - }(tmp) + tmp := r.getTmpSlice() + defer r.putTmpSlice(tmp) // Flips between input first to output first. // We put the smallest data load in the inner loop. diff --git a/reedsolomon_test.go b/reedsolomon_test.go index f7f491f8..2932787c 100644 --- a/reedsolomon_test.go +++ b/reedsolomon_test.go @@ -275,16 +275,26 @@ func TestEncoding(t *testing.T) { // matrix sizes to test. // note that par1 matrix will fail on some combinations. -var testSizes = [][2]int{ - {1, 0}, {3, 0}, {5, 0}, {8, 0}, {10, 0}, {12, 0}, {14, 0}, {41, 0}, {49, 0}, - {1, 1}, {1, 2}, {3, 3}, {3, 1}, {5, 3}, {8, 4}, {10, 30}, {12, 10}, {14, 7}, {41, 17}, {49, 1}, {5, 20}, - {256, 20}, {500, 300}, {2945, 129}, +func testSizes() [][2]int { + if testing.Short() { + return [][2]int{ + {3, 0}, + {1, 1}, {1, 2}, {8, 4}, {10, 30}, {41, 17}, + {256, 20}, {500, 300}, + } + } + return [][2]int{ + {1, 0}, {10, 0}, {12, 0}, {49, 0}, + {1, 1}, {1, 2}, {3, 3}, {3, 1}, {5, 3}, {8, 4}, {10, 30}, {12, 10}, {14, 7}, {41, 17}, {49, 1}, {5, 20}, + {256, 20}, {500, 300}, {2945, 129}, + } } + var testDataSizes = []int{10, 100, 1000, 10001, 100003, 1000055} var testDataSizesShort = []int{10, 10001, 100003} func testEncoding(t *testing.T, o ...Option) { - for _, size := range testSizes { + for _, size := range testSizes() { data, parity := size[0], size[1] rng := rand.New(rand.NewSource(0xabadc0cac01a)) t.Run(fmt.Sprintf("%dx%d", data, parity), func(t *testing.T) { @@ -398,7 +408,7 @@ func testEncoding(t *testing.T, o ...Option) { } func testEncodingIdx(t *testing.T, o ...Option) { - for _, size := range testSizes { + for _, size := range testSizes() { data, parity := size[0], size[1] rng := rand.New(rand.NewSource(0xabadc0cac01a)) t.Run(fmt.Sprintf("%dx%d", data, parity), func(t *testing.T) { @@ -2100,3 +2110,108 @@ func BenchmarkParallel_8x8x32M(b *testing.B) { benchmarkParallel(b, 8, 8, 32<< func BenchmarkParallel_8x3x1M(b *testing.B) { benchmarkParallel(b, 8, 3, 1<<20) } func BenchmarkParallel_8x4x1M(b *testing.B) { benchmarkParallel(b, 8, 4, 1<<20) } func BenchmarkParallel_8x5x1M(b *testing.B) { benchmarkParallel(b, 8, 5, 1<<20) } + +func TestReentrant(t *testing.T) { + for optN, o := range testOpts() { + for _, size := range testSizes() { + data, parity := size[0], size[1] + rng := rand.New(rand.NewSource(0xabadc0cac01a)) + t.Run(fmt.Sprintf("opt-%d-%dx%d", optN, data, parity), func(t *testing.T) { + perShard := 16384 + 1 + if testing.Short() { + perShard = 1024 + 1 + } + r, err := New(data, parity, testOptions(o...)...) + if err != nil { + t.Fatal(err) + } + x := r.(Extensions) + if want, got := data, x.DataShards(); want != got { + t.Errorf("DataShards returned %d, want %d", got, want) + } + if want, got := parity, x.ParityShards(); want != got { + t.Errorf("ParityShards returned %d, want %d", got, want) + } + if want, got := parity+data, x.TotalShards(); want != got { + t.Errorf("TotalShards returned %d, want %d", got, want) + } + mul := x.ShardSizeMultiple() + if mul <= 0 { + t.Fatalf("Got unexpected ShardSizeMultiple: %d", mul) + } + perShard = ((perShard + mul - 1) / mul) * mul + runs := 10 + if testing.Short() { + runs = 2 + } + for i := 0; i < runs; i++ { + shards := AllocAligned(data+parity, perShard) + + err = r.Encode(shards) + if err != nil { + t.Fatal(err) + } + ok, err := r.Verify(shards) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Fatal("Verification failed") + } + + if parity == 0 { + // Check that Reconstruct and ReconstructData do nothing + err = r.ReconstructData(shards) + if err != nil { + t.Fatal(err) + } + err = r.Reconstruct(shards) + if err != nil { + t.Fatal(err) + } + + // Skip integrity checks + continue + } + + // Delete one in data + idx := rng.Intn(data) + want := shards[idx] + shards[idx] = nil + + err = r.ReconstructData(shards) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(shards[idx], want) { + t.Fatal("did not ReconstructData correctly") + } + + // Delete one randomly + idx = rng.Intn(data + parity) + want = shards[idx] + shards[idx] = nil + err = r.Reconstruct(shards) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(shards[idx], want) { + t.Fatal("did not Reconstruct correctly") + } + + err = r.Encode(make([][]byte, 1)) + if err != ErrTooFewShards { + t.Errorf("expected %v, got %v", ErrTooFewShards, err) + } + + // Make one too short. + shards[idx] = shards[idx][:perShard-1] + err = r.Encode(shards) + if err != ErrShardSize { + t.Errorf("expected %v, got %v", ErrShardSize, err) + } + } + }) + } + } +}