Skip to content

Commit

Permalink
Refactor plugin setup.
Browse files Browse the repository at this point in the history
  • Loading branch information
riga committed Mar 27, 2024
1 parent a1bd979 commit 8781e78
Show file tree
Hide file tree
Showing 15 changed files with 232 additions and 275 deletions.
6 changes: 0 additions & 6 deletions cmssw/MLProf/ONNXRuntimeModule/plugins/BuildFile.xml

This file was deleted.

64 changes: 0 additions & 64 deletions cmssw/MLProf/ONNXRuntimeModule/test/onnx_runtime_template_cfg.py

This file was deleted.

25 changes: 20 additions & 5 deletions cmssw/MLProf/RuntimeMeasurement/plugins/BuildFile.xml
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
<use name="FWCore/Framework" />
<use name="FWCore/PluginManager" />
<use name="FWCore/ParameterSet" />
<use name="PhysicsTools/TensorFlow" />
<library name="MLProfRuntimeMeasurementTFInference" file="TFInference.cc">
<use name="FWCore/Framework"/>
<use name="FWCore/PluginManager"/>
<use name="FWCore/ParameterSet"/>

<flags EDM_PLUGIN="1" />
<use name="PhysicsTools/TensorFlow"/>
<use name="MLProf/Utils"/>

<flags EDM_PLUGIN="1"/>
</library>

<library name="MLProfRuntimeMeasurementONNXInference" file="ONNXInference.cc">
<use name="FWCore/Framework"/>
<use name="FWCore/PluginManager"/>
<use name="FWCore/ParameterSet"/>

<use name="PhysicsTools/ONNXRuntime"/>
<use name="MLProf/Utils"/>

<flags EDM_PLUGIN="1"/>
</library>
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
/*
* Example plugin to demonstrate the direct multi-threaded inference with ONNX
* Runtime.
* Plugin to measure the inference runtime of an onnx model.
*/

#include <chrono>
Expand All @@ -16,24 +15,24 @@
#include "FWCore/Framework/interface/MakerMacros.h"
#include "FWCore/Framework/interface/stream/EDAnalyzer.h"
#include "FWCore/ParameterSet/interface/ParameterSet.h"
#include "MLProf/Utils/interface/utils.h"
#include "PhysicsTools/ONNXRuntime/interface/ONNXRuntime.h"

#include "MLProf/Utils/interface/utils.h"

using namespace cms::Ort;

class ONNXRuntimePlugin
: public edm::stream::EDAnalyzer<edm::GlobalCache<ONNXRuntime>> {
public:
explicit ONNXRuntimePlugin(const edm::ParameterSet &, const ONNXRuntime *);
static void fillDescriptions(edm::ConfigurationDescriptions &);
class ONNXInference : public edm::stream::EDAnalyzer<edm::GlobalCache<ONNXRuntime>> {
public:
explicit ONNXInference(const edm::ParameterSet&, const ONNXRuntime*);
~ONNXInference(){};

static std::unique_ptr<ONNXRuntime> initializeGlobalCache(
const edm::ParameterSet &);
static void globalEndJob(const ONNXRuntime *);
static void fillDescriptions(edm::ConfigurationDescriptions&);
static std::unique_ptr<ONNXRuntime> initializeGlobalCache(const edm::ParameterSet&);
static void globalEndJob(const ONNXRuntime*);

private:
private:
void beginJob();
void analyze(const edm::Event &, const edm::EventSetup &);
void analyze(const edm::Event&, const edm::EventSetup&);
void endJob();

inline float drawNormal() { return normalPdf_(rndGen_); }
Expand All @@ -60,8 +59,7 @@ class ONNXRuntimePlugin
FloatArrays inputArrays_; // each stream hosts its own data
};

void ONNXRuntimePlugin::fillDescriptions(
edm::ConfigurationDescriptions &descriptions) {
void ONNXInference::fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
// defining this function will lead to a *_cfi file being generated when
// compiling
edm::ParameterSetDescription desc;
Expand Down Expand Up @@ -89,12 +87,9 @@ void ONNXRuntimePlugin::fillDescriptions(
descriptions.addWithDefaultLabel(desc);
}

ONNXRuntimePlugin::ONNXRuntimePlugin(const edm::ParameterSet &iConfig,
const ONNXRuntime *cache)
: inputTensorNames_(
iConfig.getParameter<std::vector<std::string>>("inputTensorNames")),
outputTensorNames_(
iConfig.getParameter<std::vector<std::string>>("outputTensorNames")),
ONNXInference::ONNXInference(const edm::ParameterSet& iConfig, const ONNXRuntime* cache)
: inputTensorNames_(iConfig.getParameter<std::vector<std::string>>("inputTensorNames")),
outputTensorNames_(iConfig.getParameter<std::vector<std::string>>("outputTensorNames")),
outputFile_(iConfig.getParameter<std::string>("outputFile")),
inputTypeStr_(iConfig.getParameter<std::string>("inputType")),
inputRanks_(iConfig.getParameter<std::vector<int>>("inputRanks")),
Expand All @@ -107,34 +102,28 @@ ONNXRuntimePlugin::ONNXRuntimePlugin(const edm::ParameterSet &iConfig,
normalPdf_(0.0, 1.0) {
// the number of input ranks must match the number of input tensors
if ((int)inputRanks_.size() != nInputs_) {
throw cms::Exception("InvalidInputRanks")
<< "number of input ranks must match number of input tensors";
throw cms::Exception("InvalidInputRanks") << "number of input ranks must match number of input tensors";
}
// the input must be at least 1 dimensional
for (auto rank : inputRanks_) {
if (rank < 1) {
throw cms::Exception("InvalidRank")
<< "only ranks above 0 are supported, got " << rank;
throw cms::Exception("InvalidRank") << "only ranks above 0 are supported, got " << rank;
}
}
// the sum of ranks must match the number of flat input sizes
if (std::accumulate(inputRanks_.begin(), inputRanks_.end(), 0) !=
(int)flatInputSizes_.size()) {
if (std::accumulate(inputRanks_.begin(), inputRanks_.end(), 0) != (int)flatInputSizes_.size()) {
throw cms::Exception("InvalidFlatInputSizes")
<< "sum of input ranks must match number of flat input sizes, got "
<< flatInputSizes_.size();
<< "sum of input ranks must match number of flat input sizes, got " << flatInputSizes_.size();
}
// batch size must be positive
if (batchSize_ < 1) {
throw cms::Exception("InvalidBatchSize")
<< "batch sizes must be positive, got " << batchSize_;
throw cms::Exception("InvalidBatchSize") << "batch size must be positive, got " << batchSize_;
}

// input sizes must be positive
for (auto size : flatInputSizes_) {
if (size < 1) {
throw cms::Exception("InvalidInputSize")
<< "input sizes must be positive, got " << size;
throw cms::Exception("InvalidInputSize") << "input sizes must be positive, got " << size;
}
}
// check the input type
Expand All @@ -146,15 +135,13 @@ ONNXRuntimePlugin::ONNXRuntimePlugin(const edm::ParameterSet &iConfig,
inputType_ = mlprof::InputType::Zeros;
} else {
throw cms::Exception("InvalidInputType")
<< "input type must be either 'incremental', 'zeros' or 'random', got "
<< inputTypeStr_;
<< "input type must be either 'incremental', 'zeros' or 'random', got " << inputTypeStr_;
}

// initialize the input_shapes array with inputRanks_ and flatInputSizes_
int i = 0;
for (auto rank : inputRanks_) {
std::vector<int64_t> input_shape(flatInputSizes_.begin() + i,
flatInputSizes_.begin() + i + rank);
std::vector<int64_t> input_shape(flatInputSizes_.begin() + i, flatInputSizes_.begin() + i + rank);
input_shape.insert(input_shape.begin(), batchSize_);
input_shapes_.push_back(input_shape);
i += rank;
Expand All @@ -167,26 +154,20 @@ ONNXRuntimePlugin::ONNXRuntimePlugin(const edm::ParameterSet &iConfig,
}
}

std::unique_ptr<ONNXRuntime> ONNXRuntimePlugin::initializeGlobalCache(
const edm::ParameterSet &iConfig) {
return std::make_unique<ONNXRuntime>(
edm::FileInPath(iConfig.getParameter<std::string>("graphPath"))
.fullPath());
std::unique_ptr<ONNXRuntime> ONNXInference::initializeGlobalCache(const edm::ParameterSet& iConfig) {
return std::make_unique<ONNXRuntime>(edm::FileInPath(iConfig.getParameter<std::string>("graphPath")).fullPath());
}

void ONNXRuntimePlugin::globalEndJob(const ONNXRuntime *cache) {}
void ONNXInference::globalEndJob(const ONNXRuntime* cache) {}

void ONNXRuntimePlugin::analyze(const edm::Event &iEvent,
const edm::EventSetup &iSetup) {
void ONNXInference::analyze(const edm::Event& iEvent, const edm::EventSetup& iSetup) {
for (int i = 0; i < nInputs_; i++) {
std::vector<float> &group_data = inputArrays_[i];
std::vector<float>& group_data = inputArrays_[i];
// fill the input
for (int i = 0; i < (int)group_data.size(); i++) {
group_data[i] =
inputType_ == mlprof::InputType::Incremental
? float(i)
: (inputType_ == mlprof::InputType::Zeros ? float(0)
: drawNormal());
group_data[i] = inputType_ == mlprof::InputType::Incremental
? float(i)
: (inputType_ == mlprof::InputType::Zeros ? float(0) : drawNormal());
}
}

Expand All @@ -195,16 +176,14 @@ void ONNXRuntimePlugin::analyze(const edm::Event &iEvent,

// pre calls to "warm up"
for (int r = 0; r < nPreCalls_; r++) {
outputs = globalCache()->run(inputTensorNames_, inputArrays_, input_shapes_,
outputTensorNames_, batchSize_);
outputs = globalCache()->run(inputTensorNames_, inputArrays_, input_shapes_, outputTensorNames_, batchSize_);
}

// actual calls to measure runtimes
std::vector<float> runtimes;
for (int r = 0; r < nCalls_; r++) {
auto start = std::chrono::high_resolution_clock::now();
outputs = globalCache()->run(inputTensorNames_, inputArrays_, input_shapes_,
outputTensorNames_, batchSize_);
outputs = globalCache()->run(inputTensorNames_, inputArrays_, input_shapes_, outputTensorNames_, batchSize_);
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<float> runtime_in_seconds = (end - start);
runtimes.push_back(runtime_in_seconds.count() * 1000);
Expand All @@ -214,4 +193,4 @@ void ONNXRuntimePlugin::analyze(const edm::Event &iEvent,
mlprof::writeRuntimes(outputFile_, batchSize_, runtimes);
}

DEFINE_FWK_MODULE(ONNXRuntimePlugin);
DEFINE_FWK_MODULE(ONNXInference);
Loading

0 comments on commit 8781e78

Please sign in to comment.