diff --git a/Cargo-minimal.lock b/Cargo-minimal.lock index f9fbea2e..4e51a2bc 100644 --- a/Cargo-minimal.lock +++ b/Cargo-minimal.lock @@ -1450,9 +1450,9 @@ dependencies = [ [[package]] name = "ohttp-relay" -version = "0.0.8" +version = "0.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7850c40a0aebcba289d3252c0a45f93cba6ad4b0c46b88a5fc51dba6ddce8632" +checksum = "4f8e8aef13b8327b680aaaca807aa11ba5979fc5858203e7b77c68128ede61a2" dependencies = [ "futures", "http", diff --git a/Cargo-recent.lock b/Cargo-recent.lock index f9fbea2e..4e51a2bc 100644 --- a/Cargo-recent.lock +++ b/Cargo-recent.lock @@ -1450,9 +1450,9 @@ dependencies = [ [[package]] name = "ohttp-relay" -version = "0.0.8" +version = "0.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7850c40a0aebcba289d3252c0a45f93cba6ad4b0c46b88a5fc51dba6ddce8632" +checksum = "4f8e8aef13b8327b680aaaca807aa11ba5979fc5858203e7b77c68128ede61a2" dependencies = [ "futures", "http", diff --git a/payjoin-cli/Cargo.toml b/payjoin-cli/Cargo.toml index bd7fa1fa..672481b3 100644 --- a/payjoin-cli/Cargo.toml +++ b/payjoin-cli/Cargo.toml @@ -50,7 +50,7 @@ url = { version = "2.3.1", features = ["serde"] } [dev-dependencies] bitcoind = { version = "0.36.0", features = ["0_21_2"] } http = "1" -ohttp-relay = "0.0.8" +ohttp-relay = { version = "0.0.9", features = ["_test-util"] } once_cell = "1" payjoin-directory = { path = "../payjoin-directory", features = ["_danger-local-https"] } testcontainers = "0.15.0" diff --git a/payjoin-cli/tests/e2e.rs b/payjoin-cli/tests/e2e.rs index 439dd90b..bee2f564 100644 --- a/payjoin-cli/tests/e2e.rs +++ b/payjoin-cli/tests/e2e.rs @@ -151,6 +151,11 @@ mod e2e { payjoin_sent.unwrap().unwrap_or(Some(false)).unwrap(), "Payjoin send was not detected" ); + + fn find_free_port() -> u16 { + let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + listener.local_addr().unwrap().port() + } } #[cfg(feature = "v2")] @@ -170,6 +175,7 @@ mod e2e { use url::Url; type Error = Box; + type BoxSendSyncError = Box; type Result = std::result::Result; static INIT_TRACING: OnceCell<()> = OnceCell::new(); @@ -178,18 +184,26 @@ mod e2e { init_tracing(); let (cert, key) = local_cert_key(); - let ohttp_relay_port = find_free_port(); - let ohttp_relay = Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap(); - let directory_port = find_free_port(); - let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap(); + let docker: Cli = Cli::default(); + let db = docker.run(Redis); + let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379)); + let (port, directory_handle) = + init_directory(db_host, (cert.clone(), key)).await.expect("Failed to init directory"); + let directory = Url::parse(&format!("https://localhost:{}", port)).unwrap(); + let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap(); + let (ohttp_relay_port, ohttp_relay_handle) = + ohttp_relay::listen_tcp_on_free_port(gateway_origin) + .await + .expect("Failed to init ohttp relay"); + let ohttp_relay = Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap(); let temp_dir = env::temp_dir(); let receiver_db_path = temp_dir.join("receiver_db"); let sender_db_path = temp_dir.join("sender_db"); let result: Result<()> = tokio::select! { - res = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => Err(format!("Ohttp relay is long running: {:?}", res).into()), - res = init_directory(directory_port, (cert.clone(), key)) => Err(format!("Directory server is long running: {:?}", res).into()), + res = ohttp_relay_handle => Err(format!("Ohttp relay is long running: {:?}", res).into()), + res = directory_handle => Err(format!("Directory server is long running: {:?}", res).into()), res = send_receive_cli_async(ohttp_relay, directory, cert, receiver_db_path.clone(), sender_db_path.clone()) => res.map_err(|e| format!("send_receive failed: {:?}", e).into()), }; @@ -476,13 +490,17 @@ mod e2e { Err("Timeout waiting for service to be ready".into()) } - async fn init_directory(port: u16, local_cert_key: (Vec, Vec)) -> Result<()> { - let docker: Cli = Cli::default(); + async fn init_directory( + db_host: String, + local_cert_key: (Vec, Vec), + ) -> std::result::Result< + (u16, tokio::task::JoinHandle>), + BoxSendSyncError, + > { + println!("Database running on {}", db_host); let timeout = Duration::from_secs(2); - let db = docker.run(Redis); - let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379)); - println!("Database running on {}", db.get_host_port_ipv4(6379)); - payjoin_directory::listen_tcp_with_tls(port, db_host, timeout, local_cert_key).await + payjoin_directory::listen_tcp_with_tls_on_free_port(db_host, timeout, local_cert_key) + .await } // generates or gets a DER encoded localhost cert and key. @@ -521,11 +539,6 @@ mod e2e { } } - fn find_free_port() -> u16 { - let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); - listener.local_addr().unwrap().port() - } - async fn cleanup_temp_file(path: &std::path::Path) { if let Err(e) = fs::remove_dir_all(path).await { eprintln!("Failed to remove {:?}: {}", path, e); diff --git a/payjoin-directory/src/lib.rs b/payjoin-directory/src/lib.rs index 6d633e22..9a1c651c 100644 --- a/payjoin-directory/src/lib.rs +++ b/payjoin-directory/src/lib.rs @@ -36,6 +36,67 @@ const ID_LENGTH: usize = 13; mod db; use crate::db::DbPool; +#[cfg(feature = "_danger-local-https")] +type BoxError = Box; + +#[cfg(feature = "_danger-local-https")] +pub async fn listen_tcp_with_tls_on_free_port( + db_host: String, + timeout: Duration, + cert_key: (Vec, Vec), +) -> Result<(u16, tokio::task::JoinHandle>), BoxError> { + let listener = tokio::net::TcpListener::bind("[::]:0").await?; + let port = listener.local_addr()?.port(); + println!("Directory server binding to port {}", listener.local_addr()?); + let handle = listen_tcp_with_tls_on_listener(listener, db_host, timeout, cert_key).await?; + Ok((port, handle)) +} + +// Helper function to avoid code duplication +#[cfg(feature = "_danger-local-https")] +async fn listen_tcp_with_tls_on_listener( + listener: tokio::net::TcpListener, + db_host: String, + timeout: Duration, + tls_config: (Vec, Vec), +) -> Result>, BoxError> { + let pool = DbPool::new(timeout, db_host).await?; + let ohttp = Arc::new(Mutex::new(init_ohttp()?)); + let tls_acceptor = init_tls_acceptor(tls_config)?; + // Spawn the connection handling loop in a separate task + let handle = tokio::spawn(async move { + while let Ok((stream, _)) = listener.accept().await { + let pool = pool.clone(); + let ohttp = ohttp.clone(); + let tls_acceptor = tls_acceptor.clone(); + tokio::spawn(async move { + let tls_stream = match tls_acceptor.accept(stream).await { + Ok(tls_stream) => tls_stream, + Err(e) => { + error!("TLS accept error: {}", e); + return; + } + }; + if let Err(err) = http1::Builder::new() + .serve_connection( + TokioIo::new(tls_stream), + service_fn(move |req| { + serve_payjoin_directory(req, pool.clone(), ohttp.clone()) + }), + ) + .with_upgrades() + .await + { + error!("Error serving connection: {:?}", err); + } + }); + } + Ok(()) + }); + Ok(handle) +} + +// Modify existing listen_tcp_with_tls to use the new helper pub async fn listen_tcp( port: u16, db_host: String, @@ -73,41 +134,11 @@ pub async fn listen_tcp_with_tls( port: u16, db_host: String, timeout: Duration, - tls_config: (Vec, Vec), -) -> Result<(), Box> { - let pool = DbPool::new(timeout, db_host).await?; - let ohttp = Arc::new(Mutex::new(init_ohttp()?)); - let bind_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port); - let tls_acceptor = init_tls_acceptor(tls_config)?; - let listener = TcpListener::bind(bind_addr).await?; - while let Ok((stream, _)) = listener.accept().await { - let pool = pool.clone(); - let ohttp = ohttp.clone(); - let tls_acceptor = tls_acceptor.clone(); - tokio::spawn(async move { - let tls_stream = match tls_acceptor.accept(stream).await { - Ok(tls_stream) => tls_stream, - Err(e) => { - error!("TLS accept error: {}", e); - return; - } - }; - if let Err(err) = http1::Builder::new() - .serve_connection( - TokioIo::new(tls_stream), - service_fn(move |req| { - serve_payjoin_directory(req, pool.clone(), ohttp.clone()) - }), - ) - .with_upgrades() - .await - { - error!("Error serving connection: {:?}", err); - } - }); - } - - Ok(()) + cert_key: (Vec, Vec), +) -> Result>, BoxError> { + let addr = format!("0.0.0.0:{}", port); + let listener = tokio::net::TcpListener::bind(&addr).await?; + listen_tcp_with_tls_on_listener(listener, db_host, timeout, cert_key).await } #[cfg(feature = "_danger-local-https")] diff --git a/payjoin/Cargo.toml b/payjoin/Cargo.toml index de01a1f7..fe11e5f5 100644 --- a/payjoin/Cargo.toml +++ b/payjoin/Cargo.toml @@ -41,7 +41,7 @@ serde_json = "1.0.108" bitcoind = { version = "0.36.0", features = ["0_21_2"] } http = "1" payjoin-directory = { path = "../payjoin-directory", features = ["_danger-local-https"] } -ohttp-relay = "0.0.8" +ohttp-relay = { version = "0.0.9", features = ["_test-util"] } once_cell = "1" rcgen = { version = "0.11" } reqwest = { version = "0.12", default-features = false, features = ["rustls-tls"] } diff --git a/payjoin/tests/integration.rs b/payjoin/tests/integration.rs index a82e581a..01a41d1b 100644 --- a/payjoin/tests/integration.rs +++ b/payjoin/tests/integration.rs @@ -187,6 +187,8 @@ mod integration { use super::*; + type BoxSendSyncError = Box; + static TESTS_TIMEOUT: Lazy = Lazy::new(|| Duration::from_secs(20)); static WAIT_SERVICE_INTERVAL: Lazy = Lazy::new(|| Duration::from_secs(3)); @@ -197,10 +199,17 @@ mod integration { .expect("Invalid OhttpKeys"); let (cert, key) = local_cert_key(); - let port = find_free_port(); + let docker: Cli = Cli::default(); + let db = docker.run(Redis); + let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379)); + + let (port, directory_handle) = init_directory(db_host, (cert.clone(), key)) + .await + .expect("Failed to init directory"); let directory = Url::parse(&format!("https://localhost:{}", port)).unwrap(); + tokio::select!( - _ = init_directory(port, (cert.clone(), key)) => panic!("Directory server is long running"), + err = directory_handle => panic!("Directory server exited early: {:?}", err), res = try_request_with_bad_keys(directory, bad_ohttp_keys, cert) => { assert_eq!( res.unwrap().headers().get("content-type").unwrap(), @@ -231,15 +240,24 @@ mod integration { async fn test_session_expiration() { init_tracing(); let (cert, key) = local_cert_key(); - let ohttp_relay_port = find_free_port(); - let ohttp_relay = - Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap(); - let directory_port = find_free_port(); + let docker: Cli = Cli::default(); + let db = docker.run(Redis); + let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379)); + + let (directory_port, directory_handle) = init_directory(db_host, (cert.clone(), key)) + .await + .expect("Failed to init directory"); let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap(); let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap(); + let (ohttp_relay_port, ohttp_relay_handle) = + ohttp_relay::listen_tcp_on_free_port(gateway_origin) + .await + .expect("Failed to init ohttp relay"); + let ohttp_relay = + Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap(); tokio::select!( - _ = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => panic!("Ohttp relay is long running"), - _ = init_directory(directory_port, (cert.clone(), key)) => panic!("Directory server is long running"), + err = ohttp_relay_handle => panic!("Ohttp relay exited early: {:?}", err), + err = directory_handle => panic!("Directory server exited early: {:?}", err), res = do_expiration_tests(ohttp_relay, directory, cert) => assert!(res.is_ok(), "v2 send receive failed: {:#?}", res) ); @@ -300,15 +318,24 @@ mod integration { async fn v2_to_v2() { init_tracing(); let (cert, key) = local_cert_key(); - let ohttp_relay_port = find_free_port(); - let ohttp_relay = - Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap(); - let directory_port = find_free_port(); + let docker: Cli = Cli::default(); + let db = docker.run(Redis); + let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379)); + + let (directory_port, directory_handle) = init_directory(db_host, (cert.clone(), key)) + .await + .expect("Failed to init directory"); let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap(); let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap(); + let (ohttp_relay_port, ohttp_relay_handle) = + ohttp_relay::listen_tcp_on_free_port(gateway_origin) + .await + .expect("Failed to init ohttp relay"); + let ohttp_relay = + Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap(); tokio::select!( - _ = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => panic!("Ohttp relay is long running"), - _ = init_directory(directory_port, (cert.clone(), key)) => panic!("Directory server is long running"), + err = ohttp_relay_handle => panic!("Ohttp relay exited early: {:?}", err), + err = directory_handle => panic!("Directory server exited early: {:?}", err), res = do_v2_send_receive(ohttp_relay, directory, cert) => assert!(res.is_ok(), "v2 send receive failed: {:#?}", res) ); @@ -429,15 +456,24 @@ mod integration { async fn v2_to_v2_mixed_input_script_types() { init_tracing(); let (cert, key) = local_cert_key(); - let ohttp_relay_port = find_free_port(); - let ohttp_relay = - Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap(); - let directory_port = find_free_port(); + let docker: Cli = Cli::default(); + let db = docker.run(Redis); + let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379)); + + let (directory_port, directory_handle) = init_directory(db_host, (cert.clone(), key)) + .await + .expect("Failed to init directory"); let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap(); let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap(); + let (ohttp_relay_port, ohttp_relay_handle) = + ohttp_relay::listen_tcp_on_free_port(gateway_origin) + .await + .expect("Failed to init ohttp relay"); + let ohttp_relay = + Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap(); tokio::select!( - _ = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => panic!("Ohttp relay is long running"), - _ = init_directory(directory_port, (cert.clone(), key)) => panic!("Directory server is long running"), + err = ohttp_relay_handle => panic!("Ohttp relay exited early: {:?}", err), + err = directory_handle => panic!("Directory server exited early: {:?}", err), res = do_v2_send_receive(ohttp_relay, directory, cert) => assert!(res.is_ok(), "v2 send receive failed: {:#?}", res) ); @@ -641,15 +677,23 @@ mod integration { async fn v1_to_v2() { init_tracing(); let (cert, key) = local_cert_key(); - let ohttp_relay_port = find_free_port(); - let ohttp_relay = - Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap(); - let directory_port = find_free_port(); + let docker: Cli = Cli::default(); + let db = docker.run(Redis); + let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379)); + let (directory_port, directory_handle) = init_directory(db_host, (cert.clone(), key)) + .await + .expect("Failed to init directory"); let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap(); let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap(); + let (ohttp_relay_port, ohttp_relay_handle) = + ohttp_relay::listen_tcp_on_free_port(gateway_origin) + .await + .expect("Failed to init ohttp relay"); + let ohttp_relay = + Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap(); tokio::select!( - _ = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => panic!("Ohttp relay is long running"), - _ = init_directory(directory_port, (cert.clone(), key)) => panic!("Directory server is long running"), + err = ohttp_relay_handle => panic!("Ohttp relay exited early: {:?}", err), + err = directory_handle => panic!("Directory server exited early: {:?}", err), res = do_v1_to_v2(ohttp_relay, directory, cert) => assert!(res.is_ok()), ); @@ -771,15 +815,14 @@ mod integration { } async fn init_directory( - port: u16, + db_host: String, local_cert_key: (Vec, Vec), - ) -> Result<(), BoxError> { - let docker: Cli = Cli::default(); + ) -> Result<(u16, tokio::task::JoinHandle>), BoxSendSyncError> + { + println!("Database running on {}", db_host); let timeout = Duration::from_secs(2); - let db = docker.run(Redis); - let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379)); - println!("Database running on {}", db.get_host_port_ipv4(6379)); - payjoin_directory::listen_tcp_with_tls(port, db_host, timeout, local_cert_key).await + payjoin_directory::listen_tcp_with_tls_on_free_port(db_host, timeout, local_cert_key) + .await } // generates or gets a DER encoded localhost cert and key. @@ -905,11 +948,6 @@ mod integration { )) } - fn find_free_port() -> u16 { - let listener = std::net::TcpListener::bind("0.0.0.0:0").unwrap(); - listener.local_addr().unwrap().port() - } - async fn wait_for_service_ready( service_url: Url, agent: Arc, @@ -920,7 +958,6 @@ mod integration { while start.elapsed() < *TESTS_TIMEOUT { let request_result = agent.get(health_url.as_str()).send().await.map_err(|_| "Bad request")?; - match request_result.status() { StatusCode::OK => return Ok(()), StatusCode::NOT_FOUND => return Err("Endpoint not found"),