diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 3afb9469..4fdb4575 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -47,3 +47,6 @@ sockets = ["dep:chacha20poly1305", "dep:dashmap", "tokio-util/codec", "tokio-uti [build-dependencies] build-data = "0" + +[dev-dependencies] +rand = "0.8.5" diff --git a/proxy/src/serve_sockets.rs b/proxy/src/serve_sockets.rs index 1b310751..681cf164 100644 --- a/proxy/src/serve_sockets.rs +++ b/proxy/src/serve_sockets.rs @@ -1,9 +1,5 @@ use std::{ - io::{self, Write}, - ops::{Deref, DerefMut}, - pin::Pin, - task::Poll, - time::{Duration, SystemTime, Instant}, sync::Arc, + io::{self, Write}, ops::{Deref, DerefMut}, pin::Pin, sync::{atomic::AtomicUsize, Arc}, task::Poll, time::{Duration, Instant, SystemTime} }; use axum::{ @@ -33,7 +29,7 @@ use shared::{ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf, ReadHalf, WriteHalf}; use tokio_util::{ codec::{Decoder, Encoder, Framed, FramedRead, FramedWrite}, - compat::{Compat, FuturesAsyncReadCompatExt}, + compat::{Compat, FuturesAsyncReadCompatExt}, io::SinkWriter, }; use tracing::{warn, debug}; @@ -296,7 +292,7 @@ impl DerefMut for SocketEncKey { struct EncryptedSocket { // inner: Framed, read: Compat, DecryptorCodec>>>, - write: FramedWrite, EncryptorCodec>, + write: SinkWriter, EncryptorCodec>>, } struct EncryptorCodec { @@ -380,7 +376,8 @@ impl<'a> Buffer for EncBuffer<'a> { // This should only be called when decrypting fn truncate(&mut self, len: usize) { - self.buf.truncate(len) + warn!("Buffer got truncated. This should never happen as it should be perfectly sized"); + self.buf.truncate(self.enc_idx + len) } } @@ -415,7 +412,7 @@ impl EncryptedSocket { let (r, w) = tokio::io::split(inner); let read = FramedRead::new(r, DecryptorCodec { decryptor }); let read = read.into_async_read().compat(); - let write = FramedWrite::new(w, EncryptorCodec { encryptor }); + let write = SinkWriter::new(FramedWrite::new(w, EncryptorCodec { encryptor })); Ok(Self { read, write }) } @@ -437,27 +434,28 @@ impl AsyncWrite for EncryptedSocket { cx: &mut std::task::Context<'_>, buf: &[u8], ) -> std::task::Poll> { - self.write.send(buf).poll_unpin(cx).map_ok(|_| buf.len()) + Pin::new(&mut self.write).poll_write(cx, buf) } fn poll_flush( mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - self.write.poll_flush_unpin(cx) + Pin::new(&mut self.write).poll_flush(cx) } fn poll_shutdown( mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - self.write.poll_close_unpin(cx) + Pin::new(&mut self.write).poll_shutdown(cx) } } #[cfg(test)] mod tests { use chacha20poly1305::aead::stream::{Decryptor, Encryptor, EncryptorLE31}; + use rand::Rng; use tokio::net::{TcpListener, TcpStream}; use super::*; @@ -466,20 +464,35 @@ mod tests { async fn test_encryption() { let mut key = GenericArray::default(); OsRng.fill_bytes(&mut key); - const N: usize = 2_usize.pow(13); - let test_data: &mut [u8; N] = &mut [0; N]; - OsRng.fill_bytes(test_data); - let mut read_buf = [0; N]; + let data: Arc> = (0..13337).map(|_| { + let mut chunk = vec![0; OsRng.gen_range(1..9999)]; + OsRng.fill_bytes(&mut chunk); + chunk + }).collect::>().into(); start_test_broker().await; let (mut client1, mut client2) = tokio::join!(client(&key), client(&key)); - client1.write_all(test_data).await.unwrap(); - client2.read_exact(&mut read_buf).await.unwrap(); - assert_eq!(test_data, &read_buf); - client2.write_all(test_data).await.unwrap(); - client1.read_exact(&mut read_buf).await.unwrap(); - assert_eq!(test_data, &read_buf); + let data_cp = data.clone(); + let a = tokio::spawn(async move { + for c in data_cp.iter() { + client1.write_all(&c).await.unwrap(); + client1.flush().await.unwrap(); + } + }); + let data_cp = data.clone(); + let b = tokio::spawn(async move { + for (i, c) in data_cp.iter().enumerate() { + let mut buf = vec![0; c.len()]; + client2.read_exact(&mut buf).await.unwrap(); + if &buf != c { + let mut remaining = Vec::new(); + client2.read_to_end(&mut remaining).await.unwrap(); + assert_eq!(&buf, c, "{i}: {remaining:?}"); + } + } + }); + tokio::try_join!(a, b).unwrap(); } async fn start_test_broker() {