From 6b3e7a0730ca5bbf942a448aa9af65c6866fcc76 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Mon, 9 Sep 2024 16:31:21 -0500 Subject: [PATCH] feat: threadpool access in operator kernels --- src/operator/kernel.rs | 16 +++++++++++++++- tools/api-coverage.ts | 1 + 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/operator/kernel.rs b/src/operator/kernel.rs index 393b638..5cf7065 100644 --- a/src/operator/kernel.rs +++ b/src/operator/kernel.rs @@ -1,5 +1,5 @@ use std::{ - ffi::{c_char, CString}, + ffi::{c_char, c_void, CString}, ops::{Deref, DerefMut}, ptr::{self, NonNull} }; @@ -249,6 +249,15 @@ impl KernelContext { Ok(NonNull::new(resource_ptr)) } + pub fn par_for(&self, total: usize, max_num_batches: usize, f: F) -> Result<()> + where + F: Fn(usize) + Sync + Send + { + let executor = Box::new(f) as Box; + ortsys![unsafe KernelContext_ParallelFor(self.ptr.as_ptr(), Some(parallel_for_cb), total as _, max_num_batches as _, &executor as *const _ as *mut c_void)?]; + Ok(()) + } + // TODO: STATUS_ACCESS_VIOLATION inside `KernelContext_GetScratchBuffer`. gonna assume this one is just an internal ONNX // Runtime bug. // @@ -280,3 +289,8 @@ impl KernelContext { Ok(NonNull::new(stream_ptr)) } } + +extern "C" fn parallel_for_cb(user_data: *mut c_void, iterator: ort_sys::size_t) { + let executor = unsafe { &*user_data.cast::>() }; + executor(iterator as _) +} diff --git a/tools/api-coverage.ts b/tools/api-coverage.ts index 2f2788f..32e45b6 100644 --- a/tools/api-coverage.ts +++ b/tools/api-coverage.ts @@ -14,6 +14,7 @@ const IGNORED_SYMBOLS = new Set([ 'RegisterCustomOpsUsingFunction', 'SessionOptionsAppendExecutionProvider_CUDA', // we use V2 'SessionOptionsAppendExecutionProvider_TensorRT', // we use V2 + 'GetValueType', // we get value types via GetTypeInfo -> GetOnnxTypeFromTypeInfo, which is equivalent 'SetLanguageProjection', // someday we shall have `ORT_PROJECTION_RUST`, but alas, today is not that day... // we use allocator APIs directly on the Allocator struct