Skip to content

Commit

Permalink
feat(gpu): optimize packing keyswitch on gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
andrei-stoian-zama committed Jan 9, 2025
1 parent f633eed commit e120ed8
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 283 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,6 @@ template <typename Torus> uint64_t get_shared_mem_size_tgemm() {
return BLOCK_SIZE_GEMM * THREADS_GEMM * 2 * sizeof(Torus);
}

__host__ inline bool can_use_pks_fast_path(uint32_t lwe_dimension,
uint32_t num_lwe,
uint32_t polynomial_size,
uint32_t level_count,
uint32_t glwe_dimension) {
// TODO: activate it back, fix tests and extend to level_count > 1
return false;
}

// Initialize decomposition by performing rounding
// and decomposing one level of an array of Torus LWEs. Only
// decomposes the mask elements of the incoming LWEs.
Expand All @@ -57,13 +48,17 @@ __global__ void decompose_vectorize_init(Torus const *lwe_in, Torus *lwe_out,
// is lwe_dimension + 1, while for writing it is lwe_dimension
auto read_val_idx = lwe_idx * (lwe_dimension + 1) + lwe_sample_idx;
auto write_val_idx = lwe_idx * lwe_dimension + lwe_sample_idx;
auto write_state_idx =
num_lwe * lwe_dimension + lwe_idx * lwe_dimension + lwe_sample_idx;

Torus a_i = lwe_in[read_val_idx];

Torus state = init_decomposer_state(a_i, base_log, level_count);

Torus mod_b_mask = (1ll << base_log) - 1ll;
lwe_out[write_val_idx] = decompose_one<Torus>(state, mod_b_mask, base_log);
synchronize_threads_in_block();
lwe_out[write_state_idx] = state;
}

// Continue decomposiion of an array of Torus elements in place. Supposes
Expand All @@ -84,12 +79,16 @@ decompose_vectorize_step_inplace(Torus *buffer_in, uint32_t lwe_dimension,
return;

auto val_idx = lwe_idx * lwe_dimension + lwe_sample_idx;
auto state_idx = num_lwe * lwe_dimension + val_idx;

Torus state = buffer_in[val_idx];
Torus state = buffer_in[state_idx];
synchronize_threads_in_block();

Torus mod_b_mask = (1ll << base_log) - 1ll;

buffer_in[val_idx] = decompose_one<Torus>(state, mod_b_mask, base_log);
synchronize_threads_in_block();
buffer_in[state_idx] = state;
}

// Multiply matrices A, B of size (M, K), (K, N) respectively
Expand Down Expand Up @@ -152,7 +151,7 @@ __global__ void tgemm(int M, int N, int K, const Torus *A, const Torus *B,
} else {
Bs[innerRowB * BN + innerColB] = 0;
}
__syncthreads();
synchronize_threads_in_block();

// Advance blocktile for the next iteration of this loop
A += BK;
Expand All @@ -168,7 +167,7 @@ __global__ void tgemm(int M, int N, int K, const Torus *A, const Torus *B,
As[(threadRow * TM + resIdx) * BK + dotIdx] * tmp;
}
}
__syncthreads();
synchronize_threads_in_block();
}

// Initialize the pointer to the output block of size (BLOCK_SIZE_GEMM,
Expand Down Expand Up @@ -259,10 +258,6 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(

// Optimization of packing keyswitch when packing many LWEs

if (level_count > 1) {
PANIC("Fast path PKS only supports level_count==1");
}

cudaSetDevice(gpu_index);
check_cuda_error(cudaGetLastError());

Expand All @@ -273,10 +268,11 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
// buffer and the keyswitched GLWEs in the second half of the buffer. Thus the
// scratch buffer for the fast path must determine the half-size of the
// scratch buffer as the max between the size of the GLWE and the size of the
// LWE-mask
int memory_unit = glwe_accumulator_size > lwe_dimension
// LWE-mask times two (to keep both decomposition state and decomposed
// intermediate value)
int memory_unit = glwe_accumulator_size > lwe_dimension * 2
? glwe_accumulator_size
: lwe_dimension;
: lwe_dimension * 2;

// ping pong the buffer between successive calls
// split the buffer in two parts of this size
Expand Down Expand Up @@ -309,29 +305,28 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
CEIL_DIV(num_lwes, BLOCK_SIZE_GEMM));
dim3 threads_gemm(BLOCK_SIZE_GEMM * THREADS_GEMM);

auto stride_KSK_buffer = glwe_accumulator_size;
auto stride_KSK_buffer = glwe_accumulator_size * level_count;

uint32_t shared_mem_size = get_shared_mem_size_tgemm<Torus>();
tgemm<Torus, TorusVec><<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0, fp_ksk_array,
stride_KSK_buffer, d_mem_1);
check_cuda_error(cudaGetLastError());

