Skip to content

Commit

Permalink
new C API KernelContext_GetScratchBuffer for ORT-extensions proj
Browse files Browse the repository at this point in the history
  • Loading branch information
jslhcl committed Jan 30, 2024
1 parent 04afe77 commit 818deae
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 3 deletions.
2 changes: 2 additions & 0 deletions include/onnxruntime/core/framework/stream_handles.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ class Stream {
return nullptr;
}

virtual WaitNotificationFn GetWaitNotificationFn() const { return nullptr; }

private:
StreamHandle handle_;
const OrtDevice& device_;
Expand Down
10 changes: 10 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -4569,6 +4569,16 @@ struct OrtApi {
_In_reads_(num_keys) const char* const* provider_options_keys,
_In_reads_(num_keys) const char* const* provider_options_values,
_In_ size_t num_keys);

/** \brief Get scratch buffer from the corresponding allocator under the sepcific OrtMemoryInfo object.
* NOTE: callers are responsible to release this scratch buffer from the corresponding allocator
* \param[in] context OrtKernelContext instance
* \param[in] mem_info OrtMemoryInfo instance
* \param[in] count_or_bytes How many bytes is this scratch buffer
* \param[out] out A pointer to the scrach buffer
* \snippet{doc} snippets.dox OrtStatus Return Value
*/
ORT_API2_STATUS(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out);
};

/*
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/cann/cann_stream_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "core/providers/cann/cann_call.h"

namespace onnxruntime {
void WaitCannNotificationOnDevice(Stream& stream, synchronize::Notification& notification);

struct CannStream : Stream {
CannStream(aclrtStream stream, const OrtDevice& device, bool own_flag);
Expand All @@ -23,10 +24,11 @@ struct CannStream : Stream {
void Flush() override;

bool own_stream_{true};

WaitNotificationFn GetWaitNotificationFn() const override { return WaitCannNotificationOnDevice; }
};

void RegisterCannStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry,
const OrtDevice::DeviceType device_type);

void WaitCannNotificationOnDevice(Stream& stream, synchronize::Notification& notification);
} // namespace onnxruntime
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/cuda/cuda_stream_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
namespace onnxruntime {

struct CudaStream;
void WaitCudaNotificationOnDevice(Stream& stream, synchronize::Notification& notification);

struct DeferredCpuAllocator : public OrtAllocator {
DeferredCpuAllocator(CudaStream&);
Expand Down Expand Up @@ -47,6 +48,8 @@ struct CudaStream : Stream {

onnxruntime::IAllocator* GetCpuAllocator() const { return cpu_allocator_.get(); }

WaitNotificationFn GetWaitNotificationFn() const override { return WaitCudaNotificationOnDevice; }

private:
std::vector<void*> deferred_cpu_buffers_;
AllocatorPtr cpu_allocator_;
Expand All @@ -64,5 +67,4 @@ void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis
cudnnHandle_t external_cudnn_handle,
cublasHandle_t external_cublass_handle,
const CUDAExecutionProviderInfo& ep_info);
void WaitCudaNotificationOnDevice(Stream& stream, synchronize::Notification& notification);
} // namespace onnxruntime
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/rocm/rocm_stream_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "core/framework/stream_handles.h"

namespace onnxruntime {
void WaitRocmNotificationOnDevice(Stream& stream, synchronize::Notification& notification);

struct RocmStream : Stream {
RocmStream(hipStream_t stream,
Expand Down Expand Up @@ -36,6 +37,8 @@ struct RocmStream : Stream {

void* GetResource(int version, int id) const override;

WaitNotificationFn GetWaitNotificationFn() const override { return WaitRocmNotificationOnDevice; }

private:
std::vector<void*> deferred_cpu_buffers_;
AllocatorPtr cpu_allocator_;
Expand All @@ -50,5 +53,4 @@ void RegisterRocmStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis
bool use_existing_stream,
miopenHandle_t external_miopen_handle,
rocblas_handle external_rocblas_handle);
void WaitRocmNotificationOnDevice(Stream& stream, synchronize::Notification& notification);
} // namespace onnxruntime
14 changes: 14 additions & 0 deletions onnxruntime/core/session/custom_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,20 @@ ORT_API_STATUS_IMPL(OrtApis::KernelContext_ParallelFor, _In_ const OrtKernelCont
#endif
};

ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out) {
if (count_or_bytes == 0) {
*out = nullptr;
return nullptr;
}
onnxruntime::AllocatorPtr allocator = reinterpret_cast<const onnxruntime::OpKernelContext*>(context)->GetAllocator(mem_info->device);
if (!allocator) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "No requested allocator available");
}
onnxruntime::Stream* stream = reinterpret_cast<const onnxruntime::OpKernelContext*>(context)->GetComputeStream();
*out = AllocateBufferWithOptions(*allocator, count_or_bytes, false, stream, stream->GetWaitNotificationFn());
return nullptr;
};

#ifdef _WIN32
#pragma warning(pop)
#endif
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2724,6 +2724,8 @@ static constexpr OrtApi ort_api_1_to_18 = {
&OrtApis::SetDeterministicCompute,
&OrtApis::KernelContext_ParallelFor,
&OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO_V2,

&OrtApis::KernelContext_GetScratchBuffer,
};

// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -509,4 +509,5 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_OpenVINO_V2,
_In_reads_(num_keys) const char* const* provider_options_keys,
_In_reads_(num_keys) const char* const* provider_options_values,
_In_ size_t num_keys);
ORT_API_STATUS_IMPL(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out);
} // namespace OrtApis

0 comments on commit 818deae

Please sign in to comment.