Skip to content

Commit

Permalink
refactor(gpu): fix sample extraction when nth > 0 and keep input unch…
Browse files Browse the repository at this point in the history
…anged
  • Loading branch information
pdroalves committed Jul 25, 2024
1 parent 2004333 commit a53a7b5
Show file tree
Hide file tree
Showing 15 changed files with 612 additions and 34 deletions.
16 changes: 16 additions & 0 deletions backends/tfhe-cuda-backend/cuda/include/functions.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#ifndef CUDA_FUNCTIONS_H_
#define CUDA_FUNCTIONS_H_

#include "polynomial/functions.cuh"
#include "polynomial/parameters.cuh"
#include <cstdint>

extern "C" {
void cuda_glwe_sample_extract_64(void **streams, uint32_t *gpu_indexes,
uint32_t gpu_count, void *lwe_array_out,
void *glwe_array_in, uint32_t *nth_array,
uint32_t num_samples, uint32_t glwe_dimension,
uint32_t polynomial_size);
}

#endif
1 change: 1 addition & 0 deletions backends/tfhe-cuda-backend/cuda/src/crypto/torus.cuh
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef CNCRT_TORUS_CUH
#define CNCRT_TORUS_CUH

#include "device.h"
#include "types/int128.cuh"
#include <limits>

Expand Down
56 changes: 56 additions & 0 deletions backends/tfhe-cuda-backend/cuda/src/polynomial/functions.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#include "functions.h"

