Skip to content

Commit

Permalink
refactor(gpu): fix sample extraction when nth > 0 and keep input unch…
Browse files Browse the repository at this point in the history
…anged
  • Loading branch information
pdroalves committed Jul 29, 2024
1 parent 2004333 commit 249a7f0
Show file tree
Hide file tree
Showing 11 changed files with 353 additions and 34 deletions.
6 changes: 6 additions & 0 deletions backends/tfhe-cuda-backend/cuda/include/ciphertext.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef CUDA_CIPHERTEXT_H
#define CUDA_CIPHERTEXT_H

#include "device.h"
#include <cstdint>

extern "C" {
Expand All @@ -14,5 +15,10 @@ void cuda_convert_lwe_ciphertext_vector_to_cpu_64(void *stream,
void *dest, void *src,
uint32_t number_of_cts,
uint32_t lwe_dimension);

void cuda_glwe_sample_extract_64(void **streams, void *lwe_array_out,
void *glwe_array_in, uint32_t *nth_array,
uint32_t num_glwes, uint32_t glwe_dimension,
uint32_t polynomial_size);
};
#endif
55 changes: 55 additions & 0 deletions backends/tfhe-cuda-backend/cuda/src/crypto/ciphertext.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "ciphertext.cuh"
#include "polynomial/parameters.cuh"

void cuda_convert_lwe_ciphertext_vector_to_gpu_64(void *stream,
uint32_t gpu_index,
Expand All @@ -19,3 +20,57 @@ void cuda_convert_lwe_ciphertext_vector_to_cpu_64(void *stream,
static_cast<cudaStream_t>(stream), gpu_index, (uint64_t *)dest,
(uint64_t *)src, number_of_cts, lwe_dimension);
}

void cuda_glwe_sample_extract_64(void **streams, void *lwe_array_out,
void *glwe_array_in, uint32_t *nth_array,
uint32_t num_glwes, uint32_t glwe_dimension,
uint32_t polynomial_size) {

switch (polynomial_size) {
case 256:
host_sample_extract<uint64_t, AmortizedDegree<256>>(
(cudaStream_t *)(streams), (uint64_t *)lwe_array_out,
(uint64_t *)glwe_array_in, (uint32_t *)nth_array, num_glwes,
glwe_dimension);
break;
case 512:
host_sample_extract<uint64_t, AmortizedDegree<512>>(
(cudaStream_t *)(streams), (uint64_t *)lwe_array_out,
(uint64_t *)glwe_array_in, (uint32_t *)nth_array, num_glwes,
glwe_dimension);
break;
case 1024:
host_sample_extract<uint64_t, AmortizedDegree<1024>>(
(cudaStream_t *)(streams), (uint64_t *)lwe_array_out,
(uint64_t *)glwe_array_in, (uint32_t *)nth_array, num_glwes,
glwe_dimension);
break;
case 2048:
host_sample_extract<uint64_t, AmortizedDegree<2048>>(
(cudaStream_t *)(streams), (uint64_t *)lwe_array_out,
(uint64_t *)glwe_array_in, (uint32_t *)nth_array, num_glwes,
glwe_dimension);
break;
case 4096:
host_sample_extract<uint64_t, AmortizedDegree<4096>>(
(cudaStream_t *)(streams), (uint64_t *)lwe_array_out,
(uint64_t *)glwe_array_in, (uint32_t *)nth_array, num_glwes,
glwe_dimension);
break;
case 8192:
host_sample_extract<uint64_t, AmortizedDegree<8192>>(
(cudaStream_t *)(streams), (uint64_t *)lwe_array_out,
(uint64_t *)glwe_array_in, (uint32_t *)nth_array, num_glwes,
glwe_dimension);
break;
case 16384:
host_sample_extract<uint64_t, AmortizedDegree<16384>>(
(cudaStream_t *)(streams), (uint64_t *)lwe_array_out,
(uint64_t *)glwe_array_in, (uint32_t *)nth_array, num_glwes,
glwe_dimension);
break;
default:
PANIC("Cuda error: unsupported polynomial size. Supported "
"N's are powers of two in the interval [256..16384].")
}
}
35 changes: 35 additions & 0 deletions backends/tfhe-cuda-backend/cuda/src/crypto/ciphertext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "ciphertext.h"
#include "device.h"
#include "polynomial/functions.cuh"
#include <cstdint>

template <typename T>
Expand All @@ -25,4 +26,38 @@ void cuda_convert_lwe_ciphertext_vector_to_cpu(cudaStream_t stream,
cuda_memcpy_async_to_cpu(dest, src, size, stream, gpu_index);
}

