Skip to content

Commit

Permalink
refactor: clean up and improve tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Threated committed Sep 26, 2024
1 parent 082d2cf commit aae11f5
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 125 deletions.
17 changes: 9 additions & 8 deletions broker/src/serve_sockets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ async fn connect_socket(
// This Result is just an Either type. An error value does not mean something went wrong
) -> Result<Response, StatusCode> {
let mut body_stream = body.into_data_stream();
let (is_read, token, remaining) = read_header(&mut body_stream).await?;
let result = shared::crypto_jwt::verify_with_extended_header::<MsgEmpty>(&mut parts, String::from_utf8_lossy(&token).as_ref()).await;
let (is_read, jwt, remaining) = read_header(&mut body_stream).await?;
let result = shared::crypto_jwt::verify_with_extended_header::<MsgEmpty>(&mut parts, String::from_utf8_lossy(&jwt).as_ref()).await;
let msg = match result {
Ok(msg) => msg.msg,
Err(e) => return Ok(e.into_response()),
Expand Down Expand Up @@ -153,18 +153,19 @@ async fn connect_socket(
tx
},
};
let (tx, rx) = oneshot::channel::<()>();
let mut wrapped = Some(tx);
let (tx, write_done) = oneshot::channel::<()>();
let mut notify_write_done = Some(tx);
let send_res = sender.send(stream::once(futures_util::future::ready(Ok(remaining.freeze()))).chain(body_stream).chain(stream::poll_fn(move |_| {
_ = wrapped.take().unwrap().send(());
_ = notify_write_done.take().unwrap().send(());
std::task::Poll::Ready(None)
})).boxed());
let Ok(()) = send_res else {
warn!(%task_id, "Failed to send socket body. Reciever dropped");
return Err(StatusCode::GONE);
};
debug!("Write half over stream");
_ = rx.await;
debug!("Write half send the stream to the read half");
_ = write_done.await;
debug!("Write half has written all its data");
Err(StatusCode::OK)
}
}
Expand Down Expand Up @@ -203,6 +204,6 @@ async fn read_header(s: &mut BodyDataStream) -> Result<(bool, Bytes, BytesMut),
}
buf.put(packet);
}
dbg!("Not enough data?", state, buf);
warn!(?state, ?buf, "Failed to read header");
Err(StatusCode::BAD_GATEWAY)
}
143 changes: 32 additions & 111 deletions proxy/src/serve_sockets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{
};

