From 6120fab886be483c126e94910a338bc2f9c9d28c Mon Sep 17 00:00:00 2001 From: Beka Barbakadze Date: Wed, 26 Jun 2024 15:55:33 +0400 Subject: [PATCH] feat(gpu): Implement propagate_single_carry_get_input_carries --- .../tfhe-cuda-backend/cuda/include/integer.h | 5 ++ .../cuda/src/integer/integer.cu | 12 +++ .../cuda/src/integer/integer.cuh | 8 +- .../cuda/src/integer/multiplication.cuh | 4 +- backends/tfhe-cuda-backend/src/cuda_bind.rs | 13 +++ tfhe/src/integer/gpu/mod.rs | 85 +++++++++++++++++++ tfhe/src/integer/gpu/server_key/radix/mod.rs | 77 ++++++++++++++++- 7 files changed, 200 insertions(+), 4 deletions(-) diff --git a/backends/tfhe-cuda-backend/cuda/include/integer.h b/backends/tfhe-cuda-backend/cuda/include/integer.h index fef158918e..8789eac9bf 100644 --- a/backends/tfhe-cuda-backend/cuda/include/integer.h +++ b/backends/tfhe-cuda-backend/cuda/include/integer.h @@ -272,6 +272,11 @@ void cuda_propagate_single_carry_kb_64_inplace( void *carry_out, int8_t *mem_ptr, void **bsks, void **ksks, uint32_t num_blocks); +void cuda_propagate_single_carry_get_input_carries_kb_64_inplace( + void **streams, uint32_t *gpu_indexes, uint32_t gpu_count, void *lwe_array, + void *carry_out, void *input_carries, int8_t *mem_ptr, void **bsks, + void **ksks, uint32_t num_blocks); + void cleanup_cuda_propagate_single_carry(void **streams, uint32_t *gpu_indexes, uint32_t gpu_count, int8_t **mem_ptr_void); diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/integer.cu b/backends/tfhe-cuda-backend/cuda/src/integer/integer.cu index 748a9ccabf..abaf9a7514 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/integer.cu +++ b/backends/tfhe-cuda-backend/cuda/src/integer/integer.cu @@ -68,6 +68,18 @@ void cuda_propagate_single_carry_kb_64_inplace( host_propagate_single_carry( (cudaStream_t *)(streams), gpu_indexes, gpu_count, static_cast(lwe_array), static_cast(carry_out), + nullptr, (int_sc_prop_memory *)mem_ptr, bsks, + (uint64_t **)(ksks), num_blocks); +} + +void cuda_propagate_single_carry_get_input_carries_kb_64_inplace( + void **streams, uint32_t *gpu_indexes, uint32_t gpu_count, void *lwe_array, + void *carry_out, void *input_carries, int8_t *mem_ptr, void **bsks, + void **ksks, uint32_t num_blocks) { + host_propagate_single_carry( + (cudaStream_t *)(streams), gpu_indexes, gpu_count, + static_cast(lwe_array), static_cast(carry_out), + static_cast(input_carries), (int_sc_prop_memory *)mem_ptr, bsks, (uint64_t **)(ksks), num_blocks); } diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh index c82a2c77d5..77b9a77178 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh @@ -427,7 +427,7 @@ void scratch_cuda_propagate_single_carry_kb_inplace( template void host_propagate_single_carry(cudaStream_t *streams, uint32_t *gpu_indexes, uint32_t gpu_count, Torus *lwe_array, - Torus *carry_out, + Torus *carry_out, Torus *input_carries, int_sc_prop_memory *mem, void **bsks, Torus **ksks, uint32_t num_blocks) { auto params = mem->params; @@ -482,6 +482,12 @@ void host_propagate_single_carry(cudaStream_t *streams, uint32_t *gpu_indexes, cuda_memset_async(step_output, 0, big_lwe_size_bytes, streams[0], gpu_indexes[0]); + if (input_carries != nullptr) { + cuda_memcpy_async_gpu_to_gpu(input_carries, step_output, + big_lwe_size_bytes * num_blocks, streams[0], + gpu_indexes[0]); + } + host_addition(streams[0], gpu_indexes[0], lwe_array, lwe_array, step_output, glwe_dimension * polynomial_size, num_blocks); diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh index d85ca6ebaa..cf45df6266 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh @@ -368,8 +368,8 @@ __host__ void host_integer_sum_ciphertexts_vec_kb( num_blocks); host_propagate_single_carry(streams, gpu_indexes, gpu_count, - radix_lwe_out, nullptr, mem_ptr->scp_mem, - bsks, ksks, num_blocks); + radix_lwe_out, nullptr, nullptr, + mem_ptr->scp_mem, bsks, ksks, num_blocks); } template diff --git a/backends/tfhe-cuda-backend/src/cuda_bind.rs b/backends/tfhe-cuda-backend/src/cuda_bind.rs index ece1148f86..215cb37728 100644 --- a/backends/tfhe-cuda-backend/src/cuda_bind.rs +++ b/backends/tfhe-cuda-backend/src/cuda_bind.rs @@ -1068,6 +1068,19 @@ extern "C" { num_blocks: u32, ); + pub fn cuda_propagate_single_carry_get_input_carries_kb_64_inplace( + streams: *const *mut c_void, + gpu_indexes: *const u32, + gpu_count: u32, + radix_lwe: *mut c_void, + carry_out: *mut c_void, + input_carries: *mut c_void, + mem_ptr: *mut i8, + bsks: *const *mut c_void, + ksks: *const *mut c_void, + num_blocks: u32, + ); + pub fn cleanup_cuda_propagate_single_carry( streams: *const *mut c_void, gpu_indexes: *const u32, diff --git a/tfhe/src/integer/gpu/mod.rs b/tfhe/src/integer/gpu/mod.rs index 32fa86b9ff..a0a3884e3e 100644 --- a/tfhe/src/integer/gpu/mod.rs +++ b/tfhe/src/integer/gpu/mod.rs @@ -933,6 +933,91 @@ pub unsafe fn propagate_single_carry_assign_async( + streams: &CudaStreams, + radix_lwe_input: &mut CudaVec, + carry_out: &mut CudaVec, + input_carries: &mut CudaVec, + bootstrapping_key: &CudaVec, + keyswitch_key: &CudaVec, + lwe_dimension: LweDimension, + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, + ks_level: DecompositionLevelCount, + ks_base_log: DecompositionBaseLog, + pbs_level: DecompositionLevelCount, + pbs_base_log: DecompositionBaseLog, + num_blocks: u32, + message_modulus: MessageModulus, + carry_modulus: CarryModulus, + pbs_type: PBSType, + grouping_factor: LweBskGroupingFactor, +) { + assert_eq!( + streams.gpu_indexes[0], + radix_lwe_input.gpu_index(0), + "GPU error: all data should reside on the same GPU." + ); + assert_eq!( + streams.gpu_indexes[0], + bootstrapping_key.gpu_index(0), + "GPU error: all data should reside on the same GPU." + ); + assert_eq!( + streams.gpu_indexes[0], + keyswitch_key.gpu_index(0), + "GPU error: all data should reside on the same GPU." + ); + let mut mem_ptr: *mut i8 = std::ptr::null_mut(); + let big_lwe_dimension: u32 = glwe_dimension.0 as u32 * polynomial_size.0 as u32; + scratch_cuda_propagate_single_carry_kb_64_inplace( + streams.ptr.as_ptr(), + streams.gpu_indexes.as_ptr(), + streams.len() as u32, + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension, + lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + true, + ); + cuda_propagate_single_carry_get_input_carries_kb_64_inplace( + streams.ptr.as_ptr(), + streams.gpu_indexes.as_ptr(), + streams.len() as u32, + radix_lwe_input.as_mut_c_ptr(0), + carry_out.as_mut_c_ptr(0), + input_carries.as_mut_c_ptr(0), + mem_ptr, + bootstrapping_key.ptr.as_ptr(), + keyswitch_key.ptr.as_ptr(), + num_blocks, + ); + cleanup_cuda_propagate_single_carry( + streams.ptr.as_ptr(), + streams.gpu_indexes.as_ptr(), + streams.len() as u32, + std::ptr::addr_of_mut!(mem_ptr), + ); +} + #[allow(clippy::too_many_arguments)] /// # Safety /// diff --git a/tfhe/src/integer/gpu/server_key/radix/mod.rs b/tfhe/src/integer/gpu/server_key/radix/mod.rs index fcd1f52848..e8341f823d 100644 --- a/tfhe/src/integer/gpu/server_key/radix/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/mod.rs @@ -14,7 +14,8 @@ use crate::integer::gpu::ciphertext::{ use crate::integer::gpu::server_key::CudaBootstrappingKey; use crate::integer::gpu::{ apply_univariate_lut_kb_async, full_propagate_assign_async, - propagate_single_carry_assign_async, CudaServerKey, PBSType, + propagate_single_carry_assign_async, propagate_single_carry_get_input_carries_assign_async, + CudaServerKey, PBSType, }; use crate::shortint::ciphertext::{Degree, NoiseLevel}; use crate::shortint::engine::fill_accumulator; @@ -224,6 +225,80 @@ impl CudaServerKey { carry_out } + /// # Safety + /// + /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must + /// not be dropped until stream is synchronized + #[allow(dead_code)] + pub(crate) unsafe fn propagate_single_carry_get_input_carries_assign_async( + &self, + ct: &mut T, + input_carries: &mut T, + streams: &CudaStreams, + ) -> T + where + T: CudaIntegerRadixCiphertext, + { + let mut carry_out: T = self.create_trivial_zero_radix(1, streams); + let ciphertext = ct.as_mut(); + let num_blocks = ciphertext.d_blocks.lwe_ciphertext_count().0 as u32; + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + propagate_single_carry_get_input_carries_assign_async( + streams, + &mut ciphertext.d_blocks.0.d_vec, + &mut carry_out.as_mut().d_blocks.0.d_vec, + &mut input_carries.as_mut().d_blocks.0.d_vec, + &d_bsk.d_vec, + &self.key_switching_key.d_vec, + d_bsk.input_lwe_dimension(), + d_bsk.glwe_dimension(), + d_bsk.polynomial_size(), + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_bsk.decomp_level_count(), + d_bsk.decomp_base_log(), + num_blocks, + ciphertext.info.blocks.first().unwrap().message_modulus, + ciphertext.info.blocks.first().unwrap().carry_modulus, + PBSType::Classical, + LweBskGroupingFactor(0), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + propagate_single_carry_get_input_carries_assign_async( + streams, + &mut ciphertext.d_blocks.0.d_vec, + &mut carry_out.as_mut().d_blocks.0.d_vec, + &mut input_carries.as_mut().d_blocks.0.d_vec, + &d_multibit_bsk.d_vec, + &self.key_switching_key.d_vec, + d_multibit_bsk.input_lwe_dimension(), + d_multibit_bsk.glwe_dimension(), + d_multibit_bsk.polynomial_size(), + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count(), + d_multibit_bsk.decomp_base_log(), + num_blocks, + ciphertext.info.blocks.first().unwrap().message_modulus, + ciphertext.info.blocks.first().unwrap().carry_modulus, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + ); + } + }; + ciphertext.info.blocks.iter_mut().for_each(|b| { + b.degree = Degree::new(b.message_modulus.0 - 1); + b.noise_level = NoiseLevel::NOMINAL; + }); + carry_out.as_mut().info.blocks.iter_mut().for_each(|b| { + b.degree = Degree::new(1); + b.noise_level = NoiseLevel::NOMINAL; + }); + carry_out + } + /// # Safety /// /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must