Skip to content

Commit

Permalink
fix(gpu): fix delta calculation when Torus is not a 64-bit type
Browse files Browse the repository at this point in the history
  • Loading branch information
pdroalves committed Jan 6, 2025
1 parent e087840 commit 3492e81
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 13 deletions.
18 changes: 12 additions & 6 deletions backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,9 @@ void generate_lookup_table_with_encoding(Torus *acc, uint32_t glwe_dimension,
uint32_t input_modulus_sup = input_message_modulus * input_carry_modulus;
uint32_t output_modulus_sup = output_message_modulus * output_carry_modulus;
uint32_t box_size = polynomial_size / input_modulus_sup;
Torus output_delta = (1ul << 63) / output_modulus_sup;
auto nbits = sizeof(Torus) * 8;
Torus output_delta =
(static_cast<Torus>(1) << (nbits - 1)) / output_modulus_sup;

memset(acc, 0, glwe_dimension * polynomial_size * sizeof(Torus));

Expand Down Expand Up @@ -698,7 +700,8 @@ void generate_many_lookup_table(

uint32_t modulus_sup = message_modulus * carry_modulus;
uint32_t box_size = polynomial_size / modulus_sup;
Torus delta = (1ul << 63) / modulus_sup;
auto nbits = sizeof(Torus) * 8;
Torus delta = (static_cast<Torus>(1) << (nbits - 1)) / modulus_sup;

memset(acc, 0, glwe_dimension * polynomial_size * sizeof(Torus));

Expand Down Expand Up @@ -1099,7 +1102,8 @@ void host_compute_propagation_simulators_and_group_carries(
message_modulus, carry_modulus);

uint32_t modulus_sup = message_modulus * carry_modulus;
Torus delta = (1ull << 63) / modulus_sup;
auto nbits = sizeof(Torus) * 8;
Torus delta = (static_cast<Torus>(1) << (nbits - 1)) / modulus_sup;
auto simulators = mem->simulators;
auto grouping_pgns = mem->grouping_pgns;
host_radix_split_simulators_and_grouping_pgns<Torus>(
Expand Down Expand Up @@ -1426,8 +1430,8 @@ __host__ void
create_trivial_radix(cudaStream_t stream, uint32_t gpu_index,
Torus *lwe_array_out, Torus const *scalar_array,
uint32_t lwe_dimension, uint32_t num_radix_blocks,
uint32_t num_scalar_blocks, uint64_t message_modulus,
uint64_t carry_modulus) {
uint32_t num_scalar_blocks, Torus message_modulus,
Torus carry_modulus) {

cudaSetDevice(gpu_index);
size_t radix_size = (lwe_dimension + 1) * num_radix_blocks;
Expand All @@ -1447,7 +1451,9 @@ create_trivial_radix(cudaStream_t stream, uint32_t gpu_index,
// Value of the shift we multiply our messages by
// If message_modulus and carry_modulus are always powers of 2 we can simplify
// this
uint64_t delta = ((uint64_t)1 << 63) / (message_modulus * carry_modulus);
auto nbits = sizeof(Torus) * 8;
Torus delta = (static_cast<Torus>(1) << (nbits - 1)) /
(message_modulus * carry_modulus);

device_create_trivial_radix<Torus><<<grid, thds, 0, stream>>>(
lwe_array_out, scalar_array, num_scalar_blocks, lwe_dimension, delta);
Expand Down
16 changes: 9 additions & 7 deletions tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,15 @@ impl CudaCompressedCiphertextList {
let end_block_index = start_block_index + current_info.num_blocks() - 1;

Some((
decomp_key.unpack(
&self.packed_list,
current_info,
start_block_index,
end_block_index,
streams,
).unwrap(),
decomp_key
.unpack(
&self.packed_list,
current_info,
start_block_index,
end_block_index,
streams,
)
.unwrap(),
current_info,
))
}
Expand Down

0 comments on commit 3492e81

Please sign in to comment.