Skip to content

Commit

Permalink
Fix matching against empty db (#402)
Browse files Browse the repository at this point in the history
* handle empty db

* skip empty

* change approach

* up

* up

* up

* log

* up

* Revert "up"

This reverts commit fe1b672.
  • Loading branch information
philsippl authored Sep 18, 2024
1 parent 352214d commit 28bf9d2
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
9 changes: 9 additions & 0 deletions iris-mpc-gpu/src/dot/distance_comparator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,15 @@ impl DistanceComparator {
real_db_sizes: &[usize],
offset: usize,
total_db_sizes: &[usize],
ignore_db_results: &[bool],
streams: &[CudaStream],
) {
for i in 0..self.device_manager.device_count() {
// Those correspond to 0 length dbs, which were just artificially increased to
// length 1 to avoid division by zero in the kernel
if ignore_db_results[i] {
continue;
}
let num_elements = (db_sizes[i] * self.query_length).div_ceil(64);
let threads_per_block = 256;
let blocks_per_grid = num_elements.div_ceil(threads_per_block);
Expand Down Expand Up @@ -179,6 +185,9 @@ impl DistanceComparator {
kernels: &[CudaFunction],
) {
for i in 0..self.device_manager.device_count() {
if db_sizes[i] == 0 {
continue;
}
let num_elements = (db_sizes[i] * self.query_length / ROTATIONS).div_ceil(64);
let threads_per_block = 256;
let blocks_per_grid = num_elements.div_ceil(threads_per_block);
Expand Down
15 changes: 14 additions & 1 deletion iris-mpc-gpu/src/server/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,13 @@ impl ServerActor {
/ now.elapsed().as_secs_f64()
/ 1e6
);

tracing::info!(
"Old DB size: {}, New DB size: {}",
previous_total_db_size,
self.current_db_sizes.iter().sum::<usize>()
);

Ok(())
}

Expand Down Expand Up @@ -957,6 +964,7 @@ impl ServerActor {
&db_sizes_batch,
0,
&db_sizes_batch,
&vec![false; self.device_manager.device_count()],
batch_streams,
);
self.phase2_batch.return_result_buffer(res);
Expand All @@ -974,6 +982,8 @@ impl ServerActor {

// ---- START DATABASE DEDUP ----
tracing::debug!(party_id = self.party_id, "Start DB deduplication");
let ignore_device_results: Vec<bool> =
self.current_db_sizes.iter().map(|&s| s == 0).collect();
let mut db_chunk_idx = 0;
loop {
tracing::debug!(
Expand All @@ -989,7 +999,7 @@ impl ServerActor {
let chunk_size = self
.current_db_sizes
.iter()
.map(|s| (s - DB_CHUNK_SIZE * db_chunk_idx).clamp(0, DB_CHUNK_SIZE))
.map(|s| (s - DB_CHUNK_SIZE * db_chunk_idx).clamp(1, DB_CHUNK_SIZE))
.collect::<Vec<_>>();

// We need to pad the chunk size to be a multiple of 4, because the underlying
Expand Down Expand Up @@ -1126,6 +1136,7 @@ impl ServerActor {
&chunk_size,
offset,
&self.current_db_sizes,
&ignore_device_results,
request_streams,
);
self.phase2.return_result_buffer(res);
Expand Down Expand Up @@ -1298,6 +1309,7 @@ fn open(
real_db_sizes: &[usize],
offset: usize,
total_db_sizes: &[usize],
ignore_db_results: &[bool],
streams: &[CudaStream],
) {
let n_devices = x.len();
Expand Down Expand Up @@ -1333,6 +1345,7 @@ fn open(
real_db_sizes,
offset,
total_db_sizes,
ignore_db_results,
streams,
);
}
Expand Down

0 comments on commit 28bf9d2

Please sign in to comment.