Skip to content

Commit

Permalink
fix(gpu): fix scalar rotate and add some checks
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Jul 26, 2024
1 parent 5e49727 commit 2674d17
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 10 deletions.
3 changes: 2 additions & 1 deletion backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -762,8 +762,9 @@ __host__ void pack_blocks(cudaStream_t stream, uint32_t gpu_index,
Torus *lwe_array_out, Torus *lwe_array_in,
uint32_t lwe_dimension, uint32_t num_radix_blocks,
uint32_t factor) {
if (num_radix_blocks == 0)
return;
cudaSetDevice(gpu_index);

int num_blocks = 0, num_threads = 0;
int num_entries = (lwe_dimension + 1);
getNumBlocksAndThreads(num_entries, 1024, num_blocks, num_threads);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ template <typename Torus, class params>
__global__ void fill_radix_from_lsb_msb(Torus *result_blocks, Torus *lsb_blocks,
Torus *msb_blocks,
uint32_t glwe_dimension,
uint32_t lsb_count, uint32_t msb_count,
uint32_t num_blocks) {
size_t big_lwe_dimension = glwe_dimension * params::degree + 1;
size_t big_lwe_id = blockIdx.x;
Expand Down Expand Up @@ -321,8 +320,7 @@ __host__ void host_integer_sum_ciphertexts_vec_kb(
luts_message_carry->set_lwe_indexes(streams[0], gpu_indexes[0],
h_lwe_idx_in, h_lwe_idx_out);

size_t copy_size = total_count * sizeof(Torus);
copy_size = sm_copy_count * sizeof(int32_t);
size_t copy_size = sm_copy_count * sizeof(int32_t);
cuda_memcpy_async_to_gpu(d_smart_copy_in, h_smart_copy_in, copy_size,
streams[0], gpu_indexes[0]);
cuda_memcpy_async_to_gpu(d_smart_copy_out, h_smart_copy_out, copy_size,
Expand Down Expand Up @@ -548,8 +546,7 @@ __host__ void host_integer_mult_radix_kb(
fill_radix_from_lsb_msb<Torus, params>
<<<num_blocks * num_blocks, params::degree / params::opt, 0,
streams[0]>>>(vector_result_sb, vector_result_lsb, vector_result_msb,
glwe_dimension, lsb_vector_block_count,
msb_vector_block_count, num_blocks);
glwe_dimension, num_blocks);
check_cuda_error(cudaGetLastError());

int terms_degree[2 * num_blocks * num_blocks];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,8 @@ __host__ void scalar_compare_radix_blocks_kb(
int_comparison_buffer<Torus> *mem_ptr, void **bsks, Torus **ksks,
uint32_t num_radix_blocks) {

if (num_radix_blocks == 0)
return;
auto params = mem_ptr->params;
auto big_lwe_dimension = params.big_lwe_dimension;
auto message_modulus = params.message_modulus;
Expand Down
6 changes: 4 additions & 2 deletions backends/tfhe-cuda-backend/cuda/src/integer/scalar_rotate.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ __host__ void host_integer_radix_scalar_rotate_kb_inplace(

Torus *rotated_buffer = mem->tmp_rotated;

auto lut_bivariate = mem->lut_buffers_bivariate[shift_within_block - 1];

// rotate right all the blocks in radix ciphertext
// copy result in new buffer
// 256 threads are used in every block
Expand All @@ -76,6 +74,8 @@ __host__ void host_integer_radix_scalar_rotate_kb_inplace(
giver_blocks, lwe_array, 1, num_blocks,
big_lwe_size);

auto lut_bivariate = mem->lut_buffers_bivariate[shift_within_block - 1];

integer_radix_apply_bivariate_lookup_table_kb<Torus>(
streams, gpu_indexes, gpu_count, lwe_array, receiver_blocks,
giver_blocks, bsks, ksks, num_blocks, lut_bivariate,
Expand All @@ -100,6 +100,8 @@ __host__ void host_integer_radix_scalar_rotate_kb_inplace(
host_radix_blocks_rotate_left(streams, gpu_indexes, gpu_count, giver_blocks,
lwe_array, 1, num_blocks, big_lwe_size);

auto lut_bivariate = mem->lut_buffers_bivariate[shift_within_block - 1];

integer_radix_apply_bivariate_lookup_table_kb<Torus>(
streams, gpu_indexes, gpu_count, lwe_array, receiver_blocks,
giver_blocks, bsks, ksks, num_blocks, lut_bivariate,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ __host__ void host_integer_radix_shift_and_rotate_kb_inplace(
auto big_lwe_size = big_lwe_dimension + 1;
auto big_lwe_size_bytes = big_lwe_size * sizeof(Torus);

cudaSetDevice(gpu_indexes[0]);

// Extract all bits
auto bits = mem->tmp_bits;
extract_n_bits<Torus>(streams, gpu_indexes, gpu_count, bits, lwe_array, bsks,
Expand Down

0 comments on commit 2674d17

Please sign in to comment.