Skip to content

Commit

Permalink
feat(gpu): Add signed_overflowing_scalar_add and signed_overflowing_s…
Browse files Browse the repository at this point in the history
…calar_sub
  • Loading branch information
bbarbakadze authored and agnesLeroy committed Jul 22, 2024
1 parent 230fa5a commit 95ef13f
Show file tree
Hide file tree
Showing 8 changed files with 261 additions and 13 deletions.
14 changes: 14 additions & 0 deletions tfhe/benches/integer/signed_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1801,6 +1801,18 @@ mod cuda {
rng_func: default_signed_scalar
);

define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
method_name: signed_overflowing_scalar_add,
display_name: overflowing_add,
rng_func: default_signed_scalar
);

define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
method_name: signed_overflowing_scalar_sub,
display_name: overflowing_sub,
rng_func: default_signed_scalar
);

//===========================================
// Default
//===========================================
Expand Down Expand Up @@ -2132,6 +2144,8 @@ mod cuda {
cuda_scalar_le,
cuda_scalar_min,
cuda_scalar_max,
cuda_signed_overflowing_scalar_add,
cuda_signed_overflowing_scalar_sub,
);

fn cuda_bench_server_key_signed_cast_function<F>(
Expand Down
22 changes: 16 additions & 6 deletions tfhe/src/high_level_api/integers/signed/overflowing_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,14 @@ where
(FheInt::new(result), FheBool::new(overflow))
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
todo!("Cuda devices do not support signed integer");
}
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
let (result, overflow) = cuda_key.key.signed_overflowing_scalar_add(
&self.ciphertext.on_gpu(),
other,
streams,
);
(FheInt::new(result), FheBool::new(overflow))
}),
})
}
}
Expand Down Expand Up @@ -349,9 +354,14 @@ where
(FheInt::new(result), FheBool::new(overflow))
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
todo!("Cuda devices do not support signed integer");
}
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
let (result, overflow) = cuda_key.key.signed_overflowing_scalar_sub(
&self.ciphertext.on_gpu(),
other,
streams,
);
(FheInt::new(result), FheBool::new(overflow))
}),
})
}
}
Expand Down
86 changes: 85 additions & 1 deletion tfhe/src/integer/gpu/server_key/radix/scalar_add.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use crate::core_crypto::gpu::vec::CudaVec;
use crate::core_crypto::gpu::CudaStreams;
use crate::core_crypto::prelude::SignedNumeric;
use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto};
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
use crate::integer::gpu::ciphertext::{
CudaIntegerRadixCiphertext, CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext,
};
use crate::integer::gpu::scalar_addition_integer_radix_assign_async;
use crate::integer::gpu::server_key::CudaServerKey;
use crate::prelude::CastInto;
Expand Down Expand Up @@ -279,4 +282,85 @@ impl CudaServerKey {
CudaBooleanBlock::from_cuda_radix_ciphertext(carry_out.ciphertext)
}
}

