Skip to content

Commit

Permalink
Add listen_tcp_on_free_port to return a test port
Browse files Browse the repository at this point in the history
Previously in tests downstream ohttp_relay was initiated with a port
that may no longer be free by the time it got bound. By having this
code bind on and return the port the indirection is removed.
  • Loading branch information
DanGould committed Dec 30, 2024
1 parent e55e2e3 commit 25261c1
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 21 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ default = ["bootstrap"]
bootstrap = ["connect-bootstrap", "ws-bootstrap"]
connect-bootstrap = []
ws-bootstrap = ["futures", "hyper-tungstenite", "rustls", "tokio-tungstenite"]
_test-util = []

[dependencies]
futures = { version = "0.3", optional = true }
Expand Down
61 changes: 40 additions & 21 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,43 +41,62 @@ pub async fn listen_tcp(port: u16, gateway_origin: Uri) -> Result<(), BoxError>
let addr = SocketAddr::from(([0, 0, 0, 0], port));
let listener = TcpListener::bind(addr).await?;
println!("OHTTP relay listening on tcp://{}", addr);
ohttp_relay(listener, gateway_origin).await
ohttp_relay(listener, gateway_origin).await?;
Ok(())
}

#[instrument]
pub async fn listen_socket(socket_path: &str, gateway_origin: Uri) -> Result<(), BoxError> {
let listener = UnixListener::bind(socket_path)?;
info!("OHTTP relay listening on socket: {}", socket_path);
ohttp_relay(listener, gateway_origin).await
ohttp_relay(listener, gateway_origin).await?;
Ok(())
}

#[cfg(feature = "_test-util")]
pub async fn listen_tcp_on_free_port(
gateway_origin: Uri,
) -> Result<(u16, tokio::task::JoinHandle<Result<(), BoxError>>), BoxError> {
let listener = tokio::net::TcpListener::bind("[::]:0").await?;
let port = listener.local_addr()?.port();
println!("Directory server binding to port {}", listener.local_addr()?);
let handle = ohttp_relay(listener, gateway_origin).await?;
Ok((port, handle))
}

#[instrument(skip(listener))]
async fn ohttp_relay<L>(mut listener: L, gateway_origin: Uri) -> Result<(), BoxError>
async fn ohttp_relay<L>(
mut listener: L,
gateway_origin: Uri,
) -> Result<tokio::task::JoinHandle<Result<(), BoxError>>, BoxError>
where
L: Listener + Unpin,
L: Listener + Unpin + Send + 'static,
L::Io: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let gateway_origin = GatewayUri::new(gateway_origin)?;
let gateway_origin: Arc<GatewayUri> = Arc::new(gateway_origin);

while let Ok((stream, _)) = listener.accept().await {
let gateway_origin = gateway_origin.clone();
let io = TokioIo::new(stream);
tokio::spawn(async move {
if let Err(err) = http1::Builder::new()
.serve_connection(
io,
service_fn(move |req| serve_ohttp_relay(req, gateway_origin.clone())),
)
.with_upgrades()
.await
{
error!("Error serving connection: {:?}", err);
}
});
}
let handle = tokio::spawn(async move {
while let Ok((stream, _)) = listener.accept().await {
let gateway_origin = gateway_origin.clone();
let io = TokioIo::new(stream);
tokio::spawn(async move {
if let Err(err) = http1::Builder::new()
.serve_connection(
io,
service_fn(move |req| serve_ohttp_relay(req, gateway_origin.clone())),
)
.with_upgrades()
.await
{
error!("Error serving connection: {:?}", err);
}
});
}
Ok(())
});

Ok(())
Ok(handle)
}

#[instrument]
Expand Down

0 comments on commit 25261c1

Please sign in to comment.