diff --git a/srt-protocol/src/protocol/pending_connection/hsv5.rs b/srt-protocol/src/protocol/pending_connection/hsv5.rs index e5e3c152..44639851 100644 --- a/srt-protocol/src/protocol/pending_connection/hsv5.rs +++ b/srt-protocol/src/protocol/pending_connection/hsv5.rs @@ -6,6 +6,8 @@ use std::{ time::Instant, }; +use log::warn; + use crate::{connection::ConnectionSettings, options::*, packet::*, settings::*}; use super::{ConnectError, ConnectionReject}; @@ -65,11 +67,12 @@ pub fn gen_access_control_response( }; // crypto - let cipher = match (&settings.key_settings, &incoming.ext_km) { + let cipher = match (&mut settings.key_settings, &incoming.ext_km) { // ok, both sides have crypto (Some(key_settings), Some(SrtControlPacket::KeyRefreshRequest(km))) => { if key_settings.key_size != incoming.key_size { - unimplemented!("Key size mismatch"); + warn!("Key size mismatch: caller requested {:?}, listener was configured with {:?}. Selecting {:?}", incoming.key_size, key_settings.key_size, incoming.key_size); + key_settings.key_size = incoming.key_size; } let cipher = match CipherSettings::new(key_settings, &settings.key_refresh, km) { diff --git a/srt-tokio/tests/crypto.rs b/srt-tokio/tests/crypto.rs index 884e260b..c40266bd 100644 --- a/srt-tokio/tests/crypto.rs +++ b/srt-tokio/tests/crypto.rs @@ -11,14 +11,16 @@ use futures::{SinkExt, TryStreamExt}; use log::info; use tokio::{select, spawn, time::sleep}; -async fn test_crypto(size: u16) { +async fn test_crypto(size_listen: u16, size_call: u16, port: u16) { let sender = SrtSocket::builder() - .encryption(size, "password123") - .listen_on(":2000"); + .encryption(size_listen, "password123") + .listen_on(port); + + let local_addr = format!("127.0.0.1:{port}"); let recvr = SrtSocket::builder() - .encryption(size, "password123") - .call("127.0.0.1:2000", None); + .encryption(size_call, "password123") + .call(local_addr.as_str(), None); let t = spawn(async move { let mut sender = sender.await.unwrap(); @@ -45,11 +47,16 @@ async fn test_crypto(size: u16) { async fn crypto_exchange() { let _ = pretty_env_logger::try_init(); - test_crypto(16).await; - sleep(Duration::from_millis(100)).await; - test_crypto(24).await; - sleep(Duration::from_millis(100)).await; - test_crypto(32).await; + test_crypto(16, 16, 2000).await; + test_crypto(24, 24, 2001).await; + test_crypto(32, 32, 2002).await; +} + +#[tokio::test] +async fn key_size_mismatch() { + test_crypto(32, 16, 2003).await; + test_crypto(32, 0, 2004).await; + test_crypto(0, 32, 2005).await; } #[tokio::test] @@ -94,5 +101,3 @@ async fn bad_password_rendezvous() { assert_matches!(result, Err(e) if e.kind() == io::ErrorKind::ConnectionRefused); } - -// TODO: mismatch diff --git a/srt-tokio/tests/stransmit_interop.rs b/srt-tokio/tests/stransmit_interop.rs index 637654aa..4e9a086f 100644 --- a/srt-tokio/tests/stransmit_interop.rs +++ b/srt-tokio/tests/stransmit_interop.rs @@ -79,7 +79,7 @@ async fn receiver( i += 1; - info!("Got pack!"); + info!("Got pack {i}!"); // stransmit does not totally care if it sends 100% of it's packets // (which is prob fair), just make sure that we got at least 2/3s of it @@ -502,6 +502,78 @@ async fn bidirectional_interop_encrypt_rekey() -> Result<(), Error> { Ok(()) } +#[tokio::test] +#[cfg(not(target_os = "windows"))] +async fn key_size_mismatch_rust_caller() -> Result<(), Error> { + let _ = pretty_env_logger::try_init(); + + const PACKETS: u32 = 1_000; + + let mut child = allow_not_found!(Command::new("srt-live-transmit") + .arg("udp://:2814") + .arg("srt://:2815?passphrase=password123&pbkeylen=24") + .arg("-a:no") + .arg("-loglevel:debug") + .spawn()); + + let recvr_fut = async move { + let recv = SrtSocket::builder() + .encryption(16, "password123") + .call("127.0.0.1:2815", None) + .await + .unwrap(); + + try_join( + receiver(PACKETS, recv.map(|f| f.unwrap().1)), + udp_sender(PACKETS, 2814), + ) + .await + .unwrap(); + }; + + recvr_fut.await; + child.wait().unwrap(); + + Ok(()) +} + +#[tokio::test] +#[cfg(not(target_os = "windows"))] +async fn key_size_mismatch_rust_listener() -> Result<(), Error> { + let _ = pretty_env_logger::try_init(); + + const PACKETS: u32 = 1_000; + + let mut child = allow_not_found!(Command::new("srt-live-transmit") + .arg("srt://127.0.0.1:2816?passphrase=password123&pbkeylen=24") + .arg("udp://:2817") + .arg("-a:no") + .arg("-loglevel:debug") + .spawn()); + + let sendr = async move { + let mut sender = SrtSocket::builder() + .encryption(16, "password123") + .local_port(2816) + .listen() + .await + .unwrap(); + + let mut stream = + counting_stream(PACKETS, Duration::from_millis(1)).map(|b| Ok((Instant::now(), b))); + sender.send_all(&mut stream).await.unwrap(); + sender.close().await.unwrap(); + + Ok(()) + }; + + try_join(sendr, udp_recvr(PACKETS, 2817)).await.unwrap(); + + child.wait().unwrap(); + + Ok(()) +} + type HaiSocket = i32; const SRTO_SENDER: c_int = 21;