From 051b0b38fac7c659124678f682eec3ef01c307fc Mon Sep 17 00:00:00 2001 From: Christian Sonnabend Date: Tue, 17 Dec 2024 10:25:31 +0100 Subject: [PATCH] ORT GPU implementation (#13755) --- Common/ML/CMakeLists.txt | 16 +++++- .../ML/{ort_interface.h => OrtInterface.h} | 8 +-- .../{ort_interface.cxx => OrtInterface.cxx} | 56 +++++++++++++------ 3 files changed, 59 insertions(+), 21 deletions(-) rename Common/ML/include/ML/{ort_interface.h => OrtInterface.h} (96%) rename Common/ML/src/{ort_interface.cxx => OrtInterface.cxx} (87%) diff --git a/Common/ML/CMakeLists.txt b/Common/ML/CMakeLists.txt index 74287e774efa1..74be306c8b6a5 100644 --- a/Common/ML/CMakeLists.txt +++ b/Common/ML/CMakeLists.txt @@ -9,7 +9,21 @@ # granted to it by virtue of its status as an Intergovernmental Organization # or submit itself to any jurisdiction. +# Pass ORT variables as a preprocessor definition +if(DEFINED ENV{ORT_ROCM_BUILD}) + add_compile_definitions(ORT_ROCM_BUILD=$ENV{ORT_ROCM_BUILD}) +endif() +if(DEFINED ENV{ORT_CUDA_BUILD}) + add_compile_definitions(ORT_CUDA_BUILD=$ENV{ORT_CUDA_BUILD}) +endif() +if(DEFINED ENV{ORT_MIGRAPHX_BUILD}) + add_compile_definitions(ORT_MIGRAPHX_BUILD=$ENV{ORT_MIGRAPHX_BUILD}) +endif() +if(DEFINED ENV{ORT_TENSORRT_BUILD}) + add_compile_definitions(ORT_TENSORRT_BUILD=$ENV{ORT_TENSORRT_BUILD}) +endif() + o2_add_library(ML - SOURCES src/ort_interface.cxx + SOURCES src/OrtInterface.cxx TARGETVARNAME targetName PRIVATE_LINK_LIBRARIES O2::Framework ONNXRuntime::ONNXRuntime) diff --git a/Common/ML/include/ML/ort_interface.h b/Common/ML/include/ML/OrtInterface.h similarity index 96% rename from Common/ML/include/ML/ort_interface.h rename to Common/ML/include/ML/OrtInterface.h index e2049b8508cb4..89631d59a3846 100644 --- a/Common/ML/include/ML/ort_interface.h +++ b/Common/ML/include/ML/OrtInterface.h @@ -9,12 +9,12 @@ // granted to it by virtue of its status as an Intergovernmental Organization // or submit itself to any jurisdiction. -/// \file ort_interface.h +/// \file OrtInterface.h /// \author Christian Sonnabend /// \brief A header library for loading ONNX models and inferencing them on CPU and GPU -#ifndef O2_ML_ONNX_INTERFACE_H -#define O2_ML_ONNX_INTERFACE_H +#ifndef O2_ML_ORTINTERFACE_H +#define O2_ML_ORTINTERFACE_H // C++ and system includes #include @@ -89,4 +89,4 @@ class OrtModel } // namespace o2 -#endif // O2_ML_ORT_INTERFACE_H +#endif // O2_ML_ORTINTERFACE_H diff --git a/Common/ML/src/ort_interface.cxx b/Common/ML/src/OrtInterface.cxx similarity index 87% rename from Common/ML/src/ort_interface.cxx rename to Common/ML/src/OrtInterface.cxx index 27ac8eee16b7b..eb124ff6f12c9 100644 --- a/Common/ML/src/ort_interface.cxx +++ b/Common/ML/src/OrtInterface.cxx @@ -9,11 +9,11 @@ // granted to it by virtue of its status as an Intergovernmental Organization // or submit itself to any jurisdiction. -/// \file ort_interface.cxx +/// \file OrtInterface.cxx /// \author Christian Sonnabend /// \brief A header library for loading ONNX models and inferencing them on CPU and GPU -#include "ML/ort_interface.h" +#include "ML/OrtInterface.h" #include "ML/3rdparty/GPUORTFloat16.h" // ONNX includes @@ -50,29 +50,35 @@ void OrtModel::reset(std::unordered_map optionsMap) deviceId = (optionsMap.contains("device-id") ? std::stoi(optionsMap["device-id"]) : 0); allocateDeviceMemory = (optionsMap.contains("allocate-device-memory") ? std::stoi(optionsMap["allocate-device-memory"]) : 0); intraOpNumThreads = (optionsMap.contains("intra-op-num-threads") ? std::stoi(optionsMap["intra-op-num-threads"]) : 0); - loggingLevel = (optionsMap.contains("logging-level") ? std::stoi(optionsMap["logging-level"]) : 0); + loggingLevel = (optionsMap.contains("logging-level") ? std::stoi(optionsMap["logging-level"]) : 2); enableProfiling = (optionsMap.contains("enable-profiling") ? std::stoi(optionsMap["enable-profiling"]) : 0); enableOptimizations = (optionsMap.contains("enable-optimizations") ? std::stoi(optionsMap["enable-optimizations"]) : 0); std::string dev_mem_str = "Hip"; -#ifdef ORT_ROCM_BUILD +#if defined(ORT_ROCM_BUILD) +#if ORT_ROCM_BUILD == 1 if (device == "ROCM") { Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(pImplOrt->sessionOptions, deviceId)); LOG(info) << "(ORT) ROCM execution provider set"; } #endif -#ifdef ORT_MIGRAPHX_BUILD +#endif +#if defined(ORT_MIGRAPHX_BUILD) +#if ORT_MIGRAPHX_BUILD == 1 if (device == "MIGRAPHX") { Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_MIGraphX(pImplOrt->sessionOptions, deviceId)); LOG(info) << "(ORT) MIGraphX execution provider set"; } #endif -#ifdef ORT_CUDA_BUILD +#endif +#if defined(ORT_CUDA_BUILD) +#if ORT_CUDA_BUILD == 1 if (device == "CUDA") { Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(pImplOrt->sessionOptions, deviceId)); LOG(info) << "(ORT) CUDA execution provider set"; dev_mem_str = "Cuda"; } +#endif #endif if (allocateDeviceMemory) { @@ -106,7 +112,27 @@ void OrtModel::reset(std::unordered_map optionsMap) (pImplOrt->sessionOptions).SetGraphOptimizationLevel(GraphOptimizationLevel(enableOptimizations)); (pImplOrt->sessionOptions).SetLogSeverityLevel(OrtLoggingLevel(loggingLevel)); - pImplOrt->env = std::make_shared(OrtLoggingLevel(loggingLevel), (optionsMap["onnx-environment-name"].empty() ? "onnx_model_inference" : optionsMap["onnx-environment-name"].c_str())); + pImplOrt->env = std::make_shared( + OrtLoggingLevel(loggingLevel), + (optionsMap["onnx-environment-name"].empty() ? "onnx_model_inference" : optionsMap["onnx-environment-name"].c_str()), + // Integrate ORT logging into Fairlogger + [](void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, const char* message) { + if (severity == ORT_LOGGING_LEVEL_VERBOSE) { + LOG(debug) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message; + } else if (severity == ORT_LOGGING_LEVEL_INFO) { + LOG(info) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message; + } else if (severity == ORT_LOGGING_LEVEL_WARNING) { + LOG(warning) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message; + } else if (severity == ORT_LOGGING_LEVEL_ERROR) { + LOG(error) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message; + } else if (severity == ORT_LOGGING_LEVEL_FATAL) { + LOG(fatal) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message; + } else { + LOG(info) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message; + } + }, + (void*)3); + (pImplOrt->env)->DisableTelemetryEvents(); // Disable telemetry events pImplOrt->session = std::make_shared(*(pImplOrt->env), modelPath.c_str(), pImplOrt->sessionOptions); for (size_t i = 0; i < (pImplOrt->session)->GetInputCount(); ++i) { @@ -130,16 +156,14 @@ void OrtModel::reset(std::unordered_map optionsMap) [&](const std::string& str) { return str.c_str(); }); // Print names - if (loggingLevel > 1) { - LOG(info) << "Input Nodes:"; - for (size_t i = 0; i < mInputNames.size(); i++) { - LOG(info) << "\t" << mInputNames[i] << " : " << printShape(mInputShapes[i]); - } + LOG(info) << "\tInput Nodes:"; + for (size_t i = 0; i < mInputNames.size(); i++) { + LOG(info) << "\t\t" << mInputNames[i] << " : " << printShape(mInputShapes[i]); + } - LOG(info) << "Output Nodes:"; - for (size_t i = 0; i < mOutputNames.size(); i++) { - LOG(info) << "\t" << mOutputNames[i] << " : " << printShape(mOutputShapes[i]); - } + LOG(info) << "\tOutput Nodes:"; + for (size_t i = 0; i < mOutputNames.size(); i++) { + LOG(info) << "\t\t" << mOutputNames[i] << " : " << printShape(mOutputShapes[i]); } }