Skip to content

Commit

Permalink
fix(gpu): fix the indexes used in compression
Browse files Browse the repository at this point in the history
- also general minor fixes to compression
  • Loading branch information
pdroalves authored and agnesLeroy committed Oct 3, 2024
1 parent 123c764 commit 51cae3d
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 101 deletions.
2 changes: 1 addition & 1 deletion backends/tfhe-cuda-backend/cuda/include/ciphertext.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void cuda_convert_lwe_ciphertext_vector_to_cpu_64(void *stream,

void cuda_glwe_sample_extract_64(void *stream, uint32_t gpu_index,
void *lwe_array_out, void *glwe_array_in,
uint32_t *nth_array, uint32_t num_glwes,
uint32_t *nth_array, uint32_t num_nths,
uint32_t glwe_dimension,
uint32_t polynomial_size);
};
Expand Down
12 changes: 6 additions & 6 deletions backends/tfhe-cuda-backend/cuda/include/compression.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ void scratch_cuda_integer_compress_radix_ciphertext_64(
void **streams, uint32_t *gpu_indexes, uint32_t gpu_count, int8_t **mem_ptr,
uint32_t compression_glwe_dimension, uint32_t compression_polynomial_size,
uint32_t lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
uint32_t num_lwes, uint32_t message_modulus, uint32_t carry_modulus,
uint32_t num_radix_blocks, uint32_t message_modulus, uint32_t carry_modulus,
PBS_TYPE pbs_type, uint32_t lwe_per_glwe, uint32_t storage_log_modulus,
bool allocate_gpu_memory);

Expand All @@ -17,7 +17,7 @@ void scratch_cuda_integer_decompress_radix_ciphertext_64(
uint32_t encryption_glwe_dimension, uint32_t encryption_polynomial_size,
uint32_t compression_glwe_dimension, uint32_t compression_polynomial_size,
uint32_t lwe_dimension, uint32_t pbs_level, uint32_t pbs_base_log,
uint32_t num_lwes, uint32_t message_modulus, uint32_t carry_modulus,
uint32_t num_radix_blocks, uint32_t message_modulus, uint32_t carry_modulus,
PBS_TYPE pbs_type, uint32_t storage_log_modulus, uint32_t body_count,
bool allocate_gpu_memory);

