diff --git a/PWGJE/Core/CMakeLists.txt b/PWGJE/Core/CMakeLists.txt index 592e1686ff5..95895ddc442 100644 --- a/PWGJE/Core/CMakeLists.txt +++ b/PWGJE/Core/CMakeLists.txt @@ -14,7 +14,7 @@ o2physics_add_library(PWGJECore SOURCES FastJetUtilities.cxx JetFinder.cxx JetBkgSubUtils.cxx - PUBLIC_LINK_LIBRARIES O2Physics::AnalysisCore FastJet::FastJet FastJet::Contrib) + PUBLIC_LINK_LIBRARIES O2Physics::AnalysisCore FastJet::FastJet FastJet::Contrib ONNXRuntime::ONNXRuntime) o2physics_target_root_dictionary(PWGJECore HEADERS JetFinder.h diff --git a/PWGJE/Core/JetTaggingUtilities.h b/PWGJE/Core/JetTaggingUtilities.h index d0b08215f4c..d9bbf478397 100644 --- a/PWGJE/Core/JetTaggingUtilities.h +++ b/PWGJE/Core/JetTaggingUtilities.h @@ -38,6 +38,12 @@ #include "Common/Core/trackUtilities.h" #include "PWGJE/Core/JetUtilities.h" +#if __has_include() +#include +#else +#include +#endif + using namespace o2::constants::physics; enum JetTaggingSpecies { @@ -102,6 +108,159 @@ struct BJetSVParams { double mDecayLength3DError = 0.0; }; +// ONNX Runtime tensor (Ort::Value) allocator for using customized inputs of ML models. +class TensorAllocator +{ + protected: +#if !__has_include() + Ort::MemoryInfo mem_info; +#endif + public: + TensorAllocator() +#if !__has_include() + : mem_info(Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault)) +#endif + { + } + ~TensorAllocator() = default; + template + Ort::Value createTensor(std::vector& input, std::vector& inputShape) + { +#if __has_include() + return Ort::Experimental::Value::CreateTensor(input.data(), input.size(), inputShape); +#else + return Ort::Value::CreateTensor(mem_info, input.data(), input.size(), inputShape.data(), inputShape.size()); +#endif + } +}; + +// TensorAllocator for GNN b-jet tagger +class GNNBjetAllocator : public TensorAllocator +{ + private: + int64_t nJetFeat; + int64_t nTrkFeat; + int64_t nFlav; + int64_t nTrkOrigin; + int64_t maxNNodes; + + std::vector tfJetMean; + std::vector tfJetStdev; + std::vector tfTrkMean; + std::vector tfTrkStdev; + + std::vector> edgesList; + + // Jet feature normalization + template + T jetFeatureTransform(T feat, int idx) const + { + return (feat - tfJetMean[idx]) / tfJetStdev[idx]; + } + + // Track feature normalization + template + T trkFeatureTransform(T feat, int idx) const + { + return (feat - tfTrkMean[idx]) / tfTrkStdev[idx]; + } + + // Edge input of GNN (fully-connected graph) + void setEdgesList(void) + { + for (int64_t nNodes = 0; nNodes <= maxNNodes; ++nNodes) { + std::vector> edges; + // Generate all permutations of (i, j) where i != j + for (int64_t i = 0; i < nNodes; ++i) { + for (int64_t j = 0; j < nNodes; ++j) { + if (i != j) { + edges.emplace_back(i, j); + } + } + } + // Add self-loops (i, i) + for (int64_t i = 0; i < nNodes; ++i) { + edges.emplace_back(i, i); + } + // Flatten + std::vector flattenedEdges; + for (const auto& edge : edges) { + flattenedEdges.push_back(edge.first); + } + for (const auto& edge : edges) { + flattenedEdges.push_back(edge.second); + } + edgesList.push_back(flattenedEdges); + } + } + + // Replace NaN in a vector into value + template + static int replaceNaN(std::vector& vec, T value) + { + int numNaN = 0; + for (auto& el : vec) { + if (std::isnan(el)) { + el = value; + ++numNaN; + } + } + return numNaN; + } + + public: + GNNBjetAllocator() : TensorAllocator(), nJetFeat(4), nTrkFeat(13), nFlav(3), nTrkOrigin(5), maxNNodes(40) {} + GNNBjetAllocator(int64_t nJetFeat, int64_t nTrkFeat, int64_t nFlav, int64_t nTrkOrigin, std::vector& tfJetMean, std::vector& tfJetStdev, std::vector& tfTrkMean, std::vector& tfTrkStdev, int64_t maxNNodes = 40) + : TensorAllocator(), nJetFeat(nJetFeat), nTrkFeat(nTrkFeat), nFlav(nFlav), nTrkOrigin(nTrkOrigin), maxNNodes(maxNNodes), tfJetMean(tfJetMean), tfJetStdev(tfJetStdev), tfTrkMean(tfTrkMean), tfTrkStdev(tfTrkStdev) + { + setEdgesList(); + } + ~GNNBjetAllocator() = default; + + // Copy operator for initializing GNNBjetAllocator using Configurable values + GNNBjetAllocator& operator=(const GNNBjetAllocator& other) + { + nJetFeat = other.nJetFeat; + nTrkFeat = other.nTrkFeat; + nFlav = other.nFlav; + nTrkOrigin = other.nTrkOrigin; + maxNNodes = other.maxNNodes; + tfJetMean = other.tfJetMean; + tfJetStdev = other.tfJetStdev; + tfTrkMean = other.tfTrkMean; + tfTrkStdev = other.tfTrkStdev; + setEdgesList(); + return *this; + } + + // Allocate & Return GNN input tensors (std::vector) + template + void getGNNInput(std::vector& jetFeat, std::vector>& trkFeat, std::vector& feat, std::vector& gnnInput) + { + int64_t nNodes = trkFeat.size(); + + std::vector edgesShape{2, nNodes * nNodes}; + gnnInput.emplace_back(createTensor(edgesList[nNodes], edgesShape)); + + std::vector featShape{nNodes, nJetFeat + nTrkFeat}; + + int numNaN = replaceNaN(jetFeat, 0.f); + for (auto& aTrkFeat : trkFeat) { + for (size_t i = 0; i < jetFeat.size(); ++i) + feat.push_back(jetFeatureTransform(jetFeat[i], i)); + numNaN += replaceNaN(aTrkFeat, 0.f); + for (size_t i = 0; i < aTrkFeat.size(); ++i) + feat.push_back(trkFeatureTransform(aTrkFeat[i], i)); + } + + gnnInput.emplace_back(createTensor(feat, featShape)); + + if (numNaN > 0) { + LOGF(info, "NaN found in GNN input feature, number of NaN: %d", numNaN); + } + } +}; + //________________________________________________________________________ bool isBHadron(int pc) { @@ -1005,6 +1164,63 @@ void analyzeJetTrackInfo4ML(AnalysisJet const& analysisJet, AnyTracks const& /*a // Sort the tracks based on their IP significance in descending order std::sort(tracksParams.begin(), tracksParams.end(), compare); } + +// Looping over the track info and putting them in the input vector (for GNN b-jet tagging) +template +void analyzeJetTrackInfo4GNN(AnalysisJet const& analysisJet, AnyTracks const& /*allTracks*/, AnyOriginalTracks const& /*origTracks*/, std::vector>& tracksParams, float trackPtMin = 0.5, int64_t nMaxConstit = 40) +{ + for (const auto& constituent : analysisJet.template tracks_as()) { + + if (constituent.pt() < trackPtMin) { + continue; + } + + int sign = jettaggingutilities::getGeoSign(analysisJet, constituent); + + auto origConstit = constituent.template track_as(); + + if (static_cast(tracksParams.size()) < nMaxConstit) { + tracksParams.emplace_back(std::vector{constituent.pt(), origConstit.phi(), constituent.eta(), static_cast(constituent.sign()), std::abs(constituent.dcaXY()) * sign, constituent.sigmadcaXY(), std::abs(constituent.dcaXYZ()) * sign, constituent.sigmadcaXYZ(), static_cast(origConstit.itsNCls()), static_cast(origConstit.tpcNClsFound()), static_cast(origConstit.tpcNClsCrossedRows()), origConstit.itsChi2NCl(), origConstit.tpcChi2NCl()}); + } else { + // If there are more than nMaxConstit constituents in the jet, select only nMaxConstit constituents with the highest DCA_XY significance. + size_t minIdx = 0; + for (size_t i = 0; i < tracksParams.size(); ++i) { + if (tracksParams[i][4] / tracksParams[i][5] < tracksParams[minIdx][4] / tracksParams[minIdx][5]) + minIdx = i; + } + if (std::abs(constituent.dcaXY()) * sign / constituent.sigmadcaXY() > tracksParams[minIdx][4] / tracksParams[minIdx][5]) + tracksParams[minIdx] = std::vector{constituent.pt(), origConstit.phi(), constituent.eta(), static_cast(constituent.sign()), std::abs(constituent.dcaXY()) * sign, constituent.sigmadcaXY(), std::abs(constituent.dcaXYZ()) * sign, constituent.sigmadcaXYZ(), static_cast(origConstit.itsNCls()), static_cast(origConstit.tpcNClsFound()), static_cast(origConstit.tpcNClsCrossedRows()), origConstit.itsChi2NCl(), origConstit.tpcChi2NCl()}; + } + } +} + +// Discriminant value for GNN b-jet tagging +template +T Db(const std::vector& logits, double fC = 0.018) +{ + auto softmax = [](const std::vector& logits) { + std::vector res; + T maxLogit = *std::max_element(logits.begin(), logits.end()); + T sumLogit = 0.; + for (size_t i = 0; i < logits.size(); ++i) { + res.push_back(std::exp(logits[i] - maxLogit)); + sumLogit += res[i]; + } + for (size_t i = 0; i < logits.size(); ++i) { + res[i] /= sumLogit; + } + return res; + }; + + std::vector softmaxLogits = softmax(logits); + + if (softmaxLogits[1] == 0. && softmaxLogits[2] == 0.) { + LOG(debug) << "jettaggingutilities::Db, Divide by zero: softmaxLogits = (" << softmaxLogits[0] << ", " << softmaxLogits[1] << ", " << softmaxLogits[2] << ")"; + } + + return std::log(softmaxLogits[0] / (fC * softmaxLogits[1] + (1. - fC) * softmaxLogits[2])); +} + }; // namespace jettaggingutilities #endif // PWGJE_CORE_JETTAGGINGUTILITIES_H_ diff --git a/PWGJE/TableProducer/jetTaggerHF.cxx b/PWGJE/TableProducer/jetTaggerHF.cxx index 03abadd36a9..036700d19ea 100644 --- a/PWGJE/TableProducer/jetTaggerHF.cxx +++ b/PWGJE/TableProducer/jetTaggerHF.cxx @@ -93,6 +93,24 @@ struct JetTaggerHFTask { Configurable timestampCCDB{"timestampCCDB", -1, "timestamp of the ONNX file for ML model used to query in CCDB"}; Configurable loadModelsFromCCDB{"loadModelsFromCCDB", false, "Flag to enable or disable the loading of models from CCDB"}; + // GNN configuration + Configurable fC{"fC", 0.018, "Parameter f_c for D_b calculation"}; + Configurable nJetFeat{"nJetFeat", 4, "Number of jet GNN input features"}; + Configurable nTrkFeat{"nTrkFeat", 13, "Number of track GNN input features"}; + Configurable nTrkOrigin{"nTrkOrigin", 5, "Number of track origin categories"}; + Configurable> transformFeatureJetMean{"transformFeatureJetMean", + std::vector{3.7093048e+01, 3.1462731e+00, -8.9617318e-04, 4.5036483e+00}, + "Mean values for each GNN input feature (jet)"}; + Configurable> transformFeatureJetStdev{"transformFeatureJetStdev", + std::vector{3.9559139e+01, 1.8156786e+00, 2.8845072e-01, 4.6293869e+00}, + "Stdev values for each GNN input feature (jet)"}; + Configurable> transformFeatureTrkMean{"transformFeatureTrkMean", + std::vector{5.8772368e+00, 3.1470699e+00, -1.4703944e-03, 1.9976571e-03, 1.7700187e-03, 3.5821514e-03, 1.9987826e-03, 7.3673888e-03, 6.6411214e+00, 1.3810074e+02, 1.4888744e+02, 6.5751970e-01, 1.6469173e+00}, + "Mean values for each GNN input feature (track)"}; + Configurable> transformFeatureTrkStdev{"transformFeatureTrkStdev", + std::vector{9.2763824e+00, 1.8162115e+00, 3.1512174e-01, 9.9999982e-01, 5.6147423e-02, 2.3086982e-02, 1.6523319e+00, 4.8507337e-02, 8.1565088e-01, 1.2891182e+01, 1.1064601e+01, 9.5457840e-01, 2.8930053e-01}, + "Stdev values for each GNN input feature (track)"}; + // axis spec ConfigurableAxis binTrackProbability{"binTrackProbability", {100, 0.f, 1.f}, ""}; ConfigurableAxis binJetFlavour{"binJetFlavour", {6, -0.5, 5.5}, ""}; @@ -101,6 +119,7 @@ struct JetTaggerHFTask { o2::ccdb::CcdbApi ccdbApi; using JetTracksExt = soa::Join; + using OriginalTracks = soa::Join; bool useResoFuncFromIncJet = false; int maxOrder = -1; @@ -115,6 +134,8 @@ struct JetTaggerHFTask { std::vector decisionNonML; std::vector scoreML; + jettaggingutilities::GNNBjetAllocator tensorAlloc; + template float calculateJetProbability(int origin, T const& jet, U const& tracks, bool const& isMC = false) { @@ -194,6 +215,25 @@ struct JetTaggerHFTask { } } } + if (doprocessAlgorithmGNN) { + if constexpr (isMC) { + switch (origin) { + case 2: + registry.fill(HIST("h_db_b"), scoreML[jet.globalIndex()]); + break; + case 1: + registry.fill(HIST("h_db_c"), scoreML[jet.globalIndex()]); + break; + case 0: + case 3: + registry.fill(HIST("h_db_lf"), scoreML[jet.globalIndex()]); + break; + default: + LOGF(debug, "doprocessAlgorithmGNN, Unexpected origin value: %d (%d)", origin, jet.globalIndex()); + } + } + registry.fill(HIST("h2_pt_db"), jet.pt(), scoreML[jet.globalIndex()]); + } taggingTable(decisionNonML[jet.globalIndex()], jetProb, scoreML[jet.globalIndex()]); } @@ -276,7 +316,7 @@ struct JetTaggerHFTask { } } - if (doprocessAlgorithmML) { + if (doprocessAlgorithmML || doprocessAlgorithmGNN) { bMlResponse.configure(binsPtMl, cutsMl, cutDirMl, nClassesMl); if (loadModelsFromCCDB) { ccdbApi.init(ccdbUrl); @@ -287,6 +327,14 @@ struct JetTaggerHFTask { // bMlResponse.cacheInputFeaturesIndices(namesInputFeatures); bMlResponse.init(); } + + if (doprocessAlgorithmGNN) { + tensorAlloc = jettaggingutilities::GNNBjetAllocator(nJetFeat.value, nTrkFeat.value, nClassesMl.value, nTrkOrigin.value, transformFeatureJetMean.value, transformFeatureJetStdev.value, transformFeatureTrkMean.value, transformFeatureTrkStdev.value, nJetConst); + registry.add("h_db_b", "#it{D}_{b} b-jet;#it{D}_{b}", {HistType::kTH1F, {{50, -10., 35.}}}); + registry.add("h_db_c", "#it{D}_{b} c-jet;#it{D}_{b}", {HistType::kTH1F, {{50, -10., 35.}}}); + registry.add("h_db_lf", "#it{D}_{b} lf-jet;#it{D}_{b}", {HistType::kTH1F, {{50, -10., 35.}}}); + registry.add("h2_pt_db", "#it{p}_{T} vs. #it{D}_{b};#it{p}_{T}^{ch jet} (GeV/#it{c}^{2});#it{D}_{b}", {HistType::kTH2F, {{100, 0., 200.}, {50, -10., 35.}}}); + } } template @@ -316,6 +364,29 @@ struct JetTaggerHFTask { } } + template + void analyzeJetAlgorithmGNN(AnyJets const& jets, AnyTracks const& tracks, AnyOriginalTracks const& origTracks) + { + for (const auto& jet : jets) { + std::vector> trkFeat; + jettaggingutilities::analyzeJetTrackInfo4GNN(jet, tracks, origTracks, trkFeat, trackPtMin, nJetConst); + + std::vector jetFeat{jet.pt(), jet.phi(), jet.eta(), jet.mass()}; + + if (trkFeat.size() > 0) { + std::vector feat; + std::vector gnnInput; + tensorAlloc.getGNNInput(jetFeat, trkFeat, feat, gnnInput); + + auto modelOutput = bMlResponse.getModelOutput(gnnInput, 0); + scoreML[jet.globalIndex()] = jettaggingutilities::Db(modelOutput, fC); + } else { + scoreML[jet.globalIndex()] = -999.; + LOGF(debug, "doprocessAlgorithmGNN, trkFeat.size() <= 0 (%d)", jet.globalIndex()); + } + } + } + void processDummy(aod::JetCollisions const&) { } @@ -354,6 +425,12 @@ struct JetTaggerHFTask { } PROCESS_SWITCH(JetTaggerHFTask, processAlgorithmML, "Fill ML evaluation score for charged jets", false); + void processAlgorithmGNN(JetTable const& jets, JetTracksExt const& jtracks, OriginalTracks const& origTracks) + { + analyzeJetAlgorithmGNN(jets, jtracks, origTracks); + } + PROCESS_SWITCH(JetTaggerHFTask, processAlgorithmGNN, "Fill GNN evaluation score (D_b) for charged jets", false); + void processFillTables(std::conditional_t, JetTable>::iterator const& jet, JetTracksExt const& tracks) { fillTables(jet, tracks); diff --git a/Tools/ML/MlResponse.h b/Tools/ML/MlResponse.h index 127512e52ee..2d16a67bf18 100644 --- a/Tools/ML/MlResponse.h +++ b/Tools/ML/MlResponse.h @@ -158,7 +158,7 @@ class MlResponse LOG(fatal) << "Model index " << nModel << " is out of range! The number of initialised models is " << mModels.size() << ". Please check your configurables."; } - TypeOutputScore* outputPtr = mModels[nModel].evalModel(input); + TypeOutputScore* outputPtr = mModels[nModel].template evalModel(input); return std::vector{outputPtr, outputPtr + mNClasses}; }