From 3492e81a876a15b544896de02079d6f1235078ab Mon Sep 17 00:00:00 2001 From: Pedro Alves Date: Mon, 6 Jan 2025 08:50:02 -0300 Subject: [PATCH] fix(gpu): fix delta calculation when Torus is not a 64-bit type --- .../cuda/src/integer/integer.cuh | 18 ++++++++++++------ .../ciphertext/compressed_ciphertext_list.rs | 16 +++++++++------- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh index b5319ba9f5..f6e53b5050 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh @@ -655,7 +655,9 @@ void generate_lookup_table_with_encoding(Torus *acc, uint32_t glwe_dimension, uint32_t input_modulus_sup = input_message_modulus * input_carry_modulus; uint32_t output_modulus_sup = output_message_modulus * output_carry_modulus; uint32_t box_size = polynomial_size / input_modulus_sup; - Torus output_delta = (1ul << 63) / output_modulus_sup; + auto nbits = sizeof(Torus) * 8; + Torus output_delta = + (static_cast(1) << (nbits - 1)) / output_modulus_sup; memset(acc, 0, glwe_dimension * polynomial_size * sizeof(Torus)); @@ -698,7 +700,8 @@ void generate_many_lookup_table( uint32_t modulus_sup = message_modulus * carry_modulus; uint32_t box_size = polynomial_size / modulus_sup; - Torus delta = (1ul << 63) / modulus_sup; + auto nbits = sizeof(Torus) * 8; + Torus delta = (static_cast(1) << (nbits - 1)) / modulus_sup; memset(acc, 0, glwe_dimension * polynomial_size * sizeof(Torus)); @@ -1099,7 +1102,8 @@ void host_compute_propagation_simulators_and_group_carries( message_modulus, carry_modulus); uint32_t modulus_sup = message_modulus * carry_modulus; - Torus delta = (1ull << 63) / modulus_sup; + auto nbits = sizeof(Torus) * 8; + Torus delta = (static_cast(1) << (nbits - 1)) / modulus_sup; auto simulators = mem->simulators; auto grouping_pgns = mem->grouping_pgns; host_radix_split_simulators_and_grouping_pgns( @@ -1426,8 +1430,8 @@ __host__ void create_trivial_radix(cudaStream_t stream, uint32_t gpu_index, Torus *lwe_array_out, Torus const *scalar_array, uint32_t lwe_dimension, uint32_t num_radix_blocks, - uint32_t num_scalar_blocks, uint64_t message_modulus, - uint64_t carry_modulus) { + uint32_t num_scalar_blocks, Torus message_modulus, + Torus carry_modulus) { cudaSetDevice(gpu_index); size_t radix_size = (lwe_dimension + 1) * num_radix_blocks; @@ -1447,7 +1451,9 @@ create_trivial_radix(cudaStream_t stream, uint32_t gpu_index, // Value of the shift we multiply our messages by // If message_modulus and carry_modulus are always powers of 2 we can simplify // this - uint64_t delta = ((uint64_t)1 << 63) / (message_modulus * carry_modulus); + auto nbits = sizeof(Torus) * 8; + Torus delta = (static_cast(1) << (nbits - 1)) / + (message_modulus * carry_modulus); device_create_trivial_radix<<>>( lwe_array_out, scalar_array, num_scalar_blocks, lwe_dimension, delta); diff --git a/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs b/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs index d6c6714f19..9d17f0f504 100644 --- a/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs +++ b/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs @@ -97,13 +97,15 @@ impl CudaCompressedCiphertextList { let end_block_index = start_block_index + current_info.num_blocks() - 1; Some(( - decomp_key.unpack( - &self.packed_list, - current_info, - start_block_index, - end_block_index, - streams, - ).unwrap(), + decomp_key + .unpack( + &self.packed_list, + current_info, + start_block_index, + end_block_index, + streams, + ) + .unwrap(), current_info, )) }