Skip to content

Commit

Permalink
fix(gpu): add template parameter to packing keyswitch calls
Browse files Browse the repository at this point in the history
  • Loading branch information
guillermo-oyarzun committed Oct 16, 2024
1 parent d794f4d commit d780276
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -244,16 +244,16 @@ __global__ void packing_keyswitch_lwe_list_to_glwe(
auto ks_glwe_out = d_mem + input_id * glwe_accumulator_size;
auto glwe_out = glwe_array_out + input_id * glwe_accumulator_size;
// KS LWE to GLWE
packing_keyswitch_lwe_ciphertext_into_glwe_ciphertext(
packing_keyswitch_lwe_ciphertext_into_glwe_ciphertext<Torus>(
ks_glwe_out, lwe_in, fp_ksk, lwe_dimension_in, glwe_dimension,
polynomial_size, base_log, level_count);

// P * x ^degree
auto in_poly = ks_glwe_out + (tid / polynomial_size) * polynomial_size;
auto out_result = glwe_out + (tid / polynomial_size) * polynomial_size;
polynomial_accumulate_monic_monomial_mul(out_result, in_poly, degree,
tid % polynomial_size,
polynomial_size, 1, true);
polynomial_accumulate_monic_monomial_mul<Torus>(out_result, in_poly, degree,
tid % polynomial_size,
polynomial_size, 1, true);
}

/// To-do: Rewrite this kernel for efficiency
Expand Down Expand Up @@ -299,13 +299,13 @@ __host__ void host_packing_keyswitch_lwe_list_to_glwe(
auto d_tmp_glwe_array_out = d_mem + num_lwes * glwe_accumulator_size;

// individually keyswitch each lwe
packing_keyswitch_lwe_list_to_glwe<<<grid, threads, 0, stream>>>(
packing_keyswitch_lwe_list_to_glwe<Torus><<<grid, threads, 0, stream>>>(
d_tmp_glwe_array_out, lwe_array_in, fp_ksk_array, lwe_dimension_in,
glwe_dimension, polynomial_size, base_log, level_count, d_mem);
check_cuda_error(cudaGetLastError());

// accumulate to a single glwe
accumulate_glwes<<<num_blocks, threads, 0, stream>>>(
accumulate_glwes<Torus><<<num_blocks, threads, 0, stream>>>(
glwe_out, d_tmp_glwe_array_out, glwe_dimension, polynomial_size,
num_lwes);
check_cuda_error(cudaGetLastError());
Expand Down

0 comments on commit d780276

Please sign in to comment.