Skip to content

Commit

Permalink
fix(gpu): minor fixes to compression
Browse files Browse the repository at this point in the history
  • Loading branch information
pdroalves committed Sep 30, 2024
1 parent 18f655c commit 316e840
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ template <typename Torus>
__host__ void host_pack(cudaStream_t stream, uint32_t gpu_index,
Torus *array_out, Torus *array_in, uint32_t num_glwes,
int_compression<Torus> *mem_ptr) {
if(array_in == array_out)
PANIC("Cuda error: Input and output must be different");
cudaSetDevice(gpu_index);
auto compression_params = mem_ptr->compression_params;

Expand All @@ -63,6 +65,7 @@ __host__ void host_pack(cudaStream_t stream, uint32_t gpu_index,
dim3 threads(num_threads);
pack<Torus><<<grid, threads, 0, stream>>>(array_out, array_in, log_modulus,
num_glwes, in_len, out_len);
check_cuda_error(cudaGetLastError());
}

template <typename Torus>
Expand Down Expand Up @@ -118,7 +121,6 @@ __host__ void host_integer_compress(cudaStream_t *streams,
num_glwes * (compression_params.glwe_dimension + 1) *
compression_params.polynomial_size,
mem_ptr->storage_log_modulus);
check_cuda_error(cudaGetLastError());

host_pack<Torus>(streams[0], gpu_indexes[0], glwe_array_out,
tmp_glwe_array_out, num_glwes, mem_ptr);
Expand Down Expand Up @@ -160,11 +162,15 @@ __global__ void extract(Torus *glwe_array_out, Torus *array_in, uint32_t index,
}
}

/// Extracts the glwe_index-nth GLWE ciphertext
template <typename Torus>
__host__ void host_extract(cudaStream_t stream, uint32_t gpu_index,
Torus *glwe_array_out, Torus *array_in,
uint32_t glwe_index,
int_decompression<Torus> *mem_ptr) {
if(array_in == glwe_array_out)
PANIC("Cuda error: Input and output must be different");

cudaSetDevice(gpu_index);

auto compression_params = mem_ptr->compression_params;
Expand Down Expand Up @@ -282,7 +288,6 @@ host_integer_decompress(cudaStream_t *streams, uint32_t *gpu_indexes,
auto lut = h_mem_ptr->carry_extract_lut;
auto active_gpu_count = get_active_gpu_count(num_lwes, gpu_count);
if (active_gpu_count == 1) {

execute_pbs_async<Torus>(
streams, gpu_indexes, active_gpu_count, d_lwe_array_out,
lut->lwe_indexes_out, lut->lut_vec, lut->lut_indexes_vec, extracted_lwe,
Expand Down
43 changes: 19 additions & 24 deletions tfhe/src/integer/gpu/list_compression/server_keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,32 +92,27 @@ impl CudaCompressionKey {
let lwe_ciphertext_count = LweCiphertextCount(total_num_blocks);

let gpu_index = streams.gpu_indexes[0];
let d_vec = unsafe {
let mut d_vec = CudaVec::new_async(
lwe_dimension.to_lwe_size().0 * lwe_ciphertext_count.0,
streams,
gpu_index,
let mut d_vec = CudaVec::new_async(
lwe_dimension.to_lwe_size().0 * lwe_ciphertext_count.0,
streams,
gpu_index,
);
let mut offset: usize = 0;
for ciphertext in vec_ciphertexts {
let dest_ptr = d_vec
.as_mut_c_ptr(gpu_index)
.add(offset * std::mem::size_of::<u64>());
let size = ciphertext.d_blocks.0.d_vec.len * std::mem::size_of::<u64>();
cuda_memcpy_async_gpu_to_gpu(
dest_ptr,
ciphertext.d_blocks.0.d_vec.as_c_ptr(gpu_index),
size as u64,
streams.ptr[gpu_index as usize],
streams.gpu_indexes[gpu_index as usize],
);
let mut offset: usize = 0;
for ciphertext in vec_ciphertexts {
let dest_ptr = d_vec
.as_mut_c_ptr(gpu_index)
.add(offset * std::mem::size_of::<u64>());
let size = ciphertext.d_blocks.0.d_vec.len * std::mem::size_of::<u64>();
cuda_memcpy_async_gpu_to_gpu(
dest_ptr,
ciphertext.d_blocks.0.d_vec.as_c_ptr(gpu_index),
size as u64,
streams.ptr[gpu_index as usize],
streams.gpu_indexes[gpu_index as usize],
);

offset += ciphertext.d_blocks.0.d_vec.len;
}

streams.synchronize();
d_vec
};
offset += ciphertext.d_blocks.0.d_vec.len;
}

CudaLweCiphertextList::from_cuda_vec(d_vec, lwe_ciphertext_count, ciphertext_modulus)
}
Expand Down

0 comments on commit 316e840

Please sign in to comment.