Skip to content

Commit

Permalink
Add RequestID macro
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Braun committed Mar 3, 2024
1 parent dae1195 commit d1d5d39
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 8 deletions.
4 changes: 3 additions & 1 deletion citadel-internal-service-connector/src/connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ impl InternalServiceConnector {
.ok_or(ClientError::InternalServiceDisconnected)??;
if matches!(
greeter_packet,
InternalServicePayload::Response(InternalServiceResponse::ServiceConnectionAccepted(_))
InternalServicePayload::Response(
InternalServiceResponse::ServiceConnectionAccepted { .. }
)
) {
let stream = WrappedStream { inner: stream };
let sink = WrappedSink { inner: sink };
Expand Down
48 changes: 47 additions & 1 deletion citadel-internal-service-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ fn generate_function(input: TokenStream, contains: &str, function_name: &str) ->
// Generate match arms for each enum variant
let match_arms = generate_match_arms(name, &data, contains);

// Generate the implementation of the `is_error` method
// Generate the implementation of the `$function_name` method
let expanded = quote! {
impl #name {
pub fn #function_name(&self) -> bool {
Expand Down Expand Up @@ -66,3 +66,49 @@ fn generate_match_arms(
})
.collect()
}

#[proc_macro_derive(RequestId)]
pub fn request_ids(input: TokenStream) -> TokenStream {
// Parse the input tokens into a syntax tree
let input = parse_macro_input!(input as DeriveInput);

// Extract the identifier and data from the input
let name = &input.ident;
let data = if let Data::Enum(data) = input.data {
data
} else {
// This macro only supports enums
panic!("RequestId can only be derived for enums");
};

// Generate match arms for each enum variant
let match_arms = generate_match_arms_uuid(name, &data);

// Generate the implementation of the `$function_name` method
let expanded = quote! {
impl #name {
pub fn request_id(&self) -> Option<&Uuid> {
match self {
#(#match_arms)*
}
}
}
};

// Convert into a TokenStream and return it
TokenStream::from(expanded)
}

