Skip to content

Commit

Permalink
rust: Use c_char for ffi.
Browse files Browse the repository at this point in the history
  • Loading branch information
hgaiser committed Nov 23, 2023
1 parent 2ac381c commit 9a050a1
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 17 deletions.
4 changes: 2 additions & 2 deletions rust/onnxruntime/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Module containing error definitions.
use std::{io, path::PathBuf};
use std::{io, path::PathBuf, ffi::c_char};

use thiserror::Error;

Expand Down Expand Up @@ -211,7 +211,7 @@ impl From<OrtStatusWrapper> for std::result::Result<(), OrtApiError> {
if status.0.is_null() {
Ok(())
} else {
let raw: *const i8 = unsafe {
let raw: *const c_char = unsafe {
ENV.get()
.unwrap()
.lock()
Expand Down
16 changes: 9 additions & 7 deletions rust/onnxruntime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ to download.
//! See the [`sample.rs`](https://github.com/nbigaouette/onnxruntime-rs/blob/main/onnxruntime/examples/sample.rs)
//! example for more details.
use std::ffi::c_char;

use onnxruntime_sys as sys;

// Make functions `extern "stdcall"` for Windows 32bit.
Expand Down Expand Up @@ -187,8 +189,8 @@ use sys::OnnxEnumInt;
// Re-export ndarray as it's part of the public API anyway
pub use ndarray;

fn char_p_to_string(raw: *const i8) -> Result<String> {
let c_string = unsafe { std::ffi::CStr::from_ptr(raw as *mut i8).to_owned() };
fn char_p_to_string(raw: *const c_char) -> Result<String> {
let c_string = unsafe { std::ffi::CStr::from_ptr(raw as *mut c_char).to_owned() };

match c_string.into_string() {
Ok(string) => Ok(string),
Expand All @@ -201,7 +203,7 @@ mod onnxruntime {
//! Module containing a custom logger, used to catch the runtime's own logging and send it
//! to Rust's tracing logging instead.
use std::ffi::CStr;
use std::ffi::{CStr, c_char};
use tracing::{debug, error, info, span, trace, warn, Level};

use onnxruntime_sys as sys;
Expand Down Expand Up @@ -240,10 +242,10 @@ mod onnxruntime {
pub(crate) fn custom_logger(
_params: *mut std::ffi::c_void,
severity: sys::OrtLoggingLevel,
category: *const i8,
logid: *const i8,
code_location: *const i8,
message: *const i8,
category: *const c_char,
logid: *const c_char,
code_location: *const c_char,
message: *const c_char,
) {
let log_level = match severity {
sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE => Level::TRACE,
Expand Down
18 changes: 10 additions & 8 deletions rust/onnxruntime/src/session.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Module containing session types
use std::{convert::TryFrom, ffi::CString, fmt::Debug, path::Path};
use std::{convert::TryFrom, ffi::{CString, c_char}, fmt::Debug, path::Path};

#[cfg(not(target_family = "windows"))]
use std::os::unix::ffi::OsStrExt;
Expand Down Expand Up @@ -423,17 +423,17 @@ impl Session {
.map(|output| output.name.clone())
.map(|n| CString::new(n).unwrap())
.collect();
let output_names_ptr: Vec<*const i8> = output_names_cstring
let output_names_ptr: Vec<*const c_char> = output_names_cstring
.iter()
.map(|n| n.as_ptr().cast::<i8>())
.map(|n| n.as_ptr().cast::<c_char>())
.collect();

let input_names_ptr: Vec<*const i8> = self
let input_names_ptr: Vec<*const c_char> = self
.inputs
.iter()
.map(|input| input.name.clone())
.map(|n| CString::new(n).unwrap())
.map(|n| n.into_raw() as *const i8)
.map(|n| n.into_raw() as *const c_char)
.collect();

{
Expand Down Expand Up @@ -508,7 +508,7 @@ impl Session {
.into_iter()
.map(|p| {
assert_not_null_pointer(p, "i8 for CString")?;
unsafe { Ok(CString::from_raw(p as *mut i8)) }
unsafe { Ok(CString::from_raw(p as *mut c_char)) }
})
.collect();
cstrings?;
Expand Down Expand Up @@ -632,6 +632,8 @@ unsafe fn get_tensor_dimensions(
/// Those functions are only to be used from inside the
/// `SessionBuilder::with_model_from_file()` method.
mod dangerous {
use std::ffi::c_char;

use super::{
assert_not_null_pointer, assert_null_pointer, char_p_to_string, get_tensor_dimensions,
status_to_result, sys, Input, OrtApiError, OrtError, Output, Result, TensorElementDataType,
Expand Down Expand Up @@ -694,14 +696,14 @@ mod dangerous {
*const sys::OrtSession,
usize,
*mut sys::OrtAllocator,
*mut *mut i8,
*mut *mut c_char,
) -> *mut sys::OrtStatus },
session_ptr: *mut sys::OrtSession,
allocator_ptr: *mut sys::OrtAllocator,
i: usize,
env: _Environment,
) -> Result<String> {
let mut name_bytes: *mut i8 = std::ptr::null_mut();
let mut name_bytes: *mut c_char = std::ptr::null_mut();

let status = unsafe { f(session_ptr, i, allocator_ptr, &mut name_bytes) };
status_to_result(status).map_err(OrtError::InputName)?;
Expand Down

0 comments on commit 9a050a1

Please sign in to comment.