Skip to content

Commit

Permalink
zstd: Allow to ignore checksum checking (#572)
Browse files Browse the repository at this point in the history
Fixes #571
  • Loading branch information
WojciechMula authored Apr 28, 2022
1 parent 595e86d commit 75b1f22
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 17 deletions.
6 changes: 3 additions & 3 deletions zstd/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ func (d *Decoder) nextBlock(blocking bool) (ok bool) {
println("got", len(d.current.b), "bytes, error:", d.current.err, "data crc:", tmp)
}

if len(next.b) > 0 {
if !d.o.ignoreChecksum && len(next.b) > 0 {
n, err := d.current.crc.Write(next.b)
if err == nil {
if n != len(next.b) {
Expand All @@ -451,7 +451,7 @@ func (d *Decoder) nextBlock(blocking bool) (ok bool) {
got := d.current.crc.Sum64()
var tmp [4]byte
binary.LittleEndian.PutUint32(tmp[:], uint32(got))
if !bytes.Equal(tmp[:], next.d.checkCRC) && !ignoreCRC {
if !d.o.ignoreChecksum && !bytes.Equal(tmp[:], next.d.checkCRC) && !ignoreCRC {
if debugDecoder {
println("CRC Check Failed:", tmp[:], " (got) !=", next.d.checkCRC, "(on stream)")
}
Expand Down Expand Up @@ -534,7 +534,7 @@ func (d *Decoder) nextBlockSync() (ok bool) {
}

// Update/Check CRC
if d.frame.HasCheckSum {
if !d.o.ignoreChecksum && d.frame.HasCheckSum {
d.frame.crc.Write(d.current.b)
if d.current.d.Last {
d.current.err = d.frame.checkCRC()
Expand Down
9 changes: 9 additions & 0 deletions zstd/decoder_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type decoderOptions struct {
maxDecodedSize uint64
maxWindowSize uint64
dicts []dict
ignoreChecksum bool
}

func (o *decoderOptions) setDefault() {
Expand Down Expand Up @@ -112,3 +113,11 @@ func WithDecoderMaxWindow(size uint64) DOption {
return nil
}
}

// IgnoreChecksum allows to forcibly ignore checksum checking.
func IgnoreChecksum(b bool) DOption {
return func(o *decoderOptions) error {
o.ignoreChecksum = b
return nil
}
}
51 changes: 51 additions & 0 deletions zstd/decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1737,6 +1737,57 @@ func TestResetNil(t *testing.T) {
}
}

func TestIgnoreChecksum(t *testing.T) {
// zstd file containing text "compress\n" and has a xxhash checksum
zstdBlob := []byte{0x28, 0xb5, 0x2f, 0xfd, 0x24, 0x09, 0x49, 0x00, 0x00, 'C', 'o', 'm', 'p', 'r', 'e', 's', 's', '\n', 0x79, 0x6e, 0xe0, 0xd2}

// replace letter 'c' with 'C', so decoding should fail.
zstdBlob[9] = 'C'

{
// Check if the file is indeed incorrect
dec, err := NewReader(nil)
if err != nil {
t.Fatal(err)
}
defer dec.Close()

dec.Reset(bytes.NewBuffer(zstdBlob))

_, err = ioutil.ReadAll(dec)
if err == nil {
t.Fatal("Expected decoding error")
}

if !errors.Is(err, ErrCRCMismatch) {
t.Fatalf("Expected checksum error, got '%s'", err)
}
}

{
// Ignore CRC error and decompress the content
dec, err := NewReader(nil, IgnoreChecksum(true))
if err != nil {
t.Fatal(err)
}
defer dec.Close()

dec.Reset(bytes.NewBuffer(zstdBlob))

res, err := ioutil.ReadAll(dec)
if err != nil {
t.Fatalf("Unexpected error: '%s'", err)
}

want := []byte{'C', 'o', 'm', 'p', 'r', 'e', 's', 's', '\n'}
if !bytes.Equal(res, want) {
t.Logf("want: %s", want)
t.Logf("got: %s", res)
t.Fatalf("Wrong output")
}
}
}

func timeout(after time.Duration) (cancel func()) {
if isRaceTest {
return func() {}
Expand Down
50 changes: 36 additions & 14 deletions zstd/framedec.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,13 +290,6 @@ func (d *frameDec) checkCRC() error {
if !d.HasCheckSum {
return nil
}
var tmp [4]byte
got := d.crc.Sum64()
// Flip to match file order.
tmp[0] = byte(got >> 0)
tmp[1] = byte(got >> 8)
tmp[2] = byte(got >> 16)
tmp[3] = byte(got >> 24)

// We can overwrite upper tmp now
want, err := d.rawInput.readSmall(4)
Expand All @@ -305,6 +298,18 @@ func (d *frameDec) checkCRC() error {
return err
}

if d.o.ignoreChecksum {
return nil
}

var tmp [4]byte
got := d.crc.Sum64()
// Flip to match file order.
tmp[0] = byte(got >> 0)
tmp[1] = byte(got >> 8)
tmp[2] = byte(got >> 16)
tmp[3] = byte(got >> 24)

if !bytes.Equal(tmp[:], want) && !ignoreCRC {
if debugDecoder {
println("CRC Check Failed:", tmp[:], "!=", want)
Expand All @@ -317,6 +322,19 @@ func (d *frameDec) checkCRC() error {
return nil
}

// consumeCRC reads the checksum data if the frame has one.
func (d *frameDec) consumeCRC() error {
if d.HasCheckSum {
_, err := d.rawInput.readSmall(4)
if err != nil {
println("CRC missing?", err)
return err
}
}

return nil
}

// runDecoder will create a sync decoder that will decode a block of data.
func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) {
saved := d.history.b
Expand Down Expand Up @@ -373,13 +391,17 @@ func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) {
if d.FrameContentSize != fcsUnknown && uint64(len(d.history.b)-crcStart) != d.FrameContentSize {
err = ErrFrameSizeMismatch
} else if d.HasCheckSum {
var n int
n, err = d.crc.Write(dst[crcStart:])
if err == nil {
if n != len(dst)-crcStart {
err = io.ErrShortWrite
} else {
err = d.checkCRC()
if d.o.ignoreChecksum {
err = d.consumeCRC()
} else {
var n int
n, err = d.crc.Write(dst[crcStart:])
if err == nil {
if n != len(dst)-crcStart {
err = io.ErrShortWrite
} else {
err = d.checkCRC()
}
}
}
}
Expand Down

0 comments on commit 75b1f22

Please sign in to comment.