diff --git a/backends/tfhe-cuda-backend/cuda/src/fft/bnsmfft.cuh b/backends/tfhe-cuda-backend/cuda/src/fft/bnsmfft.cuh index a57dd0d0c9..f6cc6ba044 100644 --- a/backends/tfhe-cuda-backend/cuda/src/fft/bnsmfft.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/fft/bnsmfft.cuh @@ -715,4 +715,312 @@ __global__ void batch_polynomial_mul(double2 *d_input1, double2 *d_input2, } } +template __device__ void NSMFFT_direct_bundle(double2 *A, const double2 regs[4]) { + + /* We don't make bit reverse here, since twiddles are already reversed + * Each thread is always in charge of "opt/2" pairs of coefficients, + * which is why we always loop through N/2 by N/opt strides + * The pragma unroll instruction tells the compiler to unroll the + * full loop, which should increase performance + */ + + size_t tid = threadIdx.x; + size_t twid_id; + size_t i1, i2; + double2 u, v, w; + // level 1 + // we don't make actual complex multiplication on level1 since we have only + // one twiddle, it's real and image parts are equal, so we can multiply + // it with simpler operations + // degree = 1024, opt = 2 -> +#pragma unroll + for (size_t i = 0; i < params::opt / 2; ++i) { + i1 = tid; + i2 = tid + params::degree / 2; + + //u = A[i1]; + //v = A[i2] * (double2){0.707106781186547461715008466854, + // 0.707106781186547461715008466854}; + + u = regs[i]; + v = regs[i + params::opt / 2] * (double2){0.707106781186547461715008466854, + 0.707106781186547461715008466854}; + //A[i1] += v; + A[i1] = u + v; + A[i2] = u - v; + + tid += params::degree / params::opt; //256 + } + + __syncthreads(); + + // level 2 + // from this level there are more than one twiddles and none of them has equal + // real and imag parts, so complete complex multiplication is needed + // for each level params::degree / 2^level represents number of coefficients + // inside divided chunk of specific level + // +#pragma unroll + for (int i = params::opt / 2 - 1; i >= 0 ; --i) { + tid = threadIdx.x + i * params::degree / params::opt; + twid_id = tid / (params::degree / 4); + i1 = 2 * (params::degree / 4) * twid_id + (tid & (params::degree / 4 - 1)); + i2 = i1 + params::degree / 4; + + w = negtwiddles[twid_id + 2]; + u = A[i1]; + v = A[i2] * w; + + A[i1] += v; + A[i2] = u - v; + } + __syncthreads(); + + // level 3 + tid = threadIdx.x; +#pragma unroll + for (size_t i = 0; i < params::opt / 2; ++i) { + twid_id = tid / (params::degree / 8); + i1 = 2 * (params::degree / 8) * twid_id + (tid & (params::degree / 8 - 1)); + i2 = i1 + params::degree / 8; + + w = negtwiddles[twid_id + 4]; + u = A[i1]; + v = A[i2] * w; + + A[i1] += v; + A[i2] = u - v; + + tid += params::degree / params::opt; + } + __syncthreads(); + + // level 4 + //tid = threadIdx.x; + //for (size_t i = 0; i < params::opt / 2; ++i) { +#pragma unroll + for (int i = params::opt / 2 - 1; i >= 0 ; --i) { + tid = threadIdx.x + i * params::degree / params::opt; + twid_id = tid / (params::degree / 16); + i1 = + 2 * (params::degree / 16) * twid_id + (tid & (params::degree / 16 - 1)); + i2 = i1 + params::degree / 16; + + w = negtwiddles[twid_id + 8]; + u = A[i1]; + v = A[i2] * w; + + A[i1] += v; + A[i2] = u - v; + + //tid += params::degree / params::opt; + } + __syncthreads(); + + // level 5 + tid = threadIdx.x; +#pragma unroll + for (size_t i = 0; i < params::opt / 2; ++i) { + twid_id = tid / (params::degree / 32); + i1 = + 2 * (params::degree / 32) * twid_id + (tid & (params::degree / 32 - 1)); + i2 = i1 + params::degree / 32; + + w = negtwiddles[twid_id + 16]; + u = A[i1]; + v = A[i2] * w; + + A[i1] += v; + A[i2] = u - v; + + tid += params::degree / params::opt; + } + __syncthreads(); + + // level 6 + //tid = threadIdx.x; + //for (size_t i = 0; i < params::opt / 2; ++i) { +#pragma unroll + for (int i = params::opt / 2 - 1; i >= 0 ; --i) { + tid = threadIdx.x + i * params::degree / params::opt; + twid_id = tid / (params::degree / 64); + i1 = + 2 * (params::degree / 64) * twid_id + (tid & (params::degree / 64 - 1)); + i2 = i1 + params::degree / 64; + + w = negtwiddles[twid_id + 32]; + u = A[i1]; + v = A[i2] * w; + + A[i1] += v; + A[i2] = u - v; + //tid += params::degree / params::opt; + } + __syncthreads(); + + // level 7 + tid = threadIdx.x; +#pragma unroll + for (size_t i = 0; i < params::opt / 2; ++i) { + twid_id = tid / (params::degree / 128); + i1 = 2 * (params::degree / 128) * twid_id + + (tid & (params::degree / 128 - 1)); + i2 = i1 + params::degree / 128; + + w = negtwiddles[twid_id + 64]; + u = A[i1]; + v = A[i2] * w; + + A[i1] += v; + A[i2] = u - v; + + tid += params::degree / params::opt; + } + __syncthreads(); + + // from level 8, we need to check size of params degree, because we support + // minimum actual polynomial size = 256, when compressed size is halfed and + // minimum supported compressed size is 128, so we always need first 7 + // levels of butterfly operation, since butterfly levels are hardcoded + // we need to check if polynomial size is big enough to require specific level + // of butterfly. + if constexpr (params::degree >= 256) { + // level 8 + //tid = threadIdx.x; + //for (size_t i = 0; i < params::opt / 2; ++i) { +#pragma unroll + for (int i = params::opt / 2 - 1; i >= 0 ; --i) { + tid = threadIdx.x + i * params::degree / params::opt; + twid_id = tid / (params::degree / 256); + i1 = 2 * (params::degree / 256) * twid_id + + (tid & (params::degree / 256 - 1)); + i2 = i1 + params::degree / 256; + + w = negtwiddles[twid_id + 128]; + u = A[i1]; + v = A[i2] * w; + + A[i1] += v; + A[i2] = u - v; + + //tid += params::degree / params::opt; + } + __syncthreads(); + } + + if constexpr (params::degree >= 512) { + // level 9 + tid = threadIdx.x; +#pragma unroll + for (size_t i = 0; i < params::opt / 2; ++i) { + twid_id = tid / (params::degree / 512); + i1 = 2 * (params::degree / 512) * twid_id + + (tid & (params::degree / 512 - 1)); + i2 = i1 + params::degree / 512; + + w = negtwiddles[twid_id + 256]; + u = A[i1]; + v = A[i2] * w; + + A[i1] += v; + A[i2] = u - v; + + tid += params::degree / params::opt; + } + __syncthreads(); + } + + if constexpr (params::degree >= 1024) { + // level 10 + //tid = threadIdx.x; + //for (size_t i = 0; i < params::opt / 2; ++i) { +#pragma unroll + for (int i = params::opt / 2 - 1; i >= 0 ; --i) { + tid = threadIdx.x + i * params::degree / params::opt; + twid_id = tid / (params::degree / 1024); + i1 = 2 * (params::degree / 1024) * twid_id + + (tid & (params::degree / 1024 - 1)); + i2 = i1 + params::degree / 1024; + + w = negtwiddles[twid_id + 512]; + u = A[i1]; + v = A[i2] * w; + + A[i1] += v; + A[i2] = u - v; + + //tid += params::degree / params::opt; + } + __syncthreads(); + } + + if constexpr (params::degree >= 2048) { + // level 11 + tid = threadIdx.x; +#pragma unroll + for (size_t i = 0; i < params::opt / 2; ++i) { + twid_id = tid / (params::degree / 2048); + i1 = 2 * (params::degree / 2048) * twid_id + + (tid & (params::degree / 2048 - 1)); + i2 = i1 + params::degree / 2048; + + w = negtwiddles[twid_id + 1024]; + u = A[i1]; + v = A[i2] * w; + + A[i1] += v; + A[i2] = u - v; + + tid += params::degree / params::opt; + } + __syncthreads(); + } + + if constexpr (params::degree >= 4096) { + // level 12 + //tid = threadIdx.x; + //for (size_t i = 0; i < params::opt / 2; ++i) { +#pragma unroll + for (int i = params::opt / 2 - 1; i >= 0 ; --i) { + tid = threadIdx.x + i * params::degree / params::opt; + twid_id = tid / (params::degree / 4096); + i1 = 2 * (params::degree / 4096) * twid_id + + (tid & (params::degree / 4096 - 1)); + i2 = i1 + params::degree / 4096; + + w = negtwiddles[twid_id + 2048]; + u = A[i1]; + v = A[i2] * w; + + A[i1] += v; + A[i2] = u - v; + + //tid += params::degree / params::opt; + } + __syncthreads(); + } + + if constexpr (params::degree >= 8192) { + // level 13 + tid = threadIdx.x; +#pragma unroll + for (size_t i = 0; i < params::opt / 2; ++i) { + twid_id = tid / (params::degree / 8192); + i1 = 2 * (params::degree / 8192) * twid_id + + (tid & (params::degree / 8192 - 1)); + i2 = i1 + params::degree / 8192; + + w = negtwiddles[twid_id + 4096]; + u = A[i1]; + v = A[i2] * w; + + A[i1] += v; + A[i2] = u - v; + + tid += params::degree / params::opt; + } + __syncthreads(); + } +} + #endif // GPU_BOOTSTRAP_FFT_CUH diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh index 8b8abc0f71..006c257876 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh @@ -84,6 +84,7 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle( grouping_factor, 2 * polynomial_size, glwe_dimension, level_count); Torus *bsk_poly = bsk_slice + poly_id * params::degree; + // opt = 8 degree/opt = 256 copy_polynomial( bsk_poly, accumulator); @@ -114,6 +115,7 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle( // Move accumulator to local memory double2 temp[params::opt / 2]; int tid = threadIdx.x; + //opt = 8 degree=2048 degree/opt =256 #pragma unroll for (int i = 0; i < params::opt / 2; i++) { temp[i].x = __ll2double_rn((int64_t)accumulator[tid]); @@ -123,17 +125,22 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle( temp[i].y /= (double)std::numeric_limits::max(); tid += params::degree / params::opt; } - +/* synchronize_threads_in_block(); + // Move from local memory back to shared memory but as complex tid = threadIdx.x; + //Loop for 4 times ... temp[4] + #pragma unroll for (int i = 0; i < params::opt / 2; i++) { fft[tid] = temp[i]; - tid += params::degree / params::opt; + tid += params::degree / params::opt; // degree 2048 opt 8 degree/opt = 256 } synchronize_threads_in_block(); NSMFFT_direct>(fft); +*/ + NSMFFT_direct_bundle>(fft, temp); // lwe iteration auto keybundle_out = get_ith_mask_kth_block(