template <typename Torus, class params>
__global__ void apply_sample_extract(Torus *lwe_array_out, Torus *glwe_array_in,
uint32_t *nth_array,
uint32_t glwe_dimension) {

const int input_id = blockIdx.x;

const int glwe_input_size = (glwe_dimension + 1) * params::degree;
const int lwe_output_size = glwe_dimension * params::degree + 1;

auto lwe_out = lwe_array_out + input_id * lwe_output_size;

// We assume each GLWE will store the first polynomial_size inputs
uint32_t nth_per_glwe = params::degree;
auto glwe_in = glwe_array_in + (input_id / nth_per_glwe) * glwe_input_size;

auto nth = nth_array[input_id];

sample_extract_mask<Torus, params>(lwe_out, glwe_in, glwe_dimension, nth);
sample_extract_body<Torus, params>(lwe_out, glwe_in, glwe_dimension, nth);
}

template <typename Torus, class params>
__host__ void host_sample_extract(cudaStream_t *streams, Torus *lwe_array_out,
Torus *glwe_array_in, uint32_t *nth_array,
uint32_t num_glwes, uint32_t glwe_dimension) {

dim3 grid(num_glwes);
dim3 thds(params::degree / params::opt);
apply_sample_extract<Torus, params><<<grid, thds, 0, streams[0]>>>(
lwe_array_out, glwe_array_in, nth_array, glwe_dimension);
check_cuda_error(cudaGetLastError());
}

#endif
58 changes: 24 additions & 34 deletions backends/tfhe-cuda-backend/cuda/src/polynomial/functions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -191,65 +191,55 @@ __device__ void add_to_torus(double2 *m_values, Torus *result,

// Extracts the body of the nth-LWE in a GLWE.
template <typename Torus, class params>
__device__ void sample_extract_body(Torus *lwe_array_out, Torus *accumulator,
__device__ void sample_extract_body(Torus *lwe_array_out, Torus *glwe,
uint32_t glwe_dimension, uint32_t nth = 0) {
// Set first coefficient of the accumulator as the body of the LWE sample
// Set first coefficient of the glwe as the body of the LWE sample
lwe_array_out[glwe_dimension * params::degree] =
accumulator[glwe_dimension * params::degree + nth];
glwe[glwe_dimension * params::degree + nth];
}

// Extracts the mask from the nth-LWE in a GLWE.
template <typename Torus, class params>
__device__ void sample_extract_mask(Torus *lwe_array_out, Torus *accumulator,
uint32_t num_poly = 1, uint32_t nth = 0) {
for (int z = 0; z < num_poly; z++) {
__device__ void sample_extract_mask(Torus *lwe_array_out, Torus *glwe,
uint32_t glwe_dimension = 1,
uint32_t nth = 0) {
for (int z = 0; z < glwe_dimension; z++) {
Torus *lwe_array_out_slice =
(Torus *)lwe_array_out + (ptrdiff_t)(z * params::degree);
Torus *accumulator_slice =
(Torus *)accumulator + (ptrdiff_t)(z * params::degree);
Torus *glwe_slice = (Torus *)glwe + (ptrdiff_t)(z * params::degree);

synchronize_threads_in_block();
// Reverse the accumulator
// Reverse the glwe
// Set ACC = -ACC
int tid = threadIdx.x;
Torus result[params::opt];
#pragma unroll
for (int i = 0; i < params::opt; i++) {
result[i] = accumulator_slice[params::degree - tid - 1];
tid = tid + params::degree / params::opt;
}
synchronize_threads_in_block();

// Set ACC = -ACC
tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt; i++) {
accumulator_slice[tid] =
SEL(-result[i], result[i], tid >= params::degree - nth);
auto x = glwe_slice[params::degree - tid - 1];
result[i] = SEL(-x, x, tid >= params::degree - nth);
tid = tid + params::degree / params::opt;
}
synchronize_threads_in_block();

// Perform ACC * X
// (equivalent to multiply_by_monomial_negacyclic_inplace(1))
// Copy to the mask of the LWE sample
tid = threadIdx.x;
result[params::opt];
for (int i = 0; i < params::opt; i++) {
// if (tid < 1)
// result[i] = -accumulator_slice[tid - 1 + params::degree];
// result[i] = -glwe_slice[tid - 1 + params::degree];
// else
// result[i] = accumulator_slice[tid - 1];
int x = tid - 1 + SEL(0, params::degree - nth, tid < 1);
result[i] = SEL(1, -1, tid < 1) * accumulator_slice[x];
tid += params::degree / params::opt;
}
synchronize_threads_in_block();
// result[i] = glwe_slice[tid - 1];
uint32_t dst_idx = tid + 1 + nth;
if (dst_idx == params::degree)
lwe_array_out_slice[0] = -result[i];
else {
dst_idx =
SEL(dst_idx, dst_idx - params::degree, dst_idx >= params::degree);
lwe_array_out_slice[dst_idx] = result[i];
}

// Copy to the mask of the LWE sample
tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt; i++) {
lwe_array_out_slice[tid] = result[i];
tid = tid + params::degree / params::opt;
tid += params::degree / params::opt;
}
}
}
Expand Down
9 changes: 9 additions & 0 deletions backends/tfhe-cuda-backend/src/cuda_bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,15 @@ extern "C" {
gpu_count: u32,
mem_ptr: *mut *mut i8,
);
pub fn cuda_glwe_sample_extract_64(
streams: *const *mut c_void,
lwe_array_out: *mut c_void,
glwe_array_in: *const c_void,
nth_array: *const u32,
num_glwes: u32,
glwe_dimension: u32,
polynomial_size: u32,
);

pub fn scratch_cuda_integer_radix_comparison_kb_64(
streams: *const *mut c_void,
Expand Down
60 changes: 60 additions & 0 deletions tfhe/src/core_crypto/gpu/algorithms/glwe_sample_extraction.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
use crate::core_crypto::gpu::glwe_ciphertext_list::CudaGlweCiphertextList;
use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
use crate::core_crypto::gpu::vec::CudaVec;
use crate::core_crypto::gpu::{extract_lwe_samples_from_glwe_ciphertext_list_async, CudaStreams};
use crate::core_crypto::prelude::{MonomialDegree, UnsignedTorus};
use itertools::Itertools;

/// For each [`GLWE Ciphertext`] (`CudaGlweCiphertextList`) given as input, extract the nth
/// coefficient from its body as an [`LWE ciphertext`](`CudaLweCiphertextList`). This variant is
/// GPU-accelerated.
pub fn cuda_extract_lwe_samples_from_glwe_ciphertext_list<Scalar>(
input_glwe_list: &CudaGlweCiphertextList<Scalar>,
output_lwe_list: &mut CudaLweCiphertextList<Scalar>,
vec_nth: &[MonomialDegree],
streams: &CudaStreams,
) where
// CastInto required for PBS modulus switch which returns a usize
Scalar: UnsignedTorus,
{
let in_lwe_dim = input_glwe_list
.glwe_dimension()
.to_equivalent_lwe_dimension(input_glwe_list.polynomial_size());

let out_lwe_dim = output_lwe_list.lwe_dimension();

assert_eq!(
in_lwe_dim, out_lwe_dim,
"Mismatch between equivalent LweDimension of input ciphertext and output ciphertext. \
Got {in_lwe_dim:?} for input and {out_lwe_dim:?} for output.",
);

assert_eq!(
vec_nth.len(),
input_glwe_list.glwe_ciphertext_count().0 * input_glwe_list.polynomial_size().0,
"Mismatch between number of nths and number of GLWEs provided.",
);

assert_eq!(
input_glwe_list.ciphertext_modulus(),
output_lwe_list.ciphertext_modulus(),
"Mismatched moduli between input_glwe ({:?}) and output_lwe ({:?})",
input_glwe_list.ciphertext_modulus(),
output_lwe_list.ciphertext_modulus()
);

let nth_array: Vec<u32> = vec_nth.iter().map(|x| x.0 as u32).collect_vec();
let gpu_indexes = &streams.gpu_indexes;
unsafe {
let d_nth_array = CudaVec::from_cpu_async(&nth_array, streams, gpu_indexes[0]);
extract_lwe_samples_from_glwe_ciphertext_list_async(
streams,
&mut output_lwe_list.0.d_vec,
&input_glwe_list.0.d_vec,
&d_nth_array,
vec_nth.len() as u32,
input_glwe_list.glwe_dimension(),
input_glwe_list.polynomial_size(),
);
}
}
1 change: 1 addition & 0 deletions tfhe/src/core_crypto/gpu/algorithms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub mod lwe_linear_algebra;
pub mod lwe_multi_bit_programmable_bootstrapping;
pub mod lwe_programmable_bootstrapping;

pub mod glwe_sample_extraction;
mod lwe_keyswitch;
#[cfg(test)]
mod test;
Expand Down
Loading

0 comments on commit 249a7f0

Please sign in to comment.