Skip to content

Commit

Permalink
reuse events
Browse files Browse the repository at this point in the history
  • Loading branch information
philsippl committed Dec 27, 2024
1 parent e699756 commit f6d60f8
Showing 1 changed file with 30 additions and 29 deletions.
59 changes: 30 additions & 29 deletions iris-mpc-gpu/src/server/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1111,13 +1111,13 @@ impl ServerActor {
tracing::info!(party_id = self.party_id, "Finished batch deduplication");
// ---- END BATCH DEDUP ----

// Create new initial events
let mut current_dot_event = self.device_manager.create_events();
let mut next_dot_event = self.device_manager.create_events();
let mut current_exchange_event = self.device_manager.create_events();
let mut next_exchange_event = self.device_manager.create_events();
let mut current_phase2_event = self.device_manager.create_events();
let mut next_phase2_event = self.device_manager.create_events();
// Create all needed events
let current_dot_events = vec![self.device_manager.create_events(); 2];
let next_dot_events = vec![self.device_manager.create_events(); 2];
let current_exchange_events = vec![self.device_manager.create_events(); 2];
let next_exchange_events = vec![self.device_manager.create_events(); 2];
let current_phase2_events = vec![self.device_manager.create_events(); 2];
let next_phase2_events = vec![self.device_manager.create_events(); 2];

let chunk_sizes = |chunk_idx: usize| {
self.current_db_sizes
Expand Down Expand Up @@ -1180,11 +1180,11 @@ impl ServerActor {
// First stream doesn't need to wait
if db_chunk_idx == 0 {
self.device_manager
.record_event(request_streams, &current_dot_event);
.record_event(request_streams, &current_dot_events[db_chunk_idx % 2]);
self.device_manager
.record_event(request_streams, &current_exchange_event);
.record_event(request_streams, &current_exchange_events[db_chunk_idx % 2]);
self.device_manager
.record_event(request_streams, &current_phase2_event);
.record_event(request_streams, &current_phase2_events[db_chunk_idx % 2]);
}

// Prefetch next chunk
Expand Down Expand Up @@ -1214,7 +1214,7 @@ impl ServerActor {
);

self.device_manager
.await_event(request_streams, &current_dot_event);
.await_event(request_streams, &current_dot_events[db_chunk_idx % 2]);

// ---- START PHASE 1 ----
record_stream_time!(&self.device_manager, batch_streams, events, "db_dot", {
Expand All @@ -1232,7 +1232,7 @@ impl ServerActor {

// wait for the exchange result buffers to be ready
self.device_manager
.await_event(request_streams, &current_exchange_event);
.await_event(request_streams, &current_exchange_events[db_chunk_idx % 2]);

record_stream_time!(
&self.device_manager,
Expand All @@ -1253,7 +1253,7 @@ impl ServerActor {
);

self.device_manager
.record_event(request_streams, &next_dot_event);
.record_event(request_streams, &next_dot_events[db_chunk_idx % 2]);

record_stream_time!(
&self.device_manager,
Expand All @@ -1271,7 +1271,7 @@ impl ServerActor {
// ---- END PHASE 1 ----

self.device_manager
.await_event(request_streams, &current_phase2_event);
.await_event(request_streams, &current_phase2_events[db_chunk_idx % 2]);

// ---- START PHASE 2 ----
let max_chunk_size = dot_chunk_size.iter().max().copied().unwrap();
Expand Down Expand Up @@ -1304,7 +1304,7 @@ impl ServerActor {
// code_dots/mask_dots which are just reinterpretations of the exchange result
// buffers
self.device_manager
.record_event(request_streams, &next_exchange_event);
.record_event(request_streams, &next_exchange_events[db_chunk_idx % 2]);

let res = self.phase2.take_result_buffer();
record_stream_time!(&self.device_manager, request_streams, events, "db_open", {
Expand All @@ -1325,23 +1325,10 @@ impl ServerActor {
});
}
self.device_manager
.record_event(request_streams, &next_phase2_event);
.record_event(request_streams, &next_phase2_events[db_chunk_idx % 2]);

// ---- END PHASE 2 ----

// Destroy events
self.device_manager.destroy_events(current_dot_event);
self.device_manager.destroy_events(current_exchange_event);
self.device_manager.destroy_events(current_phase2_event);

// Update events for synchronization
current_dot_event = next_dot_event;
current_exchange_event = next_exchange_event;
current_phase2_event = next_phase2_event;
next_dot_event = self.device_manager.create_events();
next_exchange_event = self.device_manager.create_events();
next_phase2_event = self.device_manager.create_events();

// Increment chunk index
db_chunk_idx += 1;

Expand All @@ -1358,6 +1345,20 @@ impl ServerActor {
self.device_manager.await_streams(&self.streams[1]);
tracing::info!(party_id = self.party_id, "db search finished");

// Destroy all events
for events in [
current_dot_events,
next_dot_events,
current_exchange_events,
next_exchange_events,
current_phase2_events,
next_phase2_events,
] {
for event in events {
self.device_manager.destroy_events(event);
}
}

// Reset the results buffers for reuse
for dst in &[&self.results, &self.batch_results, &self.final_results] {
reset_slice(self.device_manager.devices(), dst, 0xff, &self.streams[0]);
Expand Down

0 comments on commit f6d60f8

Please sign in to comment.