Skip to content

Commit

Permalink
reuse events (#861)
Browse files Browse the repository at this point in the history
* reuse events

* assign context

* dbg

* init events once

* trigger image build

* deploy test image to stage

---------

Co-authored-by: Ertugrul Aypek <ertugrul.aypek@toolsforhumanity.com>
  • Loading branch information
philsippl and eaypek-tfh authored Dec 28, 2024
1 parent 6d5b2f2 commit ba9d5fc
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 35 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/temp-branch-build-and-push.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Branch - Build and push docker image
on:
push:
branches:
- "ps/host-mem-alloc"
- "ps/reuse-events"

concurrency:
group: '${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}'
Expand Down
2 changes: 1 addition & 1 deletion deploy/stage/common-values-iris-mpc.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
image: "ghcr.io/worldcoin/iris-mpc:6b358589c25ef528f58ba02980103670b037a614"
image: "ghcr.io/worldcoin/iris-mpc:v0.13.6"

environment: stage
replicaCount: 1
Expand Down
5 changes: 3 additions & 2 deletions iris-mpc-gpu/src/helpers/device_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,9 @@ impl DeviceManager {
}

pub fn destroy_events(&self, events: Vec<CUevent>) {
for event in events {
unsafe { event::destroy(event).unwrap() };
for (device_idx, event) in events.iter().enumerate() {
self.device(device_idx).bind_to_thread().unwrap();
unsafe { event::destroy(*event).unwrap() };
}
}

Expand Down
54 changes: 23 additions & 31 deletions iris-mpc-gpu/src/server/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ pub struct ServerActor {
disable_persistence: bool,
code_chunk_buffers: Vec<DBChunkBuffers>,
mask_chunk_buffers: Vec<DBChunkBuffers>,
dot_events: Vec<Vec<CUevent>>,
exchange_events: Vec<Vec<CUevent>>,
phase2_events: Vec<Vec<CUevent>>,
}

const NON_MATCH_ID: u32 = u32::MAX;
Expand Down Expand Up @@ -330,6 +333,11 @@ impl ServerActor {
let code_chunk_buffers = vec![codes_engine.alloc_db_chunk_buffer(DB_CHUNK_SIZE); 2];
let mask_chunk_buffers = vec![masks_engine.alloc_db_chunk_buffer(DB_CHUNK_SIZE); 2];

// Create all needed events
let dot_events = vec![device_manager.create_events(); 2];
let exchange_events = vec![device_manager.create_events(); 2];
let phase2_events = vec![device_manager.create_events(); 2];

for dev in device_manager.devices() {
dev.synchronize().unwrap();
}
Expand Down Expand Up @@ -367,6 +375,9 @@ impl ServerActor {
disable_persistence,
code_chunk_buffers,
mask_chunk_buffers,
dot_events,
exchange_events,
phase2_events,
})
}

Expand Down Expand Up @@ -1126,14 +1137,6 @@ impl ServerActor {
tracing::info!(party_id = self.party_id, "Finished batch deduplication");
// ---- END BATCH DEDUP ----

// Create new initial events
let mut current_dot_event = self.device_manager.create_events();
let mut next_dot_event = self.device_manager.create_events();
let mut current_exchange_event = self.device_manager.create_events();
let mut next_exchange_event = self.device_manager.create_events();
let mut current_phase2_event = self.device_manager.create_events();
let mut next_phase2_event = self.device_manager.create_events();

let chunk_sizes = |chunk_idx: usize| {
self.current_db_sizes
.iter()
Expand Down Expand Up @@ -1195,11 +1198,11 @@ impl ServerActor {
// First stream doesn't need to wait
if db_chunk_idx == 0 {
self.device_manager
.record_event(request_streams, &current_dot_event);
.record_event(request_streams, &self.dot_events[db_chunk_idx % 2]);
self.device_manager
.record_event(request_streams, &current_exchange_event);
.record_event(request_streams, &self.exchange_events[db_chunk_idx % 2]);
self.device_manager
.record_event(request_streams, &current_phase2_event);
.record_event(request_streams, &self.phase2_events[db_chunk_idx % 2]);
}

// Prefetch next chunk
Expand Down Expand Up @@ -1229,7 +1232,7 @@ impl ServerActor {
);

self.device_manager
.await_event(request_streams, &current_dot_event);
.await_event(request_streams, &self.dot_events[db_chunk_idx % 2]);

// ---- START PHASE 1 ----
record_stream_time!(&self.device_manager, batch_streams, events, "db_dot", {
Expand All @@ -1247,7 +1250,7 @@ impl ServerActor {

// wait for the exchange result buffers to be ready
self.device_manager
.await_event(request_streams, &current_exchange_event);
.await_event(request_streams, &self.exchange_events[db_chunk_idx % 2]);

record_stream_time!(
&self.device_manager,
Expand All @@ -1268,7 +1271,7 @@ impl ServerActor {
);

self.device_manager
.record_event(request_streams, &next_dot_event);
.record_event(request_streams, &self.dot_events[(db_chunk_idx + 1) % 2]);

record_stream_time!(
&self.device_manager,
Expand All @@ -1286,7 +1289,7 @@ impl ServerActor {
// ---- END PHASE 1 ----

self.device_manager
.await_event(request_streams, &current_phase2_event);
.await_event(request_streams, &self.phase2_events[db_chunk_idx % 2]);

// ---- START PHASE 2 ----
let max_chunk_size = dot_chunk_size.iter().max().copied().unwrap();
Expand Down Expand Up @@ -1318,8 +1321,10 @@ impl ServerActor {
// we can now record the exchange event since the phase 2 is no longer using the
// code_dots/mask_dots which are just reinterpretations of the exchange result
// buffers
self.device_manager
.record_event(request_streams, &next_exchange_event);
self.device_manager.record_event(
request_streams,
&self.exchange_events[(db_chunk_idx + 1) % 2],
);

let res = self.phase2.take_result_buffer();
record_stream_time!(&self.device_manager, request_streams, events, "db_open", {
Expand All @@ -1340,23 +1345,10 @@ impl ServerActor {
});
}
self.device_manager
.record_event(request_streams, &next_phase2_event);
.record_event(request_streams, &self.phase2_events[(db_chunk_idx + 1) % 2]);

// ---- END PHASE 2 ----

// Destroy events
self.device_manager.destroy_events(current_dot_event);
self.device_manager.destroy_events(current_exchange_event);
self.device_manager.destroy_events(current_phase2_event);

// Update events for synchronization
current_dot_event = next_dot_event;
current_exchange_event = next_exchange_event;
current_phase2_event = next_phase2_event;
next_dot_event = self.device_manager.create_events();
next_exchange_event = self.device_manager.create_events();
next_phase2_event = self.device_manager.create_events();

// Increment chunk index
db_chunk_idx += 1;

Expand Down

0 comments on commit ba9d5fc

Please sign in to comment.