diff --git a/rust/onnxruntime/src/error.rs b/rust/onnxruntime/src/error.rs index 11f7e0a436d28..8b54f1c998448 100644 --- a/rust/onnxruntime/src/error.rs +++ b/rust/onnxruntime/src/error.rs @@ -118,6 +118,12 @@ pub enum OrtError { /// Error occurred when checking if ONNXRuntime tensor was properly initialized #[error("Failed to check if tensor")] IsTensorCheck, + /// Error occurred when creating CUDA provider options + #[error("Failed to create CUDA provider options: {0}")] + CudaProviderOptions(OrtApiError), + /// Error occurred when appending CUDA execution provider + #[error("Failed to append CUDA execution provider: {0}")] + AppendExecutionProviderCuda(OrtApiError), } /// Error used when dimensions of input (from model and from inference call) diff --git a/rust/onnxruntime/src/session.rs b/rust/onnxruntime/src/session.rs index d8ccd5431a9f8..2071a6fb1f503 100644 --- a/rust/onnxruntime/src/session.rs +++ b/rust/onnxruntime/src/session.rs @@ -1,6 +1,6 @@ //! Module containing session types -use std::{convert::TryFrom, ffi::{CString, c_char}, fmt::Debug, path::Path}; +use std::{convert::TryFrom, ffi::{CString, c_char}, fmt::Debug, path::Path, ptr::null_mut}; #[cfg(not(target_family = "windows"))] use std::os::unix::ffi::OsStrExt; @@ -161,6 +161,37 @@ impl<'a> SessionBuilder<'a> { Ok(self) } + /// Append a CUDA execution provider + pub fn with_execution_provider_cuda(self) -> Result> { + let mut cuda_options: *mut sys::OrtCUDAProviderOptionsV2 = null_mut(); + let status = unsafe { + self.env + .env() + .api() + .CreateCUDAProviderOptions + .unwrap()(&mut cuda_options) + }; + status_to_result(status).map_err(OrtError::CudaProviderOptions)?; + + let status = unsafe { + self.env + .env() + .api() + .SessionOptionsAppendExecutionProvider_CUDA_V2 + .unwrap()(self.session_options_ptr, cuda_options) + }; + status_to_result(status).map_err(OrtError::AppendExecutionProviderCuda)?; + + unsafe { + self.env + .env() + .api() + .ReleaseCUDAProviderOptions + .unwrap()(cuda_options); + }; + Ok(self) + } + /// Download an ONNX pre-trained model from the [ONNX Model Zoo](https://github.com/onnx/models) and commit the session #[cfg(feature = "model-fetching")] pub fn with_model_downloaded(self, model: M) -> Result