Expand Down Expand Up @@ -96,7 +96,7 @@ template <typename Torus> struct int_decompression {

uint32_t storage_log_modulus;

uint32_t num_lwes;
uint32_t num_radix_blocks;
uint32_t body_count;

Torus *tmp_extracted_glwe;
Expand All @@ -113,7 +113,7 @@ template <typename Torus> struct int_decompression {
this->encryption_params = encryption_params;
this->compression_params = compression_params;
this->storage_log_modulus = storage_log_modulus;
this->num_lwes = num_radix_blocks;
this->num_radix_blocks = num_radix_blocks;
this->body_count = body_count;

if (allocate_gpu_memory) {
Expand All @@ -134,7 +134,7 @@ template <typename Torus> struct int_decompression {
tmp_extracted_lwe = (Torus *)cuda_malloc_async(
num_radix_blocks * lwe_accumulator_size * sizeof(Torus), streams[0],
gpu_indexes[0]);
// Decompression

// Carry extract LUT
auto carry_extract_f = [encryption_params](Torus x) -> Torus {
return x / encryption_params.message_modulus;
Expand All @@ -157,7 +157,7 @@ template <typename Torus> struct int_decompression {
cuda_drop_async(tmp_indexes_array, streams[0], gpu_indexes[0]);

carry_extract_lut->release(streams, gpu_indexes, gpu_count);
delete (carry_extract_lut);
delete carry_extract_lut;
}
};
#endif
16 changes: 8 additions & 8 deletions backends/tfhe-cuda-backend/cuda/src/crypto/ciphertext.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,51 +23,51 @@ void cuda_convert_lwe_ciphertext_vector_to_cpu_64(void *stream,

void cuda_glwe_sample_extract_64(void *stream, uint32_t gpu_index,
void *lwe_array_out, void *glwe_array_in,
uint32_t *nth_array, uint32_t num_glwes,
uint32_t *nth_array, uint32_t num_nths,
uint32_t glwe_dimension,
uint32_t polynomial_size) {

switch (polynomial_size) {
case 256:
host_sample_extract<uint64_t, AmortizedDegree<256>>(
static_cast<cudaStream_t>(stream), gpu_index, (uint64_t *)lwe_array_out,
(uint64_t *)glwe_array_in, (uint32_t *)nth_array, num_glwes,
(uint64_t *)glwe_array_in, (uint32_t *)nth_array, num_nths,
glwe_dimension);
break;
case 512:
host_sample_extract<uint64_t, AmortizedDegree<512>>(
static_cast<cudaStream_t>(stream), gpu_index, (uint64_t *)lwe_array_out,
(uint64_t *)glwe_array_in, (uint32_t *)nth_array, num_glwes,
(uint64_t *)glwe_array_in, (uint32_t *)nth_array, num_nths,
glwe_dimension);
break;
case 1024:
host_sample_extract<uint64_t, AmortizedDegree<1024>>(
static_cast<cudaStream_t>(stream), gpu_index, (uint64_t *)lwe_array_out,
(uint64_t *)glwe_array_in, (uint32_t *)nth_array, num_glwes,
(uint64_t *)glwe_array_in, (uint32_t *)nth_array, num_nths,
glwe_dimension);
break;
case 2048:
host_sample_extract<uint64_t, AmortizedDegree<2048>>(
static_cast<cudaStream_t>(stream), gpu_index, (uint64_t *)lwe_array_out,
(uint64_t *)glwe_array_in, (uint32_t *)nth_array, num_glwes,
(uint64_t *)glwe_array_in, (uint32_t *)nth_array, num_nths,
glwe_dimension);
break;
case 4096:
host_sample_extract<uint64_t, AmortizedDegree<4096>>(
static_cast<cudaStream_t>(stream), gpu_index, (uint64_t *)lwe_array_out,
(uint64_t *)glwe_array_in, (uint32_t *)nth_array, num_glwes,
(uint64_t *)glwe_array_in, (uint32_t *)nth_array, num_nths,
glwe_dimension);
break;
case 8192:
host_sample_extract<uint64_t, AmortizedDegree<8192>>(
static_cast<cudaStream_t>(stream), gpu_index, (uint64_t *)lwe_array_out,
(uint64_t *)glwe_array_in, (uint32_t *)nth_array, num_glwes,
(uint64_t *)glwe_array_in, (uint32_t *)nth_array, num_nths,
glwe_dimension);
break;
case 16384:
host_sample_extract<uint64_t, AmortizedDegree<16384>>(
static_cast<cudaStream_t>(stream), gpu_index, (uint64_t *)lwe_array_out,
(uint64_t *)glwe_array_in, (uint32_t *)nth_array, num_glwes,
(uint64_t *)glwe_array_in, (uint32_t *)nth_array, num_nths,
glwe_dimension);
break;
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ void scratch_cuda_integer_compress_radix_ciphertext_64(
void **streams, uint32_t *gpu_indexes, uint32_t gpu_count, int8_t **mem_ptr,
uint32_t compression_glwe_dimension, uint32_t compression_polynomial_size,
uint32_t lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
uint32_t num_lwes, uint32_t message_modulus, uint32_t carry_modulus,
uint32_t num_radix_blocks, uint32_t message_modulus, uint32_t carry_modulus,
PBS_TYPE pbs_type, uint32_t lwe_per_glwe, uint32_t storage_log_modulus,
bool allocate_gpu_memory) {

Expand All @@ -16,15 +16,16 @@ void scratch_cuda_integer_compress_radix_ciphertext_64(

scratch_cuda_compress_integer_radix_ciphertext<uint64_t>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
(int_compression<uint64_t> **)mem_ptr, num_lwes, compression_params,
lwe_per_glwe, storage_log_modulus, allocate_gpu_memory);
(int_compression<uint64_t> **)mem_ptr, num_radix_blocks,
compression_params, lwe_per_glwe, storage_log_modulus,
allocate_gpu_memory);
}
void scratch_cuda_integer_decompress_radix_ciphertext_64(
void **streams, uint32_t *gpu_indexes, uint32_t gpu_count, int8_t **mem_ptr,
uint32_t encryption_glwe_dimension, uint32_t encryption_polynomial_size,
uint32_t compression_glwe_dimension, uint32_t compression_polynomial_size,
uint32_t lwe_dimension, uint32_t pbs_level, uint32_t pbs_base_log,
uint32_t num_lwes, uint32_t message_modulus, uint32_t carry_modulus,
uint32_t num_radix_blocks, uint32_t message_modulus, uint32_t carry_modulus,
PBS_TYPE pbs_type, uint32_t storage_log_modulus, uint32_t body_count,
bool allocate_gpu_memory) {

Expand All @@ -41,7 +42,7 @@ void scratch_cuda_integer_decompress_radix_ciphertext_64(

scratch_cuda_integer_decompress_radix_ciphertext<uint64_t>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
(int_decompression<uint64_t> **)mem_ptr, num_lwes, body_count,
(int_decompression<uint64_t> **)mem_ptr, num_radix_blocks, body_count,
encryption_params, compression_params, storage_log_modulus,
allocate_gpu_memory);
}
Expand Down
Loading

0 comments on commit 51cae3d

Please sign in to comment.