diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc index ba9251c71bce..cd2087c9d747 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc @@ -14,6 +14,7 @@ namespace tensorrt { namespace provider_option_names { constexpr const char* kDeviceId = "device_id"; constexpr const char* kHasUserComputeStream = "has_user_compute_stream"; +constexpr const char* kUserComputeStream = "user_compute_stream"; constexpr const char* kMaxPartitionIterations = "trt_max_partition_iterations"; constexpr const char* kMinSubgraphSize = "trt_min_subgraph_size"; constexpr const char* kMaxWorkspaceSize = "trt_max_workspace_size"; @@ -55,6 +56,7 @@ constexpr const char* kDumpEpContextModel = "trt_dump_ep_context_model"; TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) { TensorrtExecutionProviderInfo info{}; + void* user_compute_stream = nullptr; ORT_THROW_IF_ERROR( ProviderOptionsParser{} .AddValueParser( @@ -71,6 +73,14 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions }) .AddAssignmentToReference(tensorrt::provider_option_names::kMaxPartitionIterations, info.max_partition_iterations) .AddAssignmentToReference(tensorrt::provider_option_names::kHasUserComputeStream, info.has_user_compute_stream) + .AddValueParser( + tensorrt::provider_option_names::kUserComputeStream, + [&user_compute_stream](const std::string& value_str) -> Status { + size_t address; + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); + user_compute_stream = reinterpret_cast(address); + return Status::OK(); + }) .AddAssignmentToReference(tensorrt::provider_option_names::kMinSubgraphSize, info.min_subgraph_size) .AddAssignmentToReference(tensorrt::provider_option_names::kMaxWorkspaceSize, info.max_workspace_size) .AddAssignmentToReference(tensorrt::provider_option_names::kFp16Enable, info.fp16_enable) @@ -107,6 +117,8 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions .AddAssignmentToReference(tensorrt::provider_option_names::kEpContextEmbedMode, info.ep_context_embed_mode) .Parse(options)); // add new provider option here. + info.user_compute_stream = user_compute_stream; + info.has_user_compute_stream = (user_compute_stream != nullptr); return info; } @@ -115,6 +127,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE {tensorrt::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {tensorrt::provider_option_names::kMaxPartitionIterations, MakeStringWithClassicLocale(info.max_partition_iterations)}, {tensorrt::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)}, + {tensorrt::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast(info.user_compute_stream))}, {tensorrt::provider_option_names::kMinSubgraphSize, MakeStringWithClassicLocale(info.min_subgraph_size)}, {tensorrt::provider_option_names::kMaxWorkspaceSize, MakeStringWithClassicLocale(info.max_workspace_size)}, {tensorrt::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)}, @@ -171,6 +184,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor const ProviderOptions options{ {tensorrt::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {tensorrt::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)}, + {tensorrt::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast(info.user_compute_stream))}, {tensorrt::provider_option_names::kMaxPartitionIterations, MakeStringWithClassicLocale(info.trt_max_partition_iterations)}, {tensorrt::provider_option_names::kMinSubgraphSize, MakeStringWithClassicLocale(info.trt_min_subgraph_size)}, {tensorrt::provider_option_names::kMaxWorkspaceSize, MakeStringWithClassicLocale(info.trt_max_workspace_size)}, @@ -253,10 +267,14 @@ void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options trt_provider_options_v2.device_id = internal_options.device_id; // The 'has_user_compute_stream' of the OrtTensorRTProviderOptionsV2 instance can be set by C API UpdateTensorRTProviderOptionsWithValue() as well - // We only set the 'has_user_compute_stream' of the OrtTensorRTProviderOptionsV2 instance if it is provided in options + // We only set the 'has_user_compute_stream' of the OrtTensorRTProviderOptionsV2 instance if it is provided in options or user_compute_stream is provided if (options.find("has_user_compute_stream") != options.end()) { trt_provider_options_v2.has_user_compute_stream = internal_options.has_user_compute_stream; } + if (options.find("user_compute_stream") != options.end() && internal_options.user_compute_stream != nullptr) { + trt_provider_options_v2.user_compute_stream = internal_options.user_compute_stream; + trt_provider_options_v2.has_user_compute_stream = true; + } trt_provider_options_v2.trt_max_partition_iterations = internal_options.max_partition_iterations; trt_provider_options_v2.trt_min_subgraph_size = internal_options.min_subgraph_size; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 7b56f0c68427..0c1945109f31 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -486,6 +486,15 @@ std::unique_ptr CreateExecutionProviderInstance( } else { ORT_THROW("[ERROR] [TensorRT] The value for the key 'device_id' should be a number i.e. '0'.\n"); } + } else if (option.first == "user_compute_stream") { + if (!option.second.empty()) { + auto stream = std::stoull(option.second, nullptr, 0); + params.user_compute_stream = reinterpret_cast(stream); + params.has_user_compute_stream = true; + } else { + params.has_user_compute_stream = false; + ORT_THROW("[ERROR] [TensorRT] The value for the key 'user_compute_stream' should be a string to define the compute stream for the inference to run on.\n"); + } } else if (option.first == "trt_max_partition_iterations") { if (!option.second.empty()) { params.trt_max_partition_iterations = std::stoi(option.second); diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index ab56f3fa0f37..e4814aa7fc03 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -312,6 +312,7 @@ def test_set_providers_with_options(self): option["trt_engine_cache_path"] = engine_cache_path force_sequential_engine_build = "true" option["trt_force_sequential_engine_build"] = force_sequential_engine_build + option["user_compute_stream"] = "1" sess.set_providers(["TensorrtExecutionProvider"], [option]) options = sess.get_provider_options() @@ -326,6 +327,8 @@ def test_set_providers_with_options(self): self.assertEqual(option["trt_engine_cache_enable"], "1") self.assertEqual(option["trt_engine_cache_path"], str(engine_cache_path)) self.assertEqual(option["trt_force_sequential_engine_build"], "1") + self.assertEqual(option["user_compute_stream"], "1") + self.assertEqual(option["has_user_compute_stream"], "1") from onnxruntime.capi import _pybind_state as C @@ -354,6 +357,19 @@ def test_set_providers_with_options(self): sess.set_providers(['TensorrtExecutionProvider'], [option]) """ + try: + import torch + + if torch.cuda.is_available(): + s = torch.cuda.Stream() + option["user_compute_stream"] = str(s.cuda_stream) + sess.set_providers(["TensorrtExecutionProvider"], [option]) + options = sess.get_provider_options() + self.assertEqual(options["TensorrtExecutionProvider"]["user_compute_stream"], str(s.cuda_stream)) + self.assertEqual(options["TensorrtExecutionProvider"]["has_user_compute_stream"], "1") + except ImportError: + print("torch is not installed, skip testing setting user_compute_stream from torch cuda stream") + if "CUDAExecutionProvider" in onnxrt.get_available_providers(): cuda_success = 0