Skip to content

Commit

Permalink
feat(gpu): implement signed_overflowing_sub
Browse files Browse the repository at this point in the history
  • Loading branch information
bbarbakadze authored and agnesLeroy committed Jul 22, 2024
1 parent b443855 commit 230fa5a
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 25 deletions.
47 changes: 29 additions & 18 deletions backends/tfhe-cuda-backend/cuda/include/integer.h
Original file line number Diff line number Diff line change
Expand Up @@ -2869,39 +2869,45 @@ template <typename Torus> struct int_last_block_inner_propagate_memory {
auto f_last_block_inner_propagation_lut =
[op, message_modulus, message_bit_mask,
bits_of_message](Torus lhs_block, Torus rhs_block) -> Torus {
Torus local_rhs_block = 0;
uint64_t rhs_block_modified;
if (op == SIGNED_OPERATION::SUBTRACTION) {
Torus flipped_rhs = !rhs_block;
local_rhs_block = (flipped_rhs << 1) & message_bit_mask;
// Subtraction is done by adding the negation
// Negation(x) = bit_flip(x) + 1
// Only add the flipped value, the +1 will be resolved by carry
// propagation computation
uint64_t flipped_rhs = ~rhs_block;

// Remove the last bit, it's not interesting in this step
rhs_block_modified = (flipped_rhs << 1) & message_bit_mask;
} else {
local_rhs_block = (rhs_block << 1) & message_bit_mask;
};
rhs_block_modified = (rhs_block << 1) & message_bit_mask;
}

Torus local_lhs_block = (lhs_block << 1) & message_bit_mask;
uint64_t lhs_block_modified = (lhs_block << 1) & message_bit_mask;

// whole_result contains the result of addition with
// the carry being in the first bit of carry space
// the message space contains the message, but with one 0
// on the right (lsb)
Torus whole_result = local_lhs_block + local_rhs_block;
Torus carry = whole_result >> bits_of_message;
Torus result = (whole_result & message_bit_mask) >> 1;
Torus propagation_result = 0;
// on the right (LSB)
uint64_t whole_result = lhs_block_modified + rhs_block_modified;
uint64_t carry = whole_result >> bits_of_message;
uint64_t result = (whole_result & message_bit_mask) >> 1;
OUTPUT_CARRY propagation_result;
if (carry == 1) {
// Addition of bits before last one generates a carry
// Addition of bits before the last one generates a carry
propagation_result = OUTPUT_CARRY::GENERATED;
} else if (result == ((message_modulus - 1) >> 1)) {
// Addition of bits before last one puts the bits
// in a state that makes it so that an input carry into last block
// gets propagated to last bit.
// Addition of bits before the last one puts the bits
// in a state that makes it so that an input carry into the last block
// gets propagated to the last bit.
propagation_result = OUTPUT_CARRY::PROPAGATED;
} else {
propagation_result = OUTPUT_CARRY::NONE;
};
}

// Shift the propagation result in carry part
// Shift the propagation result in the carry part
// to have less noise growth later
return propagation_result << bits_of_message;
return (static_cast<uint64_t>(propagation_result) << bits_of_message);
};

last_block_inner_propagation_lut = new int_radix_lut<Torus>(
Expand Down Expand Up @@ -3000,6 +3006,7 @@ template <typename Torus> struct int_signed_overflowing_add_or_sub_memory {
// temporary device buffers
Torus *result; // num_blocks
Torus *input_carries; // num_blocks
Torus *neg_rhs; // num_blocks
Torus *output_carry; // single block
Torus *last_block_inner_propagation; // single block

Expand All @@ -3012,6 +3019,9 @@ template <typename Torus> struct int_signed_overflowing_add_or_sub_memory {
result = (Torus *)cuda_malloc_async(
big_lwe_size * num_blocks * sizeof(Torus), streams[0], gpu_indexes[0]);

neg_rhs = (Torus *)cuda_malloc_async(
big_lwe_size * num_blocks * sizeof(Torus), streams[0], gpu_indexes[0]);

input_carries = (Torus *)cuda_malloc_async(
big_lwe_size * num_blocks * sizeof(Torus), streams[0], gpu_indexes[0]);
output_carry = (Torus *)cuda_malloc_async(big_lwe_size * sizeof(Torus),
Expand Down Expand Up @@ -3062,6 +3072,7 @@ template <typename Torus> struct int_signed_overflowing_add_or_sub_memory {

// temporary device buffers
cuda_drop_async(result, streams[0], gpu_indexes[0]);
cuda_drop_async(neg_rhs, streams[0], gpu_indexes[0]);
cuda_drop_async(input_carries, streams[0], gpu_indexes[0]);
cuda_drop_async(output_carry, streams[0], gpu_indexes[0]);
cuda_drop_async(last_block_inner_propagation, streams[0], gpu_indexes[0]);
Expand Down
8 changes: 6 additions & 2 deletions backends/tfhe-cuda-backend/cuda/src/integer/addition.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ __host__ void host_integer_signed_overflowing_add_or_sub_kb(
assert(radix_params.message_modulus >= 4 && radix_params.carry_modulus >= 4);

auto result = mem_ptr->result;
auto neg_rhs = mem_ptr->neg_rhs;
auto input_carries = mem_ptr->input_carries;
auto output_carry = mem_ptr->output_carry;
auto last_block_inner_propagation = mem_ptr->last_block_inner_propagation;
Expand All @@ -97,8 +98,11 @@ __host__ void host_integer_signed_overflowing_add_or_sub_kb(
host_addition(streams[0], gpu_indexes[0], result, lhs, rhs,
big_lwe_dimension, num_blocks);
} else {
host_subtraction(streams[0], gpu_indexes[0], result, lhs, rhs,
big_lwe_dimension, num_blocks);
host_integer_radix_negation(
streams, gpu_indexes, gpu_count, neg_rhs, rhs, big_lwe_dimension,
num_blocks, radix_params.message_modulus, radix_params.carry_modulus);
host_addition(streams[0], gpu_indexes[0], result, lhs, neg_rhs,
big_lwe_dimension, num_blocks);
}

// phase 2
Expand Down
12 changes: 12 additions & 0 deletions tfhe/benches/integer/signed_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1688,6 +1688,11 @@ mod cuda {
display_name: overflowing_add
);

define_cuda_server_key_bench_clean_input_signed_fn!(
method_name: unchecked_signed_overflowing_sub,
display_name: overflowing_sub
);

define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
method_name: unchecked_scalar_add,
display_name: add,
Expand Down Expand Up @@ -1905,6 +1910,11 @@ mod cuda {
display_name: overflowing_add
);

define_cuda_server_key_bench_clean_input_signed_fn!(
method_name: signed_overflowing_sub,
display_name: overflowing_sub
);

define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
method_name: scalar_add,
display_name: add,
Expand Down Expand Up @@ -2036,6 +2046,7 @@ mod cuda {
cuda_unchecked_min,
cuda_unchecked_max,
cuda_unchecked_signed_overflowing_add,
cuda_unchecked_signed_overflowing_sub,
);

criterion_group!(
Expand Down Expand Up @@ -2084,6 +2095,7 @@ mod cuda {
cuda_max,
cuda_if_then_else,
cuda_signed_overflowing_add,
cuda_signed_overflowing_sub,
);

criterion_group!(
Expand Down
11 changes: 8 additions & 3 deletions tfhe/src/high_level_api/integers/signed/overflowing_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,14 @@ where
(FheInt::new(result), FheBool::new(overflow))
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
todo!("Cuda devices do not support signed integer");
}
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
let (result, overflow) = cuda_key.key.signed_overflowing_sub(
&self.ciphertext.on_gpu(),
&other.ciphertext.on_gpu(),
streams,
);
(FheInt::new(result), FheBool::new(overflow))
}),
})
}
}
Expand Down
111 changes: 110 additions & 1 deletion tfhe/src/integer/gpu/server_key/radix/sub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ use crate::core_crypto::prelude::{CiphertextModulus, LweBskGroupingFactor, LweCi
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
use crate::integer::gpu::ciphertext::info::CudaRadixCiphertextInfo;
use crate::integer::gpu::ciphertext::{
CudaIntegerRadixCiphertext, CudaRadixCiphertext, CudaUnsignedRadixCiphertext,
CudaIntegerRadixCiphertext, CudaRadixCiphertext, CudaSignedRadixCiphertext,
CudaUnsignedRadixCiphertext,
};
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
use crate::integer::gpu::{
unchecked_unsigned_overflowing_sub_integer_radix_kb_assign_async, PBSType,
};
use crate::integer::server_key::radix_parallel::sub::SignedOperation;
use crate::shortint::ciphertext::NoiseLevel;

impl CudaServerKey {
Expand Down Expand Up @@ -440,4 +442,111 @@ impl CudaServerKey {

(ct_res, ct_overflowed)
}

/// ```rust
/// use tfhe::core_crypto::gpu::CudaStreams;
/// use tfhe::integer::gpu::ciphertext::{CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext};
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
///
/// let gpu_index = 0;
/// let streams = CudaStreams::new_single_gpu(gpu_index);
///
/// // Generate the client key and the server key:
/// let num_blocks = 4;
/// let (cks, sks) = gen_keys_radix_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks, &streams);
/// let total_bits = num_blocks * cks.parameters().message_modulus().0.ilog2() as usize;
/// let modulus = 1 << total_bits;
///
/// let msg1: i8 = 120;
/// let msg2: i8 = 8;
///
/// let ct1 = cks.encrypt_signed(msg1);
/// let ct2 = cks.encrypt_signed(msg2);
///
/// // Copy to GPU
/// let d_ct1 = CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&ct1, &streams);
/// let d_ct2 = CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&ct2, &streams);
///
/// // Compute homomorphically an overflowing subtraction:
/// let (d_ct_res, d_ct_overflowed) = sks.signed_overflowing_sub(&d_ct1, &d_ct2, &streams);
///
/// let ct_res = d_ct_res.to_signed_radix_ciphertext(&streams);
/// let ct_overflowed = d_ct_overflowed.to_boolean_block(&streams);
///
/// // Decrypt:
/// let dec_result: i8 = cks.decrypt_signed(&ct_res);
/// let dec_overflowed: bool = cks.decrypt_bool(&ct_overflowed);
/// let (clear_result, clear_overflowed) = msg1.overflowing_sub(msg2);
/// assert_eq!(dec_result, clear_result);
/// assert_eq!(dec_overflowed, clear_overflowed);
/// ```
pub fn signed_overflowing_sub(
&self,
ct_left: &CudaSignedRadixCiphertext,
ct_right: &CudaSignedRadixCiphertext,
stream: &CudaStreams,
) -> (CudaSignedRadixCiphertext, CudaBooleanBlock) {
let mut tmp_lhs;
let mut tmp_rhs;
let (lhs, rhs) = match (
ct_left.block_carries_are_empty(),
ct_right.block_carries_are_empty(),
) {
(true, true) => (ct_left, ct_right),
(true, false) => {
unsafe {
tmp_rhs = ct_right.duplicate_async(stream);
self.full_propagate_assign_async(&mut tmp_rhs, stream);
}
(ct_left, &tmp_rhs)
}
(false, true) => {
unsafe {
tmp_lhs = ct_left.duplicate_async(stream);
self.full_propagate_assign_async(&mut tmp_lhs, stream);
}
(&tmp_lhs, ct_right)
}
(false, false) => {
unsafe {
tmp_lhs = ct_left.duplicate_async(stream);
tmp_rhs = ct_right.duplicate_async(stream);

self.full_propagate_assign_async(&mut tmp_lhs, stream);
self.full_propagate_assign_async(&mut tmp_rhs, stream);
}

(&tmp_lhs, &tmp_rhs)
}
};

self.unchecked_signed_overflowing_sub(lhs, rhs, stream)
}

pub fn unchecked_signed_overflowing_sub(
&self,
ct_left: &CudaSignedRadixCiphertext,
ct_right: &CudaSignedRadixCiphertext,
stream: &CudaStreams,
) -> (CudaSignedRadixCiphertext, CudaBooleanBlock) {
assert_eq!(
ct_left.as_ref().d_blocks.lwe_ciphertext_count().0,
ct_right.as_ref().d_blocks.lwe_ciphertext_count().0,
"lhs and rhs must have the name number of blocks ({} vs {})",
ct_left.as_ref().d_blocks.lwe_ciphertext_count().0,
ct_right.as_ref().d_blocks.lwe_ciphertext_count().0
);
assert!(
ct_left.as_ref().d_blocks.lwe_ciphertext_count().0 > 0,
"inputs cannot be empty"
);

self.unchecked_signed_overflowing_add_or_sub(
ct_left,
ct_right,
SignedOperation::Subtraction,
stream,
)
}
}
22 changes: 21 additions & 1 deletion tfhe/src/integer/gpu/server_key/radix/tests_signed/test_sub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@ use crate::integer::gpu::server_key::radix::tests_unsigned::{
};
use crate::integer::gpu::CudaServerKey;
use crate::integer::server_key::radix_parallel::tests_signed::test_sub::{
signed_default_sub_test, signed_unchecked_sub_test,
signed_default_overflowing_sub_test, signed_default_sub_test,
signed_unchecked_overflowing_sub_test, signed_unchecked_sub_test,
};
use crate::shortint::parameters::*;

create_gpu_parametrized_test!(integer_unchecked_sub);
create_gpu_parametrized_test!(integer_sub);

create_gpu_parametrized_test!(integer_unchecked_signed_overflowing_sub);
create_gpu_parametrized_test!(integer_signed_overflowing_sub);

fn integer_unchecked_sub<P>(param: P)
where
P: Into<PBSParameters>,
Expand All @@ -25,3 +29,19 @@ where
let executor = GpuFunctionExecutor::new(&CudaServerKey::sub);
signed_default_sub_test(param, executor);
}

fn integer_unchecked_signed_overflowing_sub<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_signed_overflowing_sub);
signed_unchecked_overflowing_sub_test(param, executor);
}

fn integer_signed_overflowing_sub<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::signed_overflowing_sub);
signed_default_overflowing_sub_test(param, executor);
}

0 comments on commit 230fa5a

Please sign in to comment.