Skip to content

Commit

Permalink
Merge pull request #1 from Danaozhong/Bug/FixBitshiftOverflowPanic
Browse files Browse the repository at this point in the history
Fix bitshift overflow panic, and cargo fmt
  • Loading branch information
Danaozhong authored Sep 1, 2023
2 parents c9be2e0 + 0104a45 commit 9cc2599
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 18 deletions.
35 changes: 19 additions & 16 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#![no_std]
cfg_if::cfg_if!{
cfg_if::cfg_if! {
if #[cfg(feature = "std")] {
extern crate std;
use std::io::Result;
Expand All @@ -25,10 +25,9 @@ pub struct BitWriter {

// number of unwritten bits in cache
bits: u8,

}

impl BitWriter{
impl BitWriter {
pub fn new() -> BitWriter {
BitWriter {
data: Vec::new(),
Expand All @@ -38,31 +37,31 @@ impl BitWriter{
}
}
/// Read at most 8 bits into a u8.
pub fn write_u8(&mut self, v: u8, bit_count: u8) -> Result<()> {
pub fn write_u8(&mut self, v: u8, bit_count: u8) -> Result<()> {
self.write_unsigned_bits(v as u64, bit_count, 8)
}

pub fn write_u16(&mut self, v: u16, bit_count: u8) -> Result<()> {
pub fn write_u16(&mut self, v: u16, bit_count: u8) -> Result<()> {
self.write_unsigned_bits(v as u64, bit_count, 16)
}

pub fn write_u32(&mut self, v: u32, bit_count: u8) -> Result<()> {
pub fn write_u32(&mut self, v: u32, bit_count: u8) -> Result<()> {
self.write_unsigned_bits(v as u64, bit_count, 32)
}

pub fn write_u64(&mut self, v: u64, bit_count: u8) -> Result<()> {
self.write_unsigned_bits(v, bit_count, 64)
}

pub fn write_i8(&mut self, v: i8, bit_count: u8) -> Result<()> {
pub fn write_i8(&mut self, v: i8, bit_count: u8) -> Result<()> {
self.write_signed_bits(v as i64, bit_count, 8)
}

pub fn write_i16(&mut self, v: i16, bit_count: u8) -> Result<()> {
pub fn write_i16(&mut self, v: i16, bit_count: u8) -> Result<()> {
self.write_signed_bits(v as i64, bit_count, 16)
}

pub fn write_i32(&mut self, v: i32, bit_count: u8) -> Result<()> {
pub fn write_i32(&mut self, v: i32, bit_count: u8) -> Result<()> {
self.write_signed_bits(v as i64, bit_count, 32)
}

Expand Down Expand Up @@ -94,7 +93,7 @@ impl BitWriter{
self.skip(bits_to_skip)
}

pub fn write_signed_bits(&mut self, mut v: i64, n: u8, maximum_count: u8) -> Result<()> {
pub fn write_signed_bits(&mut self, mut v: i64, n: u8, maximum_count: u8) -> Result<()> {
if n == 0 {
return Ok(());
}
Expand All @@ -114,7 +113,12 @@ impl BitWriter{
return Err(Error::new(ErrorKind::Unsupported, "too many bits to write"));
}
// mask all upper bits out to be 0
v &= (1 << n) - 1;
if n == 64 {
// avoid bitshift overflow exception
v &= u64::MAX;
} else {
v &= (1 << n) - 1;
}

self.bit_count += n as u64;

Expand All @@ -125,7 +129,7 @@ impl BitWriter{
self.bits = new_bits;
return Ok(());
}

if new_bits >= 8 {
// write all bytes, by first taking the existing buffer, form a complete byte,
// and write that first.
Expand All @@ -140,13 +144,13 @@ impl BitWriter{
self.data.push((v >> n) as u8);
}
}

// Whatever is left is smaller than a byte, and will be put into the cache
self.cache = 0;
self.bits = n;
if n > 0 {
let mask = ((1<<n) as u8) - 1;
self.cache = ((v as u8) & mask) << (8-n);
let mask = ((1 << n) as u8) - 1;
self.cache = ((v as u8) & mask) << (8 - n);
}
Ok(())
}
Expand All @@ -165,5 +169,4 @@ impl BitWriter{
pub fn data(&self) -> &Vec<u8> {
&self.data
}

}
20 changes: 18 additions & 2 deletions src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@ fn simple_writing() {
let mut writer = BitWriter::new();

writer.write_bool(true).expect("failed to write bool");
writer.write_u32(178956970, 28).expect("failed to write u28");
writer.write_i32(-22369622, 28).expect("failed to write i28");
writer
.write_u32(178956970, 28)
.expect("failed to write u28");
writer
.write_i32(-22369622, 28)
.expect("failed to write i28");
assert_eq!(writer.bit_count, 1 + 28 + 28);

writer.close().expect("failed to close byte vector");
Expand All @@ -15,3 +19,15 @@ fn simple_writing() {
let expected = Vec::<u8>::from([0xD5, 0x55, 0x55, 0x57, 0x55, 0x55, 0x55, 0x00]);
assert_eq!(writer.data, expected);
}

#[test]
fn test_bitshift_overflow() {
let mut writer = BitWriter::new();
writer
.write_u64(0xFFFFFFFFFFFFFFFF, 64)
.expect("failed to u64");
writer.write_u64(0x0, 64).expect("failed to write u64");
writer.write_i64(0x0, 64).expect("failed to write i64");
assert_eq!(writer.bit_count, 3 * 64);
writer.close().expect("failed to close byte vector");
}

0 comments on commit 9cc2599

Please sign in to comment.