diff --git a/iris-mpc-common/src/config/mod.rs b/iris-mpc-common/src/config/mod.rs index 4117bd180..850e52230 100644 --- a/iris-mpc-common/src/config/mod.rs +++ b/iris-mpc-common/src/config/mod.rs @@ -67,6 +67,9 @@ pub struct Config { #[serde(default)] pub fake_db_size: usize, + #[serde(default)] + pub return_partial_results: bool, + #[serde(default)] pub disable_persistence: bool, } diff --git a/iris-mpc-common/src/helpers/smpc_request.rs b/iris-mpc-common/src/helpers/smpc_request.rs index 1612106c3..fce73d4c7 100644 --- a/iris-mpc-common/src/helpers/smpc_request.rs +++ b/iris-mpc-common/src/helpers/smpc_request.rs @@ -302,11 +302,13 @@ impl UniquenessRequest { #[derive(Debug, Serialize, Deserialize, Clone)] pub struct UniquenessResult { - pub node_id: usize, - pub serial_id: Option, - pub is_match: bool, - pub signup_id: String, - pub matched_serial_ids: Option>, + pub node_id: usize, + pub serial_id: Option, + pub is_match: bool, + pub signup_id: String, + pub matched_serial_ids: Option>, + pub matched_serial_ids_left: Option>, + pub matched_serial_ids_right: Option>, } impl UniquenessResult { @@ -316,6 +318,8 @@ impl UniquenessResult { is_match: bool, signup_id: String, matched_serial_ids: Option>, + matched_serial_ids_left: Option>, + matched_serial_ids_right: Option>, ) -> Self { Self { node_id, @@ -323,6 +327,8 @@ impl UniquenessResult { is_match, signup_id, matched_serial_ids, + matched_serial_ids_left, + matched_serial_ids_right, } } } diff --git a/iris-mpc-gpu/src/dot/distance_comparator.rs b/iris-mpc-gpu/src/dot/distance_comparator.rs index f9b659d61..c70fd4a8e 100644 --- a/iris-mpc-gpu/src/dot/distance_comparator.rs +++ b/iris-mpc-gpu/src/dot/distance_comparator.rs @@ -27,18 +27,26 @@ pub struct DistanceComparator { pub final_results_init_host: Vec, pub match_counters: Vec>, pub all_matches: Vec>, + pub match_counters_left: Vec>, + pub match_counters_right: Vec>, + pub partial_results_left: Vec>, + pub partial_results_right: Vec>, } impl DistanceComparator { pub fn init(query_length: usize, device_manager: Arc) -> Self { let ptx = compile_ptx(PTX_SRC).unwrap(); - let mut open_kernels = Vec::new(); + let mut open_kernels: Vec = Vec::new(); let mut merge_db_kernels = Vec::new(); let mut merge_batch_kernels = Vec::new(); let mut opened_results = vec![]; let mut final_results = vec![]; - let mut match_counters: Vec> = vec![]; - let mut all_matches: Vec> = vec![]; + let mut match_counters = vec![]; + let mut match_counters_left = vec![]; + let mut match_counters_right = vec![]; + let mut all_matches = vec![]; + let mut partial_results_left = vec![]; + let mut partial_results_right = vec![]; let devices_count = device_manager.device_count(); @@ -63,11 +71,23 @@ impl DistanceComparator { opened_results.push(device.htod_copy(results_init_host.clone()).unwrap()); final_results.push(device.htod_copy(final_results_init_host.clone()).unwrap()); match_counters.push(device.alloc_zeros(query_length / ROTATIONS).unwrap()); + match_counters_left.push(device.alloc_zeros(query_length / ROTATIONS).unwrap()); + match_counters_right.push(device.alloc_zeros(query_length / ROTATIONS).unwrap()); all_matches.push( device .alloc_zeros(ALL_MATCHES_LEN * query_length / ROTATIONS) .unwrap(), ); + partial_results_left.push( + device + .alloc_zeros(ALL_MATCHES_LEN * query_length / ROTATIONS) + .unwrap(), + ); + partial_results_right.push( + device + .alloc_zeros(ALL_MATCHES_LEN * query_length / ROTATIONS) + .unwrap(), + ); open_kernels.push(open_results_function); merge_db_kernels.push(merge_db_results_function); @@ -85,7 +105,11 @@ impl DistanceComparator { results_init_host, final_results_init_host, match_counters, + match_counters_left, + match_counters_right, all_matches, + partial_results_left, + partial_results_right, } } @@ -213,6 +237,10 @@ impl DistanceComparator { num_elements as u64, &self.match_counters[i], &self.all_matches[i], + &self.match_counters_left[i], + &self.match_counters_right[i], + &self.partial_results_left[i], + &self.partial_results_right[i], ), ) .unwrap(); @@ -233,26 +261,30 @@ impl DistanceComparator { results } - pub fn fetch_match_counters(&self) -> Vec> { + pub fn fetch_match_counters(&self, counters: &[CudaSlice]) -> Vec> { let mut results = vec![]; for i in 0..self.device_manager.device_count() { results.push( self.device_manager .device(i) - .dtoh_sync_copy(&self.match_counters[i]) + .dtoh_sync_copy(&counters[i]) .unwrap(), ); } results } - pub fn fetch_all_match_ids(&self, match_counters: Vec>) -> Vec> { + pub fn fetch_all_match_ids( + &self, + match_counters: Vec>, + matches: &[CudaSlice], + ) -> Vec> { let mut results = vec![]; for i in 0..self.device_manager.device_count() { results.push( self.device_manager .device(i) - .dtoh_sync_copy(&self.all_matches[i]) + .dtoh_sync_copy(&matches[i]) .unwrap(), ); } diff --git a/iris-mpc-gpu/src/dot/kernel.cu b/iris-mpc-gpu/src/dot/kernel.cu index 183d17ce0..9c76172a3 100644 --- a/iris-mpc-gpu/src/dot/kernel.cu +++ b/iris-mpc-gpu/src/dot/kernel.cu @@ -51,7 +51,7 @@ extern "C" __global__ void openResults(unsigned long long *result1, unsigned lon } } -extern "C" __global__ void mergeDbResults(unsigned long long *matchResultsLeft, unsigned long long *matchResultsRight, unsigned int *finalResults, size_t queryLength, size_t dbLength, size_t numElements, unsigned int *matchCounter, unsigned int *allMatches) +extern "C" __global__ void mergeDbResults(unsigned long long *matchResultsLeft, unsigned long long *matchResultsRight, unsigned int *finalResults, size_t queryLength, size_t dbLength, size_t numElements, unsigned int *matchCounter, unsigned int *allMatches, unsigned int *matchCounterLeft, unsigned int *matchCounterRight, unsigned int *partialResultsLeft, unsigned int *partialResultsRight) { size_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < numElements) @@ -67,6 +67,20 @@ extern "C" __global__ void mergeDbResults(unsigned long long *matchResultsLeft, if (queryIdx >= queryLength || dbIdx >= dbLength) continue; + // Check for partial results (only used for debugging) + if (matchLeft) + { + unsigned int queryMatchCounter = atomicAdd(&matchCounterLeft[queryIdx], 1); + if (queryMatchCounter < MAX_MATCHES_LEN) + partialResultsLeft[MAX_MATCHES_LEN * queryIdx + queryMatchCounter] = dbIdx; + } + if (matchRight) + { + unsigned int queryMatchCounter = atomicAdd(&matchCounterRight[queryIdx], 1); + if (queryMatchCounter < MAX_MATCHES_LEN) + partialResultsRight[MAX_MATCHES_LEN * queryIdx + queryMatchCounter] = dbIdx; + } + // Current *AND* policy: only match, if both eyes match if (matchLeft && matchRight) { @@ -79,7 +93,7 @@ extern "C" __global__ void mergeDbResults(unsigned long long *matchResultsLeft, } } -extern "C" __global__ void mergeBatchResults(unsigned long long *matchResultsSelfLeft, unsigned long long *matchResultsSelfRight, unsigned int *finalResults, size_t queryLength, size_t dbLength, size_t numElements, unsigned int *__matchCounter, unsigned int *__allMatches) +extern "C" __global__ void mergeBatchResults(unsigned long long *matchResultsSelfLeft, unsigned long long *matchResultsSelfRight, unsigned int *finalResults, size_t queryLength, size_t dbLength, size_t numElements, unsigned int *__matchCounter, unsigned int *__allMatches, unsigned int *__matchCounterLeft, unsigned int *__matchCounterRight, unsigned int *__partialResultsLeft, unsigned int *__partialResultsRight) { size_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < numElements) diff --git a/iris-mpc-gpu/src/server/actor.rs b/iris-mpc-gpu/src/server/actor.rs index ba2e59300..fe28c1737 100644 --- a/iris-mpc-gpu/src/server/actor.rs +++ b/iris-mpc-gpu/src/server/actor.rs @@ -99,6 +99,7 @@ pub struct ServerActor { query_db_size: Vec, max_batch_size: usize, max_db_size: usize, + return_partial_results: bool, disable_persistence: bool, } @@ -112,6 +113,7 @@ impl ServerActor { job_queue_size: usize, max_db_size: usize, max_batch_size: usize, + return_partial_results: bool, disable_persistence: bool, ) -> eyre::Result<(Self, ServerActorHandle)> { let device_manager = Arc::new(DeviceManager::init()); @@ -122,6 +124,7 @@ impl ServerActor { job_queue_size, max_db_size, max_batch_size, + return_partial_results, disable_persistence, ) } @@ -133,6 +136,7 @@ impl ServerActor { job_queue_size: usize, max_db_size: usize, max_batch_size: usize, + return_partial_results: bool, disable_persistence: bool, ) -> eyre::Result<(Self, ServerActorHandle)> { let ids = device_manager.get_ids_from_magic(0); @@ -145,6 +149,7 @@ impl ServerActor { job_queue_size, max_db_size, max_batch_size, + return_partial_results, disable_persistence, ) } @@ -158,6 +163,7 @@ impl ServerActor { job_queue_size: usize, max_db_size: usize, max_batch_size: usize, + return_partial_results: bool, disable_persistence: bool, ) -> eyre::Result<(Self, ServerActorHandle)> { let (tx, rx) = mpsc::channel(job_queue_size); @@ -169,6 +175,7 @@ impl ServerActor { rx, max_db_size, max_batch_size, + return_partial_results, disable_persistence, )?; Ok((actor, ServerActorHandle { job_queue: tx })) @@ -183,6 +190,7 @@ impl ServerActor { job_queue: mpsc::Receiver, max_db_size: usize, max_batch_size: usize, + return_partial_results: bool, disable_persistence: bool, ) -> eyre::Result { assert!(max_batch_size != 0); @@ -343,6 +351,7 @@ impl ServerActor { batch_match_list_right, max_batch_size, max_db_size, + return_partial_results, disable_persistence, }) } @@ -759,7 +768,7 @@ impl ServerActor { // Fetch and truncate the match counters let match_counters_devices = self .distance_comparator - .fetch_match_counters() + .fetch_match_counters(&self.distance_comparator.match_counters) .into_iter() .map(|x| x[..batch_size].to_vec()) .collect::>(); @@ -776,9 +785,10 @@ impl ServerActor { }); // Transfer all match ids - let match_ids = self - .distance_comparator - .fetch_all_match_ids(match_counters_devices); + let match_ids = self.distance_comparator.fetch_all_match_ids( + match_counters_devices, + &self.distance_comparator.all_matches, + ); // Check if there are more matches than we fetch // TODO: In the future we might want to dynamically allocate more memory here @@ -793,6 +803,28 @@ impl ServerActor { } } + let (partial_match_ids_left, partial_match_ids_right) = if self.return_partial_results { + // Transfer the partial results to the host + let partial_match_counters_left = self + .distance_comparator + .fetch_match_counters(&self.distance_comparator.match_counters_left); + let partial_match_counters_right = self + .distance_comparator + .fetch_match_counters(&self.distance_comparator.match_counters_right); + + let partial_results_left = self.distance_comparator.fetch_all_match_ids( + partial_match_counters_left, + &self.distance_comparator.partial_results_left, + ); + let partial_results_right = self.distance_comparator.fetch_all_match_ids( + partial_match_counters_right, + &self.distance_comparator.partial_results_right, + ); + (partial_results_left, partial_results_right) + } else { + (vec![], vec![]) + }; + // Write back to in-memory db let previous_total_db_size = self.current_db_sizes.iter().sum::(); let n_insertions = insertion_list.iter().map(|x| x.len()).sum::(); @@ -854,6 +886,8 @@ impl ServerActor { metadata: batch.metadata, matches, match_ids, + partial_match_ids_left, + partial_match_ids_right, store_left: query_store_left, store_right: query_store_right, deleted_ids: batch.deletion_requests_indices, @@ -881,6 +915,20 @@ impl ServerActor { &self.streams[0], ); + reset_slice( + self.device_manager.devices(), + &self.distance_comparator.match_counters_left, + 0, + &self.streams[0], + ); + + reset_slice( + self.device_manager.devices(), + &self.distance_comparator.match_counters_right, + 0, + &self.streams[0], + ); + // ---- END RESULT PROCESSING ---- log_timers(events); let processed_mil_elements_per_second = (self.max_batch_size * previous_total_db_size) diff --git a/iris-mpc-gpu/src/server/mod.rs b/iris-mpc-gpu/src/server/mod.rs index 51ed3beb7..96c7b53f2 100644 --- a/iris-mpc-gpu/src/server/mod.rs +++ b/iris-mpc-gpu/src/server/mod.rs @@ -89,14 +89,16 @@ pub struct ServerJob { #[derive(Debug, Clone)] pub struct ServerJobResult { - pub merged_results: Vec, - pub request_ids: Vec, - pub metadata: Vec, - pub matches: Vec, - pub match_ids: Vec>, - pub store_left: BatchQueryEntries, - pub store_right: BatchQueryEntries, - pub deleted_ids: Vec, + pub merged_results: Vec, + pub request_ids: Vec, + pub metadata: Vec, + pub matches: Vec, + pub match_ids: Vec>, + pub partial_match_ids_left: Vec>, + pub partial_match_ids_right: Vec>, + pub store_left: BatchQueryEntries, + pub store_right: BatchQueryEntries, + pub deleted_ids: Vec, } enum Eye { diff --git a/iris-mpc-gpu/tests/e2e.rs b/iris-mpc-gpu/tests/e2e.rs index 3fe47a220..0832d21dd 100644 --- a/iris-mpc-gpu/tests/e2e.rs +++ b/iris-mpc-gpu/tests/e2e.rs @@ -118,6 +118,7 @@ mod e2e_test { 8, DB_SIZE + DB_BUFFER, MAX_BATCH_SIZE, + true, false, ) { Ok((mut actor, handle)) => { @@ -144,6 +145,7 @@ mod e2e_test { 8, DB_SIZE + DB_BUFFER, MAX_BATCH_SIZE, + true, false, ) { Ok((mut actor, handle)) => { @@ -170,6 +172,7 @@ mod e2e_test { 8, DB_SIZE + DB_BUFFER, MAX_BATCH_SIZE, + true, false, ) { Ok((mut actor, handle)) => { @@ -390,15 +393,25 @@ mod e2e_test { request_ids: thread_request_ids, matches, merged_results, + match_ids, + partial_match_ids_left, + partial_match_ids_right, .. } = res; - for ((req_id, &was_match), &idx) in thread_request_ids - .iter() - .zip(matches.iter()) - .zip(merged_results.iter()) + for (((((req_id, was_match), idx), partial_left), partial_right), match_id) in + thread_request_ids + .iter() + .zip(matches.iter()) + .zip(merged_results.iter()) + .zip(partial_match_ids_left.iter()) + .zip(partial_match_ids_right.iter()) + .zip(match_ids.iter()) { assert!(requests.contains_key(req_id)); + assert_eq!(partial_left, partial_right); + assert_eq!(partial_left, match_id); + // This was an invalid query, we should not get a response, but they should be // silently ignored assert!(requests.contains_key(req_id)); @@ -407,11 +420,11 @@ mod e2e_test { if let Some(expected_idx) = expected_idx { assert!(was_match); - assert_eq!(expected_idx, &idx); + assert_eq!(expected_idx, idx); } else { assert!(!was_match); let request = requests.get(req_id).unwrap().clone(); - responses.insert(idx, request); + responses.insert(*idx, request); } } } diff --git a/iris-mpc-store/src/lib.rs b/iris-mpc-store/src/lib.rs index 51f2d2425..0b5b7497e 100644 --- a/iris-mpc-store/src/lib.rs +++ b/iris-mpc-store/src/lib.rs @@ -556,6 +556,8 @@ mod tests { false, "A".repeat(64), None, + None, + None, ))?; let result_events = vec![result_event; count]; diff --git a/iris-mpc/src/bin/server.rs b/iris-mpc/src/bin/server.rs index 35b5c8a12..f663a8f01 100644 --- a/iris-mpc/src/bin/server.rs +++ b/iris-mpc/src/bin/server.rs @@ -711,6 +711,7 @@ async fn server_main(config: Config) -> eyre::Result<()> { 8, config.max_db_size, config.max_batch_size, + config.return_partial_results, config.disable_persistence, ) { Ok((mut actor, handle)) => { @@ -805,6 +806,8 @@ async fn server_main(config: Config) -> eyre::Result<()> { metadata, matches, match_ids, + partial_match_ids_left, + partial_match_ids_right, store_left, store_right, deleted_ids, @@ -827,6 +830,24 @@ async fn server_main(config: Config) -> eyre::Result<()> { true => Some(match_ids[i].iter().map(|x| x + 1).collect::>()), false => None, }, + match matches[i] { + true => Some( + partial_match_ids_left[i] + .iter() + .map(|x| x + 1) + .collect::>(), + ), + false => None, + }, + match matches[i] { + true => Some( + partial_match_ids_right[i] + .iter() + .map(|x| x + 1) + .collect::>(), + ), + false => None, + }, ); serde_json::to_string(&result_event).wrap_err("failed to serialize result")