fn generate_match_arms_uuid(name: &Ident, data_enum: &DataEnum) -> Vec<proc_macro2::TokenStream> {
data_enum
.variants
.iter()
.map(|variant| {
let variant_ident = &variant.ident;
// Match against each variant, ignoring any inner data
quote! {
#name::#variant_ident(inner) => inner.request_id.as_ref(),
}
})
.collect()
}
22 changes: 19 additions & 3 deletions citadel-internal-service-types/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use bytes::BytesMut;
use citadel_internal_service_macros::{IsError, IsNotification};
use citadel_internal_service_macros::{IsError, IsNotification, RequestId};
pub use citadel_types::prelude::{
ConnectMode, MemberState, MessageGroupKey, ObjectTransferStatus, SecBuffer, SecurityLevel,
SessionSecuritySettings, TransferType, UdpMode, UserIdentifier, VirtualObjectMetadata,
Expand Down Expand Up @@ -35,7 +35,9 @@ pub struct RegisterFailure {
}

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ServiceConnectionAccepted;
pub struct ServiceConnectionAccepted {
pub request_id: Option<Uuid>,
}

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MessageSendSuccess {
Expand Down Expand Up @@ -542,6 +544,7 @@ pub struct FileTransferRequestNotification {
pub cid: u64,
pub peer_cid: u64,
pub metadata: VirtualObjectMetadata,
pub request_id: Option<Uuid>,
}

#[derive(Serialize, Deserialize, Debug, Clone)]
Expand All @@ -559,9 +562,10 @@ pub struct FileTransferTickNotification {
pub cid: u64,
pub peer_cid: u64,
pub status: ObjectTransferStatus,
pub request_id: Option<Uuid>,
}

#[derive(Serialize, Deserialize, Debug, Clone, IsError, IsNotification)]
#[derive(Serialize, Deserialize, Debug, Clone, IsError, IsNotification, RequestId)]
pub enum InternalServiceResponse {
ConnectSuccess(ConnectSuccess),
ConnectFailure(ConnectFailure),
Expand Down Expand Up @@ -827,6 +831,7 @@ pub enum InternalServiceRequest {
pub struct SessionInformation {
pub cid: u64,
pub peer_connections: HashMap<u64, PeerSessionInformation>,
pub request_id: Option<Uuid>,
}

#[derive(Serialize, Deserialize, Clone, Debug)]
Expand All @@ -840,6 +845,7 @@ pub struct AccountInformation {
pub username: String,
pub full_name: String,
pub peers: HashMap<u64, PeerSessionInformation>,
pub request_id: Option<Uuid>,
}

#[derive(Serialize, Deserialize, Clone, Debug)]
Expand Down Expand Up @@ -889,4 +895,14 @@ mod tests {
assert!(!success_response.is_notification());
assert!(notification_response.is_notification());
}

#[test]
fn test_request_id_macro() {
let request_id = Uuid::new_v4();
let response = InternalServiceResponse::ConnectSuccess(ConnectSuccess {
cid: 0,
request_id: Some(request_id),
});
assert_eq!(response.request_id(), Some(&request_id));
}
}
6 changes: 5 additions & 1 deletion citadel-internal-service/src/kernel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ impl NetKernel for CitadelWorkspaceService {
cid: implicated_cid,
peer_cid,
metadata,
request_id: None,
},
);
send_response_to_tcp_client(&self.tcp_connection_map, response, uuid)
Expand Down Expand Up @@ -565,7 +566,9 @@ fn handle_connection(

let write_task = async move {
let response =
InternalServiceResponse::ServiceConnectionAccepted(ServiceConnectionAccepted);
InternalServiceResponse::ServiceConnectionAccepted(ServiceConnectionAccepted {
request_id: None,
});

if let Err(err) = sink_send_payload(response, &mut sink).await {
error!(target: "citadel", "Failed to send to client: {err:?}");
Expand Down Expand Up @@ -823,6 +826,7 @@ fn spawn_tick_updater(
cid: implicated_cid,
peer_cid,
status: status_message,
request_id: None,
},
);
match entry.send(message.clone()) {
Expand Down
7 changes: 5 additions & 2 deletions citadel-internal-service/src/kernel/request_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ pub async fn handle_request(
accounts_ret: &mut HashMap<u64, AccountInformation>,
account: CNACMetadata,
remote: &NodeRemote,
request_id: Uuid,
) {
let username = account.username.clone();
let full_name = account.full_name.clone();
Expand Down Expand Up @@ -143,6 +144,7 @@ pub async fn handle_request(
username,
full_name,
peers,
request_id: Some(request_id),
},
);
}
Expand All @@ -164,11 +166,11 @@ pub async fn handle_request(
if let Some(cid) = cid {
let account = filtered_accounts.into_iter().find(|r| r.cid == cid);
if let Some(account) = account {
add_account_to_map(&mut accounts_ret, account, remote).await;
add_account_to_map(&mut accounts_ret, account, remote, request_id).await;
}
} else {
for account in filtered_accounts {
add_account_to_map(&mut accounts_ret, account, remote).await;
add_account_to_map(&mut accounts_ret, account, remote, request_id).await;
}
}

Expand All @@ -188,6 +190,7 @@ pub async fn handle_request(
let mut session = SessionInformation {
cid: *cid,
peer_connections: HashMap::new(),
request_id: Some(request_id),
};
for (peer_cid, conn) in connection.peers.iter() {
session.peer_connections.insert(
Expand Down
1 change: 1 addition & 0 deletions citadel-internal-service/tests/file_transfer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ mod tests {
cid: _,
peer_cid: _,
status,
request_id: None,
},
) => match status {
ObjectTransferStatus::ReceptionBeginning(file_path, vfm) => {
Expand Down

0 comments on commit d1d5d39

Please sign in to comment.