Skip to content

Commit

Permalink
init events once
Browse files Browse the repository at this point in the history
  • Loading branch information
philsippl authored and eaypek-tfh committed Dec 28, 2024
1 parent 6e16499 commit 3fd0b03
Showing 1 changed file with 23 additions and 32 deletions.
55 changes: 23 additions & 32 deletions iris-mpc-gpu/src/server/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ pub struct ServerActor {
disable_persistence: bool,
code_chunk_buffers: Vec<DBChunkBuffers>,
mask_chunk_buffers: Vec<DBChunkBuffers>,
dot_events: Vec<Vec<CUevent>>,
exchange_events: Vec<Vec<CUevent>>,
phase2_events: Vec<Vec<CUevent>>,
}

const NON_MATCH_ID: u32 = u32::MAX;
Expand Down Expand Up @@ -330,6 +333,11 @@ impl ServerActor {
let code_chunk_buffers = vec![codes_engine.alloc_db_chunk_buffer(DB_CHUNK_SIZE); 2];
let mask_chunk_buffers = vec![masks_engine.alloc_db_chunk_buffer(DB_CHUNK_SIZE); 2];

// Create all needed events
let dot_events = vec![device_manager.create_events(); 2];
let exchange_events = vec![device_manager.create_events(); 2];
let phase2_events = vec![device_manager.create_events(); 2];

for dev in device_manager.devices() {
dev.synchronize().unwrap();
}
Expand Down Expand Up @@ -367,6 +375,9 @@ impl ServerActor {
disable_persistence,
code_chunk_buffers,
mask_chunk_buffers,
dot_events,
exchange_events,
phase2_events,
})
}

Expand Down Expand Up @@ -1126,14 +1137,6 @@ impl ServerActor {
tracing::info!(party_id = self.party_id, "Finished batch deduplication");
// ---- END BATCH DEDUP ----

// Create all needed events
let current_dot_events = vec![self.device_manager.create_events(); 2];
let next_dot_events = vec![self.device_manager.create_events(); 2];
let current_exchange_events = vec![self.device_manager.create_events(); 2];
let next_exchange_events = vec![self.device_manager.create_events(); 2];
let current_phase2_events = vec![self.device_manager.create_events(); 2];
let next_phase2_events = vec![self.device_manager.create_events(); 2];

let chunk_sizes = |chunk_idx: usize| {
self.current_db_sizes
.iter()
Expand Down Expand Up @@ -1195,11 +1198,11 @@ impl ServerActor {
// First stream doesn't need to wait
if db_chunk_idx == 0 {
self.device_manager
.record_event(request_streams, &current_dot_events[db_chunk_idx % 2]);
.record_event(request_streams, &self.dot_events[db_chunk_idx % 2]);
self.device_manager
.record_event(request_streams, &current_exchange_events[db_chunk_idx % 2]);
.record_event(request_streams, &self.exchange_events[db_chunk_idx % 2]);
self.device_manager
.record_event(request_streams, &current_phase2_events[db_chunk_idx % 2]);
.record_event(request_streams, &self.phase2_events[db_chunk_idx % 2]);
}

// Prefetch next chunk
Expand Down Expand Up @@ -1229,7 +1232,7 @@ impl ServerActor {
);

self.device_manager
.await_event(request_streams, &current_dot_events[db_chunk_idx % 2]);
.await_event(request_streams, &self.dot_events[db_chunk_idx % 2]);

// ---- START PHASE 1 ----
record_stream_time!(&self.device_manager, batch_streams, events, "db_dot", {
Expand All @@ -1247,7 +1250,7 @@ impl ServerActor {

// wait for the exchange result buffers to be ready
self.device_manager
.await_event(request_streams, &current_exchange_events[db_chunk_idx % 2]);
.await_event(request_streams, &self.exchange_events[db_chunk_idx % 2]);

record_stream_time!(
&self.device_manager,
Expand All @@ -1268,7 +1271,7 @@ impl ServerActor {
);

self.device_manager
.record_event(request_streams, &next_dot_events[db_chunk_idx % 2]);
.record_event(request_streams, &self.dot_events[(db_chunk_idx + 1) % 2]);

record_stream_time!(
&self.device_manager,
Expand All @@ -1286,7 +1289,7 @@ impl ServerActor {
// ---- END PHASE 1 ----

self.device_manager
.await_event(request_streams, &current_phase2_events[db_chunk_idx % 2]);
.await_event(request_streams, &self.phase2_events[db_chunk_idx % 2]);

// ---- START PHASE 2 ----
let max_chunk_size = dot_chunk_size.iter().max().copied().unwrap();
Expand Down Expand Up @@ -1318,8 +1321,10 @@ impl ServerActor {
// we can now record the exchange event since the phase 2 is no longer using the
// code_dots/mask_dots which are just reinterpretations of the exchange result
// buffers
self.device_manager
.record_event(request_streams, &next_exchange_events[db_chunk_idx % 2]);
self.device_manager.record_event(
request_streams,
&self.exchange_events[(db_chunk_idx + 1) % 2],
);

let res = self.phase2.take_result_buffer();
record_stream_time!(&self.device_manager, request_streams, events, "db_open", {
Expand All @@ -1340,7 +1345,7 @@ impl ServerActor {
});
}
self.device_manager
.record_event(request_streams, &next_phase2_events[db_chunk_idx % 2]);
.record_event(request_streams, &self.phase2_events[(db_chunk_idx + 1) % 2]);

// ---- END PHASE 2 ----

Expand All @@ -1360,20 +1365,6 @@ impl ServerActor {
self.device_manager.await_streams(&self.streams[1]);
tracing::info!(party_id = self.party_id, "db search finished");

// Destroy all events
// for events in [
// current_dot_events,
// next_dot_events,
// current_exchange_events,
// next_exchange_events,
// current_phase2_events,
// next_phase2_events,
// ] {
// for event in events {
// self.device_manager.destroy_events(event);
// }
// }

// Reset the results buffers for reuse
for dst in &[&self.results, &self.batch_results, &self.final_results] {
reset_slice(self.device_manager.devices(), dst, 0xff, &self.streams[0]);
Expand Down

0 comments on commit 3fd0b03

Please sign in to comment.