Skip to content

Commit

Permalink
feat: threadpool access in operator kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Sep 9, 2024
1 parent 41ef65a commit 6b3e7a0
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
16 changes: 15 additions & 1 deletion src/operator/kernel.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::{
ffi::{c_char, CString},
ffi::{c_char, c_void, CString},
ops::{Deref, DerefMut},
ptr::{self, NonNull}
};
Expand Down Expand Up @@ -249,6 +249,15 @@ impl KernelContext {
Ok(NonNull::new(resource_ptr))
}

pub fn par_for<F>(&self, total: usize, max_num_batches: usize, f: F) -> Result<()>
where
F: Fn(usize) + Sync + Send
{
let executor = Box::new(f) as Box<dyn Fn(usize) + Sync + Send>;
ortsys![unsafe KernelContext_ParallelFor(self.ptr.as_ptr(), Some(parallel_for_cb), total as _, max_num_batches as _, &executor as *const _ as *mut c_void)?];
Ok(())
}

// TODO: STATUS_ACCESS_VIOLATION inside `KernelContext_GetScratchBuffer`. gonna assume this one is just an internal ONNX
// Runtime bug.
//
Expand Down Expand Up @@ -280,3 +289,8 @@ impl KernelContext {
Ok(NonNull::new(stream_ptr))
}
}

extern "C" fn parallel_for_cb(user_data: *mut c_void, iterator: ort_sys::size_t) {
let executor = unsafe { &*user_data.cast::<Box<dyn Fn(usize) + Sync + Send>>() };
executor(iterator as _)
}
1 change: 1 addition & 0 deletions tools/api-coverage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ const IGNORED_SYMBOLS = new Set<string>([
'RegisterCustomOpsUsingFunction',
'SessionOptionsAppendExecutionProvider_CUDA', // we use V2
'SessionOptionsAppendExecutionProvider_TensorRT', // we use V2
'GetValueType', // we get value types via GetTypeInfo -> GetOnnxTypeFromTypeInfo, which is equivalent
'SetLanguageProjection', // someday we shall have `ORT_PROJECTION_RUST`, but alas, today is not that day...

// we use allocator APIs directly on the Allocator struct
Expand Down

0 comments on commit 6b3e7a0

Please sign in to comment.