From e120ed840c9de12db6043aef99387e61a98e1338 Mon Sep 17 00:00:00 2001 From: Andrei Stoian Date: Mon, 30 Dec 2024 16:17:06 +0100 Subject: [PATCH] feat(gpu): optimize packing keyswitch on gpu --- .../src/crypto/fast_packing_keyswitch.cuh | 61 ++++----- .../cuda/src/crypto/keyswitch.cu | 25 +--- .../cuda/src/crypto/keyswitch.cuh | 92 +------------ .../src/integer/compression/compression.cuh | 20 +-- .../ciphertext/compressed_ciphertext_list.rs | 129 ------------------ 5 files changed, 44 insertions(+), 283 deletions(-) diff --git a/backends/tfhe-cuda-backend/cuda/src/crypto/fast_packing_keyswitch.cuh b/backends/tfhe-cuda-backend/cuda/src/crypto/fast_packing_keyswitch.cuh index d78dcb37a5..f399627bab 100644 --- a/backends/tfhe-cuda-backend/cuda/src/crypto/fast_packing_keyswitch.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/crypto/fast_packing_keyswitch.cuh @@ -26,15 +26,6 @@ template uint64_t get_shared_mem_size_tgemm() { return BLOCK_SIZE_GEMM * THREADS_GEMM * 2 * sizeof(Torus); } -__host__ inline bool can_use_pks_fast_path(uint32_t lwe_dimension, - uint32_t num_lwe, - uint32_t polynomial_size, - uint32_t level_count, - uint32_t glwe_dimension) { - // TODO: activate it back, fix tests and extend to level_count > 1 - return false; -} - // Initialize decomposition by performing rounding // and decomposing one level of an array of Torus LWEs. Only // decomposes the mask elements of the incoming LWEs. @@ -57,6 +48,8 @@ __global__ void decompose_vectorize_init(Torus const *lwe_in, Torus *lwe_out, // is lwe_dimension + 1, while for writing it is lwe_dimension auto read_val_idx = lwe_idx * (lwe_dimension + 1) + lwe_sample_idx; auto write_val_idx = lwe_idx * lwe_dimension + lwe_sample_idx; + auto write_state_idx = + num_lwe * lwe_dimension + lwe_idx * lwe_dimension + lwe_sample_idx; Torus a_i = lwe_in[read_val_idx]; @@ -64,6 +57,8 @@ __global__ void decompose_vectorize_init(Torus const *lwe_in, Torus *lwe_out, Torus mod_b_mask = (1ll << base_log) - 1ll; lwe_out[write_val_idx] = decompose_one(state, mod_b_mask, base_log); + synchronize_threads_in_block(); + lwe_out[write_state_idx] = state; } // Continue decomposiion of an array of Torus elements in place. Supposes @@ -84,12 +79,16 @@ decompose_vectorize_step_inplace(Torus *buffer_in, uint32_t lwe_dimension, return; auto val_idx = lwe_idx * lwe_dimension + lwe_sample_idx; + auto state_idx = num_lwe * lwe_dimension + val_idx; - Torus state = buffer_in[val_idx]; + Torus state = buffer_in[state_idx]; + synchronize_threads_in_block(); Torus mod_b_mask = (1ll << base_log) - 1ll; buffer_in[val_idx] = decompose_one(state, mod_b_mask, base_log); + synchronize_threads_in_block(); + buffer_in[state_idx] = state; } // Multiply matrices A, B of size (M, K), (K, N) respectively @@ -152,7 +151,7 @@ __global__ void tgemm(int M, int N, int K, const Torus *A, const Torus *B, } else { Bs[innerRowB * BN + innerColB] = 0; } - __syncthreads(); + synchronize_threads_in_block(); // Advance blocktile for the next iteration of this loop A += BK; @@ -168,7 +167,7 @@ __global__ void tgemm(int M, int N, int K, const Torus *A, const Torus *B, As[(threadRow * TM + resIdx) * BK + dotIdx] * tmp; } } - __syncthreads(); + synchronize_threads_in_block(); } // Initialize the pointer to the output block of size (BLOCK_SIZE_GEMM, @@ -259,10 +258,6 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe( // Optimization of packing keyswitch when packing many LWEs - if (level_count > 1) { - PANIC("Fast path PKS only supports level_count==1"); - } - cudaSetDevice(gpu_index); check_cuda_error(cudaGetLastError()); @@ -273,10 +268,11 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe( // buffer and the keyswitched GLWEs in the second half of the buffer. Thus the // scratch buffer for the fast path must determine the half-size of the // scratch buffer as the max between the size of the GLWE and the size of the - // LWE-mask - int memory_unit = glwe_accumulator_size > lwe_dimension + // LWE-mask times two (to keep both decomposition state and decomposed + // intermediate value) + int memory_unit = glwe_accumulator_size > lwe_dimension * 2 ? glwe_accumulator_size - : lwe_dimension; + : lwe_dimension * 2; // ping pong the buffer between successive calls // split the buffer in two parts of this size @@ -309,7 +305,7 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe( CEIL_DIV(num_lwes, BLOCK_SIZE_GEMM)); dim3 threads_gemm(BLOCK_SIZE_GEMM * THREADS_GEMM); - auto stride_KSK_buffer = glwe_accumulator_size; + auto stride_KSK_buffer = glwe_accumulator_size * level_count; uint32_t shared_mem_size = get_shared_mem_size_tgemm(); tgemm<<>>( @@ -317,21 +313,20 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe( stride_KSK_buffer, d_mem_1); check_cuda_error(cudaGetLastError()); - /* - TODO: transpose key to generalize to level_count > 1 + auto ksk_block_size = glwe_accumulator_size; - for (int li = 1; li < level_count; ++li) { - decompose_vectorize_step_inplace - <<>>( - d_mem_0, lwe_dimension, num_lwes, base_log, level_count); - check_cuda_error(cudaGetLastError()); + for (int li = 1; li < level_count; ++li) { + decompose_vectorize_step_inplace + <<>>( + d_mem_0, lwe_dimension, num_lwes, base_log, level_count); + check_cuda_error(cudaGetLastError()); - tgemm<<>>( num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0, - fp_ksk_array + li * ksk_block_size, stride_KSK_buffer, d_mem_1); - check_cuda_error(cudaGetLastError()); - } - */ + tgemm + <<>>( + num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0, + fp_ksk_array + li * ksk_block_size, stride_KSK_buffer, d_mem_1); + check_cuda_error(cudaGetLastError()); + } // should we include the mask in the rotation ?? dim3 grid_rotate(CEIL_DIV(num_lwes, BLOCK_SIZE_DECOMP), diff --git a/backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cu b/backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cu index b3f63176df..e8b7c86483 100644 --- a/backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cu +++ b/backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cu @@ -73,24 +73,13 @@ void cuda_packing_keyswitch_lwe_list_to_glwe_64( uint32_t output_polynomial_size, uint32_t base_log, uint32_t level_count, uint32_t num_lwes) { - if (can_use_pks_fast_path(input_lwe_dimension, num_lwes, - output_polynomial_size, level_count, - output_glwe_dimension)) { - host_fast_packing_keyswitch_lwe_list_to_glwe( - static_cast(stream), gpu_index, - static_cast(glwe_array_out), - static_cast(lwe_array_in), - static_cast(fp_ksk_array), fp_ks_buffer, - input_lwe_dimension, output_glwe_dimension, output_polynomial_size, - base_log, level_count, num_lwes); - } else - host_packing_keyswitch_lwe_list_to_glwe( - static_cast(stream), gpu_index, - static_cast(glwe_array_out), - static_cast(lwe_array_in), - static_cast(fp_ksk_array), fp_ks_buffer, - input_lwe_dimension, output_glwe_dimension, output_polynomial_size, - base_log, level_count, num_lwes); + host_fast_packing_keyswitch_lwe_list_to_glwe( + static_cast(stream), gpu_index, + static_cast(glwe_array_out), + static_cast(lwe_array_in), + static_cast(fp_ksk_array), fp_ks_buffer, + input_lwe_dimension, output_glwe_dimension, output_polynomial_size, + base_log, level_count, num_lwes); } void cleanup_packing_keyswitch_lwe_list_to_glwe(void *stream, diff --git a/backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh b/backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh index 5b79e5b892..0980543c35 100644 --- a/backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh @@ -164,9 +164,11 @@ __host__ void scratch_packing_keyswitch_lwe_list_to_glwe( int glwe_accumulator_size = (glwe_dimension + 1) * polynomial_size; - int memory_unit = glwe_accumulator_size > lwe_dimension + // allocate at least LWE-mask times two: to keep both decomposition state and + // decomposed intermediate value + int memory_unit = glwe_accumulator_size > lwe_dimension * 2 ? glwe_accumulator_size - : lwe_dimension; + : lwe_dimension * 2; if (allocate_gpu_memory) { *fp_ks_buffer = (int8_t *)cuda_malloc_async( @@ -221,44 +223,6 @@ __device__ void packing_keyswitch_lwe_ciphertext_into_glwe_ciphertext( } } -// public functional packing keyswitch for a batch of LWE ciphertexts -// -// Selects the input each thread is working on using the y-block index. -// -// Assumes there are (glwe_dimension+1) * polynomial_size threads split through -// different thread blocks at the x-axis to work on that input. -template -__global__ void packing_keyswitch_lwe_list_to_glwe( - Torus *glwe_array_out, Torus const *lwe_array_in, Torus const *fp_ksk, - uint32_t lwe_dimension_in, uint32_t glwe_dimension, - uint32_t polynomial_size, uint32_t base_log, uint32_t level_count, - Torus *d_mem) { - const int tid = threadIdx.x + blockIdx.x * blockDim.x; - - const int glwe_accumulator_size = (glwe_dimension + 1) * polynomial_size; - const int lwe_size = (lwe_dimension_in + 1); - - const int input_id = blockIdx.y; - const int degree = input_id; - - // Select an input - auto lwe_in = lwe_array_in + input_id * lwe_size; - auto ks_glwe_out = d_mem + input_id * glwe_accumulator_size; - auto glwe_out = glwe_array_out + input_id * glwe_accumulator_size; - - // KS LWE to GLWE - packing_keyswitch_lwe_ciphertext_into_glwe_ciphertext( - ks_glwe_out, lwe_in, fp_ksk, lwe_dimension_in, glwe_dimension, - polynomial_size, base_log, level_count); - - // P * x ^degree - auto in_poly = ks_glwe_out + (tid / polynomial_size) * polynomial_size; - auto out_result = glwe_out + (tid / polynomial_size) * polynomial_size; - polynomial_accumulate_monic_monomial_mul(out_result, in_poly, degree, - tid % polynomial_size, - polynomial_size, 1, true); -} - /// To-do: Rewrite this kernel for efficiency template __global__ void accumulate_glwes(Torus *glwe_out, Torus *glwe_array_in, @@ -276,52 +240,4 @@ __global__ void accumulate_glwes(Torus *glwe_out, Torus *glwe_array_in, } } -template -__host__ void host_packing_keyswitch_lwe_list_to_glwe( - cudaStream_t stream, uint32_t gpu_index, Torus *glwe_out, - Torus const *lwe_array_in, Torus const *fp_ksk_array, int8_t *fp_ks_buffer, - uint32_t lwe_dimension_in, uint32_t glwe_dimension, - uint32_t polynomial_size, uint32_t base_log, uint32_t level_count, - uint32_t num_lwes) { - - if (num_lwes > polynomial_size) - PANIC("Cuda error: too many LWEs to pack. The number of LWEs should be " - "smaller than " - "polynomial_size.") - - cudaSetDevice(gpu_index); - int glwe_accumulator_size = (glwe_dimension + 1) * polynomial_size; - - int num_blocks = 0, num_threads = 0; - getNumBlocksAndThreads(glwe_accumulator_size, 128, num_blocks, num_threads); - - dim3 grid(num_blocks, num_lwes); - dim3 threads(num_threads); - - // The fast path of PKS uses the scratch buffer (d_mem) differently: - // it needs to store the decomposed masks in the first half of this buffer - // and the keyswitched GLWEs in the second half of the buffer. Thus the - // scratch buffer for the fast path must determine the half-size of the - // scratch buffer as the max between the size of the GLWE and the size of the - // LWE-mask - int memory_unit = glwe_accumulator_size > lwe_dimension_in - ? glwe_accumulator_size - : lwe_dimension_in; - - auto d_mem = (Torus *)fp_ks_buffer; - auto d_tmp_glwe_array_out = d_mem + num_lwes * memory_unit; - - // individually keyswitch each lwe - packing_keyswitch_lwe_list_to_glwe<<>>( - d_tmp_glwe_array_out, lwe_array_in, fp_ksk_array, lwe_dimension_in, - glwe_dimension, polynomial_size, base_log, level_count, d_mem); - check_cuda_error(cudaGetLastError()); - - // accumulate to a single glwe - accumulate_glwes<<>>( - glwe_out, d_tmp_glwe_array_out, glwe_dimension, polynomial_size, - num_lwes); - check_cuda_error(cudaGetLastError()); -} - #endif diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/compression/compression.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/compression/compression.cuh index 93fbf33bf8..bbcf0aafef 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/compression/compression.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/compression/compression.cuh @@ -117,21 +117,11 @@ host_integer_compress(cudaStream_t const *streams, uint32_t const *gpu_indexes, while (rem_lwes > 0) { auto chunk_size = min(rem_lwes, mem_ptr->lwe_per_glwe); - if (can_use_pks_fast_path( - input_lwe_dimension, chunk_size, compression_params.polynomial_size, - compression_params.ks_level, compression_params.glwe_dimension)) { - host_fast_packing_keyswitch_lwe_list_to_glwe( - streams[0], gpu_indexes[0], glwe_out, lwe_subset, fp_ksk[0], - fp_ks_buffer, input_lwe_dimension, compression_params.glwe_dimension, - compression_params.polynomial_size, compression_params.ks_base_log, - compression_params.ks_level, chunk_size); - } else { - host_packing_keyswitch_lwe_list_to_glwe( - streams[0], gpu_indexes[0], glwe_out, lwe_subset, fp_ksk[0], - fp_ks_buffer, input_lwe_dimension, compression_params.glwe_dimension, - compression_params.polynomial_size, compression_params.ks_base_log, - compression_params.ks_level, chunk_size); - } + host_fast_packing_keyswitch_lwe_list_to_glwe( + streams[0], gpu_indexes[0], glwe_out, lwe_subset, fp_ksk[0], + fp_ks_buffer, input_lwe_dimension, compression_params.glwe_dimension, + compression_params.polynomial_size, compression_params.ks_base_log, + compression_params.ks_level, chunk_size); rem_lwes -= chunk_size; lwe_subset += chunk_size * lwe_in_size; diff --git a/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs b/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs index 790fc0108a..20b0155a54 100644 --- a/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs +++ b/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs @@ -718,133 +718,4 @@ mod tests { } } } - - //#[test] - //fn test_gpu_ciphertext_compression_fast_path() { - // /// Implement a test only for the storage of ciphertexts - // /// using a custom parameter set which is supported by a fast-path - // /// packing keyswitch (only for level_count==1) - // const COMP_PARAM_CUSTOM_FAST_PATH: CompressionParameters = CompressionParameters { - // br_level: DecompositionLevelCount(1), - // br_base_log: DecompositionBaseLog(21), - // packing_ks_level: DecompositionLevelCount(1), - // packing_ks_base_log: DecompositionBaseLog(19), - // packing_ks_polynomial_size: PolynomialSize(2048), - // packing_ks_glwe_dimension: GlweDimension(1), - // lwe_per_glwe: LweCiphertextCount(2048), - // storage_log_modulus: CiphertextModulusLog(55), - // packing_ks_key_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev( - // StandardDev(2.845267479601915e-15), - // ), - // }; - - // const NUM_BLOCKS: usize = 32; - - // let streams = CudaStreams::new_multi_gpu(); - - // let (radix_cks, sks) = gen_keys_radix_gpu( - // PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - // NUM_BLOCKS, - // &streams, - // ); - // let cks = radix_cks.as_ref(); - - // let private_compression_key = - // cks.new_compression_private_key(COMP_PARAM_CUSTOM_FAST_PATH); - - // let (cuda_compression_key, cuda_decompression_key) = - // radix_cks.new_cuda_compression_decompression_keys(&private_compression_key, &streams); - - // const MAX_NB_MESSAGES: usize = 2 * COMP_PARAM_CUSTOM_FAST_PATH.lwe_per_glwe.0 / - // NUM_BLOCKS; - - // let mut rng = rand::thread_rng(); - - // let message_modulus: u128 = cks.parameters().message_modulus().0 as u128; - - // // Hybrid - // enum MessageType { - // Unsigned(u128), - // Signed(i128), - // Boolean(bool), - // } - // for _ in 0..NB_OPERATOR_TESTS { - // let mut builder = CudaCompressedCiphertextListBuilder::new(); - - // let nb_messages = rng.gen_range(1..=MAX_NB_MESSAGES as u64); - // let mut messages = vec![]; - // for _ in 0..nb_messages { - // let case_selector = rng.gen_range(0..3); - // match case_selector { - // 0 => { - // // Unsigned - // let modulus = message_modulus.pow(NUM_BLOCKS as u32); - // let message = rng.gen::() % modulus; - // let ct = radix_cks.encrypt(message); - // let d_ct = - // CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct, &streams); - // let d_and_ct = sks.bitand(&d_ct, &d_ct, &streams); - // builder.push(d_and_ct, &streams); - // messages.push(MessageType::Unsigned(message)); - // } - // 1 => { - // // Signed - // let modulus = message_modulus.pow((NUM_BLOCKS - 1) as u32) as i128; - // let message = rng.gen::() % modulus; - // let ct = radix_cks.encrypt_signed(message); - // let d_ct = - // CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&ct, - // &streams); let d_and_ct = sks.bitand(&d_ct, &d_ct, &streams); - // builder.push(d_and_ct, &streams); - // messages.push(MessageType::Signed(message)); - // } - // _ => { - // // Boolean - // let message = rng.gen::() % 2 != 0; - // let ct = radix_cks.encrypt_bool(message); - // let d_boolean_ct = CudaBooleanBlock::from_boolean_block(&ct, &streams); - // let d_ct = d_boolean_ct.0; - // let d_and_boolean_ct = - // CudaBooleanBlock::from_cuda_radix_ciphertext(d_ct.ciphertext); - // builder.push(d_and_boolean_ct, &streams); - // messages.push(MessageType::Boolean(message)); - // } - // } - // } - - // let cuda_compressed = builder.build(&cuda_compression_key, &streams); - - // for (i, val) in messages.iter().enumerate() { - // match val { - // MessageType::Unsigned(message) => { - // let d_decompressed: CudaUnsignedRadixCiphertext = cuda_compressed - // .get(i, &cuda_decompression_key, &streams) - // .unwrap() - // .unwrap(); - // let decompressed = d_decompressed.to_radix_ciphertext(&streams); - // let decrypted: u128 = radix_cks.decrypt(&decompressed); - // assert_eq!(decrypted, *message); - // } - // MessageType::Signed(message) => { - // let d_decompressed: CudaSignedRadixCiphertext = cuda_compressed - // .get(i, &cuda_decompression_key, &streams) - // .unwrap() - // .unwrap(); - // let decompressed = d_decompressed.to_signed_radix_ciphertext(&streams); - // let decrypted: i128 = radix_cks.decrypt_signed(&decompressed); - // assert_eq!(decrypted, *message); - // } - // MessageType::Boolean(message) => { - // let d_decompressed: CudaBooleanBlock = cuda_compressed - // .get(i, &cuda_decompression_key, &streams) - // .unwrap() - // .unwrap(); - // let decompressed = d_decompressed.to_boolean_block(&streams); - // let decrypted = radix_cks.decrypt_bool(&decompressed); - // assert_eq!(decrypted, *message); - // } - // } - // } - // } - //} }