diff --git a/.github/workflows/temp-branch-build-and-push.yaml b/.github/workflows/temp-branch-build-and-push.yaml index 678d37592..696d980ff 100644 --- a/.github/workflows/temp-branch-build-and-push.yaml +++ b/.github/workflows/temp-branch-build-and-push.yaml @@ -3,7 +3,7 @@ name: Branch - Build and push docker image on: push: branches: - - "ps/host-mem-alloc" + - "ps/reuse-events" concurrency: group: '${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}' diff --git a/deploy/stage/common-values-iris-mpc.yaml b/deploy/stage/common-values-iris-mpc.yaml index 39e7a9b19..276b120a1 100644 --- a/deploy/stage/common-values-iris-mpc.yaml +++ b/deploy/stage/common-values-iris-mpc.yaml @@ -1,4 +1,4 @@ -image: "ghcr.io/worldcoin/iris-mpc:6b358589c25ef528f58ba02980103670b037a614" +image: "ghcr.io/worldcoin/iris-mpc:v0.13.6" environment: stage replicaCount: 1 diff --git a/iris-mpc-gpu/src/helpers/device_manager.rs b/iris-mpc-gpu/src/helpers/device_manager.rs index fe3f7563f..4053324a4 100644 --- a/iris-mpc-gpu/src/helpers/device_manager.rs +++ b/iris-mpc-gpu/src/helpers/device_manager.rs @@ -103,8 +103,9 @@ impl DeviceManager { } pub fn destroy_events(&self, events: Vec) { - for event in events { - unsafe { event::destroy(event).unwrap() }; + for (device_idx, event) in events.iter().enumerate() { + self.device(device_idx).bind_to_thread().unwrap(); + unsafe { event::destroy(*event).unwrap() }; } } diff --git a/iris-mpc-gpu/src/server/actor.rs b/iris-mpc-gpu/src/server/actor.rs index 5dd0d0d61..1a31267a2 100644 --- a/iris-mpc-gpu/src/server/actor.rs +++ b/iris-mpc-gpu/src/server/actor.rs @@ -107,6 +107,9 @@ pub struct ServerActor { disable_persistence: bool, code_chunk_buffers: Vec, mask_chunk_buffers: Vec, + dot_events: Vec>, + exchange_events: Vec>, + phase2_events: Vec>, } const NON_MATCH_ID: u32 = u32::MAX; @@ -330,6 +333,11 @@ impl ServerActor { let code_chunk_buffers = vec![codes_engine.alloc_db_chunk_buffer(DB_CHUNK_SIZE); 2]; let mask_chunk_buffers = vec![masks_engine.alloc_db_chunk_buffer(DB_CHUNK_SIZE); 2]; + // Create all needed events + let dot_events = vec![device_manager.create_events(); 2]; + let exchange_events = vec![device_manager.create_events(); 2]; + let phase2_events = vec![device_manager.create_events(); 2]; + for dev in device_manager.devices() { dev.synchronize().unwrap(); } @@ -367,6 +375,9 @@ impl ServerActor { disable_persistence, code_chunk_buffers, mask_chunk_buffers, + dot_events, + exchange_events, + phase2_events, }) } @@ -1126,14 +1137,6 @@ impl ServerActor { tracing::info!(party_id = self.party_id, "Finished batch deduplication"); // ---- END BATCH DEDUP ---- - // Create new initial events - let mut current_dot_event = self.device_manager.create_events(); - let mut next_dot_event = self.device_manager.create_events(); - let mut current_exchange_event = self.device_manager.create_events(); - let mut next_exchange_event = self.device_manager.create_events(); - let mut current_phase2_event = self.device_manager.create_events(); - let mut next_phase2_event = self.device_manager.create_events(); - let chunk_sizes = |chunk_idx: usize| { self.current_db_sizes .iter() @@ -1195,11 +1198,11 @@ impl ServerActor { // First stream doesn't need to wait if db_chunk_idx == 0 { self.device_manager - .record_event(request_streams, ¤t_dot_event); + .record_event(request_streams, &self.dot_events[db_chunk_idx % 2]); self.device_manager - .record_event(request_streams, ¤t_exchange_event); + .record_event(request_streams, &self.exchange_events[db_chunk_idx % 2]); self.device_manager - .record_event(request_streams, ¤t_phase2_event); + .record_event(request_streams, &self.phase2_events[db_chunk_idx % 2]); } // Prefetch next chunk @@ -1229,7 +1232,7 @@ impl ServerActor { ); self.device_manager - .await_event(request_streams, ¤t_dot_event); + .await_event(request_streams, &self.dot_events[db_chunk_idx % 2]); // ---- START PHASE 1 ---- record_stream_time!(&self.device_manager, batch_streams, events, "db_dot", { @@ -1247,7 +1250,7 @@ impl ServerActor { // wait for the exchange result buffers to be ready self.device_manager - .await_event(request_streams, ¤t_exchange_event); + .await_event(request_streams, &self.exchange_events[db_chunk_idx % 2]); record_stream_time!( &self.device_manager, @@ -1268,7 +1271,7 @@ impl ServerActor { ); self.device_manager - .record_event(request_streams, &next_dot_event); + .record_event(request_streams, &self.dot_events[(db_chunk_idx + 1) % 2]); record_stream_time!( &self.device_manager, @@ -1286,7 +1289,7 @@ impl ServerActor { // ---- END PHASE 1 ---- self.device_manager - .await_event(request_streams, ¤t_phase2_event); + .await_event(request_streams, &self.phase2_events[db_chunk_idx % 2]); // ---- START PHASE 2 ---- let max_chunk_size = dot_chunk_size.iter().max().copied().unwrap(); @@ -1318,8 +1321,10 @@ impl ServerActor { // we can now record the exchange event since the phase 2 is no longer using the // code_dots/mask_dots which are just reinterpretations of the exchange result // buffers - self.device_manager - .record_event(request_streams, &next_exchange_event); + self.device_manager.record_event( + request_streams, + &self.exchange_events[(db_chunk_idx + 1) % 2], + ); let res = self.phase2.take_result_buffer(); record_stream_time!(&self.device_manager, request_streams, events, "db_open", { @@ -1340,23 +1345,10 @@ impl ServerActor { }); } self.device_manager - .record_event(request_streams, &next_phase2_event); + .record_event(request_streams, &self.phase2_events[(db_chunk_idx + 1) % 2]); // ---- END PHASE 2 ---- - // Destroy events - self.device_manager.destroy_events(current_dot_event); - self.device_manager.destroy_events(current_exchange_event); - self.device_manager.destroy_events(current_phase2_event); - - // Update events for synchronization - current_dot_event = next_dot_event; - current_exchange_event = next_exchange_event; - current_phase2_event = next_phase2_event; - next_dot_event = self.device_manager.create_events(); - next_exchange_event = self.device_manager.create_events(); - next_phase2_event = self.device_manager.create_events(); - // Increment chunk index db_chunk_idx += 1;