Skip to content

Commit

Permalink
feat(gpu): Implement propagate_single_carry_get_input_carries
Browse files Browse the repository at this point in the history
  • Loading branch information
bbarbakadze authored and agnesLeroy committed Jun 26, 2024
1 parent 53b6861 commit 6120fab
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 4 deletions.
5 changes: 5 additions & 0 deletions backends/tfhe-cuda-backend/cuda/include/integer.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,11 @@ void cuda_propagate_single_carry_kb_64_inplace(
void *carry_out, int8_t *mem_ptr, void **bsks, void **ksks,
uint32_t num_blocks);

void cuda_propagate_single_carry_get_input_carries_kb_64_inplace(
void **streams, uint32_t *gpu_indexes, uint32_t gpu_count, void *lwe_array,
void *carry_out, void *input_carries, int8_t *mem_ptr, void **bsks,
void **ksks, uint32_t num_blocks);

void cleanup_cuda_propagate_single_carry(void **streams, uint32_t *gpu_indexes,
uint32_t gpu_count,
int8_t **mem_ptr_void);
Expand Down
12 changes: 12 additions & 0 deletions backends/tfhe-cuda-backend/cuda/src/integer/integer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,18 @@ void cuda_propagate_single_carry_kb_64_inplace(
host_propagate_single_carry<uint64_t>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(lwe_array), static_cast<uint64_t *>(carry_out),
nullptr, (int_sc_prop_memory<uint64_t> *)mem_ptr, bsks,
(uint64_t **)(ksks), num_blocks);
}

void cuda_propagate_single_carry_get_input_carries_kb_64_inplace(
void **streams, uint32_t *gpu_indexes, uint32_t gpu_count, void *lwe_array,
void *carry_out, void *input_carries, int8_t *mem_ptr, void **bsks,
void **ksks, uint32_t num_blocks) {
host_propagate_single_carry<uint64_t>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(lwe_array), static_cast<uint64_t *>(carry_out),
static_cast<uint64_t *>(input_carries),
(int_sc_prop_memory<uint64_t> *)mem_ptr, bsks, (uint64_t **)(ksks),
num_blocks);
}
Expand Down
8 changes: 7 additions & 1 deletion backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ void scratch_cuda_propagate_single_carry_kb_inplace(
template <typename Torus>
void host_propagate_single_carry(cudaStream_t *streams, uint32_t *gpu_indexes,
uint32_t gpu_count, Torus *lwe_array,
Torus *carry_out,
Torus *carry_out, Torus *input_carries,
int_sc_prop_memory<Torus> *mem, void **bsks,
Torus **ksks, uint32_t num_blocks) {
auto params = mem->params;
Expand Down Expand Up @@ -482,6 +482,12 @@ void host_propagate_single_carry(cudaStream_t *streams, uint32_t *gpu_indexes,
cuda_memset_async(step_output, 0, big_lwe_size_bytes, streams[0],
gpu_indexes[0]);

if (input_carries != nullptr) {
cuda_memcpy_async_gpu_to_gpu(input_carries, step_output,
big_lwe_size_bytes * num_blocks, streams[0],
gpu_indexes[0]);
}

host_addition(streams[0], gpu_indexes[0], lwe_array, lwe_array, step_output,
glwe_dimension * polynomial_size, num_blocks);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,8 @@ __host__ void host_integer_sum_ciphertexts_vec_kb(
num_blocks);

host_propagate_single_carry<Torus>(streams, gpu_indexes, gpu_count,
radix_lwe_out, nullptr, mem_ptr->scp_mem,
bsks, ksks, num_blocks);
radix_lwe_out, nullptr, nullptr,
mem_ptr->scp_mem, bsks, ksks, num_blocks);
}

template <typename Torus, typename STorus, class params>
Expand Down
13 changes: 13 additions & 0 deletions backends/tfhe-cuda-backend/src/cuda_bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,19 @@ extern "C" {
num_blocks: u32,
);

pub fn cuda_propagate_single_carry_get_input_carries_kb_64_inplace(
streams: *const *mut c_void,
gpu_indexes: *const u32,
gpu_count: u32,
radix_lwe: *mut c_void,
carry_out: *mut c_void,
input_carries: *mut c_void,
mem_ptr: *mut i8,
bsks: *const *mut c_void,
ksks: *const *mut c_void,
num_blocks: u32,
);

pub fn cleanup_cuda_propagate_single_carry(
streams: *const *mut c_void,
gpu_indexes: *const u32,
Expand Down
85 changes: 85 additions & 0 deletions tfhe/src/integer/gpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,91 @@ pub unsafe fn propagate_single_carry_assign_async<T: UnsignedInteger, B: Numeric
);
}

