Skip to content

Commit

Permalink
refactor(gpu): random test fft
Browse files Browse the repository at this point in the history
  • Loading branch information
guillermo-oyarzun committed Jun 26, 2024
1 parent e854823 commit 46da1d3
Show file tree
Hide file tree
Showing 2 changed files with 317 additions and 2 deletions.
308 changes: 308 additions & 0 deletions backends/tfhe-cuda-backend/cuda/src/fft/bnsmfft.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -715,4 +715,312 @@ __global__ void batch_polynomial_mul(double2 *d_input1, double2 *d_input2,
}
}

template <class params> __device__ void NSMFFT_direct_bundle(double2 *A, const double2 regs[4]) {

/* We don't make bit reverse here, since twiddles are already reversed
* Each thread is always in charge of "opt/2" pairs of coefficients,
* which is why we always loop through N/2 by N/opt strides
* The pragma unroll instruction tells the compiler to unroll the
* full loop, which should increase performance
*/

size_t tid = threadIdx.x;
size_t twid_id;
size_t i1, i2;
double2 u, v, w;
// level 1
// we don't make actual complex multiplication on level1 since we have only
// one twiddle, it's real and image parts are equal, so we can multiply
// it with simpler operations
// degree = 1024, opt = 2 ->
#pragma unroll
for (size_t i = 0; i < params::opt / 2; ++i) {
i1 = tid;
i2 = tid + params::degree / 2;

//u = A[i1];
//v = A[i2] * (double2){0.707106781186547461715008466854,
// 0.707106781186547461715008466854};

u = regs[i];
v = regs[i + params::opt / 2] * (double2){0.707106781186547461715008466854,
0.707106781186547461715008466854};
//A[i1] += v;
A[i1] = u + v;
A[i2] = u - v;

tid += params::degree / params::opt; //256
}

__syncthreads();

// level 2
// from this level there are more than one twiddles and none of them has equal
// real and imag parts, so complete complex multiplication is needed
// for each level params::degree / 2^level represents number of coefficients
// inside divided chunk of specific level
//
#pragma unroll
for (int i = params::opt / 2 - 1; i >= 0 ; --i) {
tid = threadIdx.x + i * params::degree / params::opt;
twid_id = tid / (params::degree / 4);
i1 = 2 * (params::degree / 4) * twid_id + (tid & (params::degree / 4 - 1));
i2 = i1 + params::degree / 4;

w = negtwiddles[twid_id + 2];
u = A[i1];
v = A[i2] * w;

A[i1] += v;
A[i2] = u - v;
}
__syncthreads();

// level 3
tid = threadIdx.x;
#pragma unroll
for (size_t i = 0; i < params::opt / 2; ++i) {
twid_id = tid / (params::degree / 8);
i1 = 2 * (params::degree / 8) * twid_id + (tid & (params::degree / 8 - 1));
i2 = i1 + params::degree / 8;

w = negtwiddles[twid_id + 4];
u = A[i1];
v = A[i2] * w;

A[i1] += v;
A[i2] = u - v;

tid += params::degree / params::opt;
}
__syncthreads();

// level 4
//tid = threadIdx.x;
//for (size_t i = 0; i < params::opt / 2; ++i) {
#pragma unroll
for (int i = params::opt / 2 - 1; i >= 0 ; --i) {
tid = threadIdx.x + i * params::degree / params::opt;
twid_id = tid / (params::degree / 16);
i1 =
2 * (params::degree / 16) * twid_id + (tid & (params::degree / 16 - 1));
i2 = i1 + params::degree / 16;

w = negtwiddles[twid_id + 8];
u = A[i1];
v = A[i2] * w;

A[i1] += v;
A[i2] = u - v;

//tid += params::degree / params::opt;
}
__syncthreads();

// level 5
tid = threadIdx.x;
#pragma unroll
for (size_t i = 0; i < params::opt / 2; ++i) {
twid_id = tid / (params::degree / 32);
i1 =
2 * (params::degree / 32) * twid_id + (tid & (params::degree / 32 - 1));
i2 = i1 + params::degree / 32;

w = negtwiddles[twid_id + 16];
u = A[i1];
v = A[i2] * w;

A[i1] += v;
A[i2] = u - v;

tid += params::degree / params::opt;
}
__syncthreads();

// level 6
//tid = threadIdx.x;
//for (size_t i = 0; i < params::opt / 2; ++i) {
#pragma unroll
for (int i = params::opt / 2 - 1; i >= 0 ; --i) {
tid = threadIdx.x + i * params::degree / params::opt;
twid_id = tid / (params::degree / 64);
i1 =
2 * (params::degree / 64) * twid_id + (tid & (params::degree / 64 - 1));
i2 = i1 + params::degree / 64;

w = negtwiddles[twid_id + 32];
u = A[i1];
v = A[i2] * w;

A[i1] += v;
A[i2] = u - v;
//tid += params::degree / params::opt;
}
__syncthreads();

// level 7
tid = threadIdx.x;
#pragma unroll
for (size_t i = 0; i < params::opt / 2; ++i) {
twid_id = tid / (params::degree / 128);
i1 = 2 * (params::degree / 128) * twid_id +
(tid & (params::degree / 128 - 1));
i2 = i1 + params::degree / 128;

w = negtwiddles[twid_id + 64];
u = A[i1];
v = A[i2] * w;

A[i1] += v;
A[i2] = u - v;

tid += params::degree / params::opt;
}
__syncthreads();

// from level 8, we need to check size of params degree, because we support
// minimum actual polynomial size = 256, when compressed size is halfed and
// minimum supported compressed size is 128, so we always need first 7
// levels of butterfly operation, since butterfly levels are hardcoded
// we need to check if polynomial size is big enough to require specific level
// of butterfly.
if constexpr (params::degree >= 256) {
// level 8
//tid = threadIdx.x;
//for (size_t i = 0; i < params::opt / 2; ++i) {
#pragma unroll
for (int i = params::opt / 2 - 1; i >= 0 ; --i) {
tid = threadIdx.x + i * params::degree / params::opt;
twid_id = tid / (params::degree / 256);
i1 = 2 * (params::degree / 256) * twid_id +
(tid & (params::degree / 256 - 1));
i2 = i1 + params::degree / 256;

w = negtwiddles[twid_id + 128];
u = A[i1];
v = A[i2] * w;

A[i1] += v;
A[i2] = u - v;

//tid += params::degree / params::opt;
}
__syncthreads();
}

if constexpr (params::degree >= 512) {
// level 9
tid = threadIdx.x;
#pragma unroll
for (size_t i = 0; i < params::opt / 2; ++i) {
twid_id = tid / (params::degree / 512);
i1 = 2 * (params::degree / 512) * twid_id +
(tid & (params::degree / 512 - 1));
i2 = i1 + params::degree / 512;

w = negtwiddles[twid_id + 256];
u = A[i1];
v = A[i2] * w;

A[i1] += v;
A[i2] = u - v;

tid += params::degree / params::opt;
}
__syncthreads();
}

if constexpr (params::degree >= 1024) {
// level 10
//tid = threadIdx.x;
//for (size_t i = 0; i < params::opt / 2; ++i) {
#pragma unroll
for (int i = params::opt / 2 - 1; i >= 0 ; --i) {
tid = threadIdx.x + i * params::degree / params::opt;
twid_id = tid / (params::degree / 1024);
i1 = 2 * (params::degree / 1024) * twid_id +
(tid & (params::degree / 1024 - 1));
i2 = i1 + params::degree / 1024;

w = negtwiddles[twid_id + 512];
u = A[i1];
v = A[i2] * w;

A[i1] += v;
A[i2] = u - v;

//tid += params::degree / params::opt;
}
__syncthreads();
}

if constexpr (params::degree >= 2048) {
// level 11
tid = threadIdx.x;
#pragma unroll
for (size_t i = 0; i < params::opt / 2; ++i) {
twid_id = tid / (params::degree / 2048);
i1 = 2 * (params::degree / 2048) * twid_id +
(tid & (params::degree / 2048 - 1));
i2 = i1 + params::degree / 2048;

w = negtwiddles[twid_id + 1024];
u = A[i1];
v = A[i2] * w;

A[i1] += v;
A[i2] = u - v;

tid += params::degree / params::opt;
}
__syncthreads();
}

if constexpr (params::degree >= 4096) {
// level 12
//tid = threadIdx.x;
//for (size_t i = 0; i < params::opt / 2; ++i) {
#pragma unroll
for (int i = params::opt / 2 - 1; i >= 0 ; --i) {
tid = threadIdx.x + i * params::degree / params::opt;
twid_id = tid / (params::degree / 4096);
i1 = 2 * (params::degree / 4096) * twid_id +
(tid & (params::degree / 4096 - 1));
i2 = i1 + params::degree / 4096;

w = negtwiddles[twid_id + 2048];
u = A[i1];
v = A[i2] * w;

A[i1] += v;
A[i2] = u - v;

//tid += params::degree / params::opt;
}
__syncthreads();
}

if constexpr (params::degree >= 8192) {
// level 13
tid = threadIdx.x;
#pragma unroll
for (size_t i = 0; i < params::opt / 2; ++i) {
twid_id = tid / (params::degree / 8192);
i1 = 2 * (params::degree / 8192) * twid_id +
(tid & (params::degree / 8192 - 1));
i2 = i1 + params::degree / 8192;

w = negtwiddles[twid_id + 4096];
u = A[i1];
v = A[i2] * w;

A[i1] += v;
A[i2] = u - v;

tid += params::degree / params::opt;
}
__syncthreads();
}
}

#endif // GPU_BOOTSTRAP_FFT_CUH
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle(
grouping_factor, 2 * polynomial_size, glwe_dimension, level_count);
Torus *bsk_poly = bsk_slice + poly_id * params::degree;

// opt = 8 degree/opt = 256
copy_polynomial<Torus, params::opt, params::degree / params::opt>(
bsk_poly, accumulator);

Expand Down Expand Up @@ -114,6 +115,7 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle(
// Move accumulator to local memory
double2 temp[params::opt / 2];
int tid = threadIdx.x;
//opt = 8 degree=2048 degree/opt =256
#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
temp[i].x = __ll2double_rn((int64_t)accumulator[tid]);
Expand All @@ -123,17 +125,22 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle(
temp[i].y /= (double)std::numeric_limits<Torus>::max();
tid += params::degree / params::opt;
}

/*
synchronize_threads_in_block();
// Move from local memory back to shared memory but as complex
tid = threadIdx.x;
//Loop for 4 times ... temp[4]
#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
fft[tid] = temp[i];
tid += params::degree / params::opt;
tid += params::degree / params::opt; // degree 2048 opt 8 degree/opt = 256
}
synchronize_threads_in_block();
NSMFFT_direct<HalfDegree<params>>(fft);
*/
NSMFFT_direct_bundle<HalfDegree<params>>(fft, temp);

// lwe iteration
auto keybundle_out = get_ith_mask_kth_block(
Expand Down

0 comments on commit 46da1d3

Please sign in to comment.