/// ```rust
/// use tfhe::core_crypto::gpu::CudaStreams;
/// use tfhe::integer::gpu::ciphertext::{CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext};
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
///
/// let gpu_index = 0;
/// let streams = CudaStreams::new_single_gpu(gpu_index);
///
/// // Generate the client key and the server key:
/// let num_blocks = 4;
/// let (cks, sks) = gen_keys_radix_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks, &streams);
///
/// let msg: i8 = 120;
/// let scalar: i8 = 8;
///
/// let ct1 = cks.encrypt_signed(msg);
///
/// // Copy to GPU
/// let d_ct1 = CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&ct1, &streams);
///
/// // Compute homomorphically an overflowing addition:
/// let (d_ct_res, d_ct_overflowed) = sks.signed_overflowing_scalar_add(&d_ct1, scalar, &streams);
///
/// let ct_res = d_ct_res.to_signed_radix_ciphertext(&streams);
/// let ct_overflowed = d_ct_overflowed.to_boolean_block(&streams);
///
/// // Decrypt:
/// let dec_result: i8 = cks.decrypt_signed(&ct_res);
/// let dec_overflowed: bool = cks.decrypt_bool(&ct_overflowed);
/// let (clear_result, clear_overflowed) = msg.overflowing_add(scalar);
/// assert_eq!(dec_result, clear_result);
/// assert_eq!(dec_overflowed, clear_overflowed);
/// ```
pub fn signed_overflowing_scalar_add<Scalar>(
&self,
ct_left: &CudaSignedRadixCiphertext,
scalar: Scalar,
streams: &CudaStreams,
) -> (CudaSignedRadixCiphertext, CudaBooleanBlock)
where
Scalar: SignedNumeric + DecomposableInto<u64> + CastInto<u64>,
{
let mut tmp_lhs;
unsafe {
tmp_lhs = ct_left.duplicate_async(streams);
if !tmp_lhs.block_carries_are_empty() {
self.full_propagate_assign_async(&mut tmp_lhs, streams);
}
}

let trivial: CudaSignedRadixCiphertext = self.create_trivial_radix(
scalar,
ct_left.ciphertext.d_blocks.lwe_ciphertext_count().0,
streams,
);
let (result, overflowed) = self.signed_overflowing_add(&tmp_lhs, &trivial, streams);

let mut extra_scalar_block_iter =
BlockDecomposer::new(scalar, self.message_modulus.0.ilog2())
.iter_as::<u64>()
.skip(ct_left.ciphertext.d_blocks.lwe_ciphertext_count().0);

let extra_blocks_have_correct_value = if scalar < Scalar::ZERO {
extra_scalar_block_iter.all(|block| block == (self.message_modulus.0 as u64 - 1))
} else {
extra_scalar_block_iter.all(|block| block == 0)
};

if extra_blocks_have_correct_value {
(result, overflowed)
} else {
let trivial_one: CudaSignedRadixCiphertext = self.create_trivial_radix(1, 1, streams);
// Scalar has more blocks so addition counts as overflowing
(
result,
CudaBooleanBlock::from_cuda_radix_ciphertext(trivial_one.ciphertext),
)
}
}
}
88 changes: 85 additions & 3 deletions tfhe/src/integer/gpu/server_key/radix/scalar_sub.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::core_crypto::gpu::CudaStreams;
use crate::core_crypto::prelude::Numeric;
use crate::integer::block_decomposition::DecomposableInto;
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
use crate::core_crypto::prelude::{Numeric, SignedNumeric};
use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto};
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaSignedRadixCiphertext};
use crate::integer::gpu::server_key::CudaServerKey;
use crate::integer::server_key::TwosComplementNegation;
use crate::prelude::CastInto;
Expand Down Expand Up @@ -163,4 +164,85 @@ impl CudaServerKey {
}
stream.synchronize();
}

/// ```rust
/// use tfhe::core_crypto::gpu::CudaStreams;
/// use tfhe::integer::gpu::ciphertext::{CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext};
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
///
/// let gpu_index = 0;
/// let streams = CudaStreams::new_single_gpu(gpu_index);
///
/// // Generate the client key and the server key:
/// let num_blocks = 4;
/// let (cks, sks) = gen_keys_radix_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks, &streams);
///
/// let msg: i8 = 120;
/// let scalar: i8 = 8;
///
/// let ct1 = cks.encrypt_signed(msg);
///
/// // Copy to GPU
/// let d_ct1 = CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&ct1, &streams);
///
/// // Compute homomorphically an overflowing addition:
/// let (d_ct_res, d_ct_overflowed) = sks.signed_overflowing_scalar_sub(&d_ct1, scalar, &streams);
///
/// let ct_res = d_ct_res.to_signed_radix_ciphertext(&streams);
/// let ct_overflowed = d_ct_overflowed.to_boolean_block(&streams);
///
/// // Decrypt:
/// let dec_result: i8 = cks.decrypt_signed(&ct_res);
/// let dec_overflowed: bool = cks.decrypt_bool(&ct_overflowed);
/// let (clear_result, clear_overflowed) = msg.overflowing_sub(scalar);
/// assert_eq!(dec_result, clear_result);
/// assert_eq!(dec_overflowed, clear_overflowed);
/// ```
pub fn signed_overflowing_scalar_sub<Scalar>(
&self,
ct_left: &CudaSignedRadixCiphertext,
scalar: Scalar,
streams: &CudaStreams,
) -> (CudaSignedRadixCiphertext, CudaBooleanBlock)
where
Scalar: SignedNumeric + DecomposableInto<u64> + CastInto<u64>,
{
let mut tmp_lhs;
unsafe {
tmp_lhs = ct_left.duplicate_async(streams);
if !tmp_lhs.block_carries_are_empty() {
self.full_propagate_assign_async(&mut tmp_lhs, streams);
}
}

let trivial: CudaSignedRadixCiphertext = self.create_trivial_radix(
scalar,
ct_left.ciphertext.d_blocks.lwe_ciphertext_count().0,
streams,
);
let (result, overflowed) = self.signed_overflowing_sub(&tmp_lhs, &trivial, streams);

let mut extra_scalar_block_iter =
BlockDecomposer::new(scalar, self.message_modulus.0.ilog2())
.iter_as::<u64>()
.skip(ct_left.ciphertext.d_blocks.lwe_ciphertext_count().0);

let extra_blocks_have_correct_value = if scalar < Scalar::ZERO {
extra_scalar_block_iter.all(|block| block == (self.message_modulus.0 as u64 - 1))
} else {
extra_scalar_block_iter.all(|block| block == 0)
};

if extra_blocks_have_correct_value {
(result, overflowed)
} else {
let trivial_one: CudaSignedRadixCiphertext = self.create_trivial_radix(1, 1, streams);
// Scalar has more blocks so addition counts as overflowing
(
result,
CudaBooleanBlock::from_cuda_radix_ciphertext(trivial_one.ciphertext),
)
}
}
}
37 changes: 37 additions & 0 deletions tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,40 @@ where
)
}
}