/*
TODO: transpose key to generalize to level_count > 1
auto ksk_block_size = glwe_accumulator_size;

for (int li = 1; li < level_count; ++li) {
decompose_vectorize_step_inplace<Torus, TorusVec>
<<<grid_decomp, threads_decomp, 0, stream>>>(
d_mem_0, lwe_dimension, num_lwes, base_log, level_count);
check_cuda_error(cudaGetLastError());
for (int li = 1; li < level_count; ++li) {
decompose_vectorize_step_inplace<Torus, TorusVec>
<<<grid_decomp, threads_decomp, 0, stream>>>(
d_mem_0, lwe_dimension, num_lwes, base_log, level_count);
check_cuda_error(cudaGetLastError());

tgemm<Torus, TorusVec><<<grid_gemm, threads_gemm, shared_mem_size,
stream>>>( num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0,
fp_ksk_array + li * ksk_block_size, stride_KSK_buffer, d_mem_1);
check_cuda_error(cudaGetLastError());
}
*/
tgemm<Torus, TorusVec>
<<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0,
fp_ksk_array + li * ksk_block_size, stride_KSK_buffer, d_mem_1);
check_cuda_error(cudaGetLastError());
}

// should we include the mask in the rotation ??
dim3 grid_rotate(CEIL_DIV(num_lwes, BLOCK_SIZE_DECOMP),
Expand Down
25 changes: 7 additions & 18 deletions backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,24 +73,13 @@ void cuda_packing_keyswitch_lwe_list_to_glwe_64(
uint32_t output_polynomial_size, uint32_t base_log, uint32_t level_count,
uint32_t num_lwes) {

if (can_use_pks_fast_path(input_lwe_dimension, num_lwes,
output_polynomial_size, level_count,
output_glwe_dimension)) {
host_fast_packing_keyswitch_lwe_list_to_glwe<uint64_t, ulonglong4>(
static_cast<cudaStream_t>(stream), gpu_index,
static_cast<uint64_t *>(glwe_array_out),
static_cast<const uint64_t *>(lwe_array_in),
static_cast<const uint64_t *>(fp_ksk_array), fp_ks_buffer,
input_lwe_dimension, output_glwe_dimension, output_polynomial_size,
base_log, level_count, num_lwes);
} else
host_packing_keyswitch_lwe_list_to_glwe<uint64_t>(
static_cast<cudaStream_t>(stream), gpu_index,
static_cast<uint64_t *>(glwe_array_out),
static_cast<const uint64_t *>(lwe_array_in),
static_cast<const uint64_t *>(fp_ksk_array), fp_ks_buffer,
input_lwe_dimension, output_glwe_dimension, output_polynomial_size,
base_log, level_count, num_lwes);
host_fast_packing_keyswitch_lwe_list_to_glwe<uint64_t, ulonglong4>(
static_cast<cudaStream_t>(stream), gpu_index,
static_cast<uint64_t *>(glwe_array_out),
static_cast<const uint64_t *>(lwe_array_in),
static_cast<const uint64_t *>(fp_ksk_array), fp_ks_buffer,
input_lwe_dimension, output_glwe_dimension, output_polynomial_size,
base_log, level_count, num_lwes);
}

void cleanup_packing_keyswitch_lwe_list_to_glwe(void *stream,
Expand Down
92 changes: 4 additions & 88 deletions backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,11 @@ __host__ void scratch_packing_keyswitch_lwe_list_to_glwe(

int glwe_accumulator_size = (glwe_dimension + 1) * polynomial_size;

int memory_unit = glwe_accumulator_size > lwe_dimension
// allocate at least LWE-mask times two: to keep both decomposition state and
// decomposed intermediate value
int memory_unit = glwe_accumulator_size > lwe_dimension * 2
? glwe_accumulator_size
: lwe_dimension;
: lwe_dimension * 2;

if (allocate_gpu_memory) {
*fp_ks_buffer = (int8_t *)cuda_malloc_async(
Expand Down Expand Up @@ -221,44 +223,6 @@ __device__ void packing_keyswitch_lwe_ciphertext_into_glwe_ciphertext(
}
}

// public functional packing keyswitch for a batch of LWE ciphertexts
//
// Selects the input each thread is working on using the y-block index.
//
// Assumes there are (glwe_dimension+1) * polynomial_size threads split through
// different thread blocks at the x-axis to work on that input.
template <typename Torus>
__global__ void packing_keyswitch_lwe_list_to_glwe(
Torus *glwe_array_out, Torus const *lwe_array_in, Torus const *fp_ksk,
uint32_t lwe_dimension_in, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t base_log, uint32_t level_count,
Torus *d_mem) {
const int tid = threadIdx.x + blockIdx.x * blockDim.x;

const int glwe_accumulator_size = (glwe_dimension + 1) * polynomial_size;
const int lwe_size = (lwe_dimension_in + 1);

const int input_id = blockIdx.y;
const int degree = input_id;

// Select an input
auto lwe_in = lwe_array_in + input_id * lwe_size;
auto ks_glwe_out = d_mem + input_id * glwe_accumulator_size;
auto glwe_out = glwe_array_out + input_id * glwe_accumulator_size;

// KS LWE to GLWE
packing_keyswitch_lwe_ciphertext_into_glwe_ciphertext<Torus>(
ks_glwe_out, lwe_in, fp_ksk, lwe_dimension_in, glwe_dimension,
polynomial_size, base_log, level_count);

// P * x ^degree
auto in_poly = ks_glwe_out + (tid / polynomial_size) * polynomial_size;
auto out_result = glwe_out + (tid / polynomial_size) * polynomial_size;
polynomial_accumulate_monic_monomial_mul<Torus>(out_result, in_poly, degree,
tid % polynomial_size,
polynomial_size, 1, true);
}

/// To-do: Rewrite this kernel for efficiency
template <typename Torus>
__global__ void accumulate_glwes(Torus *glwe_out, Torus *glwe_array_in,
Expand All @@ -276,52 +240,4 @@ __global__ void accumulate_glwes(Torus *glwe_out, Torus *glwe_array_in,
}
}

template <typename Torus>
__host__ void host_packing_keyswitch_lwe_list_to_glwe(
cudaStream_t stream, uint32_t gpu_index, Torus *glwe_out,
Torus const *lwe_array_in, Torus const *fp_ksk_array, int8_t *fp_ks_buffer,
uint32_t lwe_dimension_in, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t base_log, uint32_t level_count,
uint32_t num_lwes) {

if (num_lwes > polynomial_size)
PANIC("Cuda error: too many LWEs to pack. The number of LWEs should be "
"smaller than "
"polynomial_size.")

cudaSetDevice(gpu_index);
int glwe_accumulator_size = (glwe_dimension + 1) * polynomial_size;

int num_blocks = 0, num_threads = 0;
getNumBlocksAndThreads(glwe_accumulator_size, 128, num_blocks, num_threads);

dim3 grid(num_blocks, num_lwes);
dim3 threads(num_threads);

// The fast path of PKS uses the scratch buffer (d_mem) differently:
// it needs to store the decomposed masks in the first half of this buffer
// and the keyswitched GLWEs in the second half of the buffer. Thus the
// scratch buffer for the fast path must determine the half-size of the
// scratch buffer as the max between the size of the GLWE and the size of the
// LWE-mask
int memory_unit = glwe_accumulator_size > lwe_dimension_in
? glwe_accumulator_size
: lwe_dimension_in;

auto d_mem = (Torus *)fp_ks_buffer;
auto d_tmp_glwe_array_out = d_mem + num_lwes * memory_unit;

// individually keyswitch each lwe
packing_keyswitch_lwe_list_to_glwe<Torus><<<grid, threads, 0, stream>>>(
d_tmp_glwe_array_out, lwe_array_in, fp_ksk_array, lwe_dimension_in,
glwe_dimension, polynomial_size, base_log, level_count, d_mem);
check_cuda_error(cudaGetLastError());

// accumulate to a single glwe
accumulate_glwes<Torus><<<num_blocks, threads, 0, stream>>>(
glwe_out, d_tmp_glwe_array_out, glwe_dimension, polynomial_size,
num_lwes);
check_cuda_error(cudaGetLastError());
}

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -117,21 +117,11 @@ host_integer_compress(cudaStream_t const *streams, uint32_t const *gpu_indexes,
while (rem_lwes > 0) {
auto chunk_size = min(rem_lwes, mem_ptr->lwe_per_glwe);

if (can_use_pks_fast_path(
input_lwe_dimension, chunk_size, compression_params.polynomial_size,
compression_params.ks_level, compression_params.glwe_dimension)) {
host_fast_packing_keyswitch_lwe_list_to_glwe<Torus, ulonglong4>(
streams[0], gpu_indexes[0], glwe_out, lwe_subset, fp_ksk[0],
fp_ks_buffer, input_lwe_dimension, compression_params.glwe_dimension,
compression_params.polynomial_size, compression_params.ks_base_log,
compression_params.ks_level, chunk_size);
} else {
host_packing_keyswitch_lwe_list_to_glwe<Torus>(
streams[0], gpu_indexes[0], glwe_out, lwe_subset, fp_ksk[0],
fp_ks_buffer, input_lwe_dimension, compression_params.glwe_dimension,
compression_params.polynomial_size, compression_params.ks_base_log,
compression_params.ks_level, chunk_size);
}
host_fast_packing_keyswitch_lwe_list_to_glwe<Torus, ulonglong4>(
streams[0], gpu_indexes[0], glwe_out, lwe_subset, fp_ksk[0],
fp_ks_buffer, input_lwe_dimension, compression_params.glwe_dimension,
compression_params.polynomial_size, compression_params.ks_base_log,
compression_params.ks_level, chunk_size);

rem_lwes -= chunk_size;
lwe_subset += chunk_size * lwe_in_size;
Expand Down
Loading

0 comments on commit e120ed8

Please sign in to comment.