-
Notifications
You must be signed in to change notification settings - Fork 155
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor(gpu): fix sample extraction when nth > 0 and keep input unch…
…anged
- Loading branch information
Showing
15 changed files
with
617 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
#ifndef CUDA_FUNCTIONS_H_ | ||
#define CUDA_FUNCTIONS_H_ | ||
|
||
#include "polynomial/functions.cuh" | ||
#include "polynomial/parameters.cuh" | ||
#include <cstdint> | ||
|
||
extern "C" { | ||
void cuda_glwe_sample_extract_64(void **streams, uint32_t *gpu_indexes, | ||
uint32_t gpu_count, void *lwe_array_out, | ||
void *glwe_array_in, uint32_t *nth_array, | ||
uint32_t num_samples, uint32_t glwe_dimension, | ||
uint32_t polynomial_size); | ||
} | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
56 changes: 56 additions & 0 deletions
56
backends/tfhe-cuda-backend/cuda/src/polynomial/functions.cu
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
#include "functions.h" | ||
|
||
void cuda_glwe_sample_extract_64(void **streams, uint32_t *gpu_indexes, | ||
uint32_t gpu_count, void *lwe_array_out, | ||
void *glwe_in, uint32_t *nth_array, | ||
uint32_t num_samples, uint32_t glwe_dimension, | ||
uint32_t polynomial_size) { | ||
|
||
switch (polynomial_size) { | ||
case 256: | ||
host_sample_extract<uint64_t, AmortizedDegree<256>>( | ||
(cudaStream_t *)(streams), (uint64_t *)lwe_array_out, | ||
(uint64_t *)glwe_in, (uint32_t *)nth_array, num_samples, | ||
glwe_dimension); | ||
break; | ||
case 512: | ||
host_sample_extract<uint64_t, AmortizedDegree<512>>( | ||
(cudaStream_t *)(streams), (uint64_t *)lwe_array_out, | ||
(uint64_t *)glwe_in, (uint32_t *)nth_array, num_samples, | ||
glwe_dimension); | ||
break; | ||
case 1024: | ||
host_sample_extract<uint64_t, AmortizedDegree<1024>>( | ||
(cudaStream_t *)(streams), (uint64_t *)lwe_array_out, | ||
(uint64_t *)glwe_in, (uint32_t *)nth_array, num_samples, | ||
glwe_dimension); | ||
break; | ||
case 2048: | ||
host_sample_extract<uint64_t, AmortizedDegree<2048>>( | ||
(cudaStream_t *)(streams), (uint64_t *)lwe_array_out, | ||
(uint64_t *)glwe_in, (uint32_t *)nth_array, num_samples, | ||
glwe_dimension); | ||
break; | ||
case 4096: | ||
host_sample_extract<uint64_t, AmortizedDegree<4096>>( | ||
(cudaStream_t *)(streams), (uint64_t *)lwe_array_out, | ||
(uint64_t *)glwe_in, (uint32_t *)nth_array, num_samples, | ||
glwe_dimension); | ||
break; | ||
case 8192: | ||
host_sample_extract<uint64_t, AmortizedDegree<8192>>( | ||
(cudaStream_t *)(streams), (uint64_t *)lwe_array_out, | ||
(uint64_t *)glwe_in, (uint32_t *)nth_array, num_samples, | ||
glwe_dimension); | ||
break; | ||
case 16384: | ||
host_sample_extract<uint64_t, AmortizedDegree<16384>>( | ||
(cudaStream_t *)(streams), (uint64_t *)lwe_array_out, | ||
(uint64_t *)glwe_in, (uint32_t *)nth_array, num_samples, | ||
glwe_dimension); | ||
break; | ||
default: | ||
PANIC("Cuda error: unsupported polynomial size. Supported " | ||
"N's are powers of two in the interval [256..16384].") | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
52 changes: 52 additions & 0 deletions
52
tfhe/src/core_crypto/gpu/algorithms/glwe_sample_extraction.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
use crate::core_crypto::gpu::glwe_ciphertext::CudaGlweCiphertext; | ||
use crate::core_crypto::gpu::lwe_ciphertext::CudaLweCiphertext; | ||
use crate::core_crypto::gpu::vec::CudaVec; | ||
use crate::core_crypto::gpu::{extract_lwe_sample_from_glwe_ciphertext_async, CudaStreams}; | ||
use crate::core_crypto::prelude::{LweCiphertextCount, MonomialDegree, UnsignedTorus}; | ||
|
||
/// Extract the nth coefficient from the body of a [`GLWE Ciphertext`](`CudaGlweCiphertext`) as an | ||
/// [`LWE ciphertext`](`CudaLweCiphertext`). This variant is GPU-accelerated. | ||
pub fn cuda_extract_lwe_sample_from_glwe_ciphertext<Scalar>( | ||
input_glwe: &CudaGlweCiphertext<Scalar>, | ||
output_lwe: &mut CudaLweCiphertext<Scalar>, | ||
nth: MonomialDegree, | ||
streams: &CudaStreams, | ||
) where | ||
// CastInto required for PBS modulus switch which returns a usize | ||
Scalar: UnsignedTorus, | ||
{ | ||
let in_lwe_dim = input_glwe | ||
.glwe_dimension() | ||
.to_equivalent_lwe_dimension(input_glwe.polynomial_size()); | ||
|
||
let out_lwe_dim = output_lwe.lwe_dimension(); | ||
|
||
assert_eq!( | ||
in_lwe_dim, out_lwe_dim, | ||
"Mismatch between equivalent LweDimension of input ciphertext and output ciphertext. \ | ||
Got {in_lwe_dim:?} for input and {out_lwe_dim:?} for output.", | ||
); | ||
|
||
assert_eq!( | ||
input_glwe.ciphertext_modulus(), | ||
output_lwe.ciphertext_modulus(), | ||
"Mismatched moduli between input_glwe ({:?}) and output_lwe ({:?})", | ||
input_glwe.ciphertext_modulus(), | ||
output_lwe.ciphertext_modulus() | ||
); | ||
|
||
let nth_array: Vec<u32> = vec![nth.0 as u32]; | ||
let gpu_indexes = &streams.gpu_indexes; | ||
unsafe { | ||
let d_nth_array = CudaVec::from_cpu_async(&nth_array, streams, gpu_indexes[0]); | ||
extract_lwe_sample_from_glwe_ciphertext_async( | ||
streams, | ||
&mut output_lwe.0.d_vec, | ||
&input_glwe.0.d_vec, | ||
&d_nth_array, | ||
LweCiphertextCount(nth_array.len()), | ||
input_glwe.glwe_dimension(), | ||
input_glwe.polynomial_size(), | ||
); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.