From 3fd0b03572cc9a46f702d02ce6069fd4802c39f1 Mon Sep 17 00:00:00 2001 From: philsippl Date: Fri, 27 Dec 2024 15:18:16 +0100 Subject: [PATCH] init events once --- iris-mpc-gpu/src/server/actor.rs | 55 +++++++++++++------------------- 1 file changed, 23 insertions(+), 32 deletions(-) diff --git a/iris-mpc-gpu/src/server/actor.rs b/iris-mpc-gpu/src/server/actor.rs index 951ca0b55..1a31267a2 100644 --- a/iris-mpc-gpu/src/server/actor.rs +++ b/iris-mpc-gpu/src/server/actor.rs @@ -107,6 +107,9 @@ pub struct ServerActor { disable_persistence: bool, code_chunk_buffers: Vec, mask_chunk_buffers: Vec, + dot_events: Vec>, + exchange_events: Vec>, + phase2_events: Vec>, } const NON_MATCH_ID: u32 = u32::MAX; @@ -330,6 +333,11 @@ impl ServerActor { let code_chunk_buffers = vec![codes_engine.alloc_db_chunk_buffer(DB_CHUNK_SIZE); 2]; let mask_chunk_buffers = vec![masks_engine.alloc_db_chunk_buffer(DB_CHUNK_SIZE); 2]; + // Create all needed events + let dot_events = vec![device_manager.create_events(); 2]; + let exchange_events = vec![device_manager.create_events(); 2]; + let phase2_events = vec![device_manager.create_events(); 2]; + for dev in device_manager.devices() { dev.synchronize().unwrap(); } @@ -367,6 +375,9 @@ impl ServerActor { disable_persistence, code_chunk_buffers, mask_chunk_buffers, + dot_events, + exchange_events, + phase2_events, }) } @@ -1126,14 +1137,6 @@ impl ServerActor { tracing::info!(party_id = self.party_id, "Finished batch deduplication"); // ---- END BATCH DEDUP ---- - // Create all needed events - let current_dot_events = vec![self.device_manager.create_events(); 2]; - let next_dot_events = vec![self.device_manager.create_events(); 2]; - let current_exchange_events = vec![self.device_manager.create_events(); 2]; - let next_exchange_events = vec![self.device_manager.create_events(); 2]; - let current_phase2_events = vec![self.device_manager.create_events(); 2]; - let next_phase2_events = vec![self.device_manager.create_events(); 2]; - let chunk_sizes = |chunk_idx: usize| { self.current_db_sizes .iter() @@ -1195,11 +1198,11 @@ impl ServerActor { // First stream doesn't need to wait if db_chunk_idx == 0 { self.device_manager - .record_event(request_streams, ¤t_dot_events[db_chunk_idx % 2]); + .record_event(request_streams, &self.dot_events[db_chunk_idx % 2]); self.device_manager - .record_event(request_streams, ¤t_exchange_events[db_chunk_idx % 2]); + .record_event(request_streams, &self.exchange_events[db_chunk_idx % 2]); self.device_manager - .record_event(request_streams, ¤t_phase2_events[db_chunk_idx % 2]); + .record_event(request_streams, &self.phase2_events[db_chunk_idx % 2]); } // Prefetch next chunk @@ -1229,7 +1232,7 @@ impl ServerActor { ); self.device_manager - .await_event(request_streams, ¤t_dot_events[db_chunk_idx % 2]); + .await_event(request_streams, &self.dot_events[db_chunk_idx % 2]); // ---- START PHASE 1 ---- record_stream_time!(&self.device_manager, batch_streams, events, "db_dot", { @@ -1247,7 +1250,7 @@ impl ServerActor { // wait for the exchange result buffers to be ready self.device_manager - .await_event(request_streams, ¤t_exchange_events[db_chunk_idx % 2]); + .await_event(request_streams, &self.exchange_events[db_chunk_idx % 2]); record_stream_time!( &self.device_manager, @@ -1268,7 +1271,7 @@ impl ServerActor { ); self.device_manager - .record_event(request_streams, &next_dot_events[db_chunk_idx % 2]); + .record_event(request_streams, &self.dot_events[(db_chunk_idx + 1) % 2]); record_stream_time!( &self.device_manager, @@ -1286,7 +1289,7 @@ impl ServerActor { // ---- END PHASE 1 ---- self.device_manager - .await_event(request_streams, ¤t_phase2_events[db_chunk_idx % 2]); + .await_event(request_streams, &self.phase2_events[db_chunk_idx % 2]); // ---- START PHASE 2 ---- let max_chunk_size = dot_chunk_size.iter().max().copied().unwrap(); @@ -1318,8 +1321,10 @@ impl ServerActor { // we can now record the exchange event since the phase 2 is no longer using the // code_dots/mask_dots which are just reinterpretations of the exchange result // buffers - self.device_manager - .record_event(request_streams, &next_exchange_events[db_chunk_idx % 2]); + self.device_manager.record_event( + request_streams, + &self.exchange_events[(db_chunk_idx + 1) % 2], + ); let res = self.phase2.take_result_buffer(); record_stream_time!(&self.device_manager, request_streams, events, "db_open", { @@ -1340,7 +1345,7 @@ impl ServerActor { }); } self.device_manager - .record_event(request_streams, &next_phase2_events[db_chunk_idx % 2]); + .record_event(request_streams, &self.phase2_events[(db_chunk_idx + 1) % 2]); // ---- END PHASE 2 ---- @@ -1360,20 +1365,6 @@ impl ServerActor { self.device_manager.await_streams(&self.streams[1]); tracing::info!(party_id = self.party_id, "db search finished"); - // Destroy all events - // for events in [ - // current_dot_events, - // next_dot_events, - // current_exchange_events, - // next_exchange_events, - // current_phase2_events, - // next_phase2_events, - // ] { - // for event in events { - // self.device_manager.destroy_events(event); - // } - // } - // Reset the results buffers for reuse for dst in &[&self.results, &self.batch_results, &self.final_results] { reset_slice(self.device_manager.devices(), dst, 0xff, &self.streams[0]);