Skip to content

Commit

Permalink
init events once
Browse files Browse the repository at this point in the history
  • Loading branch information
philsippl committed Dec 27, 2024
1 parent d7eab97 commit 5b1c039
Showing 1 changed file with 23 additions and 32 deletions.
55 changes: 23 additions & 32 deletions iris-mpc-gpu/src/server/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ pub struct ServerActor {
disable_persistence: bool,
code_chunk_buffers: Vec<DBChunkBuffers>,
mask_chunk_buffers: Vec<DBChunkBuffers>,
dot_events: Vec<Vec<CUevent>>,
exchange_events: Vec<Vec<CUevent>>,
phase2_events: Vec<Vec<CUevent>>,
}

const NON_MATCH_ID: u32 = u32::MAX;
Expand Down Expand Up @@ -326,6 +329,11 @@ impl ServerActor {
let code_chunk_buffers = vec![codes_engine.alloc_db_chunk_buffer(DB_CHUNK_SIZE); 2];
let mask_chunk_buffers = vec![masks_engine.alloc_db_chunk_buffer(DB_CHUNK_SIZE); 2];

// Create all needed events
let dot_events = vec![device_manager.create_events(); 2];
let exchange_events = vec![device_manager.create_events(); 2];
let phase2_events = vec![device_manager.create_events(); 2];

for dev in device_manager.devices() {
dev.synchronize().unwrap();
}
Expand Down Expand Up @@ -363,6 +371,9 @@ impl ServerActor {
disable_persistence,
code_chunk_buffers,
mask_chunk_buffers,
dot_events,
exchange_events,
phase2_events,
})
}

Expand Down Expand Up @@ -1111,14 +1122,6 @@ impl ServerActor {
tracing::info!(party_id = self.party_id, "Finished batch deduplication");
// ---- END BATCH DEDUP ----

// Create all needed events
let current_dot_events = vec![self.device_manager.create_events(); 2];
let next_dot_events = vec![self.device_manager.create_events(); 2];
let current_exchange_events = vec![self.device_manager.create_events(); 2];
let next_exchange_events = vec![self.device_manager.create_events(); 2];
let current_phase2_events = vec![self.device_manager.create_events(); 2];
let next_phase2_events = vec![self.device_manager.create_events(); 2];

let chunk_sizes = |chunk_idx: usize| {
self.current_db_sizes
.iter()
Expand Down Expand Up @@ -1180,11 +1183,11 @@ impl ServerActor {
// First stream doesn't need to wait
if db_chunk_idx == 0 {
self.device_manager
.record_event(request_streams, &current_dot_events[db_chunk_idx % 2]);
.record_event(request_streams, &self.dot_events[db_chunk_idx % 2]);
self.device_manager
.record_event(request_streams, &current_exchange_events[db_chunk_idx % 2]);
.record_event(request_streams, &self.exchange_events[db_chunk_idx % 2]);
self.device_manager
.record_event(request_streams, &current_phase2_events[db_chunk_idx % 2]);
.record_event(request_streams, &self.phase2_events[db_chunk_idx % 2]);
}

// Prefetch next chunk
Expand Down Expand Up @@ -1214,7 +1217,7 @@ impl ServerActor {
);

self.device_manager
.await_event(request_streams, &current_dot_events[db_chunk_idx % 2]);
.await_event(request_streams, &self.dot_events[db_chunk_idx % 2]);

// ---- START PHASE 1 ----
record_stream_time!(&self.device_manager, batch_streams, events, "db_dot", {
Expand All @@ -1232,7 +1235,7 @@ impl ServerActor {

// wait for the exchange result buffers to be ready
self.device_manager
.await_event(request_streams, &current_exchange_events[db_chunk_idx % 2]);
.await_event(request_streams, &self.exchange_events[db_chunk_idx % 2]);

record_stream_time!(
&self.device_manager,
Expand All @@ -1253,7 +1256,7 @@ impl ServerActor {
);

self.device_manager
.record_event(request_streams, &next_dot_events[db_chunk_idx % 2]);
.record_event(request_streams, &self.dot_events[(db_chunk_idx + 1) % 2]);

record_stream_time!(
&self.device_manager,
Expand All @@ -1271,7 +1274,7 @@ impl ServerActor {
// ---- END PHASE 1 ----

self.device_manager
.await_event(request_streams, &current_phase2_events[db_chunk_idx % 2]);
.await_event(request_streams, &self.phase2_events[db_chunk_idx % 2]);

// ---- START PHASE 2 ----
let max_chunk_size = dot_chunk_size.iter().max().copied().unwrap();
Expand Down Expand Up @@ -1303,8 +1306,10 @@ impl ServerActor {
// we can now record the exchange event since the phase 2 is no longer using the
// code_dots/mask_dots which are just reinterpretations of the exchange result
// buffers
self.device_manager
.record_event(request_streams, &next_exchange_events[db_chunk_idx % 2]);
self.device_manager.record_event(
request_streams,
&self.exchange_events[(db_chunk_idx + 1) % 2],
);

let res = self.phase2.take_result_buffer();
record_stream_time!(&self.device_manager, request_streams, events, "db_open", {
Expand All @@ -1325,7 +1330,7 @@ impl ServerActor {
});
}
self.device_manager
.record_event(request_streams, &next_phase2_events[db_chunk_idx % 2]);
.record_event(request_streams, &self.phase2_events[(db_chunk_idx + 1) % 2]);

// ---- END PHASE 2 ----

Expand All @@ -1345,20 +1350,6 @@ impl ServerActor {
self.device_manager.await_streams(&self.streams[1]);
tracing::info!(party_id = self.party_id, "db search finished");

// Destroy all events
// for events in [
// current_dot_events,
// next_dot_events,
// current_exchange_events,
// next_exchange_events,
// current_phase2_events,
// next_phase2_events,
// ] {
// for event in events {
// self.device_manager.destroy_events(event);
// }
// }

// Reset the results buffers for reuse
for dst in &[&self.results, &self.batch_results, &self.final_results] {
reset_slice(self.device_manager.devices(), dst, 0xff, &self.streams[0]);
Expand Down

0 comments on commit 5b1c039

Please sign in to comment.