Skip to content

Commit

Permalink
rust: Add function to expose CUDA execution provider.
Browse files Browse the repository at this point in the history
  • Loading branch information
hgaiser committed Nov 23, 2023
1 parent c7e89ac commit 16c7d90
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 1 deletion.
6 changes: 6 additions & 0 deletions rust/onnxruntime/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,12 @@ pub enum OrtError {
/// Error occurred when checking if ONNXRuntime tensor was properly initialized
#[error("Failed to check if tensor")]
IsTensorCheck,
/// Error occurred when creating CUDA provider options
#[error("Failed to create CUDA provider options: {0}")]
CudaProviderOptions(OrtApiError),
/// Error occurred when appending CUDA execution provider
#[error("Failed to append CUDA execution provider: {0}")]
AppendExecutionProviderCuda(OrtApiError),
}

/// Error used when dimensions of input (from model and from inference call)
Expand Down
33 changes: 32 additions & 1 deletion rust/onnxruntime/src/session.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Module containing session types
use std::{convert::TryFrom, ffi::{CString, c_char}, fmt::Debug, path::Path};
use std::{convert::TryFrom, ffi::{CString, c_char}, fmt::Debug, path::Path, ptr::null_mut};

#[cfg(not(target_family = "windows"))]
use std::os::unix::ffi::OsStrExt;
Expand Down Expand Up @@ -161,6 +161,37 @@ impl<'a> SessionBuilder<'a> {
Ok(self)
}

/// Append a CUDA execution provider
pub fn with_execution_provider_cuda(self) -> Result<SessionBuilder<'a>> {
let mut cuda_options: *mut sys::OrtCUDAProviderOptionsV2 = null_mut();
let status = unsafe {
self.env
.env()
.api()
.CreateCUDAProviderOptions
.unwrap()(&mut cuda_options)
};
status_to_result(status).map_err(OrtError::CudaProviderOptions)?;

let status = unsafe {
self.env
.env()
.api()
.SessionOptionsAppendExecutionProvider_CUDA_V2
.unwrap()(self.session_options_ptr, cuda_options)
};
status_to_result(status).map_err(OrtError::AppendExecutionProviderCuda)?;

unsafe {
self.env
.env()
.api()
.ReleaseCUDAProviderOptions
.unwrap()(cuda_options);
};
Ok(self)
}

/// Download an ONNX pre-trained model from the [ONNX Model Zoo](https://github.com/onnx/models) and commit the session
#[cfg(feature = "model-fetching")]
pub fn with_model_downloaded<M>(self, model: M) -> Result<Session>
Expand Down

0 comments on commit 16c7d90

Please sign in to comment.