Skip to content

Commit

Permalink
chore(gpu): refactor lwe_chunk_size
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Sep 30, 2024
1 parent e9d3e21 commit 0fc2412
Show file tree
Hide file tree
Showing 10 changed files with 93 additions and 112 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ void cuda_convert_lwe_multi_bit_programmable_bootstrap_key_64(

void scratch_cuda_multi_bit_programmable_bootstrap_64(
void *stream, uint32_t gpu_index, int8_t **pbs_buffer,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t level_count, uint32_t grouping_factor,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory);

void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector_64(
Expand Down Expand Up @@ -48,8 +47,7 @@ bool has_support_to_cuda_programmable_bootstrap_tbc_multi_bit(
template <typename Torus>
void scratch_cuda_tbc_multi_bit_programmable_bootstrap(
void *stream, uint32_t gpu_index, pbs_buffer<Torus, MULTI_BIT> **buffer,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t level_count, uint32_t grouping_factor,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory);

template <typename Torus>
Expand Down Expand Up @@ -82,8 +80,7 @@ void cuda_cg_multi_bit_programmable_bootstrap_lwe_ciphertext_vector(
template <typename Torus>
void scratch_cuda_multi_bit_programmable_bootstrap(
void *stream, uint32_t gpu_index, pbs_buffer<Torus, MULTI_BIT> **pbs_buffer,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t level_count, uint32_t grouping_factor,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory);

template <typename Torus>
Expand Down Expand Up @@ -130,7 +127,7 @@ template <typename Torus> struct pbs_buffer<Torus, PBS_TYPE::MULTI_BIT> {
int8_t *d_mem_acc_step_two = NULL;
int8_t *d_mem_acc_cg = NULL;
int8_t *d_mem_acc_tbc = NULL;

uint32_t lwe_chunk_size;
double2 *keybundle_fft;
Torus *global_accumulator;
double2 *global_accumulator_fft;
Expand All @@ -142,6 +139,7 @@ template <typename Torus> struct pbs_buffer<Torus, PBS_TYPE::MULTI_BIT> {
uint32_t input_lwe_ciphertext_count, uint32_t lwe_chunk_size,
PBS_VARIANT pbs_variant, bool allocate_gpu_memory) {
this->pbs_variant = pbs_variant;
this->lwe_chunk_size = lwe_chunk_size;
auto max_shared_memory = cuda_get_max_shared_memory(gpu_index);

// default
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,8 @@ void execute_scratch_pbs(cudaStream_t stream, uint32_t gpu_index,
if (grouping_factor == 0)
PANIC("Multi-bit PBS error: grouping factor should be > 0.")
scratch_cuda_multi_bit_programmable_bootstrap_64(
stream, gpu_index, pbs_buffer, lwe_dimension, glwe_dimension,
polynomial_size, level_count, grouping_factor,
input_lwe_ciphertext_count, allocate_gpu_memory);
stream, gpu_index, pbs_buffer, glwe_dimension, polynomial_size,
level_count, input_lwe_ciphertext_count, allocate_gpu_memory);
break;
case CLASSICAL:
scratch_cuda_programmable_bootstrap_64(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,9 @@ __host__ void execute_cg_external_product_loop(
pbs_buffer<Torus, MULTI_BIT> *buffer, uint32_t num_samples,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t grouping_factor, uint32_t base_log, uint32_t level_count,
uint32_t lwe_chunk_size, uint32_t lwe_offset, uint32_t lut_count,
uint32_t lut_stride) {
uint32_t lwe_offset, uint32_t lut_count, uint32_t lut_stride) {

auto lwe_chunk_size = buffer->lwe_chunk_size;
uint64_t full_dm =
get_buffer_size_full_sm_cg_multibit_programmable_bootstrap<Torus>(
polynomial_size);
Expand Down Expand Up @@ -314,8 +314,7 @@ __host__ void host_cg_multi_bit_programmable_bootstrap(
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
uint32_t lut_count, uint32_t lut_stride) {

auto lwe_chunk_size = get_lwe_chunk_size<Torus, params>(
gpu_index, num_samples, polynomial_size);
auto lwe_chunk_size = buffer->lwe_chunk_size;

for (uint32_t lwe_offset = 0; lwe_offset < (lwe_dimension / grouping_factor);
lwe_offset += lwe_chunk_size) {
Expand All @@ -324,15 +323,15 @@ __host__ void host_cg_multi_bit_programmable_bootstrap(
execute_compute_keybundle<Torus, params>(
stream, gpu_index, lwe_array_in, lwe_input_indexes, bootstrapping_key,
buffer, num_samples, lwe_dimension, glwe_dimension, polynomial_size,
grouping_factor, base_log, level_count, lwe_chunk_size, lwe_offset);
grouping_factor, level_count, lwe_offset);

// Accumulate
execute_cg_external_product_loop<Torus, params>(
stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in,
lwe_input_indexes, lwe_array_out, lwe_output_indexes, buffer,
num_samples, lwe_dimension, glwe_dimension, polynomial_size,
grouping_factor, base_log, level_count, lwe_chunk_size, lwe_offset,
lut_count, lut_stride);
grouping_factor, base_log, level_count, lwe_offset, lut_count,
lut_stride);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,52 +331,51 @@ void scratch_cuda_cg_multi_bit_programmable_bootstrap(
template <typename Torus>
void scratch_cuda_multi_bit_programmable_bootstrap(
void *stream, uint32_t gpu_index, pbs_buffer<Torus, MULTI_BIT> **buffer,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t level_count, uint32_t grouping_factor,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory) {

switch (polynomial_size) {
case 256:
scratch_multi_bit_programmable_bootstrap<Torus, AmortizedDegree<256>>(
static_cast<cudaStream_t>(stream), gpu_index, buffer, lwe_dimension,
glwe_dimension, polynomial_size, level_count,
input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory);
static_cast<cudaStream_t>(stream), gpu_index, buffer, glwe_dimension,
polynomial_size, level_count, input_lwe_ciphertext_count,
allocate_gpu_memory);
break;
case 512:
scratch_multi_bit_programmable_bootstrap<Torus, AmortizedDegree<512>>(
static_cast<cudaStream_t>(stream), gpu_index, buffer, lwe_dimension,
glwe_dimension, polynomial_size, level_count,
input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory);
static_cast<cudaStream_t>(stream), gpu_index, buffer, glwe_dimension,
polynomial_size, level_count, input_lwe_ciphertext_count,
allocate_gpu_memory);
break;
case 1024:
scratch_multi_bit_programmable_bootstrap<Torus, AmortizedDegree<1024>>(
static_cast<cudaStream_t>(stream), gpu_index, buffer, lwe_dimension,
glwe_dimension, polynomial_size, level_count,
input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory);
static_cast<cudaStream_t>(stream), gpu_index, buffer, glwe_dimension,
polynomial_size, level_count, input_lwe_ciphertext_count,
allocate_gpu_memory);
break;
case 2048:
scratch_multi_bit_programmable_bootstrap<Torus, AmortizedDegree<2048>>(
static_cast<cudaStream_t>(stream), gpu_index, buffer, lwe_dimension,
glwe_dimension, polynomial_size, level_count,
input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory);
static_cast<cudaStream_t>(stream), gpu_index, buffer, glwe_dimension,
polynomial_size, level_count, input_lwe_ciphertext_count,
allocate_gpu_memory);
break;
case 4096:
scratch_multi_bit_programmable_bootstrap<Torus, AmortizedDegree<4096>>(
static_cast<cudaStream_t>(stream), gpu_index, buffer, lwe_dimension,
glwe_dimension, polynomial_size, level_count,
input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory);
static_cast<cudaStream_t>(stream), gpu_index, buffer, glwe_dimension,
polynomial_size, level_count, input_lwe_ciphertext_count,
allocate_gpu_memory);
break;
case 8192:
scratch_multi_bit_programmable_bootstrap<Torus, AmortizedDegree<8192>>(
static_cast<cudaStream_t>(stream), gpu_index, buffer, lwe_dimension,
glwe_dimension, polynomial_size, level_count,
input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory);
static_cast<cudaStream_t>(stream), gpu_index, buffer, glwe_dimension,
polynomial_size, level_count, input_lwe_ciphertext_count,
allocate_gpu_memory);
break;
case 16384:
scratch_multi_bit_programmable_bootstrap<Torus, AmortizedDegree<16384>>(
static_cast<cudaStream_t>(stream), gpu_index, buffer, lwe_dimension,
glwe_dimension, polynomial_size, level_count,
input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory);
static_cast<cudaStream_t>(stream), gpu_index, buffer, glwe_dimension,
polynomial_size, level_count, input_lwe_ciphertext_count,
allocate_gpu_memory);
break;
default:
PANIC("Cuda error (multi-bit PBS): unsupported polynomial size. Supported "
Expand All @@ -386,19 +385,18 @@ void scratch_cuda_multi_bit_programmable_bootstrap(
}

void scratch_cuda_multi_bit_programmable_bootstrap_64(
void *stream, uint32_t gpu_index, int8_t **buffer, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
uint32_t grouping_factor, uint32_t input_lwe_ciphertext_count,
bool allocate_gpu_memory) {
void *stream, uint32_t gpu_index, int8_t **buffer, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t level_count,
uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory) {

#if (CUDA_ARCH >= 900)
if (has_support_to_cuda_programmable_bootstrap_tbc_multi_bit<uint64_t>(
input_lwe_ciphertext_count, glwe_dimension, polynomial_size,
level_count))
scratch_cuda_tbc_multi_bit_programmable_bootstrap<uint64_t>(
stream, gpu_index, (pbs_buffer<uint64_t, MULTI_BIT> **)buffer,
lwe_dimension, glwe_dimension, polynomial_size, level_count,
grouping_factor, input_lwe_ciphertext_count, allocate_gpu_memory);
glwe_dimension, polynomial_size, level_count,
input_lwe_ciphertext_count, allocate_gpu_memory);
else
#endif
if (supports_cooperative_groups_on_multibit_programmable_bootstrap<
Expand All @@ -411,8 +409,8 @@ void scratch_cuda_multi_bit_programmable_bootstrap_64(
else
scratch_cuda_multi_bit_programmable_bootstrap<uint64_t>(
stream, gpu_index, (pbs_buffer<uint64_t, MULTI_BIT> **)buffer,
lwe_dimension, glwe_dimension, polynomial_size, level_count,
grouping_factor, input_lwe_ciphertext_count, allocate_gpu_memory);
glwe_dimension, polynomial_size, level_count,
input_lwe_ciphertext_count, allocate_gpu_memory);
}

void cleanup_cuda_multi_bit_programmable_bootstrap(void *stream,
Expand Down Expand Up @@ -490,10 +488,9 @@ uint32_t get_lwe_chunk_size(uint32_t gpu_index, uint32_t max_num_pbs,

template void scratch_cuda_multi_bit_programmable_bootstrap<uint64_t>(
void *stream, uint32_t gpu_index,
pbs_buffer<uint64_t, MULTI_BIT> **pbs_buffer, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
uint32_t grouping_factor, uint32_t input_lwe_ciphertext_count,
bool allocate_gpu_memory);
pbs_buffer<uint64_t, MULTI_BIT> **pbs_buffer, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t level_count,
uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory);

template void
cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector<uint64_t>(
Expand Down Expand Up @@ -532,52 +529,51 @@ has_support_to_cuda_programmable_bootstrap_tbc_multi_bit<uint64_t>(
template <typename Torus>
void scratch_cuda_tbc_multi_bit_programmable_bootstrap(
void *stream, uint32_t gpu_index, pbs_buffer<Torus, MULTI_BIT> **buffer,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t level_count, uint32_t grouping_factor,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory) {

switch (polynomial_size) {
case 256:
scratch_tbc_multi_bit_programmable_bootstrap<Torus, AmortizedDegree<256>>(
static_cast<cudaStream_t>(stream), gpu_index, buffer, lwe_dimension,
glwe_dimension, polynomial_size, level_count,
input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory);
static_cast<cudaStream_t>(stream), gpu_index, buffer, glwe_dimension,
polynomial_size, level_count, input_lwe_ciphertext_count,
allocate_gpu_memory);
break;
case 512:
scratch_tbc_multi_bit_programmable_bootstrap<Torus, AmortizedDegree<512>>(
static_cast<cudaStream_t>(stream), gpu_index, buffer, lwe_dimension,
glwe_dimension, polynomial_size, level_count,
input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory);
static_cast<cudaStream_t>(stream), gpu_index, buffer, glwe_dimension,
polynomial_size, level_count, input_lwe_ciphertext_count,
allocate_gpu_memory);
break;
case 1024:
scratch_tbc_multi_bit_programmable_bootstrap<Torus, AmortizedDegree<1024>>(
static_cast<cudaStream_t>(stream), gpu_index, buffer, lwe_dimension,
glwe_dimension, polynomial_size, level_count,
input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory);
static_cast<cudaStream_t>(stream), gpu_index, buffer, glwe_dimension,
polynomial_size, level_count, input_lwe_ciphertext_count,
allocate_gpu_memory);
break;
case 2048:
scratch_tbc_multi_bit_programmable_bootstrap<Torus, AmortizedDegree<2048>>(
static_cast<cudaStream_t>(stream), gpu_index, buffer, lwe_dimension,
glwe_dimension, polynomial_size, level_count,
input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory);
static_cast<cudaStream_t>(stream), gpu_index, buffer, glwe_dimension,
polynomial_size, level_count, input_lwe_ciphertext_count,
allocate_gpu_memory);
break;
case 4096:
scratch_tbc_multi_bit_programmable_bootstrap<Torus, AmortizedDegree<4096>>(
static_cast<cudaStream_t>(stream), gpu_index, buffer, lwe_dimension,
glwe_dimension, polynomial_size, level_count,
input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory);
static_cast<cudaStream_t>(stream), gpu_index, buffer, glwe_dimension,
polynomial_size, level_count, input_lwe_ciphertext_count,
allocate_gpu_memory);
break;
case 8192:
scratch_tbc_multi_bit_programmable_bootstrap<Torus, AmortizedDegree<8192>>(
static_cast<cudaStream_t>(stream), gpu_index, buffer, lwe_dimension,
glwe_dimension, polynomial_size, level_count,
input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory);
static_cast<cudaStream_t>(stream), gpu_index, buffer, glwe_dimension,
polynomial_size, level_count, input_lwe_ciphertext_count,
allocate_gpu_memory);
break;
case 16384:
scratch_tbc_multi_bit_programmable_bootstrap<Torus, AmortizedDegree<16384>>(
static_cast<cudaStream_t>(stream), gpu_index, buffer, lwe_dimension,
glwe_dimension, polynomial_size, level_count,
input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory);
static_cast<cudaStream_t>(stream), gpu_index, buffer, glwe_dimension,
polynomial_size, level_count, input_lwe_ciphertext_count,
allocate_gpu_memory);
break;
default:
PANIC("Cuda error (multi-bit PBS): unsupported polynomial size. Supported "
Expand Down Expand Up @@ -679,8 +675,7 @@ void cuda_tbc_multi_bit_programmable_bootstrap_lwe_ciphertext_vector(

template void scratch_cuda_tbc_multi_bit_programmable_bootstrap<uint64_t>(
void *stream, uint32_t gpu_index, pbs_buffer<uint64_t, MULTI_BIT> **buffer,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t level_count, uint32_t grouping_factor,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory);

template void
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -385,10 +385,9 @@ uint64_t get_buffer_size_full_sm_multibit_programmable_bootstrap_step_two(
template <typename Torus, typename params>
__host__ void scratch_multi_bit_programmable_bootstrap(
cudaStream_t stream, uint32_t gpu_index,
pbs_buffer<Torus, MULTI_BIT> **buffer, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
uint32_t input_lwe_ciphertext_count, uint32_t grouping_factor,
bool allocate_gpu_memory) {
pbs_buffer<Torus, MULTI_BIT> **buffer, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t level_count,
uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory) {

auto lwe_chunk_size = get_lwe_chunk_size<Torus, params>(
gpu_index, input_lwe_ciphertext_count, polynomial_size);
Expand All @@ -404,9 +403,9 @@ __host__ void execute_compute_keybundle(
Torus *lwe_input_indexes, Torus *bootstrapping_key,
pbs_buffer<Torus, MULTI_BIT> *buffer, uint32_t num_samples,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t grouping_factor, uint32_t base_log, uint32_t level_count,
uint32_t lwe_chunk_size, uint32_t lwe_offset) {
uint32_t grouping_factor, uint32_t level_count, uint32_t lwe_offset) {

auto lwe_chunk_size = buffer->lwe_chunk_size;
uint32_t chunk_size =
std::min(lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset);
if (chunk_size == 0)
Expand Down Expand Up @@ -507,9 +506,9 @@ __host__ void execute_step_two(
Torus *lwe_output_indexes, pbs_buffer<Torus, MULTI_BIT> *buffer,
uint32_t num_samples, uint32_t lwe_dimension, uint32_t glwe_dimension,
uint32_t polynomial_size, int32_t grouping_factor, uint32_t level_count,
uint32_t j, uint32_t lwe_offset, uint32_t lwe_chunk_size,
uint32_t lut_count, uint32_t lut_stride) {
uint32_t j, uint32_t lwe_offset, uint32_t lut_count, uint32_t lut_stride) {

auto lwe_chunk_size = buffer->lwe_chunk_size;
uint64_t full_sm_accumulate_step_two =
get_buffer_size_full_sm_multibit_programmable_bootstrap_step_two<Torus>(
polynomial_size);
Expand Down Expand Up @@ -555,8 +554,7 @@ __host__ void host_multi_bit_programmable_bootstrap(
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
uint32_t lut_count, uint32_t lut_stride) {

auto lwe_chunk_size = get_lwe_chunk_size<Torus, params>(
gpu_index, num_samples, polynomial_size);
auto lwe_chunk_size = buffer->lwe_chunk_size;

for (uint32_t lwe_offset = 0; lwe_offset < (lwe_dimension / grouping_factor);
lwe_offset += lwe_chunk_size) {
Expand All @@ -565,7 +563,7 @@ __host__ void host_multi_bit_programmable_bootstrap(
execute_compute_keybundle<Torus, params>(
stream, gpu_index, lwe_array_in, lwe_input_indexes, bootstrapping_key,
buffer, num_samples, lwe_dimension, glwe_dimension, polynomial_size,
grouping_factor, base_log, level_count, lwe_chunk_size, lwe_offset);
grouping_factor, level_count, lwe_offset);
// Accumulate
uint32_t chunk_size = std::min(
lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset);
Expand All @@ -578,8 +576,7 @@ __host__ void host_multi_bit_programmable_bootstrap(
execute_step_two<Torus, params>(
stream, gpu_index, lwe_array_out, lwe_output_indexes, buffer,
num_samples, lwe_dimension, glwe_dimension, polynomial_size,
grouping_factor, level_count, j, lwe_offset, lwe_chunk_size,
lut_count, lut_stride);
grouping_factor, level_count, j, lwe_offset, lut_count, lut_stride);
}
}
}
Expand Down
Loading

0 comments on commit 0fc2412

Please sign in to comment.