use axum::{
body::Body, extract::{Path, Request, State}, http::{self, HeaderMap, StatusCode}, response::{IntoResponse, Response}, routing::{get, post}, Extension, Json, RequestPartsExt, Router
body::Body, extract::{Path, Request, State}, http::{self, header, HeaderMap, HeaderValue, StatusCode}, response::{IntoResponse, Response}, routing::{get, post}, Extension, Json, RequestPartsExt, Router
};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use crypto_secretstream::{Header, Key, PullStream, PushStream};
Expand Down Expand Up @@ -121,39 +121,30 @@ async fn create_socket_con(
warn!("Failed to serialize MsgSocketRequest");
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
};
let new_req = Request::post("/v1/sockets").body(axum::body::Body::from(body));
let post_socket_task_req = match new_req {
Ok(req) => req,
Err(e) => {
warn!("Failed to construct request: {e}");
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
let post_socket_task_req = Request::post("/v1/sockets").body(axum::body::Body::from(body)).unwrap();
let res = match forward_request(post_socket_task_req, &state.config, &sender, &state.client).await {
Ok(res) => res,
Err(err) => {
warn!("Failed to create post socket request: {err:?}");
return err.into_response();
}
};

let res =
match forward_request(post_socket_task_req, &state.config, &sender, &state.client).await {
Ok(res) => res,
Err(err) => {
warn!("Failed to create post socket request: {err:?}");
return err.into_response();
}
};

if res.status() != StatusCode::CREATED {
warn!(
"Failed to post MsgSocketRequest to broker. Statuscode: {}",
res.status()
);
return (res.status(), "Failed to post MsgSocketRequest to broker").into_response();
}
let req = match prepare_socket_request(sender, task_id, og_req.headers().clone(), &state).await {
let req = match prepare_socket_request(sender, task_id, &state).await {
Ok(req) => req,
Err(e) => return e,
};
// Connect write
let req = req.map(|b| {
let n = b.as_bytes().len();
let mut body = Vec::with_capacity(n + 5);
// This 0 signals write interest
body.push(0);
body.extend(u32::to_be_bytes(n as _));
body.extend(b.as_bytes());
Expand All @@ -163,31 +154,33 @@ async fn create_socket_con(
let Some(key) = task_secret_map.remove(&task_id) else {
return StatusCode::GONE.into_response();
};
let (parts, body) = req.into_parts();
let (mut parts, body) = req.into_parts();
parts.headers.append(header::TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
let stream = stream::once(ready(Ok(body))).chain(Encrypter::new(key).encrypt(og_req.into_body().into_data_stream()));
let req = Request::from_parts(parts, reqwest::Body::wrap_stream(stream));
let res = match state.client.execute(req.try_into().expect("Conversion to reqwest::Request should always work")).await {
Ok(res) => res,
Err(e) => return (StatusCode::BAD_GATEWAY, e.to_string()).into_response(),
};
let (parts, body) = http::Response::from(res).into_parts();
Response::from_parts(parts, Body::from_stream(reqwest::Response::from(http::Response::new(body)).bytes_stream()))
match state.client.execute(req.try_into().expect("Conversion to reqwest::Request should always work")).await {
Ok(res) => http::Response::try_from(res).expect("reqwest::Response to http::Response should always work").map(Body::new),
Err(e) => {
warn!("Failed to stream data to broker: {e}");
(StatusCode::BAD_GATEWAY, e.to_string()).into_response()
},
}
}

async fn connect_read(
AuthenticatedApp(sender): AuthenticatedApp,
state: State<TasksState>,
Extension(task_secret_map): Extension<MsgSecretMap>,
Path(task_id): Path<MsgId>,
req: Request,
) -> Response {
let req = match prepare_socket_request(sender, task_id, req.headers().clone(), &state).await {
let req = match prepare_socket_request(sender, task_id, &state).await {
Ok(value) => value,
Err(e) => return e,
};
let req = req.map(|b| {
let n = b.as_bytes().len();
let mut body = Vec::with_capacity(n + 5);
let mut body = Vec::with_capacity(n + 4 + 1);
// This 1 signals read interest
body.push(1);
body.extend(u32::to_be_bytes(n as _));
body.extend(b.as_bytes());
Expand All @@ -196,19 +189,18 @@ async fn connect_read(

let res = match state.client.execute(req.try_into().expect("Conversion to reqwest::Request should always work")).await {
Ok(res) => res,
Err(e) => return (StatusCode::BAD_GATEWAY, e.to_string()).into_response(),
Err(e) => {
warn!("Failed to read from broker: {e}");
return (StatusCode::BAD_GATEWAY, e.to_string()).into_response()
},
};
let Some(key) = task_secret_map.remove(&task_id) else {
return StatusCode::GONE.into_response();
};
let headers = res.headers().clone();
// let stream = Decrypter::new(key).decrypt(res.bytes_stream());
let mut res = Response::new(Body::from_stream(Decrypter::new(key).decrypt(res.bytes_stream())));
*res.headers_mut() = headers;
res
Response::new(Body::from_stream(Decrypter::new(key).decrypt(res.bytes_stream())))
}

async fn prepare_socket_request(sender: beam_lib::AppId, task_id: MsgId, headers: HeaderMap, state: &State<TasksState>) -> Result<http::Request<String>, http::Response<Body>> {
async fn prepare_socket_request(sender: beam_lib::AppId, task_id: MsgId, state: &State<TasksState>) -> Result<http::Request<String>, http::Response<Body>> {
let msg_empty = MsgEmpty {
from: AppOrProxyId::App(sender.clone()),
};
Expand All @@ -217,22 +209,20 @@ async fn prepare_socket_request(sender: beam_lib::AppId, task_id: MsgId, headers
return Err(StatusCode::INTERNAL_SERVER_ERROR.into_response());
};
let new_req = Request::get(format!("/v1/sockets/{task_id}")).body(axum::body::Body::from(body));
let mut get_socket_con_req = match new_req {
let get_socket_con_req = match new_req {
Ok(req) => req,
Err(e) => {
warn!("Failed to construct request: {e}");
return Err(StatusCode::INTERNAL_SERVER_ERROR.into_response());
}
};
*get_socket_con_req.headers_mut() = headers;
let req = match prepare_request(get_socket_con_req, &state.config, &sender).await {
Ok(req) => req,
match prepare_request(get_socket_con_req, &state.config, &sender).await {
Ok(req) => Ok(req),
Err(err) => {
warn!("Failed to create socket connect request: {err:?}");
return Err(err.into_response());
Err(err.into_response())
}
};
Ok(req)
}
}

struct Encrypter {
Expand Down Expand Up @@ -314,72 +304,3 @@ impl Decrypter {
})
}
}


#[cfg(never)]
mod tests {
use tokio::net::{TcpListener, TcpStream};

use super::*;

#[tokio::test]
async fn test_encryption() {
let mut key = GenericArray::default();
OsRng.fill_bytes(&mut key);
const N: usize = 2_usize.pow(13);
let test_data: &mut [u8; N] = &mut [0; N];
OsRng.fill_bytes(test_data);
let mut read_buf = [0; N];

start_test_broker().await;
let (mut client1, mut client2) = tokio::join!(client(&key), client(&key));

client1.write_all(test_data).await.unwrap();
client2.read_exact(&mut read_buf).await.unwrap();
assert_eq!(test_data, &read_buf);
client2.write_all(test_data).await.unwrap();
client1.read_exact(&mut read_buf).await.unwrap();
assert_eq!(test_data, &read_buf);
}

async fn start_test_broker() {
let server = TcpListener::bind("127.0.0.1:1337").await.unwrap();
tokio::spawn(async move {
let ((mut a, _), (mut b, _)) = tokio::try_join!(server.accept(), server.accept()).unwrap();
tokio::io::copy_bidirectional(&mut a, &mut b).await.unwrap();
});
}

async fn client(key: &GenericArray<u8, U32>) -> impl AsyncRead + AsyncWrite {
// Wait for server to start
tokio::time::sleep(Duration::from_millis(100)).await;
let stream = TcpStream::connect("127.0.0.1:1337").await.unwrap();
EncryptedSocket::new(stream, key).await.unwrap()
}

#[test]
fn normal_enc() {
let mut key = GenericArray::default();
OsRng.fill_bytes(&mut key);
const N: usize = 2_usize.pow(10);
let test_data: &mut [u8; N] = &mut [0; N];
OsRng.fill_bytes(test_data);

let mut nonce = GenericArray::<u8, U20>::default();
OsRng.fill_bytes(&mut nonce);
let aead = XChaCha20Poly1305::new(&key);
let client1 = StreamLE31::from_aead(aead, &nonce);
let mut encryptor = Encryptor::from_stream_primitive(client1);

// let mut nonce = GenericArray::<u8, U20>::default();
// OsRng.fill_bytes(&mut nonce);
let aead = XChaCha20Poly1305::new(&key);
let client2 = StreamLE31::from_aead(aead, &nonce);
let mut decrypter = Decryptor::from_stream_primitive(client2);

let cipher_text = encryptor.encrypt_next(test_data.as_slice()).unwrap();
dbg!(cipher_text.len());
let a = decrypter.decrypt_next(cipher_text.as_slice()).unwrap();
assert_eq!(test_data, a.as_slice());
}
}
19 changes: 13 additions & 6 deletions tests/src/socket_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,25 @@ use std::{convert::Infallible, time::Duration};

use futures::{stream, StreamExt};
use anyhow::Result;
use rand::{Rng, RngCore};
use crate::*;

#[tokio::test]
async fn test_full() -> Result<()> {
let metadata: &'static _ = Box::leak(Box::new(serde_json::json!({
"foo": vec![1, 2, 3],
})));
let range = 1..10_000;
let stream = stream::iter(range.clone())
.map(|i| Ok::<_, Infallible>(u32::to_be_bytes(i).to_vec()))
.then(|b| async {
tokio::time::sleep(Duration::from_millis(1)).await;
let data = Vec::from_iter((0..1000).map(|_| {
let mut chunk = vec![0; 1024];
rand::thread_rng().fill_bytes(&mut chunk);
chunk
}));
let stream = stream::iter(data.clone())
.map(Ok::<_, Infallible>)
.then(move |b| async {
if rand::thread_rng().gen_ratio(1, 10) {
tokio::time::sleep(Duration::from_millis(10)).await;
}
b
});
let app1 = async move {
Expand All @@ -29,7 +36,7 @@ async fn test_full() -> Result<()> {
.ok_or(anyhow::anyhow!("Failed to get a socket task"))?;
assert_eq!(&task.metadata, metadata);
let s = client2().connect_socket(&task.id).await?;
let expected = range.map(u32::to_be_bytes).flatten().collect::<Vec<_>>();
let expected = data.into_iter().flatten().collect::<Vec<_>>();
let mut buf = Vec::with_capacity(expected.len());
s.for_each(|b| {
buf.extend(b.unwrap());
Expand Down

0 comments on commit aae11f5

Please sign in to comment.