void cuda_glwe_sample_extract_64(void **streams, uint32_t *gpu_indexes,
uint32_t gpu_count, void *lwe_array_out,
void *glwe_in, uint32_t *nth_array,
uint32_t num_samples, uint32_t glwe_dimension,
uint32_t polynomial_size) {

switch (polynomial_size) {
case 256:
host_sample_extract<uint64_t, AmortizedDegree<256>>(
(cudaStream_t *)(streams), (uint64_t *)lwe_array_out,
(uint64_t *)glwe_in, (uint32_t *)nth_array, num_samples,
glwe_dimension);
break;
case 512:
host_sample_extract<uint64_t, AmortizedDegree<512>>(
(cudaStream_t *)(streams), (uint64_t *)lwe_array_out,
(uint64_t *)glwe_in, (uint32_t *)nth_array, num_samples,
glwe_dimension);
break;
case 1024:
host_sample_extract<uint64_t, AmortizedDegree<1024>>(
(cudaStream_t *)(streams), (uint64_t *)lwe_array_out,
(uint64_t *)glwe_in, (uint32_t *)nth_array, num_samples,
glwe_dimension);
break;
case 2048:
host_sample_extract<uint64_t, AmortizedDegree<2048>>(
(cudaStream_t *)(streams), (uint64_t *)lwe_array_out,
(uint64_t *)glwe_in, (uint32_t *)nth_array, num_samples,
glwe_dimension);
break;
case 4096:
host_sample_extract<uint64_t, AmortizedDegree<4096>>(
(cudaStream_t *)(streams), (uint64_t *)lwe_array_out,
(uint64_t *)glwe_in, (uint32_t *)nth_array, num_samples,
glwe_dimension);
break;
case 8192:
host_sample_extract<uint64_t, AmortizedDegree<8192>>(
(cudaStream_t *)(streams), (uint64_t *)lwe_array_out,
(uint64_t *)glwe_in, (uint32_t *)nth_array, num_samples,
glwe_dimension);
break;
case 16384:
host_sample_extract<uint64_t, AmortizedDegree<16384>>(
(cudaStream_t *)(streams), (uint64_t *)lwe_array_out,
(uint64_t *)glwe_in, (uint32_t *)nth_array, num_samples,
glwe_dimension);
break;
default:
PANIC("Cuda error: unsupported polynomial size. Supported "
"N's are powers of two in the interval [256..16384].")
}
}
87 changes: 53 additions & 34 deletions backends/tfhe-cuda-backend/cuda/src/polynomial/functions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -191,67 +191,86 @@ __device__ void add_to_torus(double2 *m_values, Torus *result,

// Extracts the body of the nth-LWE in a GLWE.
template <typename Torus, class params>
__device__ void sample_extract_body(Torus *lwe_array_out, Torus *accumulator,
__device__ void sample_extract_body(Torus *lwe_array_out, Torus *glwe,
uint32_t glwe_dimension, uint32_t nth = 0) {
// Set first coefficient of the accumulator as the body of the LWE sample
// Set first coefficient of the glwe as the body of the LWE sample
lwe_array_out[glwe_dimension * params::degree] =
accumulator[glwe_dimension * params::degree + nth];
glwe[glwe_dimension * params::degree + nth];
}

// Extracts the mask from the nth-LWE in a GLWE.
template <typename Torus, class params>
__device__ void sample_extract_mask(Torus *lwe_array_out, Torus *accumulator,
uint32_t num_poly = 1, uint32_t nth = 0) {
for (int z = 0; z < num_poly; z++) {
__device__ void sample_extract_mask(Torus *lwe_array_out, Torus *glwe,
uint32_t glwe_dimension = 1,
uint32_t nth = 0) {
for (int z = 0; z < glwe_dimension; z++) {
Torus *lwe_array_out_slice =
(Torus *)lwe_array_out + (ptrdiff_t)(z * params::degree);
Torus *accumulator_slice =
(Torus *)accumulator + (ptrdiff_t)(z * params::degree);
Torus *glwe_slice = (Torus *)glwe + (ptrdiff_t)(z * params::degree);

synchronize_threads_in_block();
// Reverse the accumulator
// Reverse the glwe
// Set ACC = -ACC
int tid = threadIdx.x;
Torus result[params::opt];
#pragma unroll
for (int i = 0; i < params::opt; i++) {
result[i] = accumulator_slice[params::degree - tid - 1];
tid = tid + params::degree / params::opt;
}
synchronize_threads_in_block();

// Set ACC = -ACC
tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt; i++) {
accumulator_slice[tid] =
SEL(-result[i], result[i], tid >= params::degree - nth);
auto x = glwe_slice[params::degree - tid - 1];
result[i] = SEL(-x, x, tid >= params::degree - nth);
tid = tid + params::degree / params::opt;
}
synchronize_threads_in_block();

// Perform ACC * X
// (equivalent to multiply_by_monomial_negacyclic_inplace(1))
// Copy to the mask of the LWE sample
tid = threadIdx.x;
result[params::opt];
for (int i = 0; i < params::opt; i++) {
// if (tid < 1)
// result[i] = -accumulator_slice[tid - 1 + params::degree];
// result[i] = -glwe_slice[tid - 1 + params::degree];
// else
// result[i] = accumulator_slice[tid - 1];
int x = tid - 1 + SEL(0, params::degree - nth, tid < 1);
result[i] = SEL(1, -1, tid < 1) * accumulator_slice[x];
tid += params::degree / params::opt;
}
synchronize_threads_in_block();
// result[i] = glwe_slice[tid - 1];
uint32_t dst_idx = tid + 1 + nth;
if (dst_idx == params::degree)
lwe_array_out_slice[0] = -result[i];
else {
dst_idx =
SEL(dst_idx, dst_idx - params::degree, dst_idx >= params::degree);
lwe_array_out_slice[dst_idx] = result[i];
}

// Copy to the mask of the LWE sample
tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt; i++) {
lwe_array_out_slice[tid] = result[i];
tid = tid + params::degree / params::opt;
tid += params::degree / params::opt;
}
}
}

template <typename Torus, class params>
__global__ void apply_sample_extract(Torus *lwe_array_out, Torus *glwe_in,
uint32_t *nth_array,
uint32_t glwe_dimension) {

const int input_id = blockIdx.x;

const int lwe_output_size = glwe_dimension * params::degree + 1;

auto lwe_out = lwe_array_out + input_id * lwe_output_size;
auto nth = nth_array[input_id];

sample_extract_mask<Torus, params>(lwe_out, glwe_in, glwe_dimension, nth);
sample_extract_body<Torus, params>(lwe_out, glwe_in, glwe_dimension, nth);
}

template <typename Torus, class params>
__host__ void host_sample_extract(cudaStream_t *streams, Torus *lwe_array_out,
Torus *glwe_in, uint32_t *nth_array,
uint32_t num_samples,
uint32_t glwe_dimension) {

dim3 grid(num_samples);
dim3 thds(params::degree / params::opt);
apply_sample_extract<Torus, params><<<grid, thds, 0, streams[0]>>>(
lwe_array_out, glwe_in, nth_array, glwe_dimension);
check_cuda_error(cudaGetLastError());
}

#endif
1 change: 1 addition & 0 deletions backends/tfhe-cuda-backend/cuda/src/types/int128.cuh
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef CNCRT_INT128_CUH
#define CNCRT_INT128_CUH

#include <cstdint>
// abseil's int128 type
// licensed under Apache license

Expand Down
11 changes: 11 additions & 0 deletions backends/tfhe-cuda-backend/src/cuda_bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,17 @@ extern "C" {
gpu_count: u32,
mem_ptr: *mut *mut i8,
);
pub fn cuda_glwe_sample_extract_64(
streams: *const *mut c_void,
gpu_indexes: *const u32,
gpu_count: u32,
lwe_array_out: *mut c_void,
glwe_in: *const c_void,
nth_array: *const u32,
num_samples: u32,
glwe_dimension: u32,
polynomial_size: u32,
);

pub fn scratch_cuda_integer_radix_comparison_kb_64(
streams: *const *mut c_void,
Expand Down
52 changes: 52 additions & 0 deletions tfhe/src/core_crypto/gpu/algorithms/glwe_sample_extraction.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use crate::core_crypto::gpu::glwe_ciphertext::CudaGlweCiphertext;
use crate::core_crypto::gpu::lwe_ciphertext::CudaLweCiphertext;
use crate::core_crypto::gpu::vec::CudaVec;
use crate::core_crypto::gpu::{extract_lwe_sample_from_glwe_ciphertext_async, CudaStreams};
use crate::core_crypto::prelude::{LweCiphertextCount, MonomialDegree, UnsignedTorus};

/// Extract the nth coefficient from the body of a [`GLWE Ciphertext`](`CudaGlweCiphertext`) as an
/// [`LWE ciphertext`](`CudaLweCiphertext`). This variant is GPU-accelerated.
pub fn cuda_extract_lwe_sample_from_glwe_ciphertext<Scalar>(
input_glwe: &CudaGlweCiphertext<Scalar>,
output_lwe: &mut CudaLweCiphertext<Scalar>,
nth: MonomialDegree,
streams: &CudaStreams,
) where
// CastInto required for PBS modulus switch which returns a usize
Scalar: UnsignedTorus,
{
let in_lwe_dim = input_glwe
.glwe_dimension()
.to_equivalent_lwe_dimension(input_glwe.polynomial_size());

let out_lwe_dim = output_lwe.lwe_dimension();

assert_eq!(
in_lwe_dim, out_lwe_dim,
"Mismatch between equivalent LweDimension of input ciphertext and output ciphertext. \
Got {in_lwe_dim:?} for input and {out_lwe_dim:?} for output.",
);

assert_eq!(
input_glwe.ciphertext_modulus(),
output_lwe.ciphertext_modulus(),
"Mismatched moduli between input_glwe ({:?}) and output_lwe ({:?})",
input_glwe.ciphertext_modulus(),
output_lwe.ciphertext_modulus()
);

let nth_array: Vec<u32> = vec![nth.0 as u32];
let gpu_indexes = &streams.gpu_indexes;
unsafe {
let d_nth_array = CudaVec::from_cpu_async(&nth_array, streams, gpu_indexes[0]);
extract_lwe_sample_from_glwe_ciphertext_async(
streams,
&mut output_lwe.0.d_vec,
&input_glwe.0.d_vec,
&d_nth_array,
LweCiphertextCount(nth_array.len()),
input_glwe.glwe_dimension(),
input_glwe.polynomial_size(),
);
}
}
1 change: 1 addition & 0 deletions tfhe/src/core_crypto/gpu/algorithms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub mod lwe_linear_algebra;
pub mod lwe_multi_bit_programmable_bootstrapping;
pub mod lwe_programmable_bootstrapping;

pub mod glwe_sample_extraction;
mod lwe_keyswitch;
#[cfg(test)]
mod test;
Expand Down
Loading

0 comments on commit a53a7b5

Please sign in to comment.