diff --git a/Cargo.toml b/Cargo.toml index d7cfa33..8ae7ad3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "async-io-typed" -version = "1.0.2" +version = "1.0.3" edition = "2021" license = "MIT OR Apache-2.0" description = "Adapts any AsyncRead or AsyncWrite type to send serde compatible types" @@ -21,4 +21,6 @@ thiserror = "1.0.37" [dev-dependencies] rand = "0.8" -tokio = { version = "1.22.0", features = ["rt-multi-thread"]} \ No newline at end of file +tokio = { version = "1.22.0", features = ["rt-multi-thread", "sync", "macros", "time"]} +tokio-util = "0.7" +futures-util = { version = "0.3", features = ["io"] } \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 8609561..894cc6a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,6 +27,9 @@ use futures_io::{AsyncRead, AsyncWrite}; use futures_util::{stream::Stream, Sink, SinkExt}; use serde::{de::DeserializeOwned, Serialize}; +#[cfg(test)] +mod tests; + const U16_MARKER: u8 = 252; const U32_MARKER: u8 = 253; const U64_MARKER: u8 = 254; @@ -313,24 +316,22 @@ impl AsyncReadTyp let mut buf = [0; 8]; let accumulated = *len_in_progress_assigned as usize; let slice = match len_read_mode { - LenReadMode::U16 => &mut buf[0..(2 - accumulated)], - LenReadMode::U32 => &mut buf[0..(4 - accumulated)], - LenReadMode::U64 => &mut buf[0..(8 - accumulated)], + LenReadMode::U16 => &mut buf[accumulated..2], + LenReadMode::U32 => &mut buf[accumulated..4], + LenReadMode::U64 => &mut buf[accumulated..8], }; let len = futures_core::ready!(Pin::new(&mut raw).poll_read(cx, slice))?; - len_in_progress[accumulated..(accumulated + slice.len())] + len_in_progress[accumulated..(accumulated + len)] .copy_from_slice(&slice[..len]); *len_in_progress_assigned += len as u8; if len == slice.len() { let new_len = match len_read_mode { LenReadMode::U16 => u16::from_le_bytes( (&len_in_progress[0..2]).try_into().expect("infallible"), - ) - as u64, + ) as u64, LenReadMode::U32 => u32::from_le_bytes( (&len_in_progress[0..4]).try_into().expect("infallible"), - ) - as u64, + ) as u64, LenReadMode::U64 => u64::from_le_bytes(*len_in_progress), }; if new_len > size_limit { @@ -344,7 +345,9 @@ impl AsyncReadTyp } AsyncReadState::ReadingItem { ref mut len_read } => { while *len_read < item_buffer.len() { - let len = futures_core::ready!(Pin::new(&mut raw).poll_read(cx, &mut item_buffer[*len_read..]))?; + let len = futures_core::ready!( + Pin::new(&mut raw).poll_read(cx, &mut item_buffer[*len_read..]) + )?; *len_read += len; if *len_read == item_buffer.len() { break; @@ -386,7 +389,8 @@ enum AsyncWriteState { Idle, WritingLen { current_len: [u8; 9], - len_to_be_sent: u8, + len_to_be_sent: usize, + len_sent: usize, }, WritingValue { bytes_sent: usize, @@ -513,15 +517,19 @@ impl AsyncWriteT }; *state = AsyncWriteState::WritingLen { current_len: new_current_len, - len_to_be_sent: to_be_sent as u8, + len_to_be_sent: to_be_sent, + len_sent: 0, }; - let len = futures_core::ready!(Pin::new(&mut *raw).poll_write(cx, &new_current_len[0..to_be_sent]))?; + let len = futures_core::ready!( + Pin::new(&mut *raw).poll_write(cx, &new_current_len[0..to_be_sent]) + )?; *state = if len == to_be_sent { AsyncWriteState::WritingValue { bytes_sent: 0 } } else { AsyncWriteState::WritingLen { current_len: new_current_len, - len_to_be_sent: (to_be_sent - len) as u8, + len_to_be_sent: to_be_sent, + len_sent: len, } }; continue; @@ -533,20 +541,22 @@ impl AsyncWriteT } } AsyncWriteState::WritingLen { - current_len, - len_to_be_sent, + ref current_len, + ref len_to_be_sent, + ref mut len_sent, } => { let len = futures_core::ready!(Pin::new(&mut *raw) - .poll_write(cx, ¤t_len[0..(*len_to_be_sent as usize)]))?; - if len == *len_to_be_sent as usize { + .poll_write(cx, ¤t_len[(*len_sent)..(*len_to_be_sent)]))?; + *len_sent += len; + if *len_sent == *len_to_be_sent { *state = AsyncWriteState::WritingValue { bytes_sent: 0 }; - } else { - *len_to_be_sent -= len as u8; } continue; } AsyncWriteState::WritingValue { bytes_sent } => { - let len = futures_core::ready!(Pin::new(&mut *raw).poll_write(cx, &write_buffer[*bytes_sent..]))?; + let len = futures_core::ready!( + Pin::new(&mut *raw).poll_write(cx, &write_buffer[*bytes_sent..]) + )?; *bytes_sent += len; if *bytes_sent == write_buffer.len() { *state = AsyncWriteState::Idle; @@ -564,7 +574,7 @@ impl AsyncWriteT } else { continue; } - }, + } AsyncWriteState::Closed => Poll::Ready(Ok(None)), }; } diff --git a/src/tests.rs b/src/tests.rs new file mode 100644 index 0000000..4024bf9 --- /dev/null +++ b/src/tests.rs @@ -0,0 +1,421 @@ +// Copyright 2022 Jacob Kiesel +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; + +use futures_io::{AsyncRead, AsyncWrite}; +use futures_util::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::sync::mpsc::{self, Receiver}; +use tokio_util::sync::PollSender; + +// What follows is an intentionally obnoxious `AsyncRead` and `AsyncWrite` implementation. Please don't use this outside of tests. +struct BasicChannelSender { + max_size_per_write: usize, + sender: PollSender>, +} + +impl AsyncWrite for BasicChannelSender { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if futures_core::ready!(self.sender.poll_reserve(cx)).is_err() { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::ConnectionAborted, + "remote hung up", + ))); + } + let write_len = self.max_size_per_write.min(buf.len()); + self.sender + .send_item((&buf[..write_len]).to_vec()) + .expect("receiver hung up!"); + Poll::Ready(Ok(write_len)) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + self.sender.close(); + Poll::Ready(Ok(())) + } +} + +struct BasicChannelReceiver { + receiver: Receiver>, + last_received: Vec, +} + +impl AsyncRead for BasicChannelReceiver { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let mut len_written = 0; + loop { + if self.last_received.len() > 0 { + let copy_len = self.last_received.len().min(buf.len() - len_written); + buf[len_written..(len_written + copy_len)] + .copy_from_slice(&self.last_received[..copy_len]); + self.last_received = self.last_received.split_off(copy_len); + len_written += copy_len; + if len_written == buf.len() { + return Poll::Ready(Ok(buf.len())); + } + } else { + self.last_received = match self.receiver.poll_recv(cx) { + Poll::Ready(Some(v)) => v, + Poll::Ready(None) => { + return if len_written > 0 { + Poll::Ready(Ok(len_written)) + } else { + Poll::Pending + } + } + Poll::Pending => { + return if len_written > 0 { + Poll::Ready(Ok(len_written)) + } else { + Poll::Pending + } + } + } + } + } + } +} + +fn basic_channel(max_size_per_write: usize) -> (BasicChannelSender, BasicChannelReceiver) { + let (sender, receiver) = mpsc::channel(32); + ( + BasicChannelSender { + sender: PollSender::new(sender), + max_size_per_write, + }, + BasicChannelReceiver { + receiver, + last_received: Vec::new(), + }, + ) +} + +// This tests our testing equipment, just makes sure the above implementations are correct. +#[tokio::test(flavor = "multi_thread")] +async fn basic_channel_test() { + for i in (1..10).chain(Some(usize::MAX)) { + { + let (mut sender, mut receiver) = basic_channel(i); + let message = "Hello World!".as_bytes(); + let mut read_buf = vec![0; message.len()]; + let write = tokio::spawn(async move { sender.write_all(message).await }); + tokio::time::timeout(Duration::from_secs(2), receiver.read_exact(&mut read_buf)) + .await + .unwrap() + .unwrap(); + write.await.unwrap().unwrap(); + assert_eq!(message, read_buf); + } + { + let (sender, mut receiver) = basic_channel(i); + let mut sender = Some(sender); + for _ in 0..10 { + let message = (0..255).collect::>(); + let mut read_buf = vec![0; message.len()]; + let message_clone = message.clone(); + let mut sender_inner = sender.take().unwrap(); + let write = tokio::spawn(async move { + sender_inner.write_all(&message_clone).await.unwrap(); + sender_inner + }); + tokio::time::timeout(Duration::from_secs(2), receiver.read_exact(&mut read_buf)) + .await + .unwrap() + .unwrap(); + sender = Some(write.await.unwrap()); + assert_eq!(message, read_buf); + } + } + } +} + +use std::mem; + +use bincode::Options; +use futures_util::{SinkExt, StreamExt}; +use tokio::task::JoinHandle; + +use super::*; + +fn start_send_helper( + mut s: AsyncWriteTyped, + value: T, +) -> JoinHandle<(AsyncWriteTyped, Result<(), Error>)> { + tokio::spawn(async move { + let ret = s.send(value).await; + (s, ret) + }) +} + +fn make_channel( + max_size_per_write: usize, +) -> ( + Option>, + AsyncReadTyped, +) { + let (sender, receiver) = basic_channel(max_size_per_write); + ( + Some(AsyncWriteTyped::new(sender)), + AsyncReadTyped::new(receiver), + ) +} + +// Copy paste of the options from async-io-typed. +fn bincode_options(size_limit: u64) -> impl Options { + // Two of these are defaults, so you might say this is over specified. I say it's future proof, as + // bincode default changes won't introduce accidental breaking changes. + bincode::DefaultOptions::new() + .with_limit(size_limit) + .with_little_endian() + .with_varint_encoding() + .reject_trailing_bytes() +} + +fn interesting_sizes() -> impl Iterator { + (1..=3).chain(8..=10).chain(Some(1024usize.pow(2))) +} + +#[tokio::test] +async fn zero_len_message() { + for size in interesting_sizes() { + let (mut server_stream, mut client_stream) = make_channel(size); + let fut = start_send_helper(server_stream.take().unwrap(), ()); + client_stream.next().await.unwrap().unwrap(); + fut.await.unwrap().1.unwrap(); + } +} + +#[tokio::test] +async fn zero_len_messages() { + for size in interesting_sizes() { + let (mut server_stream, mut client_stream) = make_channel(size); + for _ in 0..100 { + let fut = start_send_helper(server_stream.take().unwrap(), ()); + client_stream.next().await.unwrap().unwrap(); + let (stream, result) = fut.await.unwrap(); + server_stream = Some(stream); + result.unwrap(); + } + } +} + +#[tokio::test] +async fn hello_world() { + for size in interesting_sizes() { + let (mut server_stream, mut client_stream) = make_channel(size); + let message = "Hello, world!".as_bytes().to_vec(); + let fut = start_send_helper(server_stream.take().unwrap(), message.clone()); + assert_eq!(client_stream.next().await.unwrap().unwrap(), message); + fut.await.unwrap().1.unwrap(); + } +} + +#[tokio::test] +async fn shutdown_after_hello_world() { + for size in interesting_sizes() { + let (mut server_stream, mut client_stream) = make_channel(size); + let message = "Hello, world!".as_bytes().to_vec(); + let fut = start_send_helper(server_stream.take().unwrap(), message.clone()); + assert_eq!(client_stream.next().await.unwrap().unwrap(), message); + let (server_stream, result) = fut.await.unwrap(); + result.unwrap(); + mem::drop(server_stream); + let next = client_stream.next().await; + assert!(next.is_none(), "{next:?} was not none"); + } +} + +#[tokio::test] +async fn hello_worlds() { + for size in interesting_sizes() { + let (mut server_stream, mut client_stream) = make_channel(size); + for i in 0..100 { + let message = format!("Hello, world {}!", i).into_bytes(); + let fut = start_send_helper(server_stream.take().unwrap(), message.clone()); + assert_eq!(client_stream.next().await.unwrap().unwrap(), message); + let (stream, result) = fut.await.unwrap(); + server_stream = Some(stream); + result.unwrap(); + } + } +} + +#[tokio::test] +async fn u16_marker_len_message() { + for size in interesting_sizes() { + let bincode_config = bincode_options(1024); + + let (mut server_stream, mut client_stream) = make_channel(size); + let message = (0..248).map(|_| 1).chain(Some(300)).collect::>(); + assert_eq!(bincode_config.serialize(&message).unwrap().len(), 252); + let fut = start_send_helper(server_stream.take().unwrap(), message.clone()); + assert_eq!(client_stream.next().await.unwrap().unwrap(), message); + fut.await.unwrap().1.unwrap(); + } +} + +#[tokio::test] +async fn u32_marker_len_message() { + for size in interesting_sizes() { + let bincode_config = bincode_options(1024); + let (mut server_stream, mut client_stream) = make_channel(size); + let message = (0..249).map(|_| 1).chain(Some(300)).collect::>(); + assert_eq!(bincode_config.serialize(&message).unwrap().len(), 253); + let fut = start_send_helper(server_stream.take().unwrap(), message.clone()); + assert_eq!(client_stream.next().await.unwrap().unwrap(), message); + fut.await.unwrap().1.unwrap(); + } +} + +#[tokio::test] +async fn u64_marker_len_message() { + for size in interesting_sizes() { + let bincode_config = bincode_options(1024); + let (mut server_stream, mut client_stream) = make_channel(size); + let message = (0..251).map(|_| 240u8).collect::>(); + assert_eq!(bincode_config.serialize(&message).unwrap().len(), 254); + let fut = start_send_helper(server_stream.take().unwrap(), message.clone()); + assert_eq!(client_stream.next().await.unwrap().unwrap(), message); + fut.await.unwrap().1.unwrap(); + } +} + +#[tokio::test] +async fn zst_marker_len_message() { + for size in interesting_sizes() { + let bincode_config = bincode_options(1024); + let (mut server_stream, mut client_stream) = make_channel(size); + let message = (0u8..252).collect::>(); + assert_eq!(bincode_config.serialize(&message).unwrap().len(), 255); + let fut = start_send_helper(server_stream.take().unwrap(), message.clone()); + assert_eq!(client_stream.next().await.unwrap().unwrap(), message); + fut.await.unwrap().1.unwrap(); + } +} + +#[tokio::test] +async fn random_len_test() { + use rand::Rng; + for size in interesting_sizes() { + let (mut server_stream, mut client_stream) = make_channel(size); + for _ in 0..800 { + let message = (0..(rand::thread_rng().gen_range(0..(u8::MAX as u32 + 1) / 4))) + .collect::>(); + let fut = start_send_helper(server_stream.take().unwrap(), message.clone()); + assert_eq!(client_stream.next().await.unwrap().unwrap(), message); + let (stream, result) = fut.await.unwrap(); + server_stream = Some(stream); + result.unwrap(); + } + } +} + +#[tokio::test] +async fn u16_len_message() { + for size in interesting_sizes() { + let (mut server_stream, mut client_stream) = make_channel(size); + let message = (0..(u8::MAX as u16 + 1) / 2).collect::>(); + let fut = start_send_helper(server_stream.take().unwrap(), message.clone()); + assert_eq!(client_stream.next().await.unwrap().unwrap(), message); + fut.await.unwrap().1.unwrap(); + } +} + +#[tokio::test] +async fn u16_len_messages() { + for size in interesting_sizes() { + let (mut server_stream, mut client_stream) = make_channel(size); + for _ in 0..10 { + let message = (0..(u8::MAX as u16 + 1) / 2) + .map(|_| 258u16) + .collect::>(); + let fut = start_send_helper(server_stream.take().unwrap(), message.clone()); + assert_eq!(client_stream.next().await.unwrap().unwrap(), message); + let (stream, result) = fut.await.unwrap(); + server_stream = Some(stream); + result.unwrap(); + } + } +} + +#[tokio::test] +async fn u32_len_message() { + for size in interesting_sizes() { + let (mut server_stream, mut client_stream) = make_channel(size); + let message = (0..(u16::MAX as u32 + 1) / 4).collect::>(); + let fut = start_send_helper(server_stream.take().unwrap(), message.clone()); + assert_eq!(client_stream.next().await.unwrap().unwrap(), message); + fut.await.unwrap().1.unwrap(); + } +} + +#[tokio::test] +async fn u32_len_messages() { + for size in interesting_sizes() { + let (mut server_stream, mut client_stream) = make_channel(size); + for _ in 0..10 { + let message = (0..(u16::MAX as u32 + 1) / 4).collect::>(); + let fut = start_send_helper(server_stream.take().unwrap(), message.clone()); + assert_eq!(client_stream.next().await.unwrap().unwrap(), message); + let (stream, result) = fut.await.unwrap(); + server_stream = Some(stream); + result.unwrap(); + } + } +} + +// It takes a ridiculous amount of time to run the u64 tests +#[ignore] +#[tokio::test] +async fn u64_len_message() { + for size in interesting_sizes() { + let (mut server_stream, mut client_stream) = make_channel(size); + let message = (0..(u32::MAX as u64 + 1) / 8).collect::>(); + let fut = start_send_helper(server_stream.take().unwrap(), message.clone()); + assert_eq!(client_stream.next().await.unwrap().unwrap(), message); + fut.await.unwrap().1.unwrap(); + } +} + +#[ignore] +#[tokio::test] +async fn u64_len_messages() { + for size in interesting_sizes() { + let (mut server_stream, mut client_stream) = make_channel(size); + for _ in 0..10 { + let message = (0..(u32::MAX as u64 + 1) / 8).collect::>(); + let fut = start_send_helper(server_stream.take().unwrap(), message.clone()); + assert_eq!(client_stream.next().await.unwrap().unwrap(), message); + let (stream, result) = fut.await.unwrap(); + server_stream = Some(stream); + result.unwrap(); + } + } +}