#[allow(clippy::too_many_arguments)]
/// # Safety
///
/// - [CudaStreams::synchronize] __must__ be called after this function
/// as soon as synchronization is required
pub unsafe fn propagate_single_carry_get_input_carries_assign_async<
T: UnsignedInteger,
B: Numeric,
>(
streams: &CudaStreams,
radix_lwe_input: &mut CudaVec<T>,
carry_out: &mut CudaVec<T>,
input_carries: &mut CudaVec<T>,
bootstrapping_key: &CudaVec<B>,
keyswitch_key: &CudaVec<T>,
lwe_dimension: LweDimension,
glwe_dimension: GlweDimension,
polynomial_size: PolynomialSize,
ks_level: DecompositionLevelCount,
ks_base_log: DecompositionBaseLog,
pbs_level: DecompositionLevelCount,
pbs_base_log: DecompositionBaseLog,
num_blocks: u32,
message_modulus: MessageModulus,
carry_modulus: CarryModulus,
pbs_type: PBSType,
grouping_factor: LweBskGroupingFactor,
) {
assert_eq!(
streams.gpu_indexes[0],
radix_lwe_input.gpu_index(0),
"GPU error: all data should reside on the same GPU."
);
assert_eq!(
streams.gpu_indexes[0],
bootstrapping_key.gpu_index(0),
"GPU error: all data should reside on the same GPU."
);
assert_eq!(
streams.gpu_indexes[0],
keyswitch_key.gpu_index(0),
"GPU error: all data should reside on the same GPU."
);
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
let big_lwe_dimension: u32 = glwe_dimension.0 as u32 * polynomial_size.0 as u32;
scratch_cuda_propagate_single_carry_kb_64_inplace(
streams.ptr.as_ptr(),
streams.gpu_indexes.as_ptr(),
streams.len() as u32,
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension,
lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
num_blocks,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
true,
);
cuda_propagate_single_carry_get_input_carries_kb_64_inplace(
streams.ptr.as_ptr(),
streams.gpu_indexes.as_ptr(),
streams.len() as u32,
radix_lwe_input.as_mut_c_ptr(0),
carry_out.as_mut_c_ptr(0),
input_carries.as_mut_c_ptr(0),
mem_ptr,
bootstrapping_key.ptr.as_ptr(),
keyswitch_key.ptr.as_ptr(),
num_blocks,
);
cleanup_cuda_propagate_single_carry(
streams.ptr.as_ptr(),
streams.gpu_indexes.as_ptr(),
streams.len() as u32,
std::ptr::addr_of_mut!(mem_ptr),
);
}

#[allow(clippy::too_many_arguments)]
/// # Safety
///
Expand Down
77 changes: 76 additions & 1 deletion tfhe/src/integer/gpu/server_key/radix/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ use crate::integer::gpu::ciphertext::{
use crate::integer::gpu::server_key::CudaBootstrappingKey;
use crate::integer::gpu::{
apply_univariate_lut_kb_async, full_propagate_assign_async,
propagate_single_carry_assign_async, CudaServerKey, PBSType,
propagate_single_carry_assign_async, propagate_single_carry_get_input_carries_assign_async,
CudaServerKey, PBSType,
};
use crate::shortint::ciphertext::{Degree, NoiseLevel};
use crate::shortint::engine::fill_accumulator;
Expand Down Expand Up @@ -224,6 +225,80 @@ impl CudaServerKey {
carry_out
}

/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronized
#[allow(dead_code)]
pub(crate) unsafe fn propagate_single_carry_get_input_carries_assign_async<T>(
&self,
ct: &mut T,
input_carries: &mut T,
streams: &CudaStreams,
) -> T
where
T: CudaIntegerRadixCiphertext,
{
let mut carry_out: T = self.create_trivial_zero_radix(1, streams);
let ciphertext = ct.as_mut();
let num_blocks = ciphertext.d_blocks.lwe_ciphertext_count().0 as u32;
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
propagate_single_carry_get_input_carries_assign_async(
streams,
&mut ciphertext.d_blocks.0.d_vec,
&mut carry_out.as_mut().d_blocks.0.d_vec,
&mut input_carries.as_mut().d_blocks.0.d_vec,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
num_blocks,
ciphertext.info.blocks.first().unwrap().message_modulus,
ciphertext.info.blocks.first().unwrap().carry_modulus,
PBSType::Classical,
LweBskGroupingFactor(0),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
propagate_single_carry_get_input_carries_assign_async(
streams,
&mut ciphertext.d_blocks.0.d_vec,
&mut carry_out.as_mut().d_blocks.0.d_vec,
&mut input_carries.as_mut().d_blocks.0.d_vec,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
num_blocks,
ciphertext.info.blocks.first().unwrap().message_modulus,
ciphertext.info.blocks.first().unwrap().carry_modulus,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
);
}
};
ciphertext.info.blocks.iter_mut().for_each(|b| {
b.degree = Degree::new(b.message_modulus.0 - 1);
b.noise_level = NoiseLevel::NOMINAL;
});
carry_out.as_mut().info.blocks.iter_mut().for_each(|b| {
b.degree = Degree::new(1);
b.noise_level = NoiseLevel::NOMINAL;
});
carry_out
}

/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
Expand Down

0 comments on commit 6120fab

Please sign in to comment.