// for signed overflowing scalar ops
impl<'a, F>
FunctionExecutor<(&'a SignedRadixCiphertext, i64), (SignedRadixCiphertext, BooleanBlock)>
for GpuFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaSignedRadixCiphertext,
i64,
&CudaStreams,
) -> (CudaSignedRadixCiphertext, CudaBooleanBlock),
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}

fn execute(
&mut self,
input: (&'a SignedRadixCiphertext, i64),
) -> (SignedRadixCiphertext, BooleanBlock) {
let context = self
.context
.as_ref()
.expect("setup was not properly called");

let d_ctxt_1: CudaSignedRadixCiphertext =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.0, &context.streams);

let (d_res, d_res_bool) = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);

(
d_res.to_signed_radix_ciphertext(&context.streams),
d_res_bool.to_boolean_block(&context.streams),
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ use crate::integer::gpu::server_key::radix::tests_unsigned::{
};
use crate::integer::gpu::CudaServerKey;
use crate::integer::server_key::radix_parallel::tests_signed::test_scalar_add::{
signed_default_scalar_add_test, signed_unchecked_scalar_add_test,
signed_default_overflowing_scalar_add_test, signed_default_scalar_add_test,
signed_unchecked_scalar_add_test,
};
use crate::shortint::parameters::*;

create_gpu_parametrized_test!(integer_signed_unchecked_scalar_add);
create_gpu_parametrized_test!(integer_signed_scalar_add);
create_gpu_parametrized_test!(integer_signed_overflowing_scalar_add);

fn integer_signed_unchecked_scalar_add<P>(param: P)
where
Expand All @@ -25,3 +27,11 @@ where
let executor = GpuFunctionExecutor::new(&CudaServerKey::scalar_add);
signed_default_scalar_add_test(param, executor);
}

fn integer_signed_overflowing_scalar_add<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::signed_overflowing_scalar_add);
signed_default_overflowing_scalar_add_test(param, executor);
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@ use crate::integer::gpu::server_key::radix::tests_unsigned::{
create_gpu_parametrized_test, GpuFunctionExecutor,
};
use crate::integer::gpu::CudaServerKey;
use crate::integer::server_key::radix_parallel::tests_signed::test_scalar_sub::signed_unchecked_scalar_sub_test;
use crate::integer::server_key::radix_parallel::tests_signed::test_scalar_sub::{
signed_default_overflowing_scalar_sub_test, signed_unchecked_scalar_sub_test,
};
use crate::shortint::parameters::*;

create_gpu_parametrized_test!(integer_signed_unchecked_scalar_sub);
create_gpu_parametrized_test!(integer_signed_overflowing_scalar_sub);

fn integer_signed_unchecked_scalar_sub<P>(param: P)
where
Expand All @@ -14,3 +17,11 @@ where
let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_sub);
signed_unchecked_scalar_sub_test(param, executor);
}

fn integer_signed_overflowing_scalar_sub<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::signed_overflowing_scalar_sub);
signed_default_overflowing_scalar_sub_test(param, executor);
}
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ where
let ctxt_0 = cks.encrypt_signed(clear_0);

let (ct_res, result_overflowed) = executor.execute((&ctxt_0, clear_1));
let (tmp_ct, tmp_o) = sks.signed_overflowing_scalar_add_parallelized(&ctxt_0, clear_1);
let (tmp_ct, tmp_o) = executor.execute((&ctxt_0, clear_1));
assert!(ct_res.block_carries_are_empty());
assert_eq!(ct_res, tmp_ct, "Failed determinism check");
assert_eq!(tmp_o, result_overflowed, "Failed determinism check");
Expand Down

0 comments on commit 95ef13f

Please sign in to comment.