Skip to content

Commit

Permalink
refractor: flesh out semantics of IoVecBuffer[Mut]::read/write_at
Browse files Browse the repository at this point in the history
Before this patch series, IoVecBuffer[Mut] only ever had to deal with u8
slices, which inherently have an "all or nothing" semanticfor Read and
Write. Thus the semantics of read_at and write_at were "return None if
the given offset is too large this for IoVecBuffer[Mut], otherwise copy
either buf.len() or iovec.len() - offset bytes, whichever is less". This
commit changes the second part of this behavior to "copy either
buf.len() bytes, or fail", which is how these functions were used in
praxis. It also brings the semantics more in line with read and write
functions offered elsewhere in the standard library. For now, it still
only operates on u8-slices, but should the future need arise to
generalize to Read/WriteVolatile, all the pieces are there.

Signed-off-by: Patrick Roy <roypat@amazon.co.uk>
  • Loading branch information
roypat committed Dec 14, 2023
1 parent 332fb58 commit 98235b7
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 52 deletions.
120 changes: 87 additions & 33 deletions src/vmm/src/devices/virtio/iovec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,29 @@ impl IoVecBuffer {
///
/// # Returns
///
/// The number of bytes read (if any)
pub fn read_at(&self, mut buf: &mut [u8], offset: usize) -> Option<usize> {
/// `Ok(())` if `buf` was filled by reading from this [`IoVecBuffer`],
/// `Err(VolatileMemoryError::PartialBuffer)` if only part of `buf` could not be filled, and
/// `Err(VolatileMemoryError::OutOfBounds)` if `offset >= self.len()`.
pub fn read_exact_volatile_at(
&self,
mut buf: &mut [u8],
offset: usize,
) -> Result<(), VolatileMemoryError> {
if offset < self.len() {
// Make sure we only read up to the end of the `IoVecBuffer`.
let size = buf.len().min(self.len() - offset);
// write_volatile for &mut [u8] is infallible
self.read_volatile_at(&mut buf, offset, size).ok()
let expected = buf.len();
let bytes_read = self.read_volatile_at(&mut buf, offset, expected)?;

if bytes_read != expected {
return Err(VolatileMemoryError::PartialBuffer {
expected,
completed: bytes_read,
});
}

Ok(())
} else {
// If `offset` is past size, there's nothing to read.
None
Err(VolatileMemoryError::OutOfBounds { addr: offset })
}
}

Expand Down Expand Up @@ -223,15 +236,29 @@ impl IoVecBufferMut {
///
/// # Returns
///
/// The number of bytes written (if any)
pub fn write_at(&mut self, mut buf: &[u8], offset: usize) -> Option<usize> {
/// `Ok(())` if the entire contents of `buf` could be written to this [`IoVecBufferMut`],
/// `Err(VolatileMemoryError::PartialBuffer)` if only part of `buf` could be transferred, and
/// `Err(VolatileMemoryError::OutOfBounds)` if `offset >= self.len()`.
pub fn write_all_volatile_at(
&mut self,
mut buf: &[u8],
offset: usize,
) -> Result<(), VolatileMemoryError> {
if offset < self.len() {
// Make sure we only write up to the end of the `IoVecBufferMut`.
let size = buf.len().min(self.len() - offset);
self.write_volatile_at(&mut buf, offset, size).ok()
let expected = buf.len();
let bytes_written = self.write_volatile_at(&mut buf, offset, expected)?;

if bytes_written != expected {
return Err(VolatileMemoryError::PartialBuffer {
expected,
completed: bytes_written,
});
}

Ok(())
} else {
// We cannot write past the end of the `IoVecBufferMut`.
None
Err(VolatileMemoryError::OutOfBounds { addr: offset })
}
}

Expand Down Expand Up @@ -292,6 +319,7 @@ impl IoVecBufferMut {
#[cfg(test)]
mod tests {
use libc::{c_void, iovec};
use vm_memory::VolatileMemoryError;

use super::{IoVecBuffer, IoVecBufferMut};
use crate::devices::virtio::queue::{Queue, VIRTQ_DESC_F_NEXT, VIRTQ_DESC_F_WRITE};
Expand Down Expand Up @@ -397,7 +425,7 @@ mod tests {
let mem = default_mem();
let (mut q, _) = read_only_chain(&mem);
let head = q.pop(&mem).unwrap();
assert!(IoVecBuffer::from_descriptor_chain(head).is_ok());
IoVecBuffer::from_descriptor_chain(head).unwrap();

let (mut q, _) = write_only_chain(&mem);
let head = q.pop(&mem).unwrap();
Expand All @@ -409,7 +437,7 @@ mod tests {

let (mut q, _) = write_only_chain(&mem);
let head = q.pop(&mem).unwrap();
assert!(IoVecBufferMut::from_descriptor_chain(head).is_ok());
IoVecBufferMut::from_descriptor_chain(head).unwrap();
}

#[test]
Expand Down Expand Up @@ -440,28 +468,48 @@ mod tests {

let iovec = IoVecBuffer::from_descriptor_chain(head).unwrap();

let mut buf = vec![0; 257];
assert_eq!(iovec.read_at(&mut buf[..], 0), Some(256));
let mut buf = vec![0u8; 257];
assert_eq!(
iovec
.read_volatile_at(&mut buf.as_mut_slice(), 0, 257)
.unwrap(),
256
);
assert_eq!(buf[0..256], (0..=255).collect::<Vec<_>>());
assert_eq!(buf[256], 0);

let mut buf = vec![0; 5];
assert_eq!(iovec.read_at(&mut buf[..4], 0), Some(4));
iovec.read_exact_volatile_at(&mut buf[..4], 0).unwrap();
assert_eq!(buf, vec![0u8, 1, 2, 3, 0]);

assert_eq!(iovec.read_at(&mut buf, 0), Some(5));
iovec.read_exact_volatile_at(&mut buf, 0).unwrap();
assert_eq!(buf, vec![0u8, 1, 2, 3, 4]);

assert_eq!(iovec.read_at(&mut buf, 1), Some(5));
iovec.read_exact_volatile_at(&mut buf, 1).unwrap();
assert_eq!(buf, vec![1u8, 2, 3, 4, 5]);

assert_eq!(iovec.read_at(&mut buf, 60), Some(5));
iovec.read_exact_volatile_at(&mut buf, 60).unwrap();
assert_eq!(buf, vec![60u8, 61, 62, 63, 64]);

assert_eq!(iovec.read_at(&mut buf, 252), Some(4));
assert_eq!(
iovec
.read_volatile_at(&mut buf.as_mut_slice(), 252, 5)
.unwrap(),
4
);
assert_eq!(buf[0..4], vec![252u8, 253, 254, 255]);

assert_eq!(iovec.read_at(&mut buf, 256), None);
assert!(matches!(
iovec.read_exact_volatile_at(&mut buf, 252),
Err(VolatileMemoryError::PartialBuffer {
expected: 5,
completed: 4
})
));
assert!(matches!(
iovec.read_exact_volatile_at(&mut buf, 256),
Err(VolatileMemoryError::OutOfBounds { addr: 256 })
));
}

#[test]
Expand All @@ -482,10 +530,10 @@ mod tests {
let mut test_vec4 = vec![0u8; 64];

// Control test: Initially all three regions should be zero
assert_eq!(iovec.write_at(&test_vec1, 0), Some(64));
assert_eq!(iovec.write_at(&test_vec2, 64), Some(64));
assert_eq!(iovec.write_at(&test_vec3, 128), Some(64));
assert_eq!(iovec.write_at(&test_vec4, 192), Some(64));
iovec.write_all_volatile_at(&test_vec1, 0).unwrap();
iovec.write_all_volatile_at(&test_vec2, 64).unwrap();
iovec.write_all_volatile_at(&test_vec3, 128).unwrap();
iovec.write_all_volatile_at(&test_vec4, 192).unwrap();
vq.dtable[0].check_data(&test_vec1);
vq.dtable[1].check_data(&test_vec2);
vq.dtable[2].check_data(&test_vec3);
Expand All @@ -494,7 +542,7 @@ mod tests {
// Let's initialize test_vec1 with our buffer.
test_vec1[..buf.len()].copy_from_slice(&buf);
// And write just a part of it
assert_eq!(iovec.write_at(&buf[..3], 0), Some(3));
iovec.write_all_volatile_at(&buf[..3], 0).unwrap();
// Not all 5 bytes from buf should be written in memory,
// just 3 of them.
vq.dtable[0].check_data(&[0u8, 1, 2, 0, 0]);
Expand All @@ -503,7 +551,7 @@ mod tests {
vq.dtable[3].check_data(&test_vec4);
// But if we write the whole `buf` in memory then all
// of it should be observable.
assert_eq!(iovec.write_at(&buf, 0), Some(5));
iovec.write_all_volatile_at(&buf, 0).unwrap();
vq.dtable[0].check_data(&test_vec1);
vq.dtable[1].check_data(&test_vec2);
vq.dtable[2].check_data(&test_vec3);
Expand All @@ -512,7 +560,7 @@ mod tests {
// We are now writing with an offset of 1. So, initialize
// the corresponding part of `test_vec1`
test_vec1[1..buf.len() + 1].copy_from_slice(&buf);
assert_eq!(iovec.write_at(&buf, 1), Some(5));
iovec.write_all_volatile_at(&buf, 1).unwrap();
vq.dtable[0].check_data(&test_vec1);
vq.dtable[1].check_data(&test_vec2);
vq.dtable[2].check_data(&test_vec3);
Expand All @@ -523,7 +571,7 @@ mod tests {
// first region and one byte on the second
test_vec1[60..64].copy_from_slice(&buf[0..4]);
test_vec2[0] = 4;
assert_eq!(iovec.write_at(&buf, 60), Some(5));
iovec.write_all_volatile_at(&buf, 60).unwrap();
vq.dtable[0].check_data(&test_vec1);
vq.dtable[1].check_data(&test_vec2);
vq.dtable[2].check_data(&test_vec3);
Expand All @@ -535,14 +583,20 @@ mod tests {
// Now perform a write that does not fit in the buffer. Try writing
// 5 bytes at offset 252 (only 4 bytes left).
test_vec4[60..64].copy_from_slice(&buf[0..4]);
assert_eq!(iovec.write_at(&buf, 252), Some(4));
assert_eq!(
iovec.write_volatile_at(&mut &*buf, 252, buf.len()).unwrap(),
4
);
vq.dtable[0].check_data(&test_vec1);
vq.dtable[1].check_data(&test_vec2);
vq.dtable[2].check_data(&test_vec3);
vq.dtable[3].check_data(&test_vec4);

// Trying to add past the end of the buffer should not write anything
assert_eq!(iovec.write_at(&buf, 256), None);
assert!(matches!(
iovec.write_all_volatile_at(&buf, 256),
Err(VolatileMemoryError::OutOfBounds { addr: 256 })
));
vq.dtable[0].check_data(&test_vec1);
vq.dtable[1].check_data(&test_vec2);
vq.dtable[2].check_data(&test_vec3);
Expand Down
22 changes: 13 additions & 9 deletions src/vmm/src/devices/virtio/net/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -444,13 +444,15 @@ impl Net {
guest_mac: Option<MacAddr>,
net_metrics: &NetDeviceMetrics,
) -> Result<bool, NetError> {
// Read the frame headers from the IoVecBuffer. This will return None
// if the frame_iovec is empty.
let header_len = frame_iovec.read_at(headers, 0).ok_or_else(|| {
error!("Received empty TX buffer");
net_metrics.tx_malformed_frames.inc();
NetError::VnetHeaderMissing
})?;
// Read the frame headers from the IoVecBuffer
let max_header_len = headers.len();
let header_len = frame_iovec
.read_volatile_at(&mut &mut *headers, 0, max_header_len)
.map_err(|err| {
error!("Received malformed TX buffer: {:?}", err);
net_metrics.tx_malformed_frames.inc();
NetError::VnetHeaderMissing
})?;

let headers = frame_bytes_from_buf(&headers[..header_len]).map_err(|e| {
error!("VNET headers missing in TX frame");
Expand All @@ -463,7 +465,9 @@ impl Net {
let mut frame = vec![0u8; frame_iovec.len() - vnet_hdr_len()];
// Ok to unwrap here, because we are passing a buffer that has the exact size
// of the `IoVecBuffer` minus the VNET headers.
frame_iovec.read_at(&mut frame, vnet_hdr_len()).unwrap();
frame_iovec
.read_exact_volatile_at(&mut frame, vnet_hdr_len())
.unwrap();
let _ = ns.detour_frame(&frame);
METRICS.mmds.rx_accepted.inc();

Expand Down Expand Up @@ -1510,7 +1514,7 @@ pub mod tests {
let buffer = IoVecBuffer::from(&frame_buf[..frame_len]);

let mut headers = vec![0; frame_hdr_len()];
buffer.read_at(&mut headers, 0).unwrap();
buffer.read_exact_volatile_at(&mut headers, 0).unwrap();

// Call the code which sends the packet to the host or MMDS.
// Validate the frame was consumed by MMDS and that the metrics reflect that.
Expand Down
3 changes: 2 additions & 1 deletion src/vmm/src/devices/virtio/rng/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ impl Entropy {
})?;

// It is ok to unwrap here. We are writing `iovec.len()` bytes at offset 0.
Ok(iovec.write_at(&rand_bytes, 0).unwrap().try_into().unwrap())
iovec.write_all_volatile_at(&rand_bytes, 0).unwrap();
Ok(iovec.len().try_into().unwrap())
}

fn process_entropy_queue(&mut self) {
Expand Down
20 changes: 11 additions & 9 deletions src/vmm/src/devices/virtio/vsock/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
use std::fmt::Debug;

use vm_memory::volatile_memory::Error;
use vm_memory::{GuestMemoryError, ReadVolatile, WriteVolatile};

use super::{defs, VsockError};
Expand Down Expand Up @@ -126,9 +127,12 @@ impl VsockPacket {
let buffer = IoVecBuffer::from_descriptor_chain(chain)?;

let mut hdr = VsockPacketHeader::default();
let header_bytes_read = buffer.read_at(hdr.as_mut_slice(), 0).unwrap_or(0);
if header_bytes_read < VSOCK_PKT_HDR_SIZE as usize {
return Err(VsockError::DescChainTooShortForHeader(header_bytes_read));
match buffer.read_exact_volatile_at(hdr.as_mut_slice(), 0) {
Ok(()) => (),
Err(Error::PartialBuffer { completed, .. }) => {
return Err(VsockError::DescChainTooShortForHeader(completed))
}
Err(err) => return Err(VsockError::GuestMemoryMmap(err.into())),
}

if hdr.len > defs::MAX_PKT_BUF_SIZE {
Expand Down Expand Up @@ -190,12 +194,10 @@ impl VsockPacket {
return Err(VsockError::InvalidPktLen(self.hdr.len));
}

let bytes_written = buffer.write_at(self.hdr.as_slice(), 0);

// We check the the buffer has sufficient size in from_rx_virtq_head
debug_assert_eq!(bytes_written, Some(VSOCK_PKT_HDR_SIZE as usize));

Ok(())
buffer
.write_all_volatile_at(self.hdr.as_slice(), 0)
.map_err(GuestMemoryError::from)
.map_err(VsockError::GuestMemoryMmap)
}
}
}
Expand Down

0 comments on commit 98235b7

Please sign in to comment.