From cdcbceaf312b6f9c3a147491014b212f09bb9b64 Mon Sep 17 00:00:00 2001 From: Philipp Sippl Date: Thu, 2 Jan 2025 23:05:32 -0800 Subject: [PATCH] allow disabling of stream timers (#868) * allow disabling of stream timers * bump stage image --------- Co-authored-by: Ertugrul Aypek --- deploy/stage/common-values-iris-mpc.yaml | 2 +- iris-mpc-common/src/config/mod.rs | 3 + iris-mpc-gpu/src/server/actor.rs | 164 +++++++++++++++-------- iris-mpc-gpu/tests/e2e.rs | 3 + iris-mpc/src/bin/server.rs | 1 + 5 files changed, 115 insertions(+), 58 deletions(-) diff --git a/deploy/stage/common-values-iris-mpc.yaml b/deploy/stage/common-values-iris-mpc.yaml index 276b120a1..69e69d88e 100644 --- a/deploy/stage/common-values-iris-mpc.yaml +++ b/deploy/stage/common-values-iris-mpc.yaml @@ -1,4 +1,4 @@ -image: "ghcr.io/worldcoin/iris-mpc:v0.13.6" +image: "ghcr.io/worldcoin/iris-mpc:v0.13.7" environment: stage replicaCount: 1 diff --git a/iris-mpc-common/src/config/mod.rs b/iris-mpc-common/src/config/mod.rs index 666884068..971425b05 100644 --- a/iris-mpc-common/src/config/mod.rs +++ b/iris-mpc-common/src/config/mod.rs @@ -79,6 +79,9 @@ pub struct Config { #[serde(default)] pub disable_persistence: bool, + #[serde(default)] + pub enable_debug_timing: bool, + #[serde(default, deserialize_with = "deserialize_yaml_json_string")] pub node_hostnames: Vec, diff --git a/iris-mpc-gpu/src/server/actor.rs b/iris-mpc-gpu/src/server/actor.rs index 1a31267a2..e1972f172 100644 --- a/iris-mpc-gpu/src/server/actor.rs +++ b/iris-mpc-gpu/src/server/actor.rs @@ -37,14 +37,18 @@ use std::{collections::HashMap, mem, sync::Arc, time::Instant}; use tokio::sync::{mpsc, oneshot}; macro_rules! record_stream_time { - ($manager:expr, $streams:expr, $map:expr, $label:expr, $block:block) => {{ - let evt0 = $manager.create_events(); - let evt1 = $manager.create_events(); - $manager.record_event($streams, &evt0); - let res = $block; - $manager.record_event($streams, &evt1); - $map.entry($label).or_default().extend(vec![evt0, evt1]); - res + ($manager:expr, $streams:expr, $map:expr, $label:expr, $enable_timing:expr, $block:block) => {{ + if $enable_timing { + let evt0 = $manager.create_events(); + let evt1 = $manager.create_events(); + $manager.record_event($streams, &evt0); + let res = $block; + $manager.record_event($streams, &evt1); + $map.entry($label).or_default().extend(vec![evt0, evt1]); + res + } else { + $block + } }}; } @@ -105,6 +109,7 @@ pub struct ServerActor { max_db_size: usize, return_partial_results: bool, disable_persistence: bool, + enable_debug_timing: bool, code_chunk_buffers: Vec, mask_chunk_buffers: Vec, dot_events: Vec>, @@ -124,6 +129,7 @@ impl ServerActor { max_batch_size: usize, return_partial_results: bool, disable_persistence: bool, + enable_debug_timing: bool, ) -> eyre::Result<(Self, ServerActorHandle)> { let device_manager = Arc::new(DeviceManager::init()); Self::new_with_device_manager( @@ -135,6 +141,7 @@ impl ServerActor { max_batch_size, return_partial_results, disable_persistence, + enable_debug_timing, ) } #[allow(clippy::too_many_arguments)] @@ -147,6 +154,7 @@ impl ServerActor { max_batch_size: usize, return_partial_results: bool, disable_persistence: bool, + enable_debug_timing: bool, ) -> eyre::Result<(Self, ServerActorHandle)> { let ids = device_manager.get_ids_from_magic(0); let comms = device_manager.instantiate_network_from_ids(party_id, &ids)?; @@ -160,6 +168,7 @@ impl ServerActor { max_batch_size, return_partial_results, disable_persistence, + enable_debug_timing, ) } @@ -174,6 +183,7 @@ impl ServerActor { max_batch_size: usize, return_partial_results: bool, disable_persistence: bool, + enable_debug_timing: bool, ) -> eyre::Result<(Self, ServerActorHandle)> { let (tx, rx) = mpsc::channel(job_queue_size); let actor = Self::init( @@ -186,6 +196,7 @@ impl ServerActor { max_batch_size, return_partial_results, disable_persistence, + enable_debug_timing, )?; Ok((actor, ServerActorHandle { job_queue: tx })) } @@ -201,6 +212,7 @@ impl ServerActor { max_batch_size: usize, return_partial_results: bool, disable_persistence: bool, + enable_debug_timing: bool, ) -> eyre::Result { assert!(max_batch_size != 0); let mut kdf_nonce = 0; @@ -373,6 +385,7 @@ impl ServerActor { max_db_size, return_partial_results, disable_persistence, + enable_debug_timing, code_chunk_buffers, mask_chunk_buffers, dot_events, @@ -616,6 +629,7 @@ impl ServerActor { &self.streams[0], events, "query_preprocess", + self.enable_debug_timing, { // This needs to be max_batch_size, even though the query can be shorter to have // enough padding for GEMM @@ -662,6 +676,7 @@ impl ServerActor { &self.streams[0], events, "query_preprocess", + self.enable_debug_timing, { // This needs to be MAX_BATCH_SIZE, even though the query can be shorter to have // enough padding for GEMM @@ -903,6 +918,7 @@ impl ServerActor { &self.streams[0], events, "db_write", + self.enable_debug_timing, { for i in 0..self.device_manager.device_count() { self.device_manager.device(i).bind_to_thread().unwrap(); @@ -987,7 +1003,9 @@ impl ServerActor { ); // ---- END RESULT PROCESSING ---- - log_timers(events); + if self.enable_debug_timing { + log_timers(events); + } let processed_mil_elements_per_second = (self.max_batch_size * previous_total_db_size) as f64 / now.elapsed().as_secs_f64() @@ -1055,34 +1073,42 @@ impl ServerActor { // ---- START BATCH DEDUP ---- tracing::info!(party_id = self.party_id, "Starting batch deduplication"); - record_stream_time!(&self.device_manager, batch_streams, events, "batch_dot", { - tracing::info!(party_id = self.party_id, "batch_dot start"); + record_stream_time!( + &self.device_manager, + batch_streams, + events, + "batch_dot", + self.enable_debug_timing, + { + tracing::info!(party_id = self.party_id, "batch_dot start"); - compact_device_queries.compute_dot_products( - &mut self.batch_codes_engine, - &mut self.batch_masks_engine, - &self.query_db_size, - 0, - batch_streams, - batch_cublas, - ); - tracing::info!(party_id = self.party_id, "compute_dot_reducers start"); + compact_device_queries.compute_dot_products( + &mut self.batch_codes_engine, + &mut self.batch_masks_engine, + &self.query_db_size, + 0, + batch_streams, + batch_cublas, + ); + tracing::info!(party_id = self.party_id, "compute_dot_reducers start"); - compact_device_sums.compute_dot_reducers( - &mut self.batch_codes_engine, - &mut self.batch_masks_engine, - &self.query_db_size, - 0, - batch_streams, - ); - tracing::info!(party_id = self.party_id, "batch_dot end"); - }); + compact_device_sums.compute_dot_reducers( + &mut self.batch_codes_engine, + &mut self.batch_masks_engine, + &self.query_db_size, + 0, + batch_streams, + ); + tracing::info!(party_id = self.party_id, "batch_dot end"); + } + ); record_stream_time!( &self.device_manager, batch_streams, events, "batch_reshare", + self.enable_debug_timing, { tracing::info!(party_id = self.party_id, "batch_reshare start"); self.batch_codes_engine @@ -1104,6 +1130,7 @@ impl ServerActor { batch_streams, events, "batch_threshold", + self.enable_debug_timing, { tracing::info!(party_id = self.party_id, "batch_threshold start"); self.phase2_batch.compare_threshold_masked_many( @@ -1149,6 +1176,7 @@ impl ServerActor { &self.streams[0], events, "prefetch_db_chunk", + self.enable_debug_timing, { self.codes_engine.prefetch_db_chunk( code_db_slices, @@ -1211,6 +1239,7 @@ impl ServerActor { next_request_streams, events, "prefetch_db_chunk", + self.enable_debug_timing, { self.codes_engine.prefetch_db_chunk( code_db_slices, @@ -1235,18 +1264,29 @@ impl ServerActor { .await_event(request_streams, &self.dot_events[db_chunk_idx % 2]); // ---- START PHASE 1 ---- - record_stream_time!(&self.device_manager, batch_streams, events, "db_dot", { - compact_device_queries.dot_products_against_db( - &mut self.codes_engine, - &mut self.masks_engine, - &CudaVec2DSlicerRawPointer::from(&self.code_chunk_buffers[db_chunk_idx % 2]), - &CudaVec2DSlicerRawPointer::from(&self.mask_chunk_buffers[db_chunk_idx % 2]), - &dot_chunk_size, - 0, - request_streams, - request_cublas_handles, - ); - }); + record_stream_time!( + &self.device_manager, + batch_streams, + events, + "db_dot", + self.enable_debug_timing, + { + compact_device_queries.dot_products_against_db( + &mut self.codes_engine, + &mut self.masks_engine, + &CudaVec2DSlicerRawPointer::from( + &self.code_chunk_buffers[db_chunk_idx % 2], + ), + &CudaVec2DSlicerRawPointer::from( + &self.mask_chunk_buffers[db_chunk_idx % 2], + ), + &dot_chunk_size, + 0, + request_streams, + request_cublas_handles, + ); + } + ); // wait for the exchange result buffers to be ready self.device_manager @@ -1257,6 +1297,7 @@ impl ServerActor { request_streams, events, "db_reduce", + self.enable_debug_timing, { compact_device_sums.compute_dot_reducer_against_db( &mut self.codes_engine, @@ -1278,6 +1319,7 @@ impl ServerActor { request_streams, events, "db_reshare", + self.enable_debug_timing, { self.codes_engine .reshare_results(&dot_chunk_size, request_streams); @@ -1310,6 +1352,7 @@ impl ServerActor { request_streams, events, "db_threshold", + self.enable_debug_timing, { self.phase2.compare_threshold_masked_many( &code_dots, @@ -1327,22 +1370,29 @@ impl ServerActor { ); let res = self.phase2.take_result_buffer(); - record_stream_time!(&self.device_manager, request_streams, events, "db_open", { - open( - &mut self.phase2, - &res, - &self.distance_comparator, - db_match_bitmap, - max_chunk_size * self.max_batch_size * ROTATIONS / 64, - &dot_chunk_size, - &chunk_size, - offset, - &self.current_db_sizes, - &ignore_device_results, - request_streams, - ); - self.phase2.return_result_buffer(res); - }); + record_stream_time!( + &self.device_manager, + request_streams, + events, + "db_open", + self.enable_debug_timing, + { + open( + &mut self.phase2, + &res, + &self.distance_comparator, + db_match_bitmap, + max_chunk_size * self.max_batch_size * ROTATIONS / 64, + &dot_chunk_size, + &chunk_size, + offset, + &self.current_db_sizes, + &ignore_device_results, + request_streams, + ); + self.phase2.return_result_buffer(res); + } + ); } self.device_manager .record_event(request_streams, &self.phase2_events[(db_chunk_idx + 1) % 2]); diff --git a/iris-mpc-gpu/tests/e2e.rs b/iris-mpc-gpu/tests/e2e.rs index 7b9e47e7a..df377fa08 100644 --- a/iris-mpc-gpu/tests/e2e.rs +++ b/iris-mpc-gpu/tests/e2e.rs @@ -129,6 +129,7 @@ mod e2e_test { MAX_BATCH_SIZE, true, false, + false, ) { Ok((mut actor, handle)) => { actor.load_full_db(&(&db0.0, &db0.1), &(&db0.0, &db0.1), DB_SIZE); @@ -157,6 +158,7 @@ mod e2e_test { MAX_BATCH_SIZE, true, false, + false, ) { Ok((mut actor, handle)) => { actor.load_full_db(&(&db1.0, &db1.1), &(&db1.0, &db1.1), DB_SIZE); @@ -185,6 +187,7 @@ mod e2e_test { MAX_BATCH_SIZE, true, false, + false, ) { Ok((mut actor, handle)) => { actor.load_full_db(&(&db2.0, &db2.1), &(&db2.0, &db2.1), DB_SIZE); diff --git a/iris-mpc/src/bin/server.rs b/iris-mpc/src/bin/server.rs index 6a75509cc..adf7beddb 100644 --- a/iris-mpc/src/bin/server.rs +++ b/iris-mpc/src/bin/server.rs @@ -973,6 +973,7 @@ async fn server_main(config: Config) -> eyre::Result<()> { config.max_batch_size, config.return_partial_results, config.disable_persistence, + config.enable_debug_timing, ) { Ok((mut actor, handle)) => { let res = if config.fake_db_size > 0 {