diff --git a/src/format.rs b/src/format.rs index 2d5fa40..d44fe2b 100644 --- a/src/format.rs +++ b/src/format.rs @@ -15,10 +15,23 @@ use enum_map::{enum_map, EnumMap}; use once_cell::sync::OnceCell; +use std::ops::AddAssign; use uninit::prelude::*; use crate::message::{Message, MessageType}; +// Accumulator that panics on overflow +struct Accumulator(usize); + +impl AddAssign for Accumulator { + fn add_assign(&mut self, rhs: usize) { + self.0 = self + .0 + .checked_add(rhs) + .expect("message size should not exceed usize::MAX"); + } +} + impl Message where N: AsRef<[u8]>, @@ -69,7 +82,7 @@ where /// /// # Safety /// - /// The target must have size of at least [write_size]. + /// The target must have size of at least [write_size](Self::write_size). pub unsafe fn write_out(&self, mut target: Out<[u8]>) { target = Self::append_byte(target, Self::type_symbol(self.mtype)); target = Self::append_bytes(target, self.name.as_ref()); @@ -100,26 +113,31 @@ where let _ = Self::append_byte(target, b'\n'); } - /// Get the number of bytes needed by [write]. + /// Get the number of bytes needed by [write_out](Self::write_out). + /// + /// # Panics + /// + /// This function will panic if the size overflows [usize]. pub fn write_size(&self) -> usize { - // Type symbol, name, spaces and newline - let mut bytes = 2 + self.name.as_ref().len() + self.arguments.len(); + let mut bytes = Accumulator(2); // name and newline + bytes += self.name.as_ref().len(); + bytes += self.arguments.len(); // spaces between arguments if let Some(mid) = self.mid { let mut buffer = itoa::Buffer::new(); let mid_formatted = buffer.format(mid); - bytes += 2 + mid_formatted.len(); + bytes += 2 + mid_formatted.len(); // 2 for the brackets } let emap = Self::escape_map(); for argument in self.arguments.iter() { let argument = argument.as_ref(); if argument.is_empty() { bytes += 2; // For the \@ - } - for c in argument.iter() { - bytes += if emap[*c] != 0 { 2 } else { 1 }; + } else { + bytes += argument.len(); + bytes += argument.iter().filter(|&&c| emap[c] != 0).count(); // escapes } } - bytes + bytes.0 } pub fn write_size_callback(&self) -> (usize, impl Fn(Out<[u8]>) + '_) { @@ -146,3 +164,36 @@ where vec } } + +#[cfg(test)] +mod test { + use super::*; + + /// Create a Message that requires more than usize bytes. + //#[cfg(not(debug_assertions))] // Too slow in debug builds + #[test] + #[should_panic(expected = "message size should not exceed usize::MAX")] + fn overflow_size() { + /// Zero-size structure that can be used as a message argument + #[derive(Copy, Clone)] + struct ZeroSizeArgument; + + impl AsRef<[u8]> for ZeroSizeArgument { + fn as_ref(&self) -> &[u8] { + b"argument".as_slice() + } + } + + // We need a way to construct the giant vector without iterating + // over the elements (release builds will optimise away the useless + // loop, but debug builds do not). Since there is no memory to + // initialize, set_len should be safe. + let mut arguments = vec![]; + unsafe { + arguments.set_len(usize::MAX - 5); + } + let message: Message<&[u8], ZeroSizeArgument> = + Message::new(MessageType::Request, &b"big message"[..], None, arguments); + message.write_size(); + } +} diff --git a/src/parse.rs b/src/parse.rs index ba8cd0f..4ed83d0 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -383,7 +383,7 @@ impl Parser { /// Number of bytes currently buffered for an incomplete line. /// - /// This is capped at [max_line_length], even if a longer (overflowing) + /// This is capped at `Self::max_line_length`, even if a longer (overflowing) /// line is in progress. pub fn buffer_size(&self) -> usize { self.line_length