diff --git a/s2/reader.go b/s2/reader.go index daaa33fbb0..46ead58fe0 100644 --- a/s2/reader.go +++ b/s2/reader.go @@ -104,6 +104,8 @@ func ReaderIgnoreStreamIdentifier() ReaderOption { // For each chunk with the ID, the callback is called with the content. // Any returned non-nil error will abort decompression. // Only one callback per ID is supported, latest sent will be used. +// You can peek the stream, triggering the callback, by doing a Read with a 0 +// byte buffer. func ReaderSkippableCB(id uint8, fn func(r io.Reader) error) ReaderOption { return func(r *Reader) error { if id < 0x80 || id > 0xfd { @@ -1053,6 +1055,8 @@ func (r *Reader) ReadByte() (byte, error) { // Any returned non-nil error will abort decompression. // Only one callback per ID is supported, latest sent will be used. // Sending a nil function will disable previous callbacks. +// You can peek the stream, triggering the callback, by doing a Read with a 0 +// byte buffer. func (r *Reader) SkippableCB(id uint8, fn func(r io.Reader) error) error { if id < 0x80 || id >= chunkTypePadding { return fmt.Errorf("ReaderSkippableCB: Invalid id provided, must be 0x80-0xfe (inclusive)") diff --git a/s2/reader_test.go b/s2/reader_test.go new file mode 100644 index 0000000000..4e05788494 --- /dev/null +++ b/s2/reader_test.go @@ -0,0 +1,45 @@ +// Copyright (c) 2019+ Klaus Post. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package s2 + +import ( + "bytes" + "io" + "testing" +) + +func TestLeadingSkippableBlock(t *testing.T) { + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.AddSkippableBlock(0x80, []byte("skippable block")); err != nil { + t.Fatalf("w.AddSkippableBlock: %v", err) + } + if _, err := w.Write([]byte("some data")); err != nil { + t.Fatalf("w.Write: %v", err) + } + if err := w.Close(); err != nil { + t.Fatalf("w.Close: %v", err) + } + r := NewReader(&buf) + var sb []byte + r.SkippableCB(0x80, func(sr io.Reader) error { + var err error + sb, err = io.ReadAll(sr) + return err + }) + if _, err := r.Read([]byte{}); err != nil { + t.Errorf("empty read failed: %v", err) + } + if !bytes.Equal(sb, []byte("skippable block")) { + t.Errorf("didn't get correct data from skippable block: %q", string(sb)) + } + data, err := io.ReadAll(r) + if err != nil { + t.Fatalf("r.Read: %v", err) + } + if !bytes.Equal(data, []byte("some data")) { + t.Errorf("didn't get correct compressed data: %q", string(data)) + } +}