From c81fd439e897236db62e26973b5b98929db5485e Mon Sep 17 00:00:00 2001 From: Agnes Leroy Date: Thu, 14 Nov 2024 11:17:26 +0100 Subject: [PATCH] chore(gpu): update asserts on base log now that we don't cast to u32 in decomposition --- .../cuda/src/pbs/programmable_bootstrap_classic.cu | 4 ++-- .../cuda/src/pbs/programmable_bootstrap_multibit.cu | 9 +++------ 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_classic.cu b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_classic.cu index cd6275122a..dd3d446204 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_classic.cu +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_classic.cu @@ -654,8 +654,8 @@ void cuda_programmable_bootstrap_lwe_ciphertext_vector_64( int8_t *mem_ptr, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log, uint32_t level_count, uint32_t num_samples, uint32_t lut_count, uint32_t lut_stride) { - if (base_log > 32) - PANIC("Cuda error (classical PBS): base log should be <= 32") + if (base_log > 64) + PANIC("Cuda error (classical PBS): base log should be <= 64") pbs_buffer *buffer = (pbs_buffer *)mem_ptr; diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cu b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cu index 6c7418cada..72b8982549 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cu +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cu @@ -69,9 +69,6 @@ void cuda_cg_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( uint32_t base_log, uint32_t level_count, uint32_t num_samples, uint32_t lut_count, uint32_t lut_stride) { - if (base_log > 32) - PANIC("Cuda error (multi-bit PBS): base log should be <= 32") - switch (polynomial_size) { case 256: host_cg_multi_bit_programmable_bootstrap>( @@ -147,9 +144,6 @@ void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( uint32_t base_log, uint32_t level_count, uint32_t num_samples, uint32_t lut_count, uint32_t lut_stride) { - if (base_log > 32) - PANIC("Cuda error (multi-bit PBS): base log should be <= 32") - switch (polynomial_size) { case 256: host_multi_bit_programmable_bootstrap>( @@ -224,6 +218,9 @@ void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector_64( uint32_t level_count, uint32_t num_samples, uint32_t lut_count, uint32_t lut_stride) { + if (base_log > 64) + PANIC("Cuda error (multi-bit PBS): base log should be <= 64") + pbs_buffer *buffer = (pbs_buffer *)mem_ptr;