From 8993002a9143ed15aade9569ff413ac5860c63a0 Mon Sep 17 00:00:00 2001 From: Agnes Leroy Date: Fri, 27 Sep 2024 17:33:28 +0200 Subject: [PATCH] chore(gpu): refactor lwe_chunk_size --- .../include/programmable_bootstrap_multibit.h | 12 +- .../cuda/src/pbs/programmable_bootstrap.cuh | 5 +- .../programmable_bootstrap_cg_multibit.cuh | 13 +- .../pbs/programmable_bootstrap_multibit.cu | 115 +++++++++--------- .../pbs/programmable_bootstrap_multibit.cuh | 23 ++-- .../programmable_bootstrap_tbc_multibit.cuh | 21 ++-- .../benchmarks/benchmark_pbs.cpp | 8 +- .../tests/test_multibit_pbs.cpp | 4 +- backends/tfhe-cuda-backend/src/cuda_bind.rs | 2 - tfhe/src/core_crypto/gpu/mod.rs | 2 - 10 files changed, 93 insertions(+), 112 deletions(-) diff --git a/backends/tfhe-cuda-backend/cuda/include/programmable_bootstrap_multibit.h b/backends/tfhe-cuda-backend/cuda/include/programmable_bootstrap_multibit.h index c72ee20b15..51b2a62040 100644 --- a/backends/tfhe-cuda-backend/cuda/include/programmable_bootstrap_multibit.h +++ b/backends/tfhe-cuda-backend/cuda/include/programmable_bootstrap_multibit.h @@ -17,8 +17,7 @@ void cuda_convert_lwe_multi_bit_programmable_bootstrap_key_64( void scratch_cuda_multi_bit_programmable_bootstrap_64( void *stream, uint32_t gpu_index, int8_t **pbs_buffer, - uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, - uint32_t level_count, uint32_t grouping_factor, + uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count, uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory); void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector_64( @@ -48,8 +47,7 @@ bool has_support_to_cuda_programmable_bootstrap_tbc_multi_bit( template void scratch_cuda_tbc_multi_bit_programmable_bootstrap( void *stream, uint32_t gpu_index, pbs_buffer **buffer, - uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, - uint32_t level_count, uint32_t grouping_factor, + uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count, uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory); template @@ -82,8 +80,7 @@ void cuda_cg_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( template void scratch_cuda_multi_bit_programmable_bootstrap( void *stream, uint32_t gpu_index, pbs_buffer **pbs_buffer, - uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, - uint32_t level_count, uint32_t grouping_factor, + uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count, uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory); template @@ -130,7 +127,7 @@ template struct pbs_buffer { int8_t *d_mem_acc_step_two = NULL; int8_t *d_mem_acc_cg = NULL; int8_t *d_mem_acc_tbc = NULL; - + uint32_t lwe_chunk_size; double2 *keybundle_fft; Torus *global_accumulator; double2 *global_accumulator_fft; @@ -142,6 +139,7 @@ template struct pbs_buffer { uint32_t input_lwe_ciphertext_count, uint32_t lwe_chunk_size, PBS_VARIANT pbs_variant, bool allocate_gpu_memory) { this->pbs_variant = pbs_variant; + this->lwe_chunk_size = lwe_chunk_size; auto max_shared_memory = cuda_get_max_shared_memory(gpu_index); // default diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap.cuh index e3bf1471b7..459a496d11 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap.cuh @@ -271,9 +271,8 @@ void execute_scratch_pbs(cudaStream_t stream, uint32_t gpu_index, if (grouping_factor == 0) PANIC("Multi-bit PBS error: grouping factor should be > 0.") scratch_cuda_multi_bit_programmable_bootstrap_64( - stream, gpu_index, pbs_buffer, lwe_dimension, glwe_dimension, - polynomial_size, level_count, grouping_factor, - input_lwe_ciphertext_count, allocate_gpu_memory); + stream, gpu_index, pbs_buffer, glwe_dimension, polynomial_size, + level_count, input_lwe_ciphertext_count, allocate_gpu_memory); break; case CLASSICAL: scratch_cuda_programmable_bootstrap_64( diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_multibit.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_multibit.cuh index d17a953151..f26f45b810 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_multibit.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_multibit.cuh @@ -229,9 +229,9 @@ __host__ void execute_cg_external_product_loop( pbs_buffer *buffer, uint32_t num_samples, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor, uint32_t base_log, uint32_t level_count, - uint32_t lwe_chunk_size, uint32_t lwe_offset, uint32_t lut_count, - uint32_t lut_stride) { + uint32_t lwe_offset, uint32_t lut_count, uint32_t lut_stride) { + auto lwe_chunk_size = buffer->lwe_chunk_size; uint64_t full_dm = get_buffer_size_full_sm_cg_multibit_programmable_bootstrap( polynomial_size); @@ -314,8 +314,7 @@ __host__ void host_cg_multi_bit_programmable_bootstrap( uint32_t base_log, uint32_t level_count, uint32_t num_samples, uint32_t lut_count, uint32_t lut_stride) { - auto lwe_chunk_size = get_lwe_chunk_size( - gpu_index, num_samples, polynomial_size); + auto lwe_chunk_size = buffer->lwe_chunk_size; for (uint32_t lwe_offset = 0; lwe_offset < (lwe_dimension / grouping_factor); lwe_offset += lwe_chunk_size) { @@ -324,15 +323,15 @@ __host__ void host_cg_multi_bit_programmable_bootstrap( execute_compute_keybundle( stream, gpu_index, lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer, num_samples, lwe_dimension, glwe_dimension, polynomial_size, - grouping_factor, base_log, level_count, lwe_chunk_size, lwe_offset); + grouping_factor, level_count, lwe_offset); // Accumulate execute_cg_external_product_loop( stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, lwe_array_out, lwe_output_indexes, buffer, num_samples, lwe_dimension, glwe_dimension, polynomial_size, - grouping_factor, base_log, level_count, lwe_chunk_size, lwe_offset, - lut_count, lut_stride); + grouping_factor, base_log, level_count, lwe_offset, lut_count, + lut_stride); } } diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cu b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cu index 98f8074a09..a538695f63 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cu +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cu @@ -331,52 +331,51 @@ void scratch_cuda_cg_multi_bit_programmable_bootstrap( template void scratch_cuda_multi_bit_programmable_bootstrap( void *stream, uint32_t gpu_index, pbs_buffer **buffer, - uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, - uint32_t level_count, uint32_t grouping_factor, + uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count, uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory) { switch (polynomial_size) { case 256: scratch_multi_bit_programmable_bootstrap>( - static_cast(stream), gpu_index, buffer, lwe_dimension, - glwe_dimension, polynomial_size, level_count, - input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory); + static_cast(stream), gpu_index, buffer, glwe_dimension, + polynomial_size, level_count, input_lwe_ciphertext_count, + allocate_gpu_memory); break; case 512: scratch_multi_bit_programmable_bootstrap>( - static_cast(stream), gpu_index, buffer, lwe_dimension, - glwe_dimension, polynomial_size, level_count, - input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory); + static_cast(stream), gpu_index, buffer, glwe_dimension, + polynomial_size, level_count, input_lwe_ciphertext_count, + allocate_gpu_memory); break; case 1024: scratch_multi_bit_programmable_bootstrap>( - static_cast(stream), gpu_index, buffer, lwe_dimension, - glwe_dimension, polynomial_size, level_count, - input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory); + static_cast(stream), gpu_index, buffer, glwe_dimension, + polynomial_size, level_count, input_lwe_ciphertext_count, + allocate_gpu_memory); break; case 2048: scratch_multi_bit_programmable_bootstrap>( - static_cast(stream), gpu_index, buffer, lwe_dimension, - glwe_dimension, polynomial_size, level_count, - input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory); + static_cast(stream), gpu_index, buffer, glwe_dimension, + polynomial_size, level_count, input_lwe_ciphertext_count, + allocate_gpu_memory); break; case 4096: scratch_multi_bit_programmable_bootstrap>( - static_cast(stream), gpu_index, buffer, lwe_dimension, - glwe_dimension, polynomial_size, level_count, - input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory); + static_cast(stream), gpu_index, buffer, glwe_dimension, + polynomial_size, level_count, input_lwe_ciphertext_count, + allocate_gpu_memory); break; case 8192: scratch_multi_bit_programmable_bootstrap>( - static_cast(stream), gpu_index, buffer, lwe_dimension, - glwe_dimension, polynomial_size, level_count, - input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory); + static_cast(stream), gpu_index, buffer, glwe_dimension, + polynomial_size, level_count, input_lwe_ciphertext_count, + allocate_gpu_memory); break; case 16384: scratch_multi_bit_programmable_bootstrap>( - static_cast(stream), gpu_index, buffer, lwe_dimension, - glwe_dimension, polynomial_size, level_count, - input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory); + static_cast(stream), gpu_index, buffer, glwe_dimension, + polynomial_size, level_count, input_lwe_ciphertext_count, + allocate_gpu_memory); break; default: PANIC("Cuda error (multi-bit PBS): unsupported polynomial size. Supported " @@ -386,10 +385,9 @@ void scratch_cuda_multi_bit_programmable_bootstrap( } void scratch_cuda_multi_bit_programmable_bootstrap_64( - void *stream, uint32_t gpu_index, int8_t **buffer, uint32_t lwe_dimension, - uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count, - uint32_t grouping_factor, uint32_t input_lwe_ciphertext_count, - bool allocate_gpu_memory) { + void *stream, uint32_t gpu_index, int8_t **buffer, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t level_count, + uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory) { #if (CUDA_ARCH >= 900) if (has_support_to_cuda_programmable_bootstrap_tbc_multi_bit( @@ -397,8 +395,8 @@ void scratch_cuda_multi_bit_programmable_bootstrap_64( level_count)) scratch_cuda_tbc_multi_bit_programmable_bootstrap( stream, gpu_index, (pbs_buffer **)buffer, - lwe_dimension, glwe_dimension, polynomial_size, level_count, - grouping_factor, input_lwe_ciphertext_count, allocate_gpu_memory); + glwe_dimension, polynomial_size, level_count, + input_lwe_ciphertext_count, allocate_gpu_memory); else #endif if (supports_cooperative_groups_on_multibit_programmable_bootstrap< @@ -411,8 +409,8 @@ void scratch_cuda_multi_bit_programmable_bootstrap_64( else scratch_cuda_multi_bit_programmable_bootstrap( stream, gpu_index, (pbs_buffer **)buffer, - lwe_dimension, glwe_dimension, polynomial_size, level_count, - grouping_factor, input_lwe_ciphertext_count, allocate_gpu_memory); + glwe_dimension, polynomial_size, level_count, + input_lwe_ciphertext_count, allocate_gpu_memory); } void cleanup_cuda_multi_bit_programmable_bootstrap(void *stream, @@ -490,10 +488,9 @@ uint32_t get_lwe_chunk_size(uint32_t gpu_index, uint32_t max_num_pbs, template void scratch_cuda_multi_bit_programmable_bootstrap( void *stream, uint32_t gpu_index, - pbs_buffer **pbs_buffer, uint32_t lwe_dimension, - uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count, - uint32_t grouping_factor, uint32_t input_lwe_ciphertext_count, - bool allocate_gpu_memory); + pbs_buffer **pbs_buffer, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t level_count, + uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory); template void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( @@ -532,52 +529,51 @@ has_support_to_cuda_programmable_bootstrap_tbc_multi_bit( template void scratch_cuda_tbc_multi_bit_programmable_bootstrap( void *stream, uint32_t gpu_index, pbs_buffer **buffer, - uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, - uint32_t level_count, uint32_t grouping_factor, + uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count, uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory) { switch (polynomial_size) { case 256: scratch_tbc_multi_bit_programmable_bootstrap>( - static_cast(stream), gpu_index, buffer, lwe_dimension, - glwe_dimension, polynomial_size, level_count, - input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory); + static_cast(stream), gpu_index, buffer, glwe_dimension, + polynomial_size, level_count, input_lwe_ciphertext_count, + allocate_gpu_memory); break; case 512: scratch_tbc_multi_bit_programmable_bootstrap>( - static_cast(stream), gpu_index, buffer, lwe_dimension, - glwe_dimension, polynomial_size, level_count, - input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory); + static_cast(stream), gpu_index, buffer, glwe_dimension, + polynomial_size, level_count, input_lwe_ciphertext_count, + allocate_gpu_memory); break; case 1024: scratch_tbc_multi_bit_programmable_bootstrap>( - static_cast(stream), gpu_index, buffer, lwe_dimension, - glwe_dimension, polynomial_size, level_count, - input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory); + static_cast(stream), gpu_index, buffer, glwe_dimension, + polynomial_size, level_count, input_lwe_ciphertext_count, + allocate_gpu_memory); break; case 2048: scratch_tbc_multi_bit_programmable_bootstrap>( - static_cast(stream), gpu_index, buffer, lwe_dimension, - glwe_dimension, polynomial_size, level_count, - input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory); + static_cast(stream), gpu_index, buffer, glwe_dimension, + polynomial_size, level_count, input_lwe_ciphertext_count, + allocate_gpu_memory); break; case 4096: scratch_tbc_multi_bit_programmable_bootstrap>( - static_cast(stream), gpu_index, buffer, lwe_dimension, - glwe_dimension, polynomial_size, level_count, - input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory); + static_cast(stream), gpu_index, buffer, glwe_dimension, + polynomial_size, level_count, input_lwe_ciphertext_count, + allocate_gpu_memory); break; case 8192: scratch_tbc_multi_bit_programmable_bootstrap>( - static_cast(stream), gpu_index, buffer, lwe_dimension, - glwe_dimension, polynomial_size, level_count, - input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory); + static_cast(stream), gpu_index, buffer, glwe_dimension, + polynomial_size, level_count, input_lwe_ciphertext_count, + allocate_gpu_memory); break; case 16384: scratch_tbc_multi_bit_programmable_bootstrap>( - static_cast(stream), gpu_index, buffer, lwe_dimension, - glwe_dimension, polynomial_size, level_count, - input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory); + static_cast(stream), gpu_index, buffer, glwe_dimension, + polynomial_size, level_count, input_lwe_ciphertext_count, + allocate_gpu_memory); break; default: PANIC("Cuda error (multi-bit PBS): unsupported polynomial size. Supported " @@ -679,8 +675,7 @@ void cuda_tbc_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( template void scratch_cuda_tbc_multi_bit_programmable_bootstrap( void *stream, uint32_t gpu_index, pbs_buffer **buffer, - uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, - uint32_t level_count, uint32_t grouping_factor, + uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count, uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory); template void diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh index 74b3669479..455233f057 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh @@ -385,10 +385,9 @@ uint64_t get_buffer_size_full_sm_multibit_programmable_bootstrap_step_two( template __host__ void scratch_multi_bit_programmable_bootstrap( cudaStream_t stream, uint32_t gpu_index, - pbs_buffer **buffer, uint32_t lwe_dimension, - uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count, - uint32_t input_lwe_ciphertext_count, uint32_t grouping_factor, - bool allocate_gpu_memory) { + pbs_buffer **buffer, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t level_count, + uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory) { auto lwe_chunk_size = get_lwe_chunk_size( gpu_index, input_lwe_ciphertext_count, polynomial_size); @@ -404,9 +403,9 @@ __host__ void execute_compute_keybundle( Torus *lwe_input_indexes, Torus *bootstrapping_key, pbs_buffer *buffer, uint32_t num_samples, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, - uint32_t grouping_factor, uint32_t base_log, uint32_t level_count, - uint32_t lwe_chunk_size, uint32_t lwe_offset) { + uint32_t grouping_factor, uint32_t level_count, uint32_t lwe_offset) { + auto lwe_chunk_size = buffer->lwe_chunk_size; uint32_t chunk_size = std::min(lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset); if (chunk_size == 0) @@ -507,9 +506,9 @@ __host__ void execute_step_two( Torus *lwe_output_indexes, pbs_buffer *buffer, uint32_t num_samples, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, int32_t grouping_factor, uint32_t level_count, - uint32_t j, uint32_t lwe_offset, uint32_t lwe_chunk_size, - uint32_t lut_count, uint32_t lut_stride) { + uint32_t j, uint32_t lwe_offset, uint32_t lut_count, uint32_t lut_stride) { + auto lwe_chunk_size = buffer->lwe_chunk_size; uint64_t full_sm_accumulate_step_two = get_buffer_size_full_sm_multibit_programmable_bootstrap_step_two( polynomial_size); @@ -555,8 +554,7 @@ __host__ void host_multi_bit_programmable_bootstrap( uint32_t base_log, uint32_t level_count, uint32_t num_samples, uint32_t lut_count, uint32_t lut_stride) { - auto lwe_chunk_size = get_lwe_chunk_size( - gpu_index, num_samples, polynomial_size); + auto lwe_chunk_size = buffer->lwe_chunk_size; for (uint32_t lwe_offset = 0; lwe_offset < (lwe_dimension / grouping_factor); lwe_offset += lwe_chunk_size) { @@ -565,7 +563,7 @@ __host__ void host_multi_bit_programmable_bootstrap( execute_compute_keybundle( stream, gpu_index, lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer, num_samples, lwe_dimension, glwe_dimension, polynomial_size, - grouping_factor, base_log, level_count, lwe_chunk_size, lwe_offset); + grouping_factor, level_count, lwe_offset); // Accumulate uint32_t chunk_size = std::min( lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset); @@ -578,8 +576,7 @@ __host__ void host_multi_bit_programmable_bootstrap( execute_step_two( stream, gpu_index, lwe_array_out, lwe_output_indexes, buffer, num_samples, lwe_dimension, glwe_dimension, polynomial_size, - grouping_factor, level_count, j, lwe_offset, lwe_chunk_size, - lut_count, lut_stride); + grouping_factor, level_count, j, lwe_offset, lut_count, lut_stride); } } } diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_tbc_multibit.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_tbc_multibit.cuh index b1fac308ac..1e839e6c5b 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_tbc_multibit.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_tbc_multibit.cuh @@ -199,10 +199,9 @@ uint64_t get_buffer_size_full_sm_tbc_multibit_programmable_bootstrap( template __host__ void scratch_tbc_multi_bit_programmable_bootstrap( cudaStream_t stream, uint32_t gpu_index, - pbs_buffer **buffer, uint32_t lwe_dimension, - uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count, - uint32_t input_lwe_ciphertext_count, uint32_t grouping_factor, - bool allocate_gpu_memory) { + pbs_buffer **buffer, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t level_count, + uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory) { auto lwe_chunk_size = get_lwe_chunk_size( gpu_index, input_lwe_ciphertext_count, polynomial_size); @@ -220,9 +219,9 @@ __host__ void execute_tbc_external_product_loop( pbs_buffer *buffer, uint32_t num_samples, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor, uint32_t base_log, uint32_t level_count, - uint32_t lwe_chunk_size, uint32_t lwe_offset, uint32_t lut_count, - uint32_t lut_stride) { + uint32_t lwe_offset, uint32_t lut_count, uint32_t lut_stride) { + auto lwe_chunk_size = buffer->lwe_chunk_size; auto supports_dsm = supports_distributed_shared_memory_on_multibit_programmable_bootstrap< Torus>(polynomial_size); @@ -325,9 +324,7 @@ __host__ void host_tbc_multi_bit_programmable_bootstrap( uint32_t lut_count, uint32_t lut_stride) { cudaSetDevice(gpu_index); - auto lwe_chunk_size = get_lwe_chunk_size( - gpu_index, num_samples, polynomial_size); - + auto lwe_chunk_size = buffer->lwe_chunk_size; for (uint32_t lwe_offset = 0; lwe_offset < (lwe_dimension / grouping_factor); lwe_offset += lwe_chunk_size) { @@ -335,15 +332,15 @@ __host__ void host_tbc_multi_bit_programmable_bootstrap( execute_compute_keybundle( stream, gpu_index, lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer, num_samples, lwe_dimension, glwe_dimension, polynomial_size, - grouping_factor, base_log, level_count, lwe_chunk_size, lwe_offset); + grouping_factor, level_count, lwe_offset); // Accumulate execute_tbc_external_product_loop( stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, lwe_array_out, lwe_output_indexes, buffer, num_samples, lwe_dimension, glwe_dimension, polynomial_size, - grouping_factor, base_log, level_count, lwe_chunk_size, lwe_offset, - lut_count, lut_stride); + grouping_factor, base_log, level_count, lwe_offset, lut_count, + lut_stride); } } diff --git a/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/benchmarks/benchmark_pbs.cpp b/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/benchmarks/benchmark_pbs.cpp index 8ad5831eeb..315dbe6a8c 100644 --- a/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/benchmarks/benchmark_pbs.cpp +++ b/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/benchmarks/benchmark_pbs.cpp @@ -177,8 +177,8 @@ BENCHMARK_DEFINE_F(MultiBitBootstrap_u64, TbcMultiBit) scratch_cuda_tbc_multi_bit_programmable_bootstrap( stream, gpu_index, (pbs_buffer **)&buffer, - lwe_dimension, glwe_dimension, polynomial_size, pbs_level, - grouping_factor, input_lwe_ciphertext_count, true); + glwe_dimension, polynomial_size, pbs_level, input_lwe_ciphertext_count, + true); uint32_t lut_count = 1; uint32_t lut_stride = 0; for (auto _ : st) { @@ -231,8 +231,8 @@ BENCHMARK_DEFINE_F(MultiBitBootstrap_u64, DefaultMultiBit) (benchmark::State &st) { scratch_cuda_multi_bit_programmable_bootstrap( stream, gpu_index, (pbs_buffer **)&buffer, - lwe_dimension, glwe_dimension, polynomial_size, pbs_level, - grouping_factor, input_lwe_ciphertext_count, true); + glwe_dimension, polynomial_size, pbs_level, input_lwe_ciphertext_count, + true); uint32_t lut_count = 1; uint32_t lut_stride = 0; for (auto _ : st) { diff --git a/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/tests/test_multibit_pbs.cpp b/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/tests/test_multibit_pbs.cpp index 4d05d87412..1adfd773b0 100644 --- a/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/tests/test_multibit_pbs.cpp +++ b/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/tests/test_multibit_pbs.cpp @@ -92,8 +92,8 @@ class MultiBitProgrammableBootstrapTestPrimitives_u64 &payload_modulus, &delta, number_of_inputs, repetitions, samples); scratch_cuda_multi_bit_programmable_bootstrap_64( - stream, gpu_index, &pbs_buffer, lwe_dimension, glwe_dimension, - polynomial_size, pbs_level, grouping_factor, number_of_inputs, true); + stream, gpu_index, &pbs_buffer, glwe_dimension, polynomial_size, + pbs_level, number_of_inputs, true); lwe_ct_out_array = (uint64_t *)malloc((glwe_dimension * polynomial_size + 1) * diff --git a/backends/tfhe-cuda-backend/src/cuda_bind.rs b/backends/tfhe-cuda-backend/src/cuda_bind.rs index 10dc381787..d6ca49755a 100644 --- a/backends/tfhe-cuda-backend/src/cuda_bind.rs +++ b/backends/tfhe-cuda-backend/src/cuda_bind.rs @@ -1194,11 +1194,9 @@ extern "C" { stream: *mut c_void, gpu_index: u32, pbs_buffer: *mut *mut i8, - lwe_dimension: u32, glwe_dimension: u32, polynomial_size: u32, level_count: u32, - grouping_factor: u32, input_lwe_ciphertext_count: u32, allocate_gpu_memory: bool, ); diff --git a/tfhe/src/core_crypto/gpu/mod.rs b/tfhe/src/core_crypto/gpu/mod.rs index 899a32cbc5..19868235e5 100644 --- a/tfhe/src/core_crypto/gpu/mod.rs +++ b/tfhe/src/core_crypto/gpu/mod.rs @@ -180,11 +180,9 @@ pub unsafe fn programmable_bootstrap_multi_bit_async( streams.ptr[0], streams.gpu_indexes[0], std::ptr::addr_of_mut!(pbs_buffer), - lwe_dimension.0 as u32, glwe_dimension.0 as u32, polynomial_size.0 as u32, level.0 as u32, - grouping_factor.0 as u32, num_samples, true, );