From 03228df7d530292972830710a1bcc5bbf53845c8 Mon Sep 17 00:00:00 2001 From: Giulio Eulisse <10544+ktf@users.noreply.github.com> Date: Wed, 28 Feb 2024 14:08:03 +0100 Subject: [PATCH] DPL: avoid TMessage usage TMessage does not allow for non owned buffers, so we end up having an extra buffer in private memory for (de)serializing. Using TBufferFile directly allows to avoid that, so this moves the whole ROOT serialization support in DPL to use it. --- .../src/AODJAlienReaderHelpers.h | 2 + .../Core/include/Framework/DataAllocator.h | 1 + .../Core/include/Framework/DataRefUtils.h | 13 +- .../include/Framework/RootMessageContext.h | 3 + .../Framework/RootSerializationSupport.h | 3 +- .../include/Framework/TMessageSerializer.h | 130 ++++++++---------- Framework/Core/src/CommonDataProcessors.cxx | 7 +- Framework/Core/src/TMessageSerializer.cxx | 37 +++++ Framework/Core/test/test_DataRefUtils.cxx | 29 +++- .../Core/test/test_TMessageSerializer.cxx | 51 +++++-- Framework/Utils/test/test_RootTreeWriter.cxx | 1 + Steer/DigitizerWorkflow/src/SimReaderSpec.cxx | 1 - Utilities/Mergers/src/ObjectStore.cxx | 11 +- Utilities/Mergers/test/benchmark_Types.cxx | 11 +- 14 files changed, 198 insertions(+), 102 deletions(-) diff --git a/Framework/AnalysisSupport/src/AODJAlienReaderHelpers.h b/Framework/AnalysisSupport/src/AODJAlienReaderHelpers.h index 655e4b6c0b439..4b9fd710aca14 100644 --- a/Framework/AnalysisSupport/src/AODJAlienReaderHelpers.h +++ b/Framework/AnalysisSupport/src/AODJAlienReaderHelpers.h @@ -16,7 +16,9 @@ #include "Framework/AlgorithmSpec.h" #include "Framework/Logger.h" #include + #include +class TFile; namespace o2::framework::readers { diff --git a/Framework/Core/include/Framework/DataAllocator.h b/Framework/Core/include/Framework/DataAllocator.h index 8151d2f83c6c6..029e922aeb90b 100644 --- a/Framework/Core/include/Framework/DataAllocator.h +++ b/Framework/Core/include/Framework/DataAllocator.h @@ -359,6 +359,7 @@ class DataAllocator } else if constexpr (has_root_dictionary::value == true || is_specialization_v == true) { // Serialize a snapshot of an object with root dictionary payloadMessage = proxy.createOutputMessage(routeIndex); + payloadMessage->Rebuild(4096, {64}); if constexpr (is_specialization_v == true) { // Explicitely ROOT serialize a snapshot of object. // An object wrapped into type `ROOTSerialized` is explicitely marked to be ROOT serialized diff --git a/Framework/Core/include/Framework/DataRefUtils.h b/Framework/Core/include/Framework/DataRefUtils.h index defd10244bca5..264533def326d 100644 --- a/Framework/Core/include/Framework/DataRefUtils.h +++ b/Framework/Core/include/Framework/DataRefUtils.h @@ -71,12 +71,15 @@ struct DataRefUtils { throw runtime_error("Attempt to extract a TMessage from non-ROOT serialised message"); } - typename RSS::FairTMessage ftm(const_cast(ref.payload), payloadSize); - auto* storedClass = ftm.GetClass(); + typename RSS::FairInputTBuffer ftm(const_cast(ref.payload), payloadSize); auto* requestedClass = RSS::TClass::GetClass(typeid(T)); + ftm.InitMap(); + auto* storedClass = ftm.ReadClass(); // should always have the class description if has_root_dictionary is true assert(requestedClass != nullptr); + ftm.SetBufferOffset(0); + ftm.ResetMap(); auto* object = ftm.ReadObjectAny(storedClass); if (object == nullptr) { throw runtime_error_f("Failed to read object with name %s from message using ROOT serialization.", @@ -146,7 +149,11 @@ struct DataRefUtils { throw runtime_error("ROOT serialization not supported, dictionary not found for data type"); } - typename RSS::FairTMessage ftm(const_cast(ref.payload), payloadSize); + typename RSS::FairInputTBuffer ftm(const_cast(ref.payload), payloadSize); + ftm.InitMap(); + auto* classInfo = ftm.ReadClass(); + ftm.SetBufferOffset(0); + ftm.ResetMap(); result.reset(static_cast(ftm.ReadObjectAny(cl))); if (result.get() == nullptr) { throw runtime_error_f("Unable to extract class %s", cl == nullptr ? "" : cl->GetName()); diff --git a/Framework/Core/include/Framework/RootMessageContext.h b/Framework/Core/include/Framework/RootMessageContext.h index bef60ebbbf9f9..b1124880cf30f 100644 --- a/Framework/Core/include/Framework/RootMessageContext.h +++ b/Framework/Core/include/Framework/RootMessageContext.h @@ -72,6 +72,9 @@ class RootSerializedObject : public MessageContext::ContextObject fair::mq::Parts finalize() final { assert(mParts.Size() == 1); + if (mPayloadMsg->GetSize() < sizeof(char*)) { + mPayloadMsg->Rebuild(4096, {64}); + } TMessageSerializer::Serialize(*mPayloadMsg, mObject.get(), nullptr); mParts.AddPart(std::move(mPayloadMsg)); return ContextObject::finalize(); diff --git a/Framework/Core/include/Framework/RootSerializationSupport.h b/Framework/Core/include/Framework/RootSerializationSupport.h index cbf7408b13c7d..a44093f9c02bf 100644 --- a/Framework/Core/include/Framework/RootSerializationSupport.h +++ b/Framework/Core/include/Framework/RootSerializationSupport.h @@ -21,7 +21,8 @@ namespace o2::framework /// compiler. struct RootSerializationSupport { using TClass = ::TClass; - using FairTMessage = o2::framework::FairTMessage; + using FairInputTBuffer = o2::framework::FairInputTBuffer; + using FairOutputBuffer = o2::framework::FairOutputTBuffer; using TObject = ::TObject; }; diff --git a/Framework/Core/include/Framework/TMessageSerializer.h b/Framework/Core/include/Framework/TMessageSerializer.h index 1f08b456c0218..34a5156074b81 100644 --- a/Framework/Core/include/Framework/TMessageSerializer.h +++ b/Framework/Core/include/Framework/TMessageSerializer.h @@ -16,9 +16,8 @@ #include "Framework/RuntimeError.h" #include -#include +#include #include -#include #include #include #include @@ -28,67 +27,76 @@ namespace o2::framework { -class FairTMessage; +class FairOutputTBuffer; +class FairInputTBuffer; // utilities to produce a span over a byte buffer held by various message types // this is to avoid littering code with casts and conversions (span has a signed index type(!)) -gsl::span as_span(const FairTMessage& msg); +gsl::span as_span(const FairInputTBuffer& msg); +gsl::span as_span(const FairOutputTBuffer& msg); gsl::span as_span(const fair::mq::Message& msg); -class FairTMessage : public TMessage +// A TBufferFile which we can use to serialise data to a FairMQ message. +class FairOutputTBuffer : public TBufferFile { public: - using TMessage::TMessage; - FairTMessage() : TMessage(kMESS_OBJECT) {} - FairTMessage(void* buf, Int_t len) : TMessage(buf, len) { ResetBit(kIsOwner); } - FairTMessage(gsl::span buf) : TMessage(buf.data(), buf.size()) { ResetBit(kIsOwner); } + // This is to serialise data to FairMQ. We embed the pointer to the message + // in the data itself, so that we can use it to reallocate the message if needed. + // The FairMQ message retains ownership of the data. + // When deserialising the root object, keep in mind one needs to skip the 8 bytes + // for the pointer. + FairOutputTBuffer(fair::mq::Message& msg) + : TBufferFile(TBuffer::kWrite, msg.GetSize() - sizeof(char*), embedInItself(msg), false, fairMQrealloc) + { + } + // Helper function to keep track of the FairMQ message that holds the data + // in the data itself. We can use this to make sure the message can be reallocated + // even if we simply have a pointer to the data. Hopefully ROOT will not play dirty + // with us. + void* embedInItself(fair::mq::Message& msg); // helper function to clean up the object holding the data after it is transported. - static void free(void* /*data*/, void* hint); + static char* fairMQrealloc(char* oldData, size_t newSize, size_t oldSize); }; -struct TMessageSerializer { - using StreamerList = std::vector; - using CompressionLevel = int; +class FairInputTBuffer : public TBufferFile +{ + public: + // This is to serialise data to FairMQ. The provided message is expeted to have 8 bytes + // of overhead, where the source embedded the pointer for the reallocation. + // Notice this will break if the sender and receiver are not using the same + // size for a pointer. + FairInputTBuffer(char* data, size_t size) + : TBufferFile(TBuffer::kRead, size - sizeof(char*), data + sizeof(char*), false, nullptr) + { + } +}; - static void Serialize(fair::mq::Message& msg, const TObject* input, - CompressionLevel compressionLevel = -1); +struct TMessageSerializer { + static void Serialize(fair::mq::Message& msg, const TObject* input); template - static void Serialize(fair::mq::Message& msg, const T* input, const TClass* cl, // - CompressionLevel compressionLevel = -1); + static void Serialize(fair::mq::Message& msg, const T* input, const TClass* cl); template static void Deserialize(const fair::mq::Message& msg, std::unique_ptr& output); - static void serialize(FairTMessage& msg, const TObject* input, - CompressionLevel compressionLevel = -1); + static void serialize(o2::framework::FairOutputTBuffer& msg, const TObject* input); template - static void serialize(FairTMessage& msg, const T* input, // - const TClass* cl, - CompressionLevel compressionLevel = -1); + static void serialize(o2::framework::FairOutputTBuffer& msg, const T* input, const TClass* cl); template - static std::unique_ptr deserialize(gsl::span buffer); - template - static inline std::unique_ptr deserialize(std::byte* buffer, size_t size); + static inline std::unique_ptr deserialize(FairInputTBuffer& buffer); }; -inline void TMessageSerializer::serialize(FairTMessage& tm, const TObject* input, - CompressionLevel compressionLevel) +inline void TMessageSerializer::serialize(FairOutputTBuffer& tm, const TObject* input) { - return serialize(tm, input, nullptr, compressionLevel); + return serialize(tm, input, nullptr); } template -inline void TMessageSerializer::serialize(FairTMessage& tm, const T* input, // - const TClass* cl, CompressionLevel compressionLevel) +inline void TMessageSerializer::serialize(FairOutputTBuffer& tm, const T* input, const TClass* cl) { - if (compressionLevel >= 0) { - // if negative, skip to use ROOT default - tm.SetCompressionLevel(compressionLevel); - } - // TODO: check what WriateObject and WriteObjectAny are doing if (cl == nullptr) { tm.WriteObject(input); @@ -98,7 +106,7 @@ inline void TMessageSerializer::serialize(FairTMessage& tm, const T* input, // } template -inline std::unique_ptr TMessageSerializer::deserialize(gsl::span buffer) +inline std::unique_ptr TMessageSerializer::deserialize(FairInputTBuffer& buffer) { TClass* tgtClass = TClass::GetClass(typeid(T)); if (tgtClass == nullptr) { @@ -107,53 +115,32 @@ inline std::unique_ptr TMessageSerializer::deserialize(gsl::span b // FIXME: we need to add consistency check for buffer data to be serialized // at the moment, TMessage might simply crash if an invalid or inconsistent // buffer is provided - FairTMessage tm(buffer); - TClass* serializedClass = tm.GetClass(); + buffer.InitMap(); + TClass* serializedClass = buffer.ReadClass(); + buffer.SetBufferOffset(0); + buffer.ResetMap(); if (serializedClass == nullptr) { throw runtime_error_f("can not read class info from buffer"); } if (tgtClass != serializedClass && serializedClass->GetBaseClass(tgtClass) == nullptr) { throw runtime_error_f("can not convert serialized class %s into target class %s", - tm.GetClass()->GetName(), + serializedClass->GetName(), tgtClass->GetName()); } - return std::unique_ptr(reinterpret_cast(tm.ReadObjectAny(serializedClass))); + return std::unique_ptr(reinterpret_cast(buffer.ReadObjectAny(serializedClass))); } -template -inline std::unique_ptr TMessageSerializer::deserialize(std::byte* buffer, size_t size) +inline void TMessageSerializer::Serialize(fair::mq::Message& msg, const TObject* input) { - return deserialize(gsl::span(buffer, gsl::narrow::size_type>(size))); -} - -inline void FairTMessage::free(void* /*data*/, void* hint) -{ - std::default_delete deleter; - deleter(static_cast(hint)); -} - -inline void TMessageSerializer::Serialize(fair::mq::Message& msg, const TObject* input, - TMessageSerializer::CompressionLevel compressionLevel) -{ - std::unique_ptr tm = std::make_unique(kMESS_OBJECT); - - serialize(*tm, input, input->Class(), compressionLevel); - - msg.Rebuild(tm->Buffer(), tm->BufferSize(), FairTMessage::free, tm.get()); - tm.release(); + FairOutputTBuffer output(msg); + serialize(output, input, input->Class()); } template -inline void TMessageSerializer::Serialize(fair::mq::Message& msg, const T* input, // - const TClass* cl, // - TMessageSerializer::CompressionLevel compressionLevel) +inline void TMessageSerializer::Serialize(fair::mq::Message& msg, const T* input, const TClass* cl) { - std::unique_ptr tm = std::make_unique(kMESS_OBJECT); - - serialize(*tm, input, cl, compressionLevel); - - msg.Rebuild(tm->Buffer(), tm->BufferSize(), FairTMessage::free, tm.get()); - tm.release(); + FairOutputTBuffer output(msg); + serialize(output, input, cl); } template @@ -161,7 +148,8 @@ inline void TMessageSerializer::Deserialize(const fair::mq::Message& msg, std::u { // we know the message will not be modified by this, // so const_cast should be OK here(IMHO). - output = deserialize(as_span(msg)); + FairInputTBuffer input(static_cast(msg.GetData()), static_cast(msg.GetSize())); + output = deserialize(input); } // gsl::narrow is used to do a runtime narrowing check, this might be a bit paranoid, @@ -171,7 +159,7 @@ inline gsl::span as_span(const fair::mq::Message& msg) return gsl::span{static_cast(msg.GetData()), gsl::narrow::size_type>(msg.GetSize())}; } -inline gsl::span as_span(const FairTMessage& msg) +inline gsl::span as_span(const FairInputTBuffer& msg) { return gsl::span{reinterpret_cast(msg.Buffer()), gsl::narrow::size_type>(msg.BufferSize())}; diff --git a/Framework/Core/src/CommonDataProcessors.cxx b/Framework/Core/src/CommonDataProcessors.cxx index 48a3eb1da95b9..0cf7224c25ac8 100644 --- a/Framework/Core/src/CommonDataProcessors.cxx +++ b/Framework/Core/src/CommonDataProcessors.cxx @@ -141,9 +141,12 @@ DataProcessorSpec CommonDataProcessors::getOutputObjHistSink(std::vector(ref.payload), static_cast(datah->payloadSize)); InputObject obj; - obj.kind = tm.GetClass(); + FairInputTBuffer tm(const_cast(ref.payload), static_cast(datah->payloadSize)); + tm.InitMap(); + obj.kind = tm.ReadClass(); + tm.SetBufferOffset(0); + tm.ResetMap(); if (obj.kind == nullptr) { LOG(error) << "Cannot read class info from buffer."; return; diff --git a/Framework/Core/src/TMessageSerializer.cxx b/Framework/Core/src/TMessageSerializer.cxx index 5388a6d716cda..c5da4cc576242 100644 --- a/Framework/Core/src/TMessageSerializer.cxx +++ b/Framework/Core/src/TMessageSerializer.cxx @@ -9,7 +9,44 @@ // granted to it by virtue of its status as an Intergovernmental Organization // or submit itself to any jurisdiction. #include +#include #include #include using namespace o2::framework; + +void* FairOutputTBuffer::embedInItself(fair::mq::Message& msg) +{ + // The first bytes of the message are used to store the pointer to the message itself + // so that we can reallocate it if needed. + if (sizeof(char*) > msg.GetSize()) { + throw std::runtime_error("Message size too small to embed pointer"); + } + char* data = reinterpret_cast(msg.GetData()); + char* ptr = reinterpret_cast(&msg); + std::memcpy(data, &ptr, sizeof(char*)); + return data + sizeof(char*); +} + +// Reallocation function. Get the message pointer from the data and call Rebuild. +char* FairOutputTBuffer::fairMQrealloc(char* oldData, size_t newSize, size_t oldSize) +{ + // Old data is the pointer at the beginning of the message, so the pointer + // to the message is **stored** in the 8 bytes before it. + auto* msg = *(fair::mq::Message**)(oldData - sizeof(char*)); + if (newSize <= msg->GetSize()) { + // no need to reallocate, the message is already big enough + return oldData; + } + // Create a shallow copy of the message + fair::mq::MessagePtr oldMsg = msg->GetTransport()->CreateMessage(); + oldMsg->Copy(*msg); + // Copy the old data while rebuilding. Reference counting should make + // sure the old message is not deleted until the new one is ready. + // We need 8 extra bytes for the pointer to the message itself (realloc does not know about it) + // and we need to copy 8 bytes more than the old size (again, the extra pointer). + msg->Rebuild(newSize + 8, fair::mq::Alignment{64}); + memcpy(msg->GetData(), oldMsg->GetData(), oldSize + 8); + + return reinterpret_cast(msg->GetData()) + sizeof(char*); +} diff --git a/Framework/Core/test/test_DataRefUtils.cxx b/Framework/Core/test/test_DataRefUtils.cxx index 37da7912bfe8b..d4accde0fecf0 100644 --- a/Framework/Core/test/test_DataRefUtils.cxx +++ b/Framework/Core/test/test_DataRefUtils.cxx @@ -21,17 +21,38 @@ using namespace o2::framework; +TEST_CASE("PureRootTest") +{ + TBufferFile buffer(TBuffer::kWrite); + TObjString s("test"); + buffer.WriteObject(&s); + + TBufferFile buffer2(TBuffer::kRead, buffer.BufferSize(), buffer.Buffer(), false); + buffer2.SetReadMode(); + buffer2.InitMap(); + TClass* storedClass = buffer2.ReadClass(); + // ReadClass advances the buffer, so we need to reset it. + buffer2.SetBufferOffset(0); + buffer2.ResetMap(); + REQUIRE(storedClass != nullptr); + auto* outS = (TObjString*)buffer2.ReadObjectAny(storedClass); + REQUIRE(outS != nullptr); + REQUIRE(outS->GetString() == "test"); +} + // Simple test to do root deserialization. TEST_CASE("TestRootSerialization") { DataRef ref; - TMessage* tm = new TMessage(kMESS_OBJECT); + auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); + auto msg = transport->CreateMessage(4096); + FairOutputTBuffer tm(*msg); auto sOrig = std::make_unique("test"); - tm->WriteObject(sOrig.get()); + tm << sOrig.get(); o2::header::DataHeader dh; dh.payloadSerializationMethod = o2::header::gSerializationMethodROOT; - ref.payload = tm->Buffer(); - dh.payloadSize = tm->BufferSize(); + ref.payload = (char*)msg->GetData(); + dh.payloadSize = (size_t)msg->GetSize(); ref.header = reinterpret_cast(&dh); // Check by using the same type diff --git a/Framework/Core/test/test_TMessageSerializer.cxx b/Framework/Core/test/test_TMessageSerializer.cxx index bc5f817400a44..2807351058c1d 100644 --- a/Framework/Core/test/test_TMessageSerializer.cxx +++ b/Framework/Core/test/test_TMessageSerializer.cxx @@ -11,6 +11,7 @@ #include "Framework/TMessageSerializer.h" #include "Framework/RuntimeError.h" +#include #include "TestClasses.h" #include #include @@ -49,14 +50,14 @@ TEST_CASE("TestTMessageSerializer") array.SetOwner(); array.Add(new TNamed(testname, testtitle)); - FairTMessage msg; - TMessageSerializer::serialize(msg, &array); + auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); + auto msg = transport->CreateMessage(4096); + FairOutputTBuffer buffer(*msg); + TMessageSerializer::serialize(buffer, &array); - auto buf = as_span(msg); - REQUIRE(buf.size() == msg.BufferSize()); - REQUIRE(static_cast(buf.data()) == static_cast(msg.Buffer())); + FairInputTBuffer msg2((char*)msg->GetData(), msg->GetSize()); // test deserialization with TObject as target class (default) - auto out = TMessageSerializer::deserialize(buf); + auto out = TMessageSerializer::deserialize(msg2); auto* outarr = dynamic_cast(out.get()); REQUIRE(out.get() == outarr); @@ -66,9 +67,9 @@ TEST_CASE("TestTMessageSerializer") REQUIRE(named->GetTitle() == std::string(testtitle)); // test deserialization with a wrong target class and check the exception - REQUIRE_THROWS_AS(TMessageSerializer::deserialize(buf), o2::framework::RuntimeErrorRef); + REQUIRE_THROWS_AS(TMessageSerializer::deserialize(msg2), o2::framework::RuntimeErrorRef); - REQUIRE_THROWS_MATCHES(TMessageSerializer::deserialize(buf), o2::framework::RuntimeErrorRef, + REQUIRE_THROWS_MATCHES(TMessageSerializer::deserialize(msg2), o2::framework::RuntimeErrorRef, ExceptionMatcher("can not convert serialized class TObjArray into target class TNamed")); } @@ -87,23 +88,29 @@ TEST_CASE("TestTMessageSerializer_NonTObject") TClass* cl = TClass::GetClass("std::vector"); REQUIRE(cl != nullptr); - FairTMessage msg; + auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); + auto msg = transport->CreateMessage(4096); + FairOutputTBuffer buffer(*msg); char* in = reinterpret_cast(&data); - TMessageSerializer::serialize(msg, in, cl); + TMessageSerializer::serialize(buffer, in, cl); + FairInputTBuffer msg2((char*)msg->GetData(), msg->GetSize()); - auto out = TMessageSerializer::deserialize>(as_span(msg)); + auto out = TMessageSerializer::deserialize>(msg2); REQUIRE(out); REQUIRE((*out.get()).size() == 2); REQUIRE((*out.get())[0] == o2::test::Polymorphic(0xaffe)); REQUIRE((*out.get())[1] == o2::test::Polymorphic(0xd00f)); // test deserialization with a wrong target class and check the exception - REQUIRE_THROWS_AS(TMessageSerializer::deserialize(as_span(msg)), RuntimeErrorRef); + REQUIRE_THROWS_AS(TMessageSerializer::deserialize(msg2), RuntimeErrorRef); } TEST_CASE("TestTMessageSerializer_InvalidBuffer") { const char* buffer = "this is for sure not a serialized ROOT object"; + auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); + auto msg = transport->CreateMessage(strlen(buffer) + 8); + memcpy((char*)msg->GetData() + 8, buffer, strlen(buffer)); // test deserialization of invalid buffer and check the exception // FIXME: at the moment, TMessage fails directly with a segfault, which it shouldn't do /* @@ -119,5 +126,23 @@ TEST_CASE("TestTMessageSerializer_InvalidBuffer") struct Dummy { }; auto matcher = ExceptionMatcher("class is not ROOT-serializable: ZL22CATCH2_INTERNAL_TEST_4vE5Dummy"); - REQUIRE_THROWS_MATCHES(TMessageSerializer::deserialize((std::byte*)buffer, strlen(buffer)), o2::framework::RuntimeErrorRef, matcher); + FairInputTBuffer msg2((char*)msg->GetData(), msg->GetSize()); + REQUIRE_THROWS_MATCHES(TMessageSerializer::deserialize(msg2), o2::framework::RuntimeErrorRef, matcher); +} + +TEST_CASE("TestTMessageSerializer_CheckExpansion") +{ + const char* buffer = "this is for sure not a serialized ROOT object"; + auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); + auto msg = transport->CreateMessage(strlen(buffer) + 8); + FairOutputTBuffer msg2(*msg); + // The buffer starts after 8 bytes. + REQUIRE(msg2.Buffer() == (char*)msg->GetData() + 8); + // The first 8 bytes of the buffer store the pointer to the message itself. + REQUIRE(*(fair::mq::Message**)msg->GetData() == msg.get()); + // Notice that TBuffer does the same trick with the reallocation function, + // so in the end the useful buffer size is the message size minus 16. + REQUIRE(msg2.BufferSize() == (msg->GetSize() - 16)); + // This will not fit the original buffer size, so the buffer will be expanded. + msg2.Expand(100); } diff --git a/Framework/Utils/test/test_RootTreeWriter.cxx b/Framework/Utils/test/test_RootTreeWriter.cxx index 3194508f3d775..62e1eb62cb4f1 100644 --- a/Framework/Utils/test/test_RootTreeWriter.cxx +++ b/Framework/Utils/test/test_RootTreeWriter.cxx @@ -179,6 +179,7 @@ TEST_CASE("test_RootTreeWriter") auto createSerializedMessage = [&transport, &store](DataHeader&& dh, auto& data) { fair::mq::MessagePtr payload = transport->CreateMessage(); + payload->Rebuild(4096, {64}); auto* cl = TClass::GetClass(typeid(decltype(data))); TMessageSerializer().Serialize(*payload, &data, cl); dh.payloadSize = payload->GetSize(); diff --git a/Steer/DigitizerWorkflow/src/SimReaderSpec.cxx b/Steer/DigitizerWorkflow/src/SimReaderSpec.cxx index 03bfa2eb23ede..6f8502f74a85b 100644 --- a/Steer/DigitizerWorkflow/src/SimReaderSpec.cxx +++ b/Steer/DigitizerWorkflow/src/SimReaderSpec.cxx @@ -25,7 +25,6 @@ #include "DetectorsRaw/HBFUtils.h" #include #include -#include // object serialization #include // std::unique_ptr #include // memcpy #include // std::string diff --git a/Utilities/Mergers/src/ObjectStore.cxx b/Utilities/Mergers/src/ObjectStore.cxx index e88358507c31e..3bb49f1dfc9d8 100644 --- a/Utilities/Mergers/src/ObjectStore.cxx +++ b/Utilities/Mergers/src/ObjectStore.cxx @@ -38,7 +38,7 @@ static std::string concat(Args&&... arguments) return std::move(ss.str()); } -void* readObject(const TClass* type, o2::framework::FairTMessage& ftm) +void* readObject(const TClass* type, o2::framework::FairInputTBuffer& ftm) { using namespace std::string_view_literals; auto* object = ftm.ReadObjectAny(type); @@ -60,7 +60,7 @@ MergeInterface* castToMergeInterface(bool inheritsFromTObject, void* object, TCl return objectAsMergeInterface; } -std::optional extractVector(o2::framework::FairTMessage& ftm, const TClass* storedClass) +std::optional extractVector(o2::framework::FairInputTBuffer& ftm, const TClass* storedClass) { if (!storedClass->InheritsFrom(TClass::GetClass(typeid(VectorOfRawTObjects)))) { return std::nullopt; @@ -88,11 +88,14 @@ ObjectStore extractObjectFrom(const framework::DataRef& ref) throw std::runtime_error(concat(errorPrefix, "It is not ROOT-serialized"sv)); } - o2::framework::FairTMessage ftm(const_cast(ref.payload), o2::framework::DataRefUtils::getPayloadSize(ref)); - auto* storedClass = ftm.GetClass(); + o2::framework::FairInputTBuffer ftm(const_cast(ref.payload), o2::framework::DataRefUtils::getPayloadSize(ref)); + ftm.InitMap(); + auto* storedClass = ftm.ReadClass(); if (storedClass == nullptr) { throw std::runtime_error(concat(errorPrefix, "Unknown stored class"sv)); } + ftm.SetBufferOffset(0); + ftm.ResetMap(); if (const auto extractedVector = extractVector(ftm, storedClass)) { return extractedVector.value(); diff --git a/Utilities/Mergers/test/benchmark_Types.cxx b/Utilities/Mergers/test/benchmark_Types.cxx index 790fd329185ea..736685c5746b8 100644 --- a/Utilities/Mergers/test/benchmark_Types.cxx +++ b/Utilities/Mergers/test/benchmark_Types.cxx @@ -165,11 +165,16 @@ auto measure = [](Measurement m, auto* o, auto* i) -> double { tm->WriteObject(o); start = std::chrono::high_resolution_clock::now(); - o2::framework::FairTMessage ftm(const_cast(tm->Buffer()), tm->BufferSize()); - auto* storedClass = ftm.GetClass(); + // Needed to take into account that FairInputTBuffer expects the first 8 bytes to be the + // allocator pointer, which is not present in the TMessage buffer. + o2::framework::FairInputTBuffer ftm(const_cast(tm->Buffer() - 8), tm->BufferSize() + 8); + ftm.InitMap(); + auto* storedClass = ftm.ReadClass(); if (storedClass == nullptr) { throw std::runtime_error("Unknown stored class"); } + ftm.SetBufferOffset(0); + ftm.ResetMap(); auto* tObjectClass = TClass::GetClass(typeid(TObject)); if (!storedClass->InheritsFrom(tObjectClass)) { @@ -738,4 +743,4 @@ int main(int argc, const char* argv[]) file.close(); return 0; -} \ No newline at end of file +}