Skip to content

Commit

Permalink
feat(capi): add mechanism to get panic message as const char *
Browse files Browse the repository at this point in the history
Previously, when an error occurred in the rust side, the panic message
would get printed to stderr, then the c function would return 1 to
indicate error.

This commit adds the ability to disable the automatic prints of panic
messages and adds functions to get the panic message as a const char *
to allow user better control on how to display error messages.
  • Loading branch information
tmontaigu committed Nov 14, 2024
1 parent 5a664aa commit 8e1c1c3
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 2 deletions.
60 changes: 60 additions & 0 deletions tfhe/c_api_tests/test_high_level_error.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// If this test break the c_api doc needs to be updated

#include "tfhe.h"
#include <assert.h>
#include <stdio.h>
#include <string.h>
int main(void) {
tfhe_error_disable_automatic_prints();

int ok = 0;
// Prepare the config builder for the high level API and choose which types to enable
ConfigBuilder *builder;
Config *config;

config_builder_default(&builder);
config_builder_build(builder, &config);

ClientKey *client_key = NULL;
ServerKey *server_key = NULL;

generate_keys(config, &client_key, &server_key);

// Intentionally forget the set_server_key to test error
// set_server_key(server_key);

FheUint128 *lhs = NULL;
FheUint128 *rhs = NULL;
FheUint128 *result = NULL;

U128 clear_lhs = {.w0 = 10, .w1 = 20};
U128 clear_rhs = {.w0 = 1, .w1 = 2};

ok = fhe_uint128_try_encrypt_with_client_key_u128(clear_lhs, client_key, &lhs);
assert(ok == 0);

ok = fhe_uint128_try_encrypt_with_client_key_u128(clear_rhs, client_key, &rhs);
assert(ok == 0);

const char *last_error = tfhe_error_get_last();
assert(last_error != NULL);
assert(strcmp(last_error, "no error") == 0);

// Compute the subtraction
ok = fhe_uint128_sub(lhs, rhs, &result);
assert(ok == 1);

last_error = tfhe_error_get_last();
assert(last_error != NULL);
printf("Error message Received from tfhe-rs: '%s'\n", last_error);

// Destroy the ciphertexts
fhe_uint128_destroy(lhs);
fhe_uint128_destroy(rhs);

// Destroy the keys
client_key_destroy(client_key);
server_key_destroy(server_key);

return EXIT_SUCCESS;
}
122 changes: 122 additions & 0 deletions tfhe/src/c_api/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
use std::any::Any;
use std::cell::RefCell;
use std::ffi::{c_char, CString};
use std::panic::AssertUnwindSafe;

enum LastError {
None,
Message(CString),
// Only used when there was a memory panic happens
// when trying to store the last panic message into the CString above
NoMemory,
}

impl LastError {
const NO_MEMORY_MSG: [c_char; 10] = [
b'n' as c_char,
b'o' as c_char,
b' ' as c_char,
b'm' as c_char,
b'e' as c_char,
b'm' as c_char,
b'o' as c_char,
b'r' as c_char,
b'y' as c_char,
b'\0' as c_char,
];

const NO_ERROR_MSG: [c_char; 9] = [
b'n' as c_char,
b'o' as c_char,
b' ' as c_char,
b'e' as c_char,
b'r' as c_char,
b'r' as c_char,
b'o' as c_char,
b'r' as c_char,
b'\0' as c_char,
];

fn as_ptr(&self) -> *const c_char {
match self {
Self::None => Self::NO_ERROR_MSG.as_ptr(),
Self::Message(cstring) => cstring.as_ptr(),
Self::NoMemory => Self::NO_MEMORY_MSG.as_ptr(),
}
}

/// Does not include the nul-byte
fn len(&self) -> usize {
match self {
Self::None => Self::NO_ERROR_MSG.len() - 1,
Self::Message(cstring) => cstring.as_bytes().len(),
Self::NoMemory => Self::NO_MEMORY_MSG.len() - 1,
}
}
}

std::thread_local! {
pub(in crate::c_api) static LAST_LOCAL_ERROR: RefCell<LastError> = const { RefCell::new(LastError::None) };
}

pub(in crate::c_api) fn replace_last_error_with_panic_payload(payload: &Box<dyn Any + Send>) {
LAST_LOCAL_ERROR.with(|local_error| {
let _previous_error = local_error.replace(panic_payload_to_error(payload));
});
}

fn panic_payload_to_error(payload: &Box<dyn Any + Send>) -> LastError {
// Add a catch panic as technically the to_vec could fail
let result = std::panic::catch_unwind(AssertUnwindSafe(|| {
// Rust doc says:
// An invocation of the panic!() macro in Rust 2021 or later
// will always result in a panic payload of type &'static str or String.
payload
.downcast_ref::<&str>()
.map_or_else(|| b"panic occurred".to_vec(), |s| s.as_bytes().to_vec())
}));

result.map_or_else(
|_| LastError::NoMemory,
|bytes| LastError::Message(CString::new(bytes).unwrap()),
)
}

/// Returns a pointer to a nul-terminated string describing the last error in the thread.
///
/// * This pointer is only valid for as long as no other tfhe-rs function are called within this
/// thread.
///
/// This should be used to directly print/display or copy the error.
#[no_mangle]
pub unsafe extern "C" fn tfhe_error_get_last() -> *const c_char {
LAST_LOCAL_ERROR.with_borrow(|last_error| last_error.as_ptr())
}

/// Returns the length of the current error message stored
///
/// The length **DOES NOT INCLUDE** the nul byte
#[no_mangle]
pub unsafe extern "C" fn tfhe_error_get_size() -> usize {
LAST_LOCAL_ERROR.with_borrow(|last_error| last_error.len())
}

/// Clears the last error
#[no_mangle]
pub unsafe extern "C" fn tfhe_error_clear() {
LAST_LOCAL_ERROR.replace(LastError::None);
}

/// Disables panic prints to stderr when a thread panics
#[no_mangle]
pub unsafe extern "C" fn tfhe_error_disable_automatic_prints() {
std::panic::set_hook(Box::new(|_panic_info| {}));
}

/// Enables panic prints to stderr when a thread panics
#[no_mangle]
pub unsafe extern "C" fn tfhe_error_enable_automatic_prints() {
// if the current hook is the default one, 'taking' it
// will still make is registered
let _ = std::panic::take_hook();
}
1 change: 1 addition & 0 deletions tfhe/src/c_api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
pub mod boolean;
pub mod buffer;
pub mod core_crypto;
pub mod error;
#[cfg(feature = "high-level-c-api")]
pub mod high_level_api;
#[cfg(feature = "shortint-c-api")]
Expand Down
5 changes: 4 additions & 1 deletion tfhe/src/c_api/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ where
{
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(closure)) {
Ok(_) => 0,
_ => 1,
Err(err) => {
super::error::replace_last_error_with_panic_payload(&err);
1
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/high_level_api/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ impl Display for UninitializedServerKey {
write!(
f,
"The server key was not properly initialized.\n\
Did you forget to call `set_server_key` in the current thread ?
Did you forget to call `set_server_key` in the current thread ?\
",
)
}
Expand Down

0 comments on commit 8e1c1c3

Please sign in to comment.