Skip to content

Commit

Permalink
allow disabling of stream timers (#868)
Browse files Browse the repository at this point in the history
* allow disabling of stream timers

* bump stage image

---------

Co-authored-by: Ertugrul Aypek <ertugrul.aypek@toolsforhumanity.com>
  • Loading branch information
philsippl and eaypek-tfh authored Jan 3, 2025
1 parent 884e2a9 commit cdcbcea
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 58 deletions.
2 changes: 1 addition & 1 deletion deploy/stage/common-values-iris-mpc.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
image: "ghcr.io/worldcoin/iris-mpc:v0.13.6"
image: "ghcr.io/worldcoin/iris-mpc:v0.13.7"

environment: stage
replicaCount: 1
Expand Down
3 changes: 3 additions & 0 deletions iris-mpc-common/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ pub struct Config {
#[serde(default)]
pub disable_persistence: bool,

#[serde(default)]
pub enable_debug_timing: bool,

#[serde(default, deserialize_with = "deserialize_yaml_json_string")]
pub node_hostnames: Vec<String>,

Expand Down
164 changes: 107 additions & 57 deletions iris-mpc-gpu/src/server/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,18 @@ use std::{collections::HashMap, mem, sync::Arc, time::Instant};
use tokio::sync::{mpsc, oneshot};

macro_rules! record_stream_time {
($manager:expr, $streams:expr, $map:expr, $label:expr, $block:block) => {{
let evt0 = $manager.create_events();
let evt1 = $manager.create_events();
$manager.record_event($streams, &evt0);
let res = $block;
$manager.record_event($streams, &evt1);
$map.entry($label).or_default().extend(vec![evt0, evt1]);
res
($manager:expr, $streams:expr, $map:expr, $label:expr, $enable_timing:expr, $block:block) => {{
if $enable_timing {
let evt0 = $manager.create_events();
let evt1 = $manager.create_events();
$manager.record_event($streams, &evt0);
let res = $block;
$manager.record_event($streams, &evt1);
$map.entry($label).or_default().extend(vec![evt0, evt1]);
res
} else {
$block
}
}};
}

Expand Down Expand Up @@ -105,6 +109,7 @@ pub struct ServerActor {
max_db_size: usize,
return_partial_results: bool,
disable_persistence: bool,
enable_debug_timing: bool,
code_chunk_buffers: Vec<DBChunkBuffers>,
mask_chunk_buffers: Vec<DBChunkBuffers>,
dot_events: Vec<Vec<CUevent>>,
Expand All @@ -124,6 +129,7 @@ impl ServerActor {
max_batch_size: usize,
return_partial_results: bool,
disable_persistence: bool,
enable_debug_timing: bool,
) -> eyre::Result<(Self, ServerActorHandle)> {
let device_manager = Arc::new(DeviceManager::init());
Self::new_with_device_manager(
Expand All @@ -135,6 +141,7 @@ impl ServerActor {
max_batch_size,
return_partial_results,
disable_persistence,
enable_debug_timing,
)
}
#[allow(clippy::too_many_arguments)]
Expand All @@ -147,6 +154,7 @@ impl ServerActor {
max_batch_size: usize,
return_partial_results: bool,
disable_persistence: bool,
enable_debug_timing: bool,
) -> eyre::Result<(Self, ServerActorHandle)> {
let ids = device_manager.get_ids_from_magic(0);
let comms = device_manager.instantiate_network_from_ids(party_id, &ids)?;
Expand All @@ -160,6 +168,7 @@ impl ServerActor {
max_batch_size,
return_partial_results,
disable_persistence,
enable_debug_timing,
)
}

Expand All @@ -174,6 +183,7 @@ impl ServerActor {
max_batch_size: usize,
return_partial_results: bool,
disable_persistence: bool,
enable_debug_timing: bool,
) -> eyre::Result<(Self, ServerActorHandle)> {
let (tx, rx) = mpsc::channel(job_queue_size);
let actor = Self::init(
Expand All @@ -186,6 +196,7 @@ impl ServerActor {
max_batch_size,
return_partial_results,
disable_persistence,
enable_debug_timing,
)?;
Ok((actor, ServerActorHandle { job_queue: tx }))
}
Expand All @@ -201,6 +212,7 @@ impl ServerActor {
max_batch_size: usize,
return_partial_results: bool,
disable_persistence: bool,
enable_debug_timing: bool,
) -> eyre::Result<Self> {
assert!(max_batch_size != 0);
let mut kdf_nonce = 0;
Expand Down Expand Up @@ -373,6 +385,7 @@ impl ServerActor {
max_db_size,
return_partial_results,
disable_persistence,
enable_debug_timing,
code_chunk_buffers,
mask_chunk_buffers,
dot_events,
Expand Down Expand Up @@ -616,6 +629,7 @@ impl ServerActor {
&self.streams[0],
events,
"query_preprocess",
self.enable_debug_timing,
{
// This needs to be max_batch_size, even though the query can be shorter to have
// enough padding for GEMM
Expand Down Expand Up @@ -662,6 +676,7 @@ impl ServerActor {
&self.streams[0],
events,
"query_preprocess",
self.enable_debug_timing,
{
// This needs to be MAX_BATCH_SIZE, even though the query can be shorter to have
// enough padding for GEMM
Expand Down Expand Up @@ -903,6 +918,7 @@ impl ServerActor {
&self.streams[0],
events,
"db_write",
self.enable_debug_timing,
{
for i in 0..self.device_manager.device_count() {
self.device_manager.device(i).bind_to_thread().unwrap();
Expand Down Expand Up @@ -987,7 +1003,9 @@ impl ServerActor {
);

// ---- END RESULT PROCESSING ----
log_timers(events);
if self.enable_debug_timing {
log_timers(events);
}
let processed_mil_elements_per_second = (self.max_batch_size * previous_total_db_size)
as f64
/ now.elapsed().as_secs_f64()
Expand Down Expand Up @@ -1055,34 +1073,42 @@ impl ServerActor {
// ---- START BATCH DEDUP ----
tracing::info!(party_id = self.party_id, "Starting batch deduplication");

record_stream_time!(&self.device_manager, batch_streams, events, "batch_dot", {
tracing::info!(party_id = self.party_id, "batch_dot start");
record_stream_time!(
&self.device_manager,
batch_streams,
events,
"batch_dot",
self.enable_debug_timing,
{
tracing::info!(party_id = self.party_id, "batch_dot start");

compact_device_queries.compute_dot_products(
&mut self.batch_codes_engine,
&mut self.batch_masks_engine,
&self.query_db_size,
0,
batch_streams,
batch_cublas,
);
tracing::info!(party_id = self.party_id, "compute_dot_reducers start");
compact_device_queries.compute_dot_products(
&mut self.batch_codes_engine,
&mut self.batch_masks_engine,
&self.query_db_size,
0,
batch_streams,
batch_cublas,
);
tracing::info!(party_id = self.party_id, "compute_dot_reducers start");

compact_device_sums.compute_dot_reducers(
&mut self.batch_codes_engine,
&mut self.batch_masks_engine,
&self.query_db_size,
0,
batch_streams,
);
tracing::info!(party_id = self.party_id, "batch_dot end");
});
compact_device_sums.compute_dot_reducers(
&mut self.batch_codes_engine,
&mut self.batch_masks_engine,
&self.query_db_size,
0,
batch_streams,
);
tracing::info!(party_id = self.party_id, "batch_dot end");
}
);

record_stream_time!(
&self.device_manager,
batch_streams,
events,
"batch_reshare",
self.enable_debug_timing,
{
tracing::info!(party_id = self.party_id, "batch_reshare start");
self.batch_codes_engine
Expand All @@ -1104,6 +1130,7 @@ impl ServerActor {
batch_streams,
events,
"batch_threshold",
self.enable_debug_timing,
{
tracing::info!(party_id = self.party_id, "batch_threshold start");
self.phase2_batch.compare_threshold_masked_many(
Expand Down Expand Up @@ -1149,6 +1176,7 @@ impl ServerActor {
&self.streams[0],
events,
"prefetch_db_chunk",
self.enable_debug_timing,
{
self.codes_engine.prefetch_db_chunk(
code_db_slices,
Expand Down Expand Up @@ -1211,6 +1239,7 @@ impl ServerActor {
next_request_streams,
events,
"prefetch_db_chunk",
self.enable_debug_timing,
{
self.codes_engine.prefetch_db_chunk(
code_db_slices,
Expand All @@ -1235,18 +1264,29 @@ impl ServerActor {
.await_event(request_streams, &self.dot_events[db_chunk_idx % 2]);

// ---- START PHASE 1 ----
record_stream_time!(&self.device_manager, batch_streams, events, "db_dot", {
compact_device_queries.dot_products_against_db(
&mut self.codes_engine,
&mut self.masks_engine,
&CudaVec2DSlicerRawPointer::from(&self.code_chunk_buffers[db_chunk_idx % 2]),
&CudaVec2DSlicerRawPointer::from(&self.mask_chunk_buffers[db_chunk_idx % 2]),
&dot_chunk_size,
0,
request_streams,
request_cublas_handles,
);
});
record_stream_time!(
&self.device_manager,
batch_streams,
events,
"db_dot",
self.enable_debug_timing,
{
compact_device_queries.dot_products_against_db(
&mut self.codes_engine,
&mut self.masks_engine,
&CudaVec2DSlicerRawPointer::from(
&self.code_chunk_buffers[db_chunk_idx % 2],
),
&CudaVec2DSlicerRawPointer::from(
&self.mask_chunk_buffers[db_chunk_idx % 2],
),
&dot_chunk_size,
0,
request_streams,
request_cublas_handles,
);
}
);

// wait for the exchange result buffers to be ready
self.device_manager
Expand All @@ -1257,6 +1297,7 @@ impl ServerActor {
request_streams,
events,
"db_reduce",
self.enable_debug_timing,
{
compact_device_sums.compute_dot_reducer_against_db(
&mut self.codes_engine,
Expand All @@ -1278,6 +1319,7 @@ impl ServerActor {
request_streams,
events,
"db_reshare",
self.enable_debug_timing,
{
self.codes_engine
.reshare_results(&dot_chunk_size, request_streams);
Expand Down Expand Up @@ -1310,6 +1352,7 @@ impl ServerActor {
request_streams,
events,
"db_threshold",
self.enable_debug_timing,
{
self.phase2.compare_threshold_masked_many(
&code_dots,
Expand All @@ -1327,22 +1370,29 @@ impl ServerActor {
);

let res = self.phase2.take_result_buffer();
record_stream_time!(&self.device_manager, request_streams, events, "db_open", {
open(
&mut self.phase2,
&res,
&self.distance_comparator,
db_match_bitmap,
max_chunk_size * self.max_batch_size * ROTATIONS / 64,
&dot_chunk_size,
&chunk_size,
offset,
&self.current_db_sizes,
&ignore_device_results,
request_streams,
);
self.phase2.return_result_buffer(res);
});
record_stream_time!(
&self.device_manager,
request_streams,
events,
"db_open",
self.enable_debug_timing,
{
open(
&mut self.phase2,
&res,
&self.distance_comparator,
db_match_bitmap,
max_chunk_size * self.max_batch_size * ROTATIONS / 64,
&dot_chunk_size,
&chunk_size,
offset,
&self.current_db_sizes,
&ignore_device_results,
request_streams,
);
self.phase2.return_result_buffer(res);
}
);
}
self.device_manager
.record_event(request_streams, &self.phase2_events[(db_chunk_idx + 1) % 2]);
Expand Down
3 changes: 3 additions & 0 deletions iris-mpc-gpu/tests/e2e.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ mod e2e_test {
MAX_BATCH_SIZE,
true,
false,
false,
) {
Ok((mut actor, handle)) => {
actor.load_full_db(&(&db0.0, &db0.1), &(&db0.0, &db0.1), DB_SIZE);
Expand Down Expand Up @@ -157,6 +158,7 @@ mod e2e_test {
MAX_BATCH_SIZE,
true,
false,
false,
) {
Ok((mut actor, handle)) => {
actor.load_full_db(&(&db1.0, &db1.1), &(&db1.0, &db1.1), DB_SIZE);
Expand Down Expand Up @@ -185,6 +187,7 @@ mod e2e_test {
MAX_BATCH_SIZE,
true,
false,
false,
) {
Ok((mut actor, handle)) => {
actor.load_full_db(&(&db2.0, &db2.1), &(&db2.0, &db2.1), DB_SIZE);
Expand Down
1 change: 1 addition & 0 deletions iris-mpc/src/bin/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,7 @@ async fn server_main(config: Config) -> eyre::Result<()> {
config.max_batch_size,
config.return_partial_results,
config.disable_persistence,
config.enable_debug_timing,
) {
Ok((mut actor, handle)) => {
let res = if config.fake_db_size > 0 {
Expand Down

0 comments on commit cdcbcea

Please sign in to comment.