From d1d5d395d270221688e13cb3e0932920a4b1d14e Mon Sep 17 00:00:00 2001 From: Thomas Braun Date: Sun, 3 Mar 2024 15:48:47 -0500 Subject: [PATCH] Add RequestID macro --- .../src/connector.rs | 4 +- citadel-internal-service-macros/src/lib.rs | 48 ++++++++++++++++++- citadel-internal-service-types/src/lib.rs | 22 +++++++-- citadel-internal-service/src/kernel/mod.rs | 6 ++- .../src/kernel/request_handler.rs | 7 ++- .../tests/file_transfer.rs | 1 + 6 files changed, 80 insertions(+), 8 deletions(-) diff --git a/citadel-internal-service-connector/src/connector.rs b/citadel-internal-service-connector/src/connector.rs index d7f381f..10c1e1b 100644 --- a/citadel-internal-service-connector/src/connector.rs +++ b/citadel-internal-service-connector/src/connector.rs @@ -72,7 +72,9 @@ impl InternalServiceConnector { .ok_or(ClientError::InternalServiceDisconnected)??; if matches!( greeter_packet, - InternalServicePayload::Response(InternalServiceResponse::ServiceConnectionAccepted(_)) + InternalServicePayload::Response( + InternalServiceResponse::ServiceConnectionAccepted { .. } + ) ) { let stream = WrappedStream { inner: stream }; let sink = WrappedSink { inner: sink }; diff --git a/citadel-internal-service-macros/src/lib.rs b/citadel-internal-service-macros/src/lib.rs index 1bb4af3..ee228c3 100644 --- a/citadel-internal-service-macros/src/lib.rs +++ b/citadel-internal-service-macros/src/lib.rs @@ -31,7 +31,7 @@ fn generate_function(input: TokenStream, contains: &str, function_name: &str) -> // Generate match arms for each enum variant let match_arms = generate_match_arms(name, &data, contains); - // Generate the implementation of the `is_error` method + // Generate the implementation of the `$function_name` method let expanded = quote! { impl #name { pub fn #function_name(&self) -> bool { @@ -66,3 +66,49 @@ fn generate_match_arms( }) .collect() } + +#[proc_macro_derive(RequestId)] +pub fn request_ids(input: TokenStream) -> TokenStream { + // Parse the input tokens into a syntax tree + let input = parse_macro_input!(input as DeriveInput); + + // Extract the identifier and data from the input + let name = &input.ident; + let data = if let Data::Enum(data) = input.data { + data + } else { + // This macro only supports enums + panic!("RequestId can only be derived for enums"); + }; + + // Generate match arms for each enum variant + let match_arms = generate_match_arms_uuid(name, &data); + + // Generate the implementation of the `$function_name` method + let expanded = quote! { + impl #name { + pub fn request_id(&self) -> Option<&Uuid> { + match self { + #(#match_arms)* + } + } + } + }; + + // Convert into a TokenStream and return it + TokenStream::from(expanded) +} + +fn generate_match_arms_uuid(name: &Ident, data_enum: &DataEnum) -> Vec { + data_enum + .variants + .iter() + .map(|variant| { + let variant_ident = &variant.ident; + // Match against each variant, ignoring any inner data + quote! { + #name::#variant_ident(inner) => inner.request_id.as_ref(), + } + }) + .collect() +} diff --git a/citadel-internal-service-types/src/lib.rs b/citadel-internal-service-types/src/lib.rs index 29579f9..2d8a2df 100644 --- a/citadel-internal-service-types/src/lib.rs +++ b/citadel-internal-service-types/src/lib.rs @@ -1,5 +1,5 @@ use bytes::BytesMut; -use citadel_internal_service_macros::{IsError, IsNotification}; +use citadel_internal_service_macros::{IsError, IsNotification, RequestId}; pub use citadel_types::prelude::{ ConnectMode, MemberState, MessageGroupKey, ObjectTransferStatus, SecBuffer, SecurityLevel, SessionSecuritySettings, TransferType, UdpMode, UserIdentifier, VirtualObjectMetadata, @@ -35,7 +35,9 @@ pub struct RegisterFailure { } #[derive(Serialize, Deserialize, Debug, Clone)] -pub struct ServiceConnectionAccepted; +pub struct ServiceConnectionAccepted { + pub request_id: Option, +} #[derive(Serialize, Deserialize, Debug, Clone)] pub struct MessageSendSuccess { @@ -542,6 +544,7 @@ pub struct FileTransferRequestNotification { pub cid: u64, pub peer_cid: u64, pub metadata: VirtualObjectMetadata, + pub request_id: Option, } #[derive(Serialize, Deserialize, Debug, Clone)] @@ -559,9 +562,10 @@ pub struct FileTransferTickNotification { pub cid: u64, pub peer_cid: u64, pub status: ObjectTransferStatus, + pub request_id: Option, } -#[derive(Serialize, Deserialize, Debug, Clone, IsError, IsNotification)] +#[derive(Serialize, Deserialize, Debug, Clone, IsError, IsNotification, RequestId)] pub enum InternalServiceResponse { ConnectSuccess(ConnectSuccess), ConnectFailure(ConnectFailure), @@ -827,6 +831,7 @@ pub enum InternalServiceRequest { pub struct SessionInformation { pub cid: u64, pub peer_connections: HashMap, + pub request_id: Option, } #[derive(Serialize, Deserialize, Clone, Debug)] @@ -840,6 +845,7 @@ pub struct AccountInformation { pub username: String, pub full_name: String, pub peers: HashMap, + pub request_id: Option, } #[derive(Serialize, Deserialize, Clone, Debug)] @@ -889,4 +895,14 @@ mod tests { assert!(!success_response.is_notification()); assert!(notification_response.is_notification()); } + + #[test] + fn test_request_id_macro() { + let request_id = Uuid::new_v4(); + let response = InternalServiceResponse::ConnectSuccess(ConnectSuccess { + cid: 0, + request_id: Some(request_id), + }); + assert_eq!(response.request_id(), Some(&request_id)); + } } diff --git a/citadel-internal-service/src/kernel/mod.rs b/citadel-internal-service/src/kernel/mod.rs index 3660823..c417915 100644 --- a/citadel-internal-service/src/kernel/mod.rs +++ b/citadel-internal-service/src/kernel/mod.rs @@ -329,6 +329,7 @@ impl NetKernel for CitadelWorkspaceService { cid: implicated_cid, peer_cid, metadata, + request_id: None, }, ); send_response_to_tcp_client(&self.tcp_connection_map, response, uuid) @@ -565,7 +566,9 @@ fn handle_connection( let write_task = async move { let response = - InternalServiceResponse::ServiceConnectionAccepted(ServiceConnectionAccepted); + InternalServiceResponse::ServiceConnectionAccepted(ServiceConnectionAccepted { + request_id: None, + }); if let Err(err) = sink_send_payload(response, &mut sink).await { error!(target: "citadel", "Failed to send to client: {err:?}"); @@ -823,6 +826,7 @@ fn spawn_tick_updater( cid: implicated_cid, peer_cid, status: status_message, + request_id: None, }, ); match entry.send(message.clone()) { diff --git a/citadel-internal-service/src/kernel/request_handler.rs b/citadel-internal-service/src/kernel/request_handler.rs index 89f61bb..ca9aa88 100644 --- a/citadel-internal-service/src/kernel/request_handler.rs +++ b/citadel-internal-service/src/kernel/request_handler.rs @@ -106,6 +106,7 @@ pub async fn handle_request( accounts_ret: &mut HashMap, account: CNACMetadata, remote: &NodeRemote, + request_id: Uuid, ) { let username = account.username.clone(); let full_name = account.full_name.clone(); @@ -143,6 +144,7 @@ pub async fn handle_request( username, full_name, peers, + request_id: Some(request_id), }, ); } @@ -164,11 +166,11 @@ pub async fn handle_request( if let Some(cid) = cid { let account = filtered_accounts.into_iter().find(|r| r.cid == cid); if let Some(account) = account { - add_account_to_map(&mut accounts_ret, account, remote).await; + add_account_to_map(&mut accounts_ret, account, remote, request_id).await; } } else { for account in filtered_accounts { - add_account_to_map(&mut accounts_ret, account, remote).await; + add_account_to_map(&mut accounts_ret, account, remote, request_id).await; } } @@ -188,6 +190,7 @@ pub async fn handle_request( let mut session = SessionInformation { cid: *cid, peer_connections: HashMap::new(), + request_id: Some(request_id), }; for (peer_cid, conn) in connection.peers.iter() { session.peer_connections.insert( diff --git a/citadel-internal-service/tests/file_transfer.rs b/citadel-internal-service/tests/file_transfer.rs index aaf74b7..6d966a1 100644 --- a/citadel-internal-service/tests/file_transfer.rs +++ b/citadel-internal-service/tests/file_transfer.rs @@ -431,6 +431,7 @@ mod tests { cid: _, peer_cid: _, status, + request_id: None, }, ) => match status { ObjectTransferStatus::ReceptionBeginning(file_path, vfm) => {