Skip to content

Commit

Permalink
fix(gpu): fix multi-gpu mult
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Jul 19, 2024
1 parent 8dffbbc commit c5dccf3
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 138 deletions.
38 changes: 20 additions & 18 deletions backends/tfhe-cuda-backend/cuda/include/integer.h
Original file line number Diff line number Diff line change
Expand Up @@ -583,21 +583,21 @@ template <typename Torus> struct int_radix_lut {
/// With multiple GPUs we allocate arrays to be pushed to the vectors and
/// copy data on each GPU then when we gather data to GPU 0 we can copy
/// back to the original indexing
multi_gpu_alloc_lwe(streams, gpu_indexes, active_gpu_count,
lwe_array_in_vec, num_radix_blocks,
params.big_lwe_dimension + 1, false);
multi_gpu_alloc_lwe(streams, gpu_indexes, active_gpu_count,
lwe_after_ks_vec, num_radix_blocks,
params.small_lwe_dimension + 1, false);
multi_gpu_alloc_lwe(streams, gpu_indexes, active_gpu_count,
lwe_after_pbs_vec, num_radix_blocks,
params.big_lwe_dimension + 1, false);
multi_gpu_alloc_array(streams, gpu_indexes, active_gpu_count,
lwe_trivial_indexes_vec, num_radix_blocks, false);
multi_gpu_alloc_lwe_async(streams, gpu_indexes, active_gpu_count,
lwe_array_in_vec, num_radix_blocks,
params.big_lwe_dimension + 1);
multi_gpu_alloc_lwe_async(streams, gpu_indexes, active_gpu_count,
lwe_after_ks_vec, num_radix_blocks,
params.small_lwe_dimension + 1);
multi_gpu_alloc_lwe_async(streams, gpu_indexes, active_gpu_count,
lwe_after_pbs_vec, num_radix_blocks,
params.big_lwe_dimension + 1);
multi_gpu_alloc_array_async(streams, gpu_indexes, active_gpu_count,
lwe_trivial_indexes_vec, num_radix_blocks);
cuda_synchronize_stream(streams[0], gpu_indexes[0]);
multi_gpu_copy_array(streams, gpu_indexes, active_gpu_count,
lwe_trivial_indexes_vec, lwe_trivial_indexes,
num_radix_blocks, false);
multi_gpu_copy_array_async(streams, gpu_indexes, active_gpu_count,
lwe_trivial_indexes_vec, lwe_trivial_indexes,
num_radix_blocks);

// Keyswitch
Torus big_size =
Expand Down Expand Up @@ -778,10 +778,12 @@ template <typename Torus> struct int_radix_lut {
}
buffer.clear();

multi_gpu_release(streams, gpu_indexes, lwe_array_in_vec, false);
multi_gpu_release(streams, gpu_indexes, lwe_after_ks_vec, false);
multi_gpu_release(streams, gpu_indexes, lwe_after_pbs_vec, false);
multi_gpu_release(streams, gpu_indexes, lwe_trivial_indexes_vec);
multi_gpu_release_async(streams, gpu_indexes, lwe_array_in_vec);
multi_gpu_release_async(streams, gpu_indexes, lwe_after_ks_vec);
multi_gpu_release_async(streams, gpu_indexes, lwe_after_pbs_vec);
multi_gpu_release_async(streams, gpu_indexes, lwe_trivial_indexes_vec);
for (uint i = 0; i < active_gpu_count; i++)
cuda_synchronize_stream(streams[i], gpu_indexes[i]);
lwe_array_in_vec.clear();
lwe_after_ks_vec.clear();
lwe_after_pbs_vec.clear();
Expand Down
18 changes: 9 additions & 9 deletions backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,15 @@ __host__ void cuda_keyswitch_lwe_ciphertext_vector(
}

template <typename Torus>
void execute_keyswitch(cudaStream_t *streams, uint32_t *gpu_indexes,
uint32_t gpu_count,
const LweArrayVariant<Torus> &lwe_array_out,
const LweArrayVariant<Torus> &lwe_output_indexes,
const LweArrayVariant<Torus> &lwe_array_in,
const LweArrayVariant<Torus> &lwe_input_indexes,
Torus **ksks, uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out, uint32_t base_log,
uint32_t level_count, uint32_t num_samples) {
void execute_keyswitch_async(cudaStream_t *streams, uint32_t *gpu_indexes,
uint32_t gpu_count,
const LweArrayVariant<Torus> &lwe_array_out,
const LweArrayVariant<Torus> &lwe_output_indexes,
const LweArrayVariant<Torus> &lwe_array_in,
const LweArrayVariant<Torus> &lwe_input_indexes,
Torus **ksks, uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out, uint32_t base_log,
uint32_t level_count, uint32_t num_samples) {

/// If the number of radix blocks is lower than the number of GPUs, not all
/// GPUs will be active and there will be 1 input per GPU
Expand Down
100 changes: 50 additions & 50 deletions backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -164,42 +164,42 @@ __host__ void integer_radix_apply_univariate_lookup_table_kb(

auto active_gpu_count = get_active_gpu_count(num_radix_blocks, gpu_count);
if (active_gpu_count == 1) {
execute_keyswitch<Torus>(streams, gpu_indexes, 1, lwe_after_ks_vec[0],
lwe_trivial_indexes_vec[0], lwe_array_in,
lut->lwe_indexes_in, ksks, big_lwe_dimension,
small_lwe_dimension, ks_base_log, ks_level,
num_radix_blocks);
execute_keyswitch_async<Torus>(streams, gpu_indexes, 1, lwe_after_ks_vec[0],
lwe_trivial_indexes_vec[0], lwe_array_in,
lut->lwe_indexes_in, ksks, big_lwe_dimension,
small_lwe_dimension, ks_base_log, ks_level,
num_radix_blocks);

/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
/// dimension to a big LWE dimension
execute_pbs<Torus>(streams, gpu_indexes, 1, lwe_array_out,
lut->lwe_indexes_out, lut->lut_vec, lut->lut_indexes_vec,
lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], bsks,
lut->buffer, glwe_dimension, small_lwe_dimension,
polynomial_size, pbs_base_log, pbs_level,
grouping_factor, num_radix_blocks, 1, 0,
cuda_get_max_shared_memory(gpu_indexes[0]), pbs_type);
execute_pbs_async<Torus>(
streams, gpu_indexes, 1, lwe_array_out, lut->lwe_indexes_out,
lut->lut_vec, lut->lut_indexes_vec, lwe_after_ks_vec[0],
lwe_trivial_indexes_vec[0], bsks, lut->buffer, glwe_dimension,
small_lwe_dimension, polynomial_size, pbs_base_log, pbs_level,
grouping_factor, num_radix_blocks, 1, 0,
cuda_get_max_shared_memory(gpu_indexes[0]), pbs_type);
} else {
/// Make sure all data that should be on GPU 0 is indeed there
cuda_synchronize_stream(streams[0], gpu_indexes[0]);

/// With multiple GPUs we push to the vectors on each GPU then when we
/// gather data to GPU 0 we can copy back to the original indexing
multi_gpu_scatter_lwe<Torus>(
multi_gpu_scatter_lwe_async<Torus>(
streams, gpu_indexes, active_gpu_count, lwe_array_in_vec, lwe_array_in,
lut->h_lwe_indexes_in, lut->using_trivial_lwe_indexes, num_radix_blocks,
big_lwe_dimension + 1);

/// Apply KS to go from a big LWE dimension to a small LWE dimension
execute_keyswitch<Torus>(streams, gpu_indexes, active_gpu_count,
lwe_after_ks_vec, lwe_trivial_indexes_vec,
lwe_array_in_vec, lwe_trivial_indexes_vec, ksks,
big_lwe_dimension, small_lwe_dimension,
ks_base_log, ks_level, num_radix_blocks);
execute_keyswitch_async<Torus>(streams, gpu_indexes, active_gpu_count,
lwe_after_ks_vec, lwe_trivial_indexes_vec,
lwe_array_in_vec, lwe_trivial_indexes_vec,
ksks, big_lwe_dimension, small_lwe_dimension,
ks_base_log, ks_level, num_radix_blocks);

/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
/// dimension to a big LWE dimension
execute_pbs<Torus>(
execute_pbs_async<Torus>(
streams, gpu_indexes, active_gpu_count, lwe_after_pbs_vec,
lwe_trivial_indexes_vec, lut->lut_vec, lut->lut_indexes_vec,
lwe_after_ks_vec, lwe_trivial_indexes_vec, bsks, lut->buffer,
Expand All @@ -208,11 +208,11 @@ __host__ void integer_radix_apply_univariate_lookup_table_kb(
cuda_get_max_shared_memory(gpu_indexes[0]), pbs_type);

/// Copy data back to GPU 0 and release vecs
multi_gpu_gather_lwe<Torus>(streams, gpu_indexes, active_gpu_count,
lwe_array_out, lwe_after_pbs_vec,
lut->h_lwe_indexes_out,
lut->using_trivial_lwe_indexes,
num_radix_blocks, big_lwe_dimension + 1);
multi_gpu_gather_lwe_async<Torus>(streams, gpu_indexes, active_gpu_count,
lwe_array_out, lwe_after_pbs_vec,
lut->h_lwe_indexes_out,
lut->using_trivial_lwe_indexes,
num_radix_blocks, big_lwe_dimension + 1);

/// Synchronize all GPUs
for (uint i = 0; i < active_gpu_count; i++) {
Expand Down Expand Up @@ -257,38 +257,38 @@ __host__ void integer_radix_apply_bivariate_lookup_table_kb(

auto active_gpu_count = get_active_gpu_count(num_radix_blocks, gpu_count);
if (active_gpu_count == 1) {
execute_keyswitch<Torus>(streams, gpu_indexes, 1, lwe_after_ks_vec[0],
lwe_trivial_indexes_vec[0], lwe_array_pbs_in,
lut->lwe_indexes_in, ksks, big_lwe_dimension,
small_lwe_dimension, ks_base_log, ks_level,
num_radix_blocks);
execute_keyswitch_async<Torus>(streams, gpu_indexes, 1, lwe_after_ks_vec[0],
lwe_trivial_indexes_vec[0], lwe_array_pbs_in,
lut->lwe_indexes_in, ksks, big_lwe_dimension,
small_lwe_dimension, ks_base_log, ks_level,
num_radix_blocks);

/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
/// dimension to a big LWE dimension
execute_pbs<Torus>(streams, gpu_indexes, 1, lwe_array_out,
lut->lwe_indexes_out, lut->lut_vec, lut->lut_indexes_vec,
lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], bsks,
lut->buffer, glwe_dimension, small_lwe_dimension,
polynomial_size, pbs_base_log, pbs_level,
grouping_factor, num_radix_blocks, 1, 0,
cuda_get_max_shared_memory(gpu_indexes[0]), pbs_type);
execute_pbs_async<Torus>(
streams, gpu_indexes, 1, lwe_array_out, lut->lwe_indexes_out,
lut->lut_vec, lut->lut_indexes_vec, lwe_after_ks_vec[0],
lwe_trivial_indexes_vec[0], bsks, lut->buffer, glwe_dimension,
small_lwe_dimension, polynomial_size, pbs_base_log, pbs_level,
grouping_factor, num_radix_blocks, 1, 0,
cuda_get_max_shared_memory(gpu_indexes[0]), pbs_type);
} else {
cuda_synchronize_stream(streams[0], gpu_indexes[0]);
multi_gpu_scatter_lwe<Torus>(
multi_gpu_scatter_lwe_async<Torus>(
streams, gpu_indexes, active_gpu_count, lwe_array_in_vec,
lwe_array_pbs_in, lut->h_lwe_indexes_in, lut->using_trivial_lwe_indexes,
num_radix_blocks, big_lwe_dimension + 1);

/// Apply KS to go from a big LWE dimension to a small LWE dimension
execute_keyswitch<Torus>(streams, gpu_indexes, active_gpu_count,
lwe_after_ks_vec, lwe_trivial_indexes_vec,
lwe_array_in_vec, lwe_trivial_indexes_vec, ksks,
big_lwe_dimension, small_lwe_dimension,
ks_base_log, ks_level, num_radix_blocks);
execute_keyswitch_async<Torus>(streams, gpu_indexes, active_gpu_count,
lwe_after_ks_vec, lwe_trivial_indexes_vec,
lwe_array_in_vec, lwe_trivial_indexes_vec,
ksks, big_lwe_dimension, small_lwe_dimension,
ks_base_log, ks_level, num_radix_blocks);

/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
/// dimension to a big LWE dimension
execute_pbs<Torus>(
execute_pbs_async<Torus>(
streams, gpu_indexes, active_gpu_count, lwe_after_pbs_vec,
lwe_trivial_indexes_vec, lut->lut_vec, lut->lut_indexes_vec,
lwe_after_ks_vec, lwe_trivial_indexes_vec, bsks, lut->buffer,
Expand All @@ -297,11 +297,11 @@ __host__ void integer_radix_apply_bivariate_lookup_table_kb(
cuda_get_max_shared_memory(gpu_indexes[0]), pbs_type);

/// Copy data back to GPU 0 and release vecs
multi_gpu_gather_lwe<Torus>(streams, gpu_indexes, active_gpu_count,
lwe_array_out, lwe_after_pbs_vec,
lut->h_lwe_indexes_out,
lut->using_trivial_lwe_indexes,
num_radix_blocks, big_lwe_dimension + 1);
multi_gpu_gather_lwe_async<Torus>(streams, gpu_indexes, active_gpu_count,
lwe_array_out, lwe_after_pbs_vec,
lut->h_lwe_indexes_out,
lut->using_trivial_lwe_indexes,
num_radix_blocks, big_lwe_dimension + 1);

/// Synchronize all GPUs
for (uint i = 0; i < active_gpu_count; i++) {
Expand Down Expand Up @@ -672,7 +672,7 @@ void host_full_propagate_inplace(cudaStream_t *streams, uint32_t *gpu_indexes,

cudaSetDevice(gpu_indexes[0]);
/// Since the keyswitch is done on one input only, use only 1 GPU
execute_keyswitch<Torus>(
execute_keyswitch_async<Torus>(
streams, gpu_indexes, 1, mem_ptr->tmp_small_lwe_vector,
mem_ptr->lut->lwe_trivial_indexes, cur_input_block,
mem_ptr->lut->lwe_trivial_indexes, ksks, params.big_lwe_dimension,
Expand All @@ -683,7 +683,7 @@ void host_full_propagate_inplace(cudaStream_t *streams, uint32_t *gpu_indexes,
small_lwe_size * sizeof(Torus), streams[0],
gpu_indexes[0]);

execute_pbs<Torus>(
execute_pbs_async<Torus>(
streams, gpu_indexes, 1, mem_ptr->tmp_big_lwe_vector,
mem_ptr->lut->lwe_trivial_indexes, mem_ptr->lut->lut_vec,
mem_ptr->lut->lut_indexes_vec, mem_ptr->tmp_small_lwe_vector,
Expand Down
50 changes: 34 additions & 16 deletions backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -355,15 +355,15 @@ __host__ void host_integer_sum_ciphertexts_vec_kb(
/// After this keyswitch execution, we need to synchronize the streams
/// because the keyswitch and PBS do not operate on the same number of
/// inputs
execute_keyswitch<Torus>(streams, gpu_indexes, 1, small_lwe_vector,
lwe_indexes_in, new_blocks, lwe_indexes_in, ksks,
polynomial_size * glwe_dimension,
small_lwe_dimension, mem_ptr->params.ks_base_log,
mem_ptr->params.ks_level, message_count);
execute_keyswitch_async<Torus>(
streams, gpu_indexes, 1, small_lwe_vector, lwe_indexes_in, new_blocks,
lwe_indexes_in, ksks, polynomial_size * glwe_dimension,
small_lwe_dimension, mem_ptr->params.ks_base_log,
mem_ptr->params.ks_level, message_count);

/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
/// dimension to a big LWE dimension
execute_pbs<Torus>(
execute_pbs_async<Torus>(
streams, gpu_indexes, 1, new_blocks, lwe_indexes_out,
luts_message_carry->lut_vec, luts_message_carry->lut_indexes_vec,
small_lwe_vector, lwe_indexes_in, bsks, luts_message_carry->buffer,
Expand All @@ -374,25 +374,43 @@ __host__ void host_integer_sum_ciphertexts_vec_kb(
} else {
cuda_synchronize_stream(streams[0], gpu_indexes[0]);

multi_gpu_scatter_lwe<Torus>(
multi_gpu_scatter_lwe_async<Torus>(
streams, gpu_indexes, active_gpu_count, new_blocks_vec, new_blocks,
luts_message_carry->h_lwe_indexes_in,
luts_message_carry->using_trivial_lwe_indexes, total_count,
luts_message_carry->using_trivial_lwe_indexes, message_count,
big_lwe_size);

/// Apply KS to go from a big LWE dimension to a small LWE dimension
/// After this keyswitch execution, we need to synchronize the streams
/// because the keyswitch and PBS do not operate on the same number of
/// inputs
execute_keyswitch<Torus>(
execute_keyswitch_async<Torus>(
streams, gpu_indexes, active_gpu_count, small_lwe_vector_vec,
lwe_trivial_indexes_vec, new_blocks_vec, lwe_trivial_indexes_vec,
ksks, big_lwe_dimension, small_lwe_dimension,
mem_ptr->params.ks_base_log, mem_ptr->params.ks_level, total_count);

/// Copy data back to GPU 0, rebuild the lwe array, and scatter again on a
/// different configuration
multi_gpu_gather_lwe_async<Torus>(
streams, gpu_indexes, gpu_count, small_lwe_vector,
small_lwe_vector_vec, luts_message_carry->h_lwe_indexes_in,
luts_message_carry->using_trivial_lwe_indexes, message_count,
small_lwe_size);
/// Synchronize all GPUs
for (uint i = 0; i < active_gpu_count; i++) {
cuda_synchronize_stream(streams[i], gpu_indexes[i]);
}

multi_gpu_scatter_lwe_async<Torus>(
streams, gpu_indexes, gpu_count, small_lwe_vector_vec,
small_lwe_vector, luts_message_carry->h_lwe_indexes_in,
luts_message_carry->using_trivial_lwe_indexes, total_count,
small_lwe_size);

/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
/// dimension to a big LWE dimension
execute_pbs<Torus>(
execute_pbs_async<Torus>(
streams, gpu_indexes, active_gpu_count, lwe_after_pbs_vec,
lwe_trivial_indexes_vec, luts_message_carry->lut_vec,
luts_message_carry->lut_indexes_vec, small_lwe_vector_vec,
Expand All @@ -402,13 +420,13 @@ __host__ void host_integer_sum_ciphertexts_vec_kb(
mem_ptr->params.grouping_factor, total_count, 2, 0, max_shared_memory,
mem_ptr->params.pbs_type);

multi_gpu_gather_lwe<Torus>(streams, gpu_indexes, active_gpu_count,
new_blocks, lwe_after_pbs_vec,
luts_message_carry->h_lwe_indexes_out,
luts_message_carry->using_trivial_lwe_indexes,
total_count, big_lwe_size);
multi_gpu_gather_lwe_async<Torus>(
streams, gpu_indexes, active_gpu_count, new_blocks, lwe_after_pbs_vec,
luts_message_carry->h_lwe_indexes_out,
luts_message_carry->using_trivial_lwe_indexes, total_count,
big_lwe_size);
/// Synchronize all GPUs
for (uint i = 1; i < active_gpu_count; i++) {
for (uint i = 0; i < active_gpu_count; i++) {
cuda_synchronize_stream(streams[i], gpu_indexes[i]);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ mul_ggsw_glwe(Torus *accumulator, double2 *fft, double2 *join_buffer,
}

template <typename Torus>
void execute_pbs(
void execute_pbs_async(
cudaStream_t *streams, uint32_t *gpu_indexes, uint32_t gpu_count,
const LweArrayVariant<Torus> &lwe_array_out,
const LweArrayVariant<Torus> &lwe_output_indexes,
Expand Down
Loading

0 comments on commit c5dccf3

Please sign in to comment.