From 7f75c3fd5e76383c8f7733232f5abb9dfc8b38d0 Mon Sep 17 00:00:00 2001 From: Guillermo Oyarzun Date: Mon, 9 Sep 2024 14:01:33 +0200 Subject: [PATCH] generate many lut gpu --- tfhe/src/integer/gpu/server_key/radix/mod.rs | 31 ++++++++++++++++++-- tfhe/src/shortint/engine/mod.rs | 13 ++++---- tfhe/src/shortint/server_key/mod.rs | 4 ++- 3 files changed, 40 insertions(+), 8 deletions(-) diff --git a/tfhe/src/integer/gpu/server_key/radix/mod.rs b/tfhe/src/integer/gpu/server_key/radix/mod.rs index e5b82caa9c..c9c927f407 100644 --- a/tfhe/src/integer/gpu/server_key/radix/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/mod.rs @@ -18,8 +18,8 @@ use crate::integer::gpu::{ CudaServerKey, PBSType, }; use crate::shortint::ciphertext::{Degree, NoiseLevel}; -use crate::shortint::engine::fill_accumulator; -use crate::shortint::server_key::{BivariateLookupTableOwned, LookupTableOwned}; +use crate::shortint::engine::{fill_accumulator, fill_many_lut_accumulator}; +use crate::shortint::server_key::{BivariateLookupTableOwned, LookupTableOwned, ManyLookupTableOwned}; use crate::shortint::PBSOrder; mod add; @@ -676,6 +676,33 @@ impl CudaServerKey { } } + + pub(crate) fn generate_many_lookup_table(&self, + functions: &[&dyn Fn(u64) -> u64]) -> ManyLookupTableOwned + where + F: Fn(u64) -> u64, + { + let (glwe_size, polynomial_size) = match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + (d_bsk.glwe_dimension.to_glwe_size(), d_bsk.polynomial_size) + } + CudaBootstrappingKey::MultiBit(d_bsk) => { + (d_bsk.glwe_dimension.to_glwe_size(), d_bsk.polynomial_size) + } + }; + let mut acc = GlweCiphertext::new(0, glwe_size, polynomial_size, self.ciphertext_modulus); + + let (input_max_degree, sample_extraction_stride, per_function_output_degree) = + fill_many_lut_accumulator(&mut acc, polynomial_size, glwe_size, self.message_modulus, self.carry_modulus, functions); + + ManyLookupTableOwned { + acc, + input_max_degree, + sample_extraction_stride, + per_function_output_degree, + } + } + /// Generates a bivariate accumulator pub(crate) fn generate_lookup_table_bivariate(&self, f: F) -> BivariateLookupTableOwned where diff --git a/tfhe/src/shortint/engine/mod.rs b/tfhe/src/shortint/engine/mod.rs index bc4a84ba01..c9896474e9 100644 --- a/tfhe/src/shortint/engine/mod.rs +++ b/tfhe/src/shortint/engine/mod.rs @@ -162,7 +162,10 @@ pub(crate) fn fill_accumulator_no_encoding( /// Fills a GlweCiphertext for use in a ManyLookupTable setting pub(crate) fn fill_many_lut_accumulator( accumulator: &mut GlweCiphertext, - server_key: &ServerKey, + polynomial_size: PolynomialSize, + glwe_size: GlweSize, + message_modulus: MessageModulus, + carry_modulus: CarryModulus, functions: &[&dyn Fn(u64) -> u64], ) -> (MaxDegree, usize, Vec) where @@ -170,11 +173,11 @@ where { assert_eq!( accumulator.polynomial_size(), - server_key.bootstrapping_key.polynomial_size() + polynomial_size ); assert_eq!( accumulator.glwe_size(), - server_key.bootstrapping_key.glwe_size() + glwe_size ); let mut accumulator_view = accumulator.as_mut_view(); @@ -182,10 +185,10 @@ where accumulator_view.get_mut_mask().as_mut().fill(0); // Modulus of the msg contained in the msg bits and operations buffer - let modulus_sup = server_key.message_modulus.0 * server_key.carry_modulus.0; + let modulus_sup = message_modulus.0 * carry_modulus.0; // N/(p/2) = size of each block - let box_size = server_key.bootstrapping_key.polynomial_size().0 / modulus_sup; + let box_size = polynomial_size.0 / modulus_sup; // Value of the delta we multiply our messages by let delta = (1_u64 << 63) / (modulus_sup as u64); diff --git a/tfhe/src/shortint/server_key/mod.rs b/tfhe/src/shortint/server_key/mod.rs index 351b0b2cc7..22fd542615 100644 --- a/tfhe/src/shortint/server_key/mod.rs +++ b/tfhe/src/shortint/server_key/mod.rs @@ -45,6 +45,7 @@ use crate::core_crypto::commons::traits::*; use crate::core_crypto::entities::*; use crate::core_crypto::fft_impl::fft64::math::fft::Fft; use crate::core_crypto::prelude::ComputationBuffers; +use crate::integer::encryption::KnowsMessageModulus; use crate::shortint::ciphertext::{Ciphertext, Degree, MaxDegree, MaxNoiseLevel, NoiseLevel}; use crate::shortint::client_key::ClientKey; use crate::shortint::engine::{ @@ -883,7 +884,8 @@ impl ServerKey { self.ciphertext_modulus, ); let (input_max_degree, sample_extraction_stride, per_function_output_degree) = - fill_many_lut_accumulator(&mut acc, self, functions); + fill_many_lut_accumulator(&mut acc, self.bootstrapping_key.polynomial_size(), + self.bootstrapping_key.glwe_size(), self.message_modulus(), self.carry_modulus, functions); ManyLookupTableOwned { acc,