Skip to content

Commit

Permalink
fix(sockets): correctly handle pending writes to prevent double write
Browse files Browse the repository at this point in the history
  • Loading branch information
Threated committed Nov 8, 2024
1 parent 8a82238 commit 7875067
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 22 deletions.
3 changes: 3 additions & 0 deletions proxy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,6 @@ sockets = ["dep:chacha20poly1305", "dep:dashmap", "tokio-util/codec", "tokio-uti

[build-dependencies]
build-data = "0"

[dev-dependencies]
rand = "0.8.5"
57 changes: 35 additions & 22 deletions proxy/src/serve_sockets.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
use std::{
io::{self, Write},
ops::{Deref, DerefMut},
pin::Pin,
task::Poll,
time::{Duration, SystemTime, Instant}, sync::Arc,
io::{self, Write}, ops::{Deref, DerefMut}, pin::Pin, sync::{atomic::AtomicUsize, Arc}, task::Poll, time::{Duration, Instant, SystemTime}
};

use axum::{
Expand Down Expand Up @@ -33,7 +29,7 @@ use shared::{
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf, ReadHalf, WriteHalf};
use tokio_util::{
codec::{Decoder, Encoder, Framed, FramedRead, FramedWrite},
compat::{Compat, FuturesAsyncReadCompatExt},
compat::{Compat, FuturesAsyncReadCompatExt}, io::SinkWriter,
};
use tracing::{warn, debug};

Expand Down Expand Up @@ -296,7 +292,7 @@ impl DerefMut for SocketEncKey {
struct EncryptedSocket<S: AsyncRead + AsyncWrite> {
// inner: Framed<S, EncryptorCodec>,
read: Compat<IntoAsyncRead<FramedRead<ReadHalf<S>, DecryptorCodec>>>,
write: FramedWrite<WriteHalf<S>, EncryptorCodec>,
write: SinkWriter<FramedWrite<WriteHalf<S>, EncryptorCodec>>,
}

struct EncryptorCodec {
Expand Down Expand Up @@ -380,7 +376,8 @@ impl<'a> Buffer for EncBuffer<'a> {

// This should only be called when decrypting
fn truncate(&mut self, len: usize) {
self.buf.truncate(len)
warn!("Buffer got truncated. This should never happen as it should be perfectly sized");
self.buf.truncate(self.enc_idx + len)
}
}

Expand Down Expand Up @@ -415,7 +412,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> EncryptedSocket<S> {
let (r, w) = tokio::io::split(inner);
let read = FramedRead::new(r, DecryptorCodec { decryptor });
let read = read.into_async_read().compat();
let write = FramedWrite::new(w, EncryptorCodec { encryptor });
let write = SinkWriter::new(FramedWrite::new(w, EncryptorCodec { encryptor }));

Ok(Self { read, write })
}
Expand All @@ -437,27 +434,28 @@ impl<S: AsyncWrite + AsyncRead + Unpin> AsyncWrite for EncryptedSocket<S> {
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, io::Error>> {
self.write.send(buf).poll_unpin(cx).map_ok(|_| buf.len())
Pin::new(&mut self.write).poll_write(cx, buf)
}

fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), io::Error>> {
self.write.poll_flush_unpin(cx)
Pin::new(&mut self.write).poll_flush(cx)
}

fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), io::Error>> {
self.write.poll_close_unpin(cx)
Pin::new(&mut self.write).poll_shutdown(cx)
}
}

#[cfg(test)]
mod tests {
use chacha20poly1305::aead::stream::{Decryptor, Encryptor, EncryptorLE31};
use rand::Rng;
use tokio::net::{TcpListener, TcpStream};

use super::*;
Expand All @@ -466,20 +464,35 @@ mod tests {
async fn test_encryption() {
let mut key = GenericArray::default();
OsRng.fill_bytes(&mut key);
const N: usize = 2_usize.pow(13);
let test_data: &mut [u8; N] = &mut [0; N];
OsRng.fill_bytes(test_data);
let mut read_buf = [0; N];
let data: Arc<Vec<_>> = (0..13337).map(|_| {
let mut chunk = vec![0; OsRng.gen_range(1..9999)];
OsRng.fill_bytes(&mut chunk);
chunk
}).collect::<Vec<_>>().into();

start_test_broker().await;
let (mut client1, mut client2) = tokio::join!(client(&key), client(&key));

client1.write_all(test_data).await.unwrap();
client2.read_exact(&mut read_buf).await.unwrap();
assert_eq!(test_data, &read_buf);
client2.write_all(test_data).await.unwrap();
client1.read_exact(&mut read_buf).await.unwrap();
assert_eq!(test_data, &read_buf);
let data_cp = data.clone();
let a = tokio::spawn(async move {
for c in data_cp.iter() {
client1.write_all(&c).await.unwrap();
client1.flush().await.unwrap();
}
});
let data_cp = data.clone();
let b = tokio::spawn(async move {
for (i, c) in data_cp.iter().enumerate() {
let mut buf = vec![0; c.len()];
client2.read_exact(&mut buf).await.unwrap();
if &buf != c {
let mut remaining = Vec::new();
client2.read_to_end(&mut remaining).await.unwrap();
assert_eq!(&buf, c, "{i}: {remaining:?}");
}
}
});
tokio::try_join!(a, b).unwrap();
}

async fn start_test_broker() {
Expand Down

0 comments on commit 7875067

Please sign in to comment.