Skip to content

Commit

Permalink
ORT GPU implementation (#13755)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChSonnabend authored Dec 17, 2024
1 parent 922cad6 commit 051b0b3
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 21 deletions.
16 changes: 15 additions & 1 deletion Common/ML/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,21 @@
# granted to it by virtue of its status as an Intergovernmental Organization
# or submit itself to any jurisdiction.

# Pass ORT variables as a preprocessor definition
if(DEFINED ENV{ORT_ROCM_BUILD})
add_compile_definitions(ORT_ROCM_BUILD=$ENV{ORT_ROCM_BUILD})
endif()
if(DEFINED ENV{ORT_CUDA_BUILD})
add_compile_definitions(ORT_CUDA_BUILD=$ENV{ORT_CUDA_BUILD})
endif()
if(DEFINED ENV{ORT_MIGRAPHX_BUILD})
add_compile_definitions(ORT_MIGRAPHX_BUILD=$ENV{ORT_MIGRAPHX_BUILD})
endif()
if(DEFINED ENV{ORT_TENSORRT_BUILD})
add_compile_definitions(ORT_TENSORRT_BUILD=$ENV{ORT_TENSORRT_BUILD})
endif()

o2_add_library(ML
SOURCES src/ort_interface.cxx
SOURCES src/OrtInterface.cxx
TARGETVARNAME targetName
PRIVATE_LINK_LIBRARIES O2::Framework ONNXRuntime::ONNXRuntime)
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
// granted to it by virtue of its status as an Intergovernmental Organization
// or submit itself to any jurisdiction.

/// \file ort_interface.h
/// \file OrtInterface.h
/// \author Christian Sonnabend <christian.sonnabend@cern.ch>
/// \brief A header library for loading ONNX models and inferencing them on CPU and GPU

#ifndef O2_ML_ONNX_INTERFACE_H
#define O2_ML_ONNX_INTERFACE_H
#ifndef O2_ML_ORTINTERFACE_H
#define O2_ML_ORTINTERFACE_H

// C++ and system includes
#include <vector>
Expand Down Expand Up @@ -89,4 +89,4 @@ class OrtModel

} // namespace o2

#endif // O2_ML_ORT_INTERFACE_H
#endif // O2_ML_ORTINTERFACE_H
56 changes: 40 additions & 16 deletions Common/ML/src/ort_interface.cxx → Common/ML/src/OrtInterface.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
// granted to it by virtue of its status as an Intergovernmental Organization
// or submit itself to any jurisdiction.

/// \file ort_interface.cxx
/// \file OrtInterface.cxx
/// \author Christian Sonnabend <christian.sonnabend@cern.ch>
/// \brief A header library for loading ONNX models and inferencing them on CPU and GPU

#include "ML/ort_interface.h"
#include "ML/OrtInterface.h"
#include "ML/3rdparty/GPUORTFloat16.h"

// ONNX includes
Expand Down Expand Up @@ -50,29 +50,35 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
deviceId = (optionsMap.contains("device-id") ? std::stoi(optionsMap["device-id"]) : 0);
allocateDeviceMemory = (optionsMap.contains("allocate-device-memory") ? std::stoi(optionsMap["allocate-device-memory"]) : 0);
intraOpNumThreads = (optionsMap.contains("intra-op-num-threads") ? std::stoi(optionsMap["intra-op-num-threads"]) : 0);
loggingLevel = (optionsMap.contains("logging-level") ? std::stoi(optionsMap["logging-level"]) : 0);
loggingLevel = (optionsMap.contains("logging-level") ? std::stoi(optionsMap["logging-level"]) : 2);
enableProfiling = (optionsMap.contains("enable-profiling") ? std::stoi(optionsMap["enable-profiling"]) : 0);
enableOptimizations = (optionsMap.contains("enable-optimizations") ? std::stoi(optionsMap["enable-optimizations"]) : 0);

std::string dev_mem_str = "Hip";
#ifdef ORT_ROCM_BUILD
#if defined(ORT_ROCM_BUILD)
#if ORT_ROCM_BUILD == 1
if (device == "ROCM") {
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(pImplOrt->sessionOptions, deviceId));
LOG(info) << "(ORT) ROCM execution provider set";
}
#endif
#ifdef ORT_MIGRAPHX_BUILD
#endif
#if defined(ORT_MIGRAPHX_BUILD)
#if ORT_MIGRAPHX_BUILD == 1
if (device == "MIGRAPHX") {
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_MIGraphX(pImplOrt->sessionOptions, deviceId));
LOG(info) << "(ORT) MIGraphX execution provider set";
}
#endif
#ifdef ORT_CUDA_BUILD
#endif
#if defined(ORT_CUDA_BUILD)
#if ORT_CUDA_BUILD == 1
if (device == "CUDA") {
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(pImplOrt->sessionOptions, deviceId));
LOG(info) << "(ORT) CUDA execution provider set";
dev_mem_str = "Cuda";
}
#endif
#endif

if (allocateDeviceMemory) {
Expand Down Expand Up @@ -106,7 +112,27 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
(pImplOrt->sessionOptions).SetGraphOptimizationLevel(GraphOptimizationLevel(enableOptimizations));
(pImplOrt->sessionOptions).SetLogSeverityLevel(OrtLoggingLevel(loggingLevel));

pImplOrt->env = std::make_shared<Ort::Env>(OrtLoggingLevel(loggingLevel), (optionsMap["onnx-environment-name"].empty() ? "onnx_model_inference" : optionsMap["onnx-environment-name"].c_str()));
pImplOrt->env = std::make_shared<Ort::Env>(
OrtLoggingLevel(loggingLevel),
(optionsMap["onnx-environment-name"].empty() ? "onnx_model_inference" : optionsMap["onnx-environment-name"].c_str()),
// Integrate ORT logging into Fairlogger
[](void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, const char* message) {
if (severity == ORT_LOGGING_LEVEL_VERBOSE) {
LOG(debug) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
} else if (severity == ORT_LOGGING_LEVEL_INFO) {
LOG(info) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
} else if (severity == ORT_LOGGING_LEVEL_WARNING) {
LOG(warning) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
} else if (severity == ORT_LOGGING_LEVEL_ERROR) {
LOG(error) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
} else if (severity == ORT_LOGGING_LEVEL_FATAL) {
LOG(fatal) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
} else {
LOG(info) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
}
},
(void*)3);
(pImplOrt->env)->DisableTelemetryEvents(); // Disable telemetry events
pImplOrt->session = std::make_shared<Ort::Session>(*(pImplOrt->env), modelPath.c_str(), pImplOrt->sessionOptions);

for (size_t i = 0; i < (pImplOrt->session)->GetInputCount(); ++i) {
Expand All @@ -130,16 +156,14 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
[&](const std::string& str) { return str.c_str(); });

// Print names
if (loggingLevel > 1) {
LOG(info) << "Input Nodes:";
for (size_t i = 0; i < mInputNames.size(); i++) {
LOG(info) << "\t" << mInputNames[i] << " : " << printShape(mInputShapes[i]);
}
LOG(info) << "\tInput Nodes:";
for (size_t i = 0; i < mInputNames.size(); i++) {
LOG(info) << "\t\t" << mInputNames[i] << " : " << printShape(mInputShapes[i]);
}

LOG(info) << "Output Nodes:";
for (size_t i = 0; i < mOutputNames.size(); i++) {
LOG(info) << "\t" << mOutputNames[i] << " : " << printShape(mOutputShapes[i]);
}
LOG(info) << "\tOutput Nodes:";
for (size_t i = 0; i < mOutputNames.size(); i++) {
LOG(info) << "\t\t" << mOutputNames[i] << " : " << printShape(mOutputShapes[i]);
}
}

Expand Down

0 comments on commit 051b0b3

Please sign in to comment.