Skip to content

Commit

Permalink
Merge pull request #26 from choich08365/HFJetTagging
Browse files Browse the repository at this point in the history
Updated jetTaggerHF GNN b-jet tagger
  • Loading branch information
hanseopark authored Jan 22, 2025
2 parents bb2e8e9 + 4662ee0 commit 8022672
Show file tree
Hide file tree
Showing 4 changed files with 296 additions and 3 deletions.
2 changes: 1 addition & 1 deletion PWGJE/Core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ o2physics_add_library(PWGJECore
SOURCES FastJetUtilities.cxx
JetFinder.cxx
JetBkgSubUtils.cxx
PUBLIC_LINK_LIBRARIES O2Physics::AnalysisCore FastJet::FastJet FastJet::Contrib)
PUBLIC_LINK_LIBRARIES O2Physics::AnalysisCore FastJet::FastJet FastJet::Contrib ONNXRuntime::ONNXRuntime)

o2physics_target_root_dictionary(PWGJECore
HEADERS JetFinder.h
Expand Down
216 changes: 216 additions & 0 deletions PWGJE/Core/JetTaggingUtilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@
#include "Common/Core/trackUtilities.h"
#include "PWGJE/Core/JetUtilities.h"

#if __has_include(<onnxruntime/core/session/onnxruntime_cxx_api.h>)
#include <onnxruntime/core/session/experimental_onnxruntime_cxx_api.h>
#else
#include <onnxruntime_cxx_api.h>
#endif

using namespace o2::constants::physics;

Check warning on line 47 in PWGJE/Core/JetTaggingUtilities.h

View workflow job for this annotation

GitHub Actions / O2 linter

[using-directive]

Using directives are not allowed in headers.

enum JetTaggingSpecies {
Expand Down Expand Up @@ -102,6 +108,159 @@ struct BJetSVParams {
double mDecayLength3DError = 0.0;
};

// ONNX Runtime tensor (Ort::Value) allocator for using customized inputs of ML models.
class TensorAllocator
{
protected:
#if !__has_include(<onnxruntime/core/session/onnxruntime_cxx_api.h>)
Ort::MemoryInfo mem_info;

Check warning on line 116 in PWGJE/Core/JetTaggingUtilities.h

View workflow job for this annotation

GitHub Actions / O2 linter

[name/function-variable]

Use lowerCamelCase for names of functions and variables.
#endif
public:
TensorAllocator()
#if !__has_include(<onnxruntime/core/session/onnxruntime_cxx_api.h>)
: mem_info(Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault))
#endif
{
}
~TensorAllocator() = default;
template <typename T>
Ort::Value createTensor(std::vector<T>& input, std::vector<int64_t>& inputShape)
{
#if __has_include(<onnxruntime/core/session/onnxruntime_cxx_api.h>)
return Ort::Experimental::Value::CreateTensor<T>(input.data(), input.size(), inputShape);
#else
return Ort::Value::CreateTensor<T>(mem_info, input.data(), input.size(), inputShape.data(), inputShape.size());
#endif
}
};

// TensorAllocator for GNN b-jet tagger
class GNNBjetAllocator : public TensorAllocator
{
private:
int64_t nJetFeat;
int64_t nTrkFeat;
int64_t nFlav;
int64_t nTrkOrigin;
int64_t maxNNodes;

std::vector<float> tfJetMean;
std::vector<float> tfJetStdev;
std::vector<float> tfTrkMean;
std::vector<float> tfTrkStdev;

std::vector<std::vector<int64_t>> edgesList;

// Jet feature normalization
template <typename T>
T jetFeatureTransform(T feat, int idx) const
{
return (feat - tfJetMean[idx]) / tfJetStdev[idx];
}

// Track feature normalization
template <typename T>
T trkFeatureTransform(T feat, int idx) const
{
return (feat - tfTrkMean[idx]) / tfTrkStdev[idx];
}

// Edge input of GNN (fully-connected graph)
void setEdgesList(void)
{
for (int64_t nNodes = 0; nNodes <= maxNNodes; ++nNodes) {
std::vector<std::pair<int64_t, int64_t>> edges;
// Generate all permutations of (i, j) where i != j
for (int64_t i = 0; i < nNodes; ++i) {
for (int64_t j = 0; j < nNodes; ++j) {
if (i != j) {
edges.emplace_back(i, j);
}
}
}
// Add self-loops (i, i)
for (int64_t i = 0; i < nNodes; ++i) {
edges.emplace_back(i, i);
}
// Flatten
std::vector<int64_t> flattenedEdges;
for (const auto& edge : edges) {
flattenedEdges.push_back(edge.first);
}
for (const auto& edge : edges) {
flattenedEdges.push_back(edge.second);
}
edgesList.push_back(flattenedEdges);
}
}

// Replace NaN in a vector into value
template <typename T>
static int replaceNaN(std::vector<T>& vec, T value)
{
int numNaN = 0;
for (auto& el : vec) {

Check warning on line 202 in PWGJE/Core/JetTaggingUtilities.h

View workflow job for this annotation

GitHub Actions / O2 linter

[const-ref-in-for-loop]

Use constant references for non-modified iterators in range-based for loops.
if (std::isnan(el)) {
el = value;
++numNaN;
}
}
return numNaN;
}

public:
GNNBjetAllocator() : TensorAllocator(), nJetFeat(4), nTrkFeat(13), nFlav(3), nTrkOrigin(5), maxNNodes(40) {}
GNNBjetAllocator(int64_t nJetFeat, int64_t nTrkFeat, int64_t nFlav, int64_t nTrkOrigin, std::vector<float>& tfJetMean, std::vector<float>& tfJetStdev, std::vector<float>& tfTrkMean, std::vector<float>& tfTrkStdev, int64_t maxNNodes = 40)
: TensorAllocator(), nJetFeat(nJetFeat), nTrkFeat(nTrkFeat), nFlav(nFlav), nTrkOrigin(nTrkOrigin), maxNNodes(maxNNodes), tfJetMean(tfJetMean), tfJetStdev(tfJetStdev), tfTrkMean(tfTrkMean), tfTrkStdev(tfTrkStdev)
{
setEdgesList();
}
~GNNBjetAllocator() = default;

// Copy operator for initializing GNNBjetAllocator using Configurable values
GNNBjetAllocator& operator=(const GNNBjetAllocator& other)
{
nJetFeat = other.nJetFeat;
nTrkFeat = other.nTrkFeat;
nFlav = other.nFlav;
nTrkOrigin = other.nTrkOrigin;
maxNNodes = other.maxNNodes;
tfJetMean = other.tfJetMean;
tfJetStdev = other.tfJetStdev;
tfTrkMean = other.tfTrkMean;
tfTrkStdev = other.tfTrkStdev;
setEdgesList();
return *this;
}

// Allocate & Return GNN input tensors (std::vector<Ort::Value>)
template <typename T>
void getGNNInput(std::vector<T>& jetFeat, std::vector<std::vector<T>>& trkFeat, std::vector<T>& feat, std::vector<Ort::Value>& gnnInput)
{
int64_t nNodes = trkFeat.size();

std::vector<int64_t> edgesShape{2, nNodes * nNodes};
gnnInput.emplace_back(createTensor(edgesList[nNodes], edgesShape));

std::vector<int64_t> featShape{nNodes, nJetFeat + nTrkFeat};

int numNaN = replaceNaN(jetFeat, 0.f);
for (auto& aTrkFeat : trkFeat) {

Check warning on line 248 in PWGJE/Core/JetTaggingUtilities.h

View workflow job for this annotation

GitHub Actions / O2 linter

[const-ref-in-for-loop]

Use constant references for non-modified iterators in range-based for loops.
for (size_t i = 0; i < jetFeat.size(); ++i)
feat.push_back(jetFeatureTransform(jetFeat[i], i));
numNaN += replaceNaN(aTrkFeat, 0.f);
for (size_t i = 0; i < aTrkFeat.size(); ++i)
feat.push_back(trkFeatureTransform(aTrkFeat[i], i));
}

gnnInput.emplace_back(createTensor(feat, featShape));

if (numNaN > 0) {
LOGF(info, "NaN found in GNN input feature, number of NaN: %d", numNaN);
}
}
};

//________________________________________________________________________
bool isBHadron(int pc)
{
Expand Down Expand Up @@ -1005,6 +1164,63 @@ void analyzeJetTrackInfo4ML(AnalysisJet const& analysisJet, AnyTracks const& /*a
// Sort the tracks based on their IP significance in descending order
std::sort(tracksParams.begin(), tracksParams.end(), compare);
}

// Looping over the track info and putting them in the input vector (for GNN b-jet tagging)
template <typename AnalysisJet, typename AnyTracks, typename AnyOriginalTracks>
void analyzeJetTrackInfo4GNN(AnalysisJet const& analysisJet, AnyTracks const& /*allTracks*/, AnyOriginalTracks const& /*origTracks*/, std::vector<std::vector<float>>& tracksParams, float trackPtMin = 0.5, int64_t nMaxConstit = 40)
{
for (const auto& constituent : analysisJet.template tracks_as<AnyTracks>()) {

if (constituent.pt() < trackPtMin) {
continue;
}

int sign = jettaggingutilities::getGeoSign(analysisJet, constituent);

auto origConstit = constituent.template track_as<AnyOriginalTracks>();

if (static_cast<int64_t>(tracksParams.size()) < nMaxConstit) {
tracksParams.emplace_back(std::vector<float>{constituent.pt(), origConstit.phi(), constituent.eta(), static_cast<float>(constituent.sign()), std::abs(constituent.dcaXY()) * sign, constituent.sigmadcaXY(), std::abs(constituent.dcaXYZ()) * sign, constituent.sigmadcaXYZ(), static_cast<float>(origConstit.itsNCls()), static_cast<float>(origConstit.tpcNClsFound()), static_cast<float>(origConstit.tpcNClsCrossedRows()), origConstit.itsChi2NCl(), origConstit.tpcChi2NCl()});
} else {
// If there are more than nMaxConstit constituents in the jet, select only nMaxConstit constituents with the highest DCA_XY significance.
size_t minIdx = 0;
for (size_t i = 0; i < tracksParams.size(); ++i) {
if (tracksParams[i][4] / tracksParams[i][5] < tracksParams[minIdx][4] / tracksParams[minIdx][5])
minIdx = i;
}
if (std::abs(constituent.dcaXY()) * sign / constituent.sigmadcaXY() > tracksParams[minIdx][4] / tracksParams[minIdx][5])
tracksParams[minIdx] = std::vector<float>{constituent.pt(), origConstit.phi(), constituent.eta(), static_cast<float>(constituent.sign()), std::abs(constituent.dcaXY()) * sign, constituent.sigmadcaXY(), std::abs(constituent.dcaXYZ()) * sign, constituent.sigmadcaXYZ(), static_cast<float>(origConstit.itsNCls()), static_cast<float>(origConstit.tpcNClsFound()), static_cast<float>(origConstit.tpcNClsCrossedRows()), origConstit.itsChi2NCl(), origConstit.tpcChi2NCl()};
}
}
}

// Discriminant value for GNN b-jet tagging
template <typename T>
T Db(const std::vector<T>& logits, double fC = 0.018)

Check warning on line 1199 in PWGJE/Core/JetTaggingUtilities.h

View workflow job for this annotation

GitHub Actions / O2 linter

[name/function-variable]

Use lowerCamelCase for names of functions and variables.
{
auto softmax = [](const std::vector<T>& logits) {
std::vector<T> res;
T maxLogit = *std::max_element(logits.begin(), logits.end());
T sumLogit = 0.;
for (size_t i = 0; i < logits.size(); ++i) {
res.push_back(std::exp(logits[i] - maxLogit));
sumLogit += res[i];
}
for (size_t i = 0; i < logits.size(); ++i) {
res[i] /= sumLogit;
}
return res;
};

std::vector<T> softmaxLogits = softmax(logits);

if (softmaxLogits[1] == 0. && softmaxLogits[2] == 0.) {
LOG(debug) << "jettaggingutilities::Db, Divide by zero: softmaxLogits = (" << softmaxLogits[0] << ", " << softmaxLogits[1] << ", " << softmaxLogits[2] << ")";
}

return std::log(softmaxLogits[0] / (fC * softmaxLogits[1] + (1. - fC) * softmaxLogits[2]));
}

}; // namespace jettaggingutilities

#endif // PWGJE_CORE_JETTAGGINGUTILITIES_H_
79 changes: 78 additions & 1 deletion PWGJE/TableProducer/jetTaggerHF.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,24 @@ struct JetTaggerHFTask {
Configurable<int64_t> timestampCCDB{"timestampCCDB", -1, "timestamp of the ONNX file for ML model used to query in CCDB"};
Configurable<bool> loadModelsFromCCDB{"loadModelsFromCCDB", false, "Flag to enable or disable the loading of models from CCDB"};

// GNN configuration
Configurable<double> fC{"fC", 0.018, "Parameter f_c for D_b calculation"};
Configurable<int64_t> nJetFeat{"nJetFeat", 4, "Number of jet GNN input features"};
Configurable<int64_t> nTrkFeat{"nTrkFeat", 13, "Number of track GNN input features"};
Configurable<int64_t> nTrkOrigin{"nTrkOrigin", 5, "Number of track origin categories"};
Configurable<std::vector<float>> transformFeatureJetMean{"transformFeatureJetMean",
std::vector<float>{3.7093048e+01, 3.1462731e+00, -8.9617318e-04, 4.5036483e+00},
"Mean values for each GNN input feature (jet)"};
Configurable<std::vector<float>> transformFeatureJetStdev{"transformFeatureJetStdev",
std::vector<float>{3.9559139e+01, 1.8156786e+00, 2.8845072e-01, 4.6293869e+00},
"Stdev values for each GNN input feature (jet)"};
Configurable<std::vector<float>> transformFeatureTrkMean{"transformFeatureTrkMean",
std::vector<float>{5.8772368e+00, 3.1470699e+00, -1.4703944e-03, 1.9976571e-03, 1.7700187e-03, 3.5821514e-03, 1.9987826e-03, 7.3673888e-03, 6.6411214e+00, 1.3810074e+02, 1.4888744e+02, 6.5751970e-01, 1.6469173e+00},
"Mean values for each GNN input feature (track)"};
Configurable<std::vector<float>> transformFeatureTrkStdev{"transformFeatureTrkStdev",
std::vector<float>{9.2763824e+00, 1.8162115e+00, 3.1512174e-01, 9.9999982e-01, 5.6147423e-02, 2.3086982e-02, 1.6523319e+00, 4.8507337e-02, 8.1565088e-01, 1.2891182e+01, 1.1064601e+01, 9.5457840e-01, 2.8930053e-01},
"Stdev values for each GNN input feature (track)"};

// axis spec
ConfigurableAxis binTrackProbability{"binTrackProbability", {100, 0.f, 1.f}, ""};
ConfigurableAxis binJetFlavour{"binJetFlavour", {6, -0.5, 5.5}, ""};
Expand All @@ -101,6 +119,7 @@ struct JetTaggerHFTask {
o2::ccdb::CcdbApi ccdbApi;

using JetTracksExt = soa::Join<aod::JetTracks, aod::JTrackExtras, aod::JTrackPIs>;
using OriginalTracks = soa::Join<aod::Tracks, aod::TracksCov, aod::TrackSelection, aod::TracksDCA, aod::TracksDCACov, aod::TracksExtra>;

bool useResoFuncFromIncJet = false;
int maxOrder = -1;
Expand All @@ -115,6 +134,8 @@ struct JetTaggerHFTask {
std::vector<uint16_t> decisionNonML;
std::vector<float> scoreML;

jettaggingutilities::GNNBjetAllocator tensorAlloc;

template <typename T, typename U>
float calculateJetProbability(int origin, T const& jet, U const& tracks, bool const& isMC = false)
{
Expand Down Expand Up @@ -194,6 +215,25 @@ struct JetTaggerHFTask {
}
}
}
if (doprocessAlgorithmGNN) {
if constexpr (isMC) {
switch (origin) {
case 2:
registry.fill(HIST("h_db_b"), scoreML[jet.globalIndex()]);
break;
case 1:
registry.fill(HIST("h_db_c"), scoreML[jet.globalIndex()]);
break;
case 0:
case 3:
registry.fill(HIST("h_db_lf"), scoreML[jet.globalIndex()]);
break;
default:
LOGF(debug, "doprocessAlgorithmGNN, Unexpected origin value: %d (%d)", origin, jet.globalIndex());
}
}
registry.fill(HIST("h2_pt_db"), jet.pt(), scoreML[jet.globalIndex()]);
}
taggingTable(decisionNonML[jet.globalIndex()], jetProb, scoreML[jet.globalIndex()]);
}

Expand Down Expand Up @@ -276,7 +316,7 @@ struct JetTaggerHFTask {
}
}

if (doprocessAlgorithmML) {
if (doprocessAlgorithmML || doprocessAlgorithmGNN) {
bMlResponse.configure(binsPtMl, cutsMl, cutDirMl, nClassesMl);
if (loadModelsFromCCDB) {
ccdbApi.init(ccdbUrl);
Expand All @@ -287,6 +327,14 @@ struct JetTaggerHFTask {
// bMlResponse.cacheInputFeaturesIndices(namesInputFeatures);
bMlResponse.init();
}

if (doprocessAlgorithmGNN) {
tensorAlloc = jettaggingutilities::GNNBjetAllocator(nJetFeat.value, nTrkFeat.value, nClassesMl.value, nTrkOrigin.value, transformFeatureJetMean.value, transformFeatureJetStdev.value, transformFeatureTrkMean.value, transformFeatureTrkStdev.value, nJetConst);
registry.add("h_db_b", "#it{D}_{b} b-jet;#it{D}_{b}", {HistType::kTH1F, {{50, -10., 35.}}});
registry.add("h_db_c", "#it{D}_{b} c-jet;#it{D}_{b}", {HistType::kTH1F, {{50, -10., 35.}}});
registry.add("h_db_lf", "#it{D}_{b} lf-jet;#it{D}_{b}", {HistType::kTH1F, {{50, -10., 35.}}});
registry.add("h2_pt_db", "#it{p}_{T} vs. #it{D}_{b};#it{p}_{T}^{ch jet} (GeV/#it{c}^{2});#it{D}_{b}", {HistType::kTH2F, {{100, 0., 200.}, {50, -10., 35.}}});
}
}

template <typename AnyJets, typename AnyTracks, typename SecondaryVertices>
Expand Down Expand Up @@ -316,6 +364,29 @@ struct JetTaggerHFTask {
}
}

template <typename AnyJets, typename AnyTracks, typename AnyOriginalTracks>
void analyzeJetAlgorithmGNN(AnyJets const& jets, AnyTracks const& tracks, AnyOriginalTracks const& origTracks)
{
for (const auto& jet : jets) {
std::vector<std::vector<float>> trkFeat;
jettaggingutilities::analyzeJetTrackInfo4GNN(jet, tracks, origTracks, trkFeat, trackPtMin, nJetConst);

std::vector<float> jetFeat{jet.pt(), jet.phi(), jet.eta(), jet.mass()};

if (trkFeat.size() > 0) {
std::vector<float> feat;
std::vector<Ort::Value> gnnInput;
tensorAlloc.getGNNInput(jetFeat, trkFeat, feat, gnnInput);

auto modelOutput = bMlResponse.getModelOutput(gnnInput, 0);
scoreML[jet.globalIndex()] = jettaggingutilities::Db(modelOutput, fC);
} else {
scoreML[jet.globalIndex()] = -999.;
LOGF(debug, "doprocessAlgorithmGNN, trkFeat.size() <= 0 (%d)", jet.globalIndex());
}
}
}

void processDummy(aod::JetCollisions const&)
{
}
Expand Down Expand Up @@ -354,6 +425,12 @@ struct JetTaggerHFTask {
}
PROCESS_SWITCH(JetTaggerHFTask, processAlgorithmML, "Fill ML evaluation score for charged jets", false);

void processAlgorithmGNN(JetTable const& jets, JetTracksExt const& jtracks, OriginalTracks const& origTracks)
{
analyzeJetAlgorithmGNN(jets, jtracks, origTracks);
}
PROCESS_SWITCH(JetTaggerHFTask, processAlgorithmGNN, "Fill GNN evaluation score (D_b) for charged jets", false);

void processFillTables(std::conditional_t<isMCD, soa::Join<JetTable, aod::ChargedMCDetectorLevelJetFlavourDef>, JetTable>::iterator const& jet, JetTracksExt const& tracks)
{
fillTables<isMCD>(jet, tracks);
Expand Down
2 changes: 1 addition & 1 deletion Tools/ML/MlResponse.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class MlResponse
LOG(fatal) << "Model index " << nModel << " is out of range! The number of initialised models is " << mModels.size() << ". Please check your configurables.";
}

TypeOutputScore* outputPtr = mModels[nModel].evalModel(input);
TypeOutputScore* outputPtr = mModels[nModel].template evalModel<TypeOutputScore>(input);
return std::vector<TypeOutputScore>{outputPtr, outputPtr + mNClasses};
}

Expand Down

0 comments on commit 8022672

Please sign in to comment.