diff --git a/crates/sip-core/src/endpoint.rs b/crates/sip-core/src/endpoint.rs index b23b755..e58bccf 100644 --- a/crates/sip-core/src/endpoint.rs +++ b/crates/sip-core/src/endpoint.rs @@ -106,14 +106,14 @@ impl Endpoint { /// Create a [`ServerTsx`] from an [`IncomingRequest`]. The returned transaction /// can be used to form and send responses to the request. - pub fn create_server_tsx(&self, request: &IncomingRequest) -> ServerTsx { - ServerTsx::new(self.clone(), request) + pub fn create_server_tsx(&self, request: &mut IncomingRequest) -> ServerTsx { + ServerTsx::new(request) } /// Create a [`ServerInvTsx`] from an INVITE [`IncomingRequest`]. The returned transaction /// can be used to form and send responses to the request. - pub fn create_server_inv_tsx(&self, request: &IncomingRequest) -> ServerInvTsx { - ServerInvTsx::new(self.clone(), request) + pub fn create_server_inv_tsx(&self, request: &mut IncomingRequest) -> ServerInvTsx { + ServerInvTsx::new(request) } /// Returns all ALLOW headers this endpoint supports @@ -360,32 +360,40 @@ impl Endpoint { } }; + let mut tsx = None; + // Try to find a transaction that might be able to handle the message - if let Some(handler) = self.transactions().get_handler(&tsx_key) { - let tsx_message = TsxMessage { - tp_info: message.tp_info, - line: message.line, - base_headers, - headers: message.headers, - body: message.body, - }; + match self.transactions().get_handler(&self, &tsx_key) { + Ok(handler) => { + let tsx_message = TsxMessage { + tp_info: message.tp_info, + line: message.line, + base_headers, + headers: message.headers, + body: message.body, + }; - log::debug!("delegating message to transaction {}", tsx_key); + log::debug!("delegating message to transaction {}", tsx_key); - if let Some(rejected_tsx_message) = handler(tsx_message) { - log::trace!("transaction {} rejected message", tsx_key); + if let Some(rejected_tsx_message) = handler(tsx_message) { + log::trace!("transaction {} rejected message", tsx_key); - // TsxMessage was rejected, restore previous state - base_headers = rejected_tsx_message.base_headers; - message = ReceivedMessage { - tp_info: rejected_tsx_message.tp_info, - line: rejected_tsx_message.line, - headers: rejected_tsx_message.headers, - body: rejected_tsx_message.body, - }; - } else { - // Handled - return; + // TsxMessage was rejected, restore previous state + base_headers = rejected_tsx_message.base_headers; + message = ReceivedMessage { + tp_info: rejected_tsx_message.tp_info, + line: rejected_tsx_message.line, + headers: rejected_tsx_message.headers, + body: rejected_tsx_message.body, + }; + } else { + // Handled + return; + } + } + Err(registration) => { + log::debug!("no transaction for {tsx_key} found, created registration"); + tsx = Some(registration); } } @@ -401,6 +409,7 @@ impl Endpoint { let incoming = IncomingRequest { tp_info: message.tp_info, + tsx, line, base_headers, headers: message.headers, @@ -433,7 +442,7 @@ impl Endpoint { } } - async fn handle_unwanted_request(&self, request: IncomingRequest) -> Result<()> { + async fn handle_unwanted_request(&self, mut request: IncomingRequest) -> Result<()> { if request.line.method == Method::ACK { // Cannot respond to unhandled ACK requests return Ok(()); @@ -443,11 +452,11 @@ impl Endpoint { self.create_response(&request, Code::CALL_OR_TRANSACTION_DOES_NOT_EXIST, None); if request.line.method == Method::INVITE { - let tsx = self.create_server_inv_tsx(&request); + let tsx = self.create_server_inv_tsx(&mut request); tsx.respond_failure(response).await } else { - let tsx = self.create_server_tsx(&request); + let tsx = self.create_server_tsx(&mut request); tsx.respond(response).await } diff --git a/crates/sip-core/src/lib.rs b/crates/sip-core/src/lib.rs index f53c8ac..9d640f0 100644 --- a/crates/sip-core/src/lib.rs +++ b/crates/sip-core/src/lib.rs @@ -14,7 +14,7 @@ use sip_types::print::AppendCtx; use sip_types::uri::Uri; use sip_types::{Headers, Method, Name}; use std::fmt; -use transaction::TsxKey; +use transaction::{TsxKey, TsxRegistration}; use transport::MessageTpInfo; #[macro_use] @@ -103,6 +103,7 @@ impl BaseHeaders { pub struct IncomingRequest { pub tp_info: MessageTpInfo, pub tsx_key: TsxKey, + tsx: Option, pub line: RequestLine, pub base_headers: BaseHeaders, @@ -116,6 +117,17 @@ impl fmt::Display for IncomingRequest { } } +impl IncomingRequest { + #[track_caller] + fn take_tsx_registration(&mut self) -> TsxRegistration { + let Some(tsx) = self.tsx.take() else { + panic!("Tried to create transaction for {:?}, which is an already handled message or isn't a transaction creating request", self.tsx_key); + }; + + tsx + } +} + /// Layers are extensions to the endpoint. /// /// They can be added to the endpoint in the building stage bay calling diff --git a/crates/sip-core/src/transaction/mod.rs b/crates/sip-core/src/transaction/mod.rs index cd6997f..baef9dc 100644 --- a/crates/sip-core/src/transaction/mod.rs +++ b/crates/sip-core/src/transaction/mod.rs @@ -1,13 +1,14 @@ use crate::transport::MessageTpInfo; -use crate::BaseHeaders; +use crate::{BaseHeaders, Endpoint}; use bytes::Bytes; use bytesstr::BytesStr; -use parking_lot::{MappedRwLockReadGuard, RwLock, RwLockReadGuard}; -use registration::TsxRegistration; +use parking_lot::lock_api::MutexGuard; +use parking_lot::{MappedMutexGuard, Mutex}; use sip_types::msg::{MessageLine, StatusLine}; use sip_types::Headers; use std::collections::hash_map::Entry; use std::collections::HashMap; +use tokio::sync::mpsc; mod client; mod client_inv; @@ -32,34 +33,55 @@ pub use key::TsxKey; pub use server::ServerTsx; pub use server_inv::{Accepted, ServerInvTsx}; +pub(crate) use registration::TsxRegistration; + pub(crate) type TsxHandler = Box Option + Send + Sync>; #[derive(Default)] pub(crate) struct Transactions { - map: RwLock>, + map: Mutex>, } impl Transactions { - pub fn get_handler<'a: 'k, 'k>( + pub(crate) fn get_handler<'a: 'k, 'k>( &'a self, + endoint: &Endpoint, tsx_key: &TsxKey, - ) -> Option> { - let map = self.map.read(); - RwLockReadGuard::try_map(map, |map| map.get(tsx_key)).ok() + ) -> Result, TsxRegistration> { + let map = self.map.lock(); + + let mut map = match MutexGuard::try_map(map, |map| map.get_mut(tsx_key)) { + Ok(handler) => return Ok(handler), + Err(map) => map, + }; + + let (sender, receiver) = mpsc::unbounded_channel(); + + map.insert( + tsx_key.clone(), + Box::new(move |msg| sender.send(msg).map_err(|e| e.0).err()), + ); + + Err(TsxRegistration { + endpoint: endoint.clone(), + tsx_key: tsx_key.clone(), + receiver, + }) } - pub fn register_transaction(&self, key: TsxKey, handler: TsxHandler) { - let mut map = self.map.write(); + pub(crate) fn register_transaction(&self, key: TsxKey, handler: TsxHandler) { + let mut map = self.map.lock(); match map.entry(key) { - // See https://github.com/kbalt/ezk/issues/16 - Entry::Occupied(e) => panic!("Tried to create a second transaction for {:?}. This can happen if a retransmission of message is received before creating a transaction for the original one.", e.key()), - Entry::Vacant(e) => { e.insert(handler); }, + Entry::Occupied(e) => panic!("Tried to create a second transaction for {:?}", e.key()), + Entry::Vacant(e) => { + e.insert(handler); + } } } - pub fn remove_transaction(&self, key: &TsxKey) { - self.map.write().remove(key); + pub(crate) fn remove_transaction(&self, key: &TsxKey) { + self.map.lock().remove(key); } } diff --git a/crates/sip-core/src/transaction/registration.rs b/crates/sip-core/src/transaction/registration.rs index 54ec01b..4d0795e 100644 --- a/crates/sip-core/src/transaction/registration.rs +++ b/crates/sip-core/src/transaction/registration.rs @@ -12,10 +12,10 @@ use tokio::sync::mpsc; /// transactional messages from it #[derive(Debug)] pub(crate) struct TsxRegistration { - pub endpoint: Endpoint, - pub tsx_key: TsxKey, + pub(crate) endpoint: Endpoint, + pub(crate) tsx_key: TsxKey, - receiver: mpsc::UnboundedReceiver, + pub(super) receiver: mpsc::UnboundedReceiver, } impl TsxRegistration { @@ -41,7 +41,7 @@ impl TsxRegistration { F: Fn(&TsxMessage) -> bool + Send + Sync + 'static, { let transactions = self.endpoint.transactions(); - let mut tsx_map = transactions.map.write(); + let mut tsx_map = transactions.map.lock(); let handler = tsx_map .get_mut(&self.tsx_key) .expect("registration is responsible of handler lifetime inside endpoint"); diff --git a/crates/sip-core/src/transaction/server.rs b/crates/sip-core/src/transaction/server.rs index 761432f..6225f14 100644 --- a/crates/sip-core/src/transaction/server.rs +++ b/crates/sip-core/src/transaction/server.rs @@ -1,7 +1,7 @@ use super::consts::T1; use super::TsxRegistration; use crate::transport::OutgoingResponse; -use crate::{Endpoint, IncomingRequest, Result}; +use crate::{IncomingRequest, Result}; use sip_types::{CodeKind, Method}; use std::time::Instant; use tokio::time::timeout_at; @@ -19,16 +19,16 @@ pub struct ServerTsx { impl ServerTsx { /// Internal: Used by [Endpoint::create_server_tsx] - pub(crate) fn new(endpoint: Endpoint, request: &IncomingRequest) -> Self { + pub(crate) fn new(request: &mut IncomingRequest) -> Self { assert!( !matches!(request.line.method, Method::INVITE | Method::ACK), "tried to create server transaction from {} request", request.line.method ); - let registration = TsxRegistration::create(endpoint, request.tsx_key.clone()); - - Self { registration } + Self { + registration: request.take_tsx_registration(), + } } /// Respond with a provisional response (1XX) diff --git a/crates/sip-core/src/transaction/server_inv.rs b/crates/sip-core/src/transaction/server_inv.rs index 5a85a4a..0bbbc7d 100644 --- a/crates/sip-core/src/transaction/server_inv.rs +++ b/crates/sip-core/src/transaction/server_inv.rs @@ -2,7 +2,7 @@ use crate::error::Error; use crate::transaction::consts::{T1, T2}; use crate::transaction::TsxRegistration; use crate::transport::OutgoingResponse; -use crate::{Endpoint, IncomingRequest, Result}; +use crate::{IncomingRequest, Result}; use sip_types::msg::MessageLine; use sip_types::{CodeKind, Method}; use std::io; @@ -23,7 +23,7 @@ pub struct ServerInvTsx { impl ServerInvTsx { /// Internal: Used by [Endpoint::create_server_inv_tsx] - pub(crate) fn new(endpoint: Endpoint, request: &IncomingRequest) -> Self { + pub(crate) fn new(request: &mut IncomingRequest) -> Self { assert_eq!( request.line.method, Method::INVITE, @@ -31,9 +31,9 @@ impl ServerInvTsx { request.line.method ); - let registration = TsxRegistration::create(endpoint, request.tsx_key.clone()); - - Self { registration } + Self { + registration: request.take_tsx_registration(), + } } /// Respond with a provisional response (1XX) diff --git a/crates/sip-ua/src/dialog/layer.rs b/crates/sip-ua/src/dialog/layer.rs index d7bb351..0450d07 100644 --- a/crates/sip-ua/src/dialog/layer.rs +++ b/crates/sip-ua/src/dialog/layer.rs @@ -151,7 +151,7 @@ impl DialogLayer { async fn handle_unwanted_request( &self, endpoint: &Endpoint, - request: IncomingRequest, + mut request: IncomingRequest, ) -> Result<()> { if request.line.method == Method::ACK { // Cannot respond to ACK request @@ -161,11 +161,11 @@ impl DialogLayer { let response = endpoint.create_response(&request, Code::NOT_FOUND, None); if request.line.method == Method::INVITE { - let tsx = endpoint.create_server_inv_tsx(&request); + let tsx = endpoint.create_server_inv_tsx(&mut request); tsx.respond_failure(response).await } else { - let tsx = endpoint.create_server_tsx(&request); + let tsx = endpoint.create_server_tsx(&mut request); tsx.respond(response).await } diff --git a/crates/sip-ua/src/invite/acceptor.rs b/crates/sip-ua/src/invite/acceptor.rs index efceb31..333cf0e 100644 --- a/crates/sip-ua/src/invite/acceptor.rs +++ b/crates/sip-ua/src/invite/acceptor.rs @@ -58,7 +58,7 @@ impl Acceptor { pub fn new( dialog: Dialog, invite_layer: LayerKey, - invite: IncomingRequest, + mut invite: IncomingRequest, ) -> Result { assert_eq!( invite.line.method, @@ -88,7 +88,7 @@ impl Acceptor { let dialog_layer = dialog.dialog_layer; // Create Inner shared state - let tsx = endpoint.create_server_inv_tsx(&invite); + let tsx = endpoint.create_server_inv_tsx(&mut invite); let inner = Arc::new(Inner { invite_layer, state: Mutex::new(InviteSessionState::UasProvisional { diff --git a/crates/sip-ua/src/invite/mod.rs b/crates/sip-ua/src/invite/mod.rs index 63a2f9f..371e08e 100644 --- a/crates/sip-ua/src/invite/mod.rs +++ b/crates/sip-ua/src/invite/mod.rs @@ -174,8 +174,8 @@ impl InviteLayer { // Transaction found but completed: respond 200 to cancel // No matching transaction: don't handle it, endpoint will respond accordingly if let Some(inner) = inner { - let cancel = cancel.take(); - let cancel_tsx = endpoint.create_server_tsx(&cancel); + let mut cancel = cancel.take(); + let cancel_tsx = endpoint.create_server_tsx(&mut cancel); if let Some((dialog, invite_tsx, invite)) = inner.state.lock().await.set_cancelled() { let invite_response = @@ -303,10 +303,10 @@ impl InviteUsage { dialog: Dialog, invite_tsx: ServerInvTsx, invite: IncomingRequest, - bye: IncomingRequest, + mut bye: IncomingRequest, ) -> Result<()> { let bye_response = dialog.create_response(&invite, Code::OK, None)?; - let bye_tsx = endpoint.create_server_tsx(&bye); + let bye_tsx = endpoint.create_server_tsx(&mut bye); let invite_response = dialog.create_response(&invite, Code::REQUEST_TERMINATED, None)?; diff --git a/crates/sip-ua/src/invite/prack.rs b/crates/sip-ua/src/invite/prack.rs index 881275d..76986f1 100644 --- a/crates/sip-ua/src/invite/prack.rs +++ b/crates/sip-ua/src/invite/prack.rs @@ -24,7 +24,7 @@ impl InviteUsage { endpoint: &Endpoint, request: MayTake<'_, IncomingRequest>, ) -> Result<()> { - let (prack, awaited_prack) = { + let (mut prack, awaited_prack) = { let mut awaited_prack_opt = self.inner.awaited_prack.lock(); if let Some(awaited_prack) = awaited_prack_opt.take() { let rack = request.headers.get_named::()?; @@ -40,7 +40,7 @@ impl InviteUsage { } }; - let prack_tsx = endpoint.create_server_tsx(&prack); + let prack_tsx = endpoint.create_server_tsx(&mut prack); let response = endpoint.create_response(&prack, Code::OK, None); diff --git a/crates/sip-ua/src/invite/session.rs b/crates/sip-ua/src/invite/session.rs index 094dd39..3c58c39 100644 --- a/crates/sip-ua/src/invite/session.rs +++ b/crates/sip-ua/src/invite/session.rs @@ -189,8 +189,8 @@ impl Session { }; match evt { - UsageEvent::Bye(request) => { - let transaction = self.endpoint.create_server_tsx(&request); + UsageEvent::Bye(mut request) => { + let transaction = self.endpoint.create_server_tsx(&mut request); Ok(Event::Bye(ByeEvent { session: self, @@ -198,10 +198,10 @@ impl Session { transaction, })) } - UsageEvent::ReInvite(invite) => { + UsageEvent::ReInvite(mut invite) => { self.session_timer.reset(); - let transaction = self.endpoint.create_server_inv_tsx(&invite); + let transaction = self.endpoint.create_server_inv_tsx(&mut invite); Ok(Event::ReInviteReceived(ReInviteReceived { session: self,