Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(capi): add mechanism to get panic message as const char * #1740

Merged
merged 1 commit into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions tfhe/c_api_tests/test_high_level_error.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// If this test break the c_api doc needs to be updated

#include "tfhe.h"
#include <assert.h>
#include <stdio.h>
#include <string.h>
int main(void) {
tfhe_error_disable_automatic_prints();

int ok = 0;
// Prepare the config builder for the high level API and choose which types to enable
ConfigBuilder *builder;
Config *config;

config_builder_default(&builder);
config_builder_build(builder, &config);

ClientKey *client_key = NULL;
ServerKey *server_key = NULL;

generate_keys(config, &client_key, &server_key);

// Intentionally forget the set_server_key to test error
// set_server_key(server_key);

FheUint128 *lhs = NULL;
FheUint128 *rhs = NULL;
FheUint128 *result = NULL;

U128 clear_lhs = {.w0 = 10, .w1 = 20};
U128 clear_rhs = {.w0 = 1, .w1 = 2};

ok = fhe_uint128_try_encrypt_with_client_key_u128(clear_lhs, client_key, &lhs);
assert(ok == 0);

ok = fhe_uint128_try_encrypt_with_client_key_u128(clear_rhs, client_key, &rhs);
assert(ok == 0);

const char *last_error = tfhe_error_get_last();
assert(last_error != NULL);
assert(strcmp(last_error, "no error") == 0);

// Compute the subtraction
ok = fhe_uint128_sub(lhs, rhs, &result);
assert(ok == 1);

last_error = tfhe_error_get_last();
assert(last_error != NULL);
printf("Error message Received from tfhe-rs: '%s'\n", last_error);

// Destroy the ciphertexts
fhe_uint128_destroy(lhs);
fhe_uint128_destroy(rhs);

// Destroy the keys
client_key_destroy(client_key);
server_key_destroy(server_key);

return EXIT_SUCCESS;
}
122 changes: 122 additions & 0 deletions tfhe/src/c_api/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
use std::any::Any;
use std::cell::RefCell;
use std::ffi::{c_char, CString};
use std::panic::AssertUnwindSafe;

enum LastError {
None,
Message(CString),
// Only used when there was a memory panic happens
// when trying to store the last panic message into the CString above
NoMemory,
}

impl LastError {
const NO_MEMORY_MSG: [c_char; 10] = [
b'n' as c_char,
b'o' as c_char,
b' ' as c_char,
b'm' as c_char,
b'e' as c_char,
b'm' as c_char,
b'o' as c_char,
b'r' as c_char,
b'y' as c_char,
b'\0' as c_char,
];

const NO_ERROR_MSG: [c_char; 9] = [
b'n' as c_char,
b'o' as c_char,
b' ' as c_char,
b'e' as c_char,
b'r' as c_char,
b'r' as c_char,
b'o' as c_char,
b'r' as c_char,
b'\0' as c_char,
];

fn as_ptr(&self) -> *const c_char {
match self {
Self::None => Self::NO_ERROR_MSG.as_ptr(),
Self::Message(cstring) => cstring.as_ptr(),
Self::NoMemory => Self::NO_MEMORY_MSG.as_ptr(),
}
}

/// Does not include the nul-byte
fn len(&self) -> usize {
match self {
Self::None => Self::NO_ERROR_MSG.len() - 1,
Self::Message(cstring) => cstring.as_bytes().len(),
Self::NoMemory => Self::NO_MEMORY_MSG.len() - 1,
}
}
}

std::thread_local! {
pub(in crate::c_api) static LAST_LOCAL_ERROR: RefCell<LastError> = const { RefCell::new(LastError::None) };
}

pub(in crate::c_api) fn replace_last_error_with_panic_payload(payload: &Box<dyn Any + Send>) {
LAST_LOCAL_ERROR.with(|local_error| {
let _previous_error = local_error.replace(panic_payload_to_error(payload));
});
}

fn panic_payload_to_error(payload: &Box<dyn Any + Send>) -> LastError {
// Add a catch panic as technically the to_vec could fail
let result = std::panic::catch_unwind(AssertUnwindSafe(|| {
// Rust doc says:
// An invocation of the panic!() macro in Rust 2021 or later
// will always result in a panic payload of type &'static str or String.
payload
.downcast_ref::<&str>()
.map_or_else(|| b"panic occurred".to_vec(), |s| s.as_bytes().to_vec())
}));

result.map_or_else(
|_| LastError::NoMemory,
|bytes| LastError::Message(CString::new(bytes).unwrap()),
)
}

/// Returns a pointer to a nul-terminated string describing the last error in the thread.
///
/// * This pointer is only valid for as long as no other tfhe-rs function are called within this
/// thread.
///
/// This should be used to directly print/display or copy the error.
#[no_mangle]
pub unsafe extern "C" fn tfhe_error_get_last() -> *const c_char {
LAST_LOCAL_ERROR.with_borrow(|last_error| last_error.as_ptr())
}

/// Returns the length of the current error message stored
///
/// The length **DOES NOT INCLUDE** the nul byte
IceTDrinker marked this conversation as resolved.
Show resolved Hide resolved
#[no_mangle]
pub unsafe extern "C" fn tfhe_error_get_size() -> usize {
LAST_LOCAL_ERROR.with_borrow(|last_error| last_error.len())
}

/// Clears the last error
#[no_mangle]
pub unsafe extern "C" fn tfhe_error_clear() {
LAST_LOCAL_ERROR.replace(LastError::None);
}

/// Disables panic prints to stderr when a thread panics
#[no_mangle]
pub unsafe extern "C" fn tfhe_error_disable_automatic_prints() {
std::panic::set_hook(Box::new(|_panic_info| {}));
}

/// Enables panic prints to stderr when a thread panics
#[no_mangle]
pub unsafe extern "C" fn tfhe_error_enable_automatic_prints() {
// if the current hook is the default one, 'taking' it
// will still make is registered
let _ = std::panic::take_hook();
}
1 change: 1 addition & 0 deletions tfhe/src/c_api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
pub mod boolean;
pub mod buffer;
pub mod core_crypto;
pub mod error;
#[cfg(feature = "high-level-c-api")]
pub mod high_level_api;
#[cfg(feature = "shortint-c-api")]
Expand Down
5 changes: 4 additions & 1 deletion tfhe/src/c_api/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ where
{
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(closure)) {
Ok(_) => 0,
_ => 1,
Err(err) => {
super::error::replace_last_error_with_panic_payload(&err);
1
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/high_level_api/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ impl Display for UninitializedServerKey {
write!(
f,
"The server key was not properly initialized.\n\
Did you forget to call `set_server_key` in the current thread ?
Did you forget to call `set_server_key` in the current thread ?\
",
)
}
Expand Down
Loading