Skip to content

Commit

Permalink
fix(gpu): add missing synchronize in scalar add, refactor scalar add …
Browse files Browse the repository at this point in the history
…on cuda side
  • Loading branch information
agnesLeroy committed Sep 16, 2024
1 parent 8299e1c commit 9633b61
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@ __global__ void device_integer_radix_scalar_addition_inplace(

int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < num_blocks) {
Torus scalar = scalar_input[tid];
Torus *body = lwe_array + tid * (lwe_dimension + 1) + lwe_dimension;

*body += scalar * delta;
lwe_array[tid * (lwe_dimension + 1) + lwe_dimension] +=
scalar_input[tid] * delta;
}
}

Expand Down
1 change: 1 addition & 0 deletions tfhe/src/integer/gpu/server_key/radix/scalar_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ impl CudaServerKey {
unsafe {
carry_out = self.propagate_single_carry_assign_async(ct_left, stream);
}
stream.synchronize();

let num_scalar_blocks =
BlockDecomposer::with_early_stop_at_zero(scalar, self.message_modulus.0.ilog2())
Expand Down

0 comments on commit 9633b61

Please sign in to comment.