diff --git a/include/onnxruntime/core/framework/stream_handles.h b/include/onnxruntime/core/framework/stream_handles.h index c235ee904762..26d78133b52f 100644 --- a/include/onnxruntime/core/framework/stream_handles.h +++ b/include/onnxruntime/core/framework/stream_handles.h @@ -100,6 +100,8 @@ class Stream { return nullptr; } + virtual WaitNotificationFn GetWaitNotificationFn() const { return nullptr; } + private: StreamHandle handle_; const OrtDevice& device_; diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 2ce9d361e8e5..1e8754925cab 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -4569,6 +4569,16 @@ struct OrtApi { _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); + + /** \brief Get scratch buffer from the corresponding allocator under the sepcific OrtMemoryInfo object. + * NOTE: callers are responsible to release this scratch buffer from the corresponding allocator + * \param[in] context OrtKernelContext instance + * \param[in] mem_info OrtMemoryInfo instance + * \param[in] count_or_bytes How many bytes is this scratch buffer + * \param[out] out A pointer to the scrach buffer + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out); }; /* diff --git a/onnxruntime/core/providers/cann/cann_stream_handle.h b/onnxruntime/core/providers/cann/cann_stream_handle.h index 4d03fe520120..5d822d23f966 100644 --- a/onnxruntime/core/providers/cann/cann_stream_handle.h +++ b/onnxruntime/core/providers/cann/cann_stream_handle.h @@ -12,6 +12,7 @@ #include "core/providers/cann/cann_call.h" namespace onnxruntime { +void WaitCannNotificationOnDevice(Stream& stream, synchronize::Notification& notification); struct CannStream : Stream { CannStream(aclrtStream stream, const OrtDevice& device, bool own_flag); @@ -23,10 +24,11 @@ struct CannStream : Stream { void Flush() override; bool own_stream_{true}; + + WaitNotificationFn GetWaitNotificationFn() const override { return WaitCannNotificationOnDevice; } }; void RegisterCannStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry, const OrtDevice::DeviceType device_type); -void WaitCannNotificationOnDevice(Stream& stream, synchronize::Notification& notification); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.h b/onnxruntime/core/providers/cuda/cuda_stream_handle.h index b02c167e9e9e..15e7a0553c84 100644 --- a/onnxruntime/core/providers/cuda/cuda_stream_handle.h +++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.h @@ -11,6 +11,7 @@ namespace onnxruntime { struct CudaStream; +void WaitCudaNotificationOnDevice(Stream& stream, synchronize::Notification& notification); struct DeferredCpuAllocator : public OrtAllocator { DeferredCpuAllocator(CudaStream&); @@ -47,6 +48,8 @@ struct CudaStream : Stream { onnxruntime::IAllocator* GetCpuAllocator() const { return cpu_allocator_.get(); } + WaitNotificationFn GetWaitNotificationFn() const override { return WaitCudaNotificationOnDevice; } + private: std::vector deferred_cpu_buffers_; AllocatorPtr cpu_allocator_; @@ -64,5 +67,4 @@ void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis cudnnHandle_t external_cudnn_handle, cublasHandle_t external_cublass_handle, const CUDAExecutionProviderInfo& ep_info); -void WaitCudaNotificationOnDevice(Stream& stream, synchronize::Notification& notification); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_stream_handle.h b/onnxruntime/core/providers/rocm/rocm_stream_handle.h index 1f3e5b75548e..30983ce03568 100644 --- a/onnxruntime/core/providers/rocm/rocm_stream_handle.h +++ b/onnxruntime/core/providers/rocm/rocm_stream_handle.h @@ -8,6 +8,7 @@ #include "core/framework/stream_handles.h" namespace onnxruntime { +void WaitRocmNotificationOnDevice(Stream& stream, synchronize::Notification& notification); struct RocmStream : Stream { RocmStream(hipStream_t stream, @@ -36,6 +37,8 @@ struct RocmStream : Stream { void* GetResource(int version, int id) const override; + WaitNotificationFn GetWaitNotificationFn() const override { return WaitRocmNotificationOnDevice; } + private: std::vector deferred_cpu_buffers_; AllocatorPtr cpu_allocator_; @@ -50,5 +53,4 @@ void RegisterRocmStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis bool use_existing_stream, miopenHandle_t external_miopen_handle, rocblas_handle external_rocblas_handle); -void WaitRocmNotificationOnDevice(Stream& stream, synchronize::Notification& notification); } // namespace onnxruntime diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 4bae42f4b80a..de7da73ccee3 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -416,6 +416,20 @@ ORT_API_STATUS_IMPL(OrtApis::KernelContext_ParallelFor, _In_ const OrtKernelCont #endif }; +ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out) { + if (count_or_bytes == 0) { + *out = nullptr; + return nullptr; + } + onnxruntime::AllocatorPtr allocator = reinterpret_cast(context)->GetAllocator(mem_info->device); + if (!allocator) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "No requested allocator available"); + } + onnxruntime::Stream* stream = reinterpret_cast(context)->GetComputeStream(); + *out = AllocateBufferWithOptions(*allocator, count_or_bytes, false, stream, stream->GetWaitNotificationFn()); + return nullptr; +}; + #ifdef _WIN32 #pragma warning(pop) #endif diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 91a7f0d930b5..9275b0c0ad74 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2724,6 +2724,8 @@ static constexpr OrtApi ort_api_1_to_18 = { &OrtApis::SetDeterministicCompute, &OrtApis::KernelContext_ParallelFor, &OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO_V2, + + &OrtApis::KernelContext_GetScratchBuffer, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index c1caafa4dcad..e6cbe02db9d4 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -509,4 +509,5 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_OpenVINO_V2, _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); +ORT_API_STATUS_IMPL(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out); } // namespace OrtApis