Skip to content

Commit

Permalink
swap bool out for more descriptive new type on checksum API
Browse files Browse the repository at this point in the history
  • Loading branch information
Xaeroxe committed Feb 7, 2023
1 parent 6b522bd commit 3fc58ef
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 36 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "async-io-typed"
version = "2.0.0"
version = "3.0.0"
edition = "2021"
license = "MIT OR Apache-2.0"
description = "Adapts any AsyncRead or AsyncWrite type to send serde compatible types"
Expand Down
14 changes: 5 additions & 9 deletions src/duplex.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::read::{AsyncReadState, AsyncReadTyped, ChecksumReadState};
use crate::write::{AsyncWriteState, AsyncWriteTyped, MessageFeatures};
use crate::{Error, PROTOCOL_VERSION};
use crate::{ChecksumEnabled, Error, PROTOCOL_VERSION};
use futures_core::Stream;
use futures_io::{AsyncRead, AsyncWrite};
use futures_util::{Sink, SinkExt};
Expand Down Expand Up @@ -34,7 +34,7 @@ impl<RW: AsyncRead + AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin
/// if both the reader and the writer enable it. If either one disables it, then no checking is performed.**
///
/// Be careful, large size limits might create a vulnerability to a Denial of Service attack.
pub fn new_with_limit(rw: RW, size_limit: u64, checksum_enabled: bool) -> Self {
pub fn new_with_limit(rw: RW, size_limit: u64, checksum_enabled: ChecksumEnabled) -> Self {
Self {
rw: Some(rw),
read_state: AsyncReadState::ReadingVersion {
Expand All @@ -48,22 +48,18 @@ impl<RW: AsyncRead + AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin
},
write_buffer: Vec::new(),
primed_values: VecDeque::new(),
checksum_read_state: if checksum_enabled {
ChecksumReadState::Yes
} else {
ChecksumReadState::No
},
checksum_read_state: checksum_enabled.into(),
message_features: MessageFeatures {
size_limit,
checksum_enabled,
checksum_enabled: checksum_enabled.into(),
},
}
}

/// Creates a duplex typed reader and writer, initializing it with a default size limit of 1 MB per message.
/// Checksums are used to validate that messages arrived without corruption. **The checksum will only be used
/// if both the reader and the writer enable it. If either one disables it, then no checking is performed.**
pub fn new(rw: RW, checksum_enabled: bool) -> Self {
pub fn new(rw: RW, checksum_enabled: ChecksumEnabled) -> Self {
Self::new_with_limit(rw, 1024_u64.pow(2), checksum_enabled)
}

Expand Down
31 changes: 31 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,37 @@ pub enum Error {
ChecksumHandshakeFailed { checksum_value: u8 },
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum ChecksumEnabled {
Yes,
No,
}

impl From<bool> for ChecksumEnabled {
fn from(value: bool) -> Self {
if value {
ChecksumEnabled::Yes
} else {
ChecksumEnabled::No
}
}
}

impl From<ChecksumEnabled> for bool {
fn from(value: ChecksumEnabled) -> Self {
value == ChecksumEnabled::Yes
}
}

impl From<ChecksumEnabled> for ChecksumReadState {
fn from(value: ChecksumEnabled) -> Self {
match value {
ChecksumEnabled::Yes => ChecksumReadState::Yes,
ChecksumEnabled::No => ChecksumReadState::No,
}
}
}

fn bincode_options(size_limit: u64) -> impl Options {
// Two of these are defaults, so you might say this is over specified. I say it's future proof, as
// bincode default changes won't introduce accidental breaking changes.
Expand Down
14 changes: 5 additions & 9 deletions src/read.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
Error, CHECKSUM_DISABLED, CHECKSUM_ENABLED, PROTOCOL_VERSION, U16_MARKER, U32_MARKER,
U64_MARKER, ZST_MARKER,
ChecksumEnabled, Error, CHECKSUM_DISABLED, CHECKSUM_ENABLED, PROTOCOL_VERSION, U16_MARKER,
U32_MARKER, U64_MARKER, ZST_MARKER,
};
use bincode::Options;
use futures_core::Stream;
Expand Down Expand Up @@ -89,7 +89,7 @@ impl<R: AsyncRead + Unpin, T: Serialize + DeserializeOwned + Unpin> AsyncReadTyp
/// Creates a typed reader, initializing it with the given size limit specified in bytes.
///
/// Be careful, large limits might create a vulnerability to a Denial of Service attack.
pub fn new_with_limit(raw: R, size_limit: u64, checksum_enabled: bool) -> Self {
pub fn new_with_limit(raw: R, size_limit: u64, checksum_enabled: ChecksumEnabled) -> Self {
Self {
raw,
size_limit,
Expand All @@ -98,17 +98,13 @@ impl<R: AsyncRead + Unpin, T: Serialize + DeserializeOwned + Unpin> AsyncReadTyp
version_in_progress_assigned: 0,
},
item_buffer: Vec::new(),
checksum_read_state: if checksum_enabled {
ChecksumReadState::Yes
} else {
ChecksumReadState::No
},
checksum_read_state: checksum_enabled.into(),
_phantom: PhantomData,
}
}

/// Creates a typed reader, initializing it with a default size limit of 1 MB.
pub fn new(raw: R, checksum_enabled: bool) -> Self {
pub fn new(raw: R, checksum_enabled: ChecksumEnabled) -> Self {
Self::new_with_limit(raw, 1024u64.pow(2), checksum_enabled)
}

Expand Down
27 changes: 15 additions & 12 deletions src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ fn basic_channel(max_size_per_write: usize) -> (BasicChannelSender, BasicChannel
#[tokio::test]
async fn bad_protocol_version() {
let (mut sender, receiver) = basic_channel(1024);
let mut typed_receiver = AsyncReadTyped::<_, u8>::new(receiver, false);
let mut typed_receiver = AsyncReadTyped::<_, u8>::new(receiver, ChecksumEnabled::No);
// Intentionally send a message with a bad checksum
let sent_value = 5;
let mut message = Vec::from(0u64.to_le_bytes());
Expand All @@ -154,7 +154,7 @@ async fn bad_protocol_version() {
#[tokio::test]
async fn bad_checksum_enabled_value() {
let (mut sender, receiver) = basic_channel(1024);
let mut typed_receiver = AsyncReadTyped::<_, u8>::new(receiver, false);
let mut typed_receiver = AsyncReadTyped::<_, u8>::new(receiver, ChecksumEnabled::No);
// Intentionally send a message with a bad checksum
let sent_value = 5;
const BAD_CHECKSUM_ENABLED_VALUE: u8 = 42;
Expand All @@ -177,7 +177,7 @@ async fn bad_checksum_enabled_value() {
#[tokio::test]
async fn checksum_ignored() {
let (mut sender, receiver) = basic_channel(1024);
let mut typed_receiver = AsyncReadTyped::new(receiver, false);
let mut typed_receiver = AsyncReadTyped::new(receiver, ChecksumEnabled::No);
// Intentionally send a message with a bad checksum
let sent_value = 5;
let mut message = Vec::from(PROTOCOL_VERSION.to_le_bytes());
Expand All @@ -193,7 +193,7 @@ async fn checksum_ignored() {
#[tokio::test]
async fn checksum_unavailable() {
let (mut sender, receiver) = basic_channel(1024);
let mut typed_receiver = AsyncReadTyped::<_, u8>::new(receiver, true);
let mut typed_receiver = AsyncReadTyped::<_, u8>::new(receiver, ChecksumEnabled::Yes);
assert!(typed_receiver.checksum_enabled());
// Send two message without checksums.
const SENT_VALUE: u8 = 5;
Expand All @@ -218,7 +218,7 @@ async fn checksum_unavailable() {
#[tokio::test]
async fn checksum_used() {
let (mut sender, receiver) = basic_channel(1024);
let mut typed_receiver = AsyncReadTyped::<_, u8>::new(receiver, true);
let mut typed_receiver = AsyncReadTyped::<_, u8>::new(receiver, ChecksumEnabled::Yes);
// Intentionally send a message with a bad checksum
const SENT_VALUE: u8 = 5;
const SENT_VALUE_CHECKSUM: u64 = 10536747468361244917;
Expand Down Expand Up @@ -251,7 +251,7 @@ async fn checksum_used() {
#[tokio::test]
async fn checksum_unused() {
let (mut sender, receiver) = basic_channel(1024);
let mut typed_receiver = AsyncReadTyped::<_, u8>::new(receiver, true);
let mut typed_receiver = AsyncReadTyped::<_, u8>::new(receiver, ChecksumEnabled::Yes);
// Send two messages with no checksum
const SENT_VALUE: u8 = 5;
const SENT_VALUE_2: u8 = 20;
Expand All @@ -273,7 +273,7 @@ async fn checksum_unused() {
#[tokio::test]
async fn checksum_sent() {
let (sender, mut receiver) = basic_channel(1024);
let mut typed_sender = AsyncWriteTyped::new(sender, true);
let mut typed_sender = AsyncWriteTyped::new(sender, ChecksumEnabled::Yes);
const SENT_VALUE: u8 = 5;
const SENT_VALUE_CHECKSUM: u64 = 10536747468361244917;
typed_sender.send(SENT_VALUE).await.unwrap();
Expand All @@ -297,7 +297,7 @@ async fn checksum_sent() {
#[tokio::test]
async fn checksum_not_sent() {
let (sender, mut receiver) = basic_channel(1024);
let mut typed_sender = AsyncWriteTyped::new(sender, false);
let mut typed_sender = AsyncWriteTyped::new(sender, ChecksumEnabled::No);
const SENT_VALUE: u8 = 5;
typed_sender.send(SENT_VALUE).await.unwrap();
const SENT_VALUE_2: u8 = 20;
Expand Down Expand Up @@ -387,8 +387,8 @@ fn make_channel<T: DeserializeOwned + Serialize + Unpin>(
) {
let (sender, receiver) = basic_channel(max_size_per_write);
(
Some(AsyncWriteTyped::new(sender, sender_checksum_enabled)),
AsyncReadTyped::new(receiver, receiver_checksum_enabled),
Some(AsyncWriteTyped::new(sender, sender_checksum_enabled.into())),
AsyncReadTyped::new(receiver, receiver_checksum_enabled.into()),
)
}

Expand Down Expand Up @@ -511,10 +511,13 @@ async fn hello_world_tokio_tcp() {
.await
.unwrap()
.compat(),
true,
ChecksumEnabled::Yes,
);
let (server_stream, _address) = accept_fut.await.unwrap();
let mut server_stream = Some(DuplexStreamTyped::new(server_stream.compat_write(), true));
let mut server_stream = Some(DuplexStreamTyped::new(
server_stream.compat_write(),
ChecksumEnabled::Yes,
));
let message = "Hello, world!".as_bytes().to_vec();
let fut = start_send_helper(server_stream.take().unwrap(), message.clone());
assert_eq!(client_stream.next().await.unwrap().unwrap(), message);
Expand Down
10 changes: 5 additions & 5 deletions src/write.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
Error, CHECKSUM_DISABLED, CHECKSUM_ENABLED, PROTOCOL_VERSION, U16_MARKER, U32_MARKER,
U64_MARKER, ZST_MARKER,
ChecksumEnabled, Error, CHECKSUM_DISABLED, CHECKSUM_ENABLED, PROTOCOL_VERSION, U16_MARKER,
U32_MARKER, U64_MARKER, ZST_MARKER,
};
use bincode::Options;
use futures_io::AsyncWrite;
Expand Down Expand Up @@ -223,7 +223,7 @@ impl<W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> AsyncWriteT
/// if both the reader and the writer enable it. If either one disables it, then no checking is performed.**
///
/// Be careful, large size limits might create a vulnerability to a Denial of Service attack.
pub fn new_with_limit(raw: W, size_limit: u64, checksum_enabled: bool) -> Self {
pub fn new_with_limit(raw: W, size_limit: u64, checksum_enabled: ChecksumEnabled) -> Self {
Self {
raw: Some(raw),
write_buffer: Vec::new(),
Expand All @@ -233,7 +233,7 @@ impl<W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> AsyncWriteT
},
message_features: MessageFeatures {
size_limit,
checksum_enabled,
checksum_enabled: checksum_enabled.into(),
},
primed_values: VecDeque::new(),
}
Expand All @@ -242,7 +242,7 @@ impl<W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> AsyncWriteT
/// Creates a typed writer, initializing it with a default size limit of 1 MB per message.
/// Checksums are used to validate that messages arrived without corruption. **The checksum will only be used
/// if both the reader and the writer enable it. If either one disables it, then no checking is performed.**
pub fn new(raw: W, checksum_enabled: bool) -> Self {
pub fn new(raw: W, checksum_enabled: ChecksumEnabled) -> Self {
Self::new_with_limit(raw, 1024u64.pow(2), checksum_enabled)
}

Expand Down

0 comments on commit 3fc58ef

Please sign in to comment.