Skip to content

Commit

Permalink
chore(gpu): update asserts on base log now that we don't cast to u32 …
Browse files Browse the repository at this point in the history
…in decomposition
  • Loading branch information
agnesLeroy committed Nov 14, 2024
1 parent 5a664aa commit c81fd43
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -654,8 +654,8 @@ void cuda_programmable_bootstrap_lwe_ciphertext_vector_64(
int8_t *mem_ptr, uint32_t lwe_dimension, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t base_log, uint32_t level_count,
uint32_t num_samples, uint32_t lut_count, uint32_t lut_stride) {
if (base_log > 32)
PANIC("Cuda error (classical PBS): base log should be <= 32")
if (base_log > 64)
PANIC("Cuda error (classical PBS): base log should be <= 64")

pbs_buffer<uint64_t, CLASSICAL> *buffer =
(pbs_buffer<uint64_t, CLASSICAL> *)mem_ptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,6 @@ void cuda_cg_multi_bit_programmable_bootstrap_lwe_ciphertext_vector(
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
uint32_t lut_count, uint32_t lut_stride) {

if (base_log > 32)
PANIC("Cuda error (multi-bit PBS): base log should be <= 32")

switch (polynomial_size) {
case 256:
host_cg_multi_bit_programmable_bootstrap<Torus, AmortizedDegree<256>>(
Expand Down Expand Up @@ -147,9 +144,6 @@ void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector(
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
uint32_t lut_count, uint32_t lut_stride) {

if (base_log > 32)
PANIC("Cuda error (multi-bit PBS): base log should be <= 32")

switch (polynomial_size) {
case 256:
host_multi_bit_programmable_bootstrap<Torus, AmortizedDegree<256>>(
Expand Down Expand Up @@ -224,6 +218,9 @@ void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector_64(
uint32_t level_count, uint32_t num_samples, uint32_t lut_count,
uint32_t lut_stride) {

if (base_log > 64)
PANIC("Cuda error (multi-bit PBS): base log should be <= 64")

pbs_buffer<uint64_t, MULTI_BIT> *buffer =
(pbs_buffer<uint64_t, MULTI_BIT> *)mem_ptr;

Expand Down

0 comments on commit c81fd43

Please sign in to comment.