From f6d60f85b44859ea007d2ea78a772dc9eb01bd4e Mon Sep 17 00:00:00 2001 From: philsippl Date: Fri, 27 Dec 2024 14:43:38 +0100 Subject: [PATCH] reuse events --- iris-mpc-gpu/src/server/actor.rs | 59 ++++++++++++++++---------------- 1 file changed, 30 insertions(+), 29 deletions(-) diff --git a/iris-mpc-gpu/src/server/actor.rs b/iris-mpc-gpu/src/server/actor.rs index cfb58f96f..3d377c21a 100644 --- a/iris-mpc-gpu/src/server/actor.rs +++ b/iris-mpc-gpu/src/server/actor.rs @@ -1111,13 +1111,13 @@ impl ServerActor { tracing::info!(party_id = self.party_id, "Finished batch deduplication"); // ---- END BATCH DEDUP ---- - // Create new initial events - let mut current_dot_event = self.device_manager.create_events(); - let mut next_dot_event = self.device_manager.create_events(); - let mut current_exchange_event = self.device_manager.create_events(); - let mut next_exchange_event = self.device_manager.create_events(); - let mut current_phase2_event = self.device_manager.create_events(); - let mut next_phase2_event = self.device_manager.create_events(); + // Create all needed events + let current_dot_events = vec![self.device_manager.create_events(); 2]; + let next_dot_events = vec![self.device_manager.create_events(); 2]; + let current_exchange_events = vec![self.device_manager.create_events(); 2]; + let next_exchange_events = vec![self.device_manager.create_events(); 2]; + let current_phase2_events = vec![self.device_manager.create_events(); 2]; + let next_phase2_events = vec![self.device_manager.create_events(); 2]; let chunk_sizes = |chunk_idx: usize| { self.current_db_sizes @@ -1180,11 +1180,11 @@ impl ServerActor { // First stream doesn't need to wait if db_chunk_idx == 0 { self.device_manager - .record_event(request_streams, ¤t_dot_event); + .record_event(request_streams, ¤t_dot_events[db_chunk_idx % 2]); self.device_manager - .record_event(request_streams, ¤t_exchange_event); + .record_event(request_streams, ¤t_exchange_events[db_chunk_idx % 2]); self.device_manager - .record_event(request_streams, ¤t_phase2_event); + .record_event(request_streams, ¤t_phase2_events[db_chunk_idx % 2]); } // Prefetch next chunk @@ -1214,7 +1214,7 @@ impl ServerActor { ); self.device_manager - .await_event(request_streams, ¤t_dot_event); + .await_event(request_streams, ¤t_dot_events[db_chunk_idx % 2]); // ---- START PHASE 1 ---- record_stream_time!(&self.device_manager, batch_streams, events, "db_dot", { @@ -1232,7 +1232,7 @@ impl ServerActor { // wait for the exchange result buffers to be ready self.device_manager - .await_event(request_streams, ¤t_exchange_event); + .await_event(request_streams, ¤t_exchange_events[db_chunk_idx % 2]); record_stream_time!( &self.device_manager, @@ -1253,7 +1253,7 @@ impl ServerActor { ); self.device_manager - .record_event(request_streams, &next_dot_event); + .record_event(request_streams, &next_dot_events[db_chunk_idx % 2]); record_stream_time!( &self.device_manager, @@ -1271,7 +1271,7 @@ impl ServerActor { // ---- END PHASE 1 ---- self.device_manager - .await_event(request_streams, ¤t_phase2_event); + .await_event(request_streams, ¤t_phase2_events[db_chunk_idx % 2]); // ---- START PHASE 2 ---- let max_chunk_size = dot_chunk_size.iter().max().copied().unwrap(); @@ -1304,7 +1304,7 @@ impl ServerActor { // code_dots/mask_dots which are just reinterpretations of the exchange result // buffers self.device_manager - .record_event(request_streams, &next_exchange_event); + .record_event(request_streams, &next_exchange_events[db_chunk_idx % 2]); let res = self.phase2.take_result_buffer(); record_stream_time!(&self.device_manager, request_streams, events, "db_open", { @@ -1325,23 +1325,10 @@ impl ServerActor { }); } self.device_manager - .record_event(request_streams, &next_phase2_event); + .record_event(request_streams, &next_phase2_events[db_chunk_idx % 2]); // ---- END PHASE 2 ---- - // Destroy events - self.device_manager.destroy_events(current_dot_event); - self.device_manager.destroy_events(current_exchange_event); - self.device_manager.destroy_events(current_phase2_event); - - // Update events for synchronization - current_dot_event = next_dot_event; - current_exchange_event = next_exchange_event; - current_phase2_event = next_phase2_event; - next_dot_event = self.device_manager.create_events(); - next_exchange_event = self.device_manager.create_events(); - next_phase2_event = self.device_manager.create_events(); - // Increment chunk index db_chunk_idx += 1; @@ -1358,6 +1345,20 @@ impl ServerActor { self.device_manager.await_streams(&self.streams[1]); tracing::info!(party_id = self.party_id, "db search finished"); + // Destroy all events + for events in [ + current_dot_events, + next_dot_events, + current_exchange_events, + next_exchange_events, + current_phase2_events, + next_phase2_events, + ] { + for event in events { + self.device_manager.destroy_events(event); + } + } + // Reset the results buffers for reuse for dst in &[&self.results, &self.batch_results, &self.final_results] { reset_slice(self.device_manager.devices(), dst, 0xff, &self.streams[0]);