Skip to content

Commit

Permalink
DPL: avoid TMessage usage
Browse files Browse the repository at this point in the history
TMessage does not allow for non owned buffers, so we end up having
an extra buffer in private memory for (de)serializing. Using TBufferFile
directly allows to avoid that, so this moves the whole ROOT serialization
support in DPL to use it.
  • Loading branch information
ktf committed Feb 28, 2024
1 parent b23bd7b commit 03228df
Show file tree
Hide file tree
Showing 14 changed files with 198 additions and 102 deletions.
2 changes: 2 additions & 0 deletions Framework/AnalysisSupport/src/AODJAlienReaderHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
#include "Framework/AlgorithmSpec.h"
#include "Framework/Logger.h"
#include <Monitoring/Monitoring.h>

#include <uv.h>
class TFile;

namespace o2::framework::readers
{
Expand Down
1 change: 1 addition & 0 deletions Framework/Core/include/Framework/DataAllocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ class DataAllocator
} else if constexpr (has_root_dictionary<T>::value == true || is_specialization_v<T, ROOTSerialized> == true) {
// Serialize a snapshot of an object with root dictionary
payloadMessage = proxy.createOutputMessage(routeIndex);
payloadMessage->Rebuild(4096, {64});
if constexpr (is_specialization_v<T, ROOTSerialized> == true) {
// Explicitely ROOT serialize a snapshot of object.
// An object wrapped into type `ROOTSerialized` is explicitely marked to be ROOT serialized
Expand Down
13 changes: 10 additions & 3 deletions Framework/Core/include/Framework/DataRefUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,15 @@ struct DataRefUtils {
throw runtime_error("Attempt to extract a TMessage from non-ROOT serialised message");
}

typename RSS::FairTMessage ftm(const_cast<char*>(ref.payload), payloadSize);
auto* storedClass = ftm.GetClass();
typename RSS::FairInputTBuffer ftm(const_cast<char*>(ref.payload), payloadSize);
auto* requestedClass = RSS::TClass::GetClass(typeid(T));
ftm.InitMap();
auto* storedClass = ftm.ReadClass();
// should always have the class description if has_root_dictionary is true
assert(requestedClass != nullptr);

ftm.SetBufferOffset(0);
ftm.ResetMap();
auto* object = ftm.ReadObjectAny(storedClass);
if (object == nullptr) {
throw runtime_error_f("Failed to read object with name %s from message using ROOT serialization.",
Expand Down Expand Up @@ -146,7 +149,11 @@ struct DataRefUtils {
throw runtime_error("ROOT serialization not supported, dictionary not found for data type");
}

typename RSS::FairTMessage ftm(const_cast<char*>(ref.payload), payloadSize);
typename RSS::FairInputTBuffer ftm(const_cast<char*>(ref.payload), payloadSize);
ftm.InitMap();
auto* classInfo = ftm.ReadClass();
ftm.SetBufferOffset(0);
ftm.ResetMap();
result.reset(static_cast<wrapped*>(ftm.ReadObjectAny(cl)));
if (result.get() == nullptr) {
throw runtime_error_f("Unable to extract class %s", cl == nullptr ? "<name not available>" : cl->GetName());
Expand Down
3 changes: 3 additions & 0 deletions Framework/Core/include/Framework/RootMessageContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ class RootSerializedObject : public MessageContext::ContextObject
fair::mq::Parts finalize() final
{
assert(mParts.Size() == 1);
if (mPayloadMsg->GetSize() < sizeof(char*)) {
mPayloadMsg->Rebuild(4096, {64});
}
TMessageSerializer::Serialize(*mPayloadMsg, mObject.get(), nullptr);
mParts.AddPart(std::move(mPayloadMsg));
return ContextObject::finalize();
Expand Down
3 changes: 2 additions & 1 deletion Framework/Core/include/Framework/RootSerializationSupport.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ namespace o2::framework
/// compiler.
struct RootSerializationSupport {
using TClass = ::TClass;
using FairTMessage = o2::framework::FairTMessage;
using FairInputTBuffer = o2::framework::FairInputTBuffer;
using FairOutputBuffer = o2::framework::FairOutputTBuffer;
using TObject = ::TObject;
};

Expand Down
130 changes: 59 additions & 71 deletions Framework/Core/include/Framework/TMessageSerializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@
#include "Framework/RuntimeError.h"

#include <TList.h>
#include <TMessage.h>
#include <TBufferFile.h>
#include <TObjArray.h>
#include <TStreamerInfo.h>
#include <gsl/util>
#include <gsl/span>
#include <gsl/narrow>
Expand All @@ -28,67 +27,76 @@

namespace o2::framework
{
class FairTMessage;
class FairOutputTBuffer;
class FairInputTBuffer;

// utilities to produce a span over a byte buffer held by various message types
// this is to avoid littering code with casts and conversions (span has a signed index type(!))
gsl::span<std::byte> as_span(const FairTMessage& msg);
gsl::span<std::byte> as_span(const FairInputTBuffer& msg);
gsl::span<std::byte> as_span(const FairOutputTBuffer& msg);
gsl::span<std::byte> as_span(const fair::mq::Message& msg);

class FairTMessage : public TMessage
// A TBufferFile which we can use to serialise data to a FairMQ message.
class FairOutputTBuffer : public TBufferFile
{
public:
using TMessage::TMessage;
FairTMessage() : TMessage(kMESS_OBJECT) {}
FairTMessage(void* buf, Int_t len) : TMessage(buf, len) { ResetBit(kIsOwner); }
FairTMessage(gsl::span<std::byte> buf) : TMessage(buf.data(), buf.size()) { ResetBit(kIsOwner); }
// This is to serialise data to FairMQ. We embed the pointer to the message
// in the data itself, so that we can use it to reallocate the message if needed.
// The FairMQ message retains ownership of the data.
// When deserialising the root object, keep in mind one needs to skip the 8 bytes
// for the pointer.
FairOutputTBuffer(fair::mq::Message& msg)
: TBufferFile(TBuffer::kWrite, msg.GetSize() - sizeof(char*), embedInItself(msg), false, fairMQrealloc)
{
}
// Helper function to keep track of the FairMQ message that holds the data
// in the data itself. We can use this to make sure the message can be reallocated
// even if we simply have a pointer to the data. Hopefully ROOT will not play dirty
// with us.
void* embedInItself(fair::mq::Message& msg);
// helper function to clean up the object holding the data after it is transported.
static void free(void* /*data*/, void* hint);
static char* fairMQrealloc(char* oldData, size_t newSize, size_t oldSize);
};

struct TMessageSerializer {
using StreamerList = std::vector<TVirtualStreamerInfo*>;
using CompressionLevel = int;
class FairInputTBuffer : public TBufferFile
{
public:
// This is to serialise data to FairMQ. The provided message is expeted to have 8 bytes
// of overhead, where the source embedded the pointer for the reallocation.
// Notice this will break if the sender and receiver are not using the same
// size for a pointer.
FairInputTBuffer(char* data, size_t size)
: TBufferFile(TBuffer::kRead, size - sizeof(char*), data + sizeof(char*), false, nullptr)
{
}
};

static void Serialize(fair::mq::Message& msg, const TObject* input,
CompressionLevel compressionLevel = -1);
struct TMessageSerializer {
static void Serialize(fair::mq::Message& msg, const TObject* input);

template <typename T>
static void Serialize(fair::mq::Message& msg, const T* input, const TClass* cl, //
CompressionLevel compressionLevel = -1);
static void Serialize(fair::mq::Message& msg, const T* input, const TClass* cl);

template <typename T = TObject>
static void Deserialize(const fair::mq::Message& msg, std::unique_ptr<T>& output);

static void serialize(FairTMessage& msg, const TObject* input,
CompressionLevel compressionLevel = -1);
static void serialize(o2::framework::FairOutputTBuffer& msg, const TObject* input);

template <typename T>
static void serialize(FairTMessage& msg, const T* input, //
const TClass* cl,
CompressionLevel compressionLevel = -1);
static void serialize(o2::framework::FairOutputTBuffer& msg, const T* input, const TClass* cl);

template <typename T = TObject>
static std::unique_ptr<T> deserialize(gsl::span<std::byte> buffer);
template <typename T = TObject>
static inline std::unique_ptr<T> deserialize(std::byte* buffer, size_t size);
static inline std::unique_ptr<T> deserialize(FairInputTBuffer& buffer);
};

inline void TMessageSerializer::serialize(FairTMessage& tm, const TObject* input,
CompressionLevel compressionLevel)
inline void TMessageSerializer::serialize(FairOutputTBuffer& tm, const TObject* input)
{
return serialize(tm, input, nullptr, compressionLevel);
return serialize(tm, input, nullptr);
}

template <typename T>
inline void TMessageSerializer::serialize(FairTMessage& tm, const T* input, //
const TClass* cl, CompressionLevel compressionLevel)
inline void TMessageSerializer::serialize(FairOutputTBuffer& tm, const T* input, const TClass* cl)
{
if (compressionLevel >= 0) {
// if negative, skip to use ROOT default
tm.SetCompressionLevel(compressionLevel);
}

// TODO: check what WriateObject and WriteObjectAny are doing
if (cl == nullptr) {
tm.WriteObject(input);
Expand All @@ -98,7 +106,7 @@ inline void TMessageSerializer::serialize(FairTMessage& tm, const T* input, //
}

template <typename T>
inline std::unique_ptr<T> TMessageSerializer::deserialize(gsl::span<std::byte> buffer)
inline std::unique_ptr<T> TMessageSerializer::deserialize(FairInputTBuffer& buffer)
{
TClass* tgtClass = TClass::GetClass(typeid(T));
if (tgtClass == nullptr) {
Expand All @@ -107,61 +115,41 @@ inline std::unique_ptr<T> TMessageSerializer::deserialize(gsl::span<std::byte> b
// FIXME: we need to add consistency check for buffer data to be serialized
// at the moment, TMessage might simply crash if an invalid or inconsistent
// buffer is provided
FairTMessage tm(buffer);
TClass* serializedClass = tm.GetClass();
buffer.InitMap();
TClass* serializedClass = buffer.ReadClass();
buffer.SetBufferOffset(0);
buffer.ResetMap();
if (serializedClass == nullptr) {
throw runtime_error_f("can not read class info from buffer");
}
if (tgtClass != serializedClass && serializedClass->GetBaseClass(tgtClass) == nullptr) {
throw runtime_error_f("can not convert serialized class %s into target class %s",
tm.GetClass()->GetName(),
serializedClass->GetName(),
tgtClass->GetName());
}
return std::unique_ptr<T>(reinterpret_cast<T*>(tm.ReadObjectAny(serializedClass)));
return std::unique_ptr<T>(reinterpret_cast<T*>(buffer.ReadObjectAny(serializedClass)));
}

template <typename T>
inline std::unique_ptr<T> TMessageSerializer::deserialize(std::byte* buffer, size_t size)
inline void TMessageSerializer::Serialize(fair::mq::Message& msg, const TObject* input)
{
return deserialize<T>(gsl::span<std::byte>(buffer, gsl::narrow<gsl::span<std::byte>::size_type>(size)));
}

inline void FairTMessage::free(void* /*data*/, void* hint)
{
std::default_delete<FairTMessage> deleter;
deleter(static_cast<FairTMessage*>(hint));
}

inline void TMessageSerializer::Serialize(fair::mq::Message& msg, const TObject* input,
TMessageSerializer::CompressionLevel compressionLevel)
{
std::unique_ptr<FairTMessage> tm = std::make_unique<FairTMessage>(kMESS_OBJECT);

serialize(*tm, input, input->Class(), compressionLevel);

msg.Rebuild(tm->Buffer(), tm->BufferSize(), FairTMessage::free, tm.get());
tm.release();
FairOutputTBuffer output(msg);
serialize(output, input, input->Class());
}

template <typename T>
inline void TMessageSerializer::Serialize(fair::mq::Message& msg, const T* input, //
const TClass* cl, //
TMessageSerializer::CompressionLevel compressionLevel)
inline void TMessageSerializer::Serialize(fair::mq::Message& msg, const T* input, const TClass* cl)
{
std::unique_ptr<FairTMessage> tm = std::make_unique<FairTMessage>(kMESS_OBJECT);

serialize(*tm, input, cl, compressionLevel);

msg.Rebuild(tm->Buffer(), tm->BufferSize(), FairTMessage::free, tm.get());
tm.release();
FairOutputTBuffer output(msg);
serialize(output, input, cl);
}

template <typename T>
inline void TMessageSerializer::Deserialize(const fair::mq::Message& msg, std::unique_ptr<T>& output)
{
// we know the message will not be modified by this,
// so const_cast should be OK here(IMHO).
output = deserialize(as_span(msg));
FairInputTBuffer input(static_cast<char*>(msg.GetData()), static_cast<int>(msg.GetSize()));
output = deserialize(input);
}

// gsl::narrow is used to do a runtime narrowing check, this might be a bit paranoid,
Expand All @@ -171,7 +159,7 @@ inline gsl::span<std::byte> as_span(const fair::mq::Message& msg)
return gsl::span<std::byte>{static_cast<std::byte*>(msg.GetData()), gsl::narrow<gsl::span<std::byte>::size_type>(msg.GetSize())};
}

inline gsl::span<std::byte> as_span(const FairTMessage& msg)
inline gsl::span<std::byte> as_span(const FairInputTBuffer& msg)
{
return gsl::span<std::byte>{reinterpret_cast<std::byte*>(msg.Buffer()),
gsl::narrow<gsl::span<std::byte>::size_type>(msg.BufferSize())};
Expand Down
7 changes: 5 additions & 2 deletions Framework/Core/src/CommonDataProcessors.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,12 @@ DataProcessorSpec CommonDataProcessors::getOutputObjHistSink(std::vector<OutputO
return;
}

FairTMessage tm(const_cast<char*>(ref.payload), static_cast<int>(datah->payloadSize));
InputObject obj;
obj.kind = tm.GetClass();
FairInputTBuffer tm(const_cast<char*>(ref.payload), static_cast<int>(datah->payloadSize));
tm.InitMap();
obj.kind = tm.ReadClass();
tm.SetBufferOffset(0);
tm.ResetMap();
if (obj.kind == nullptr) {
LOG(error) << "Cannot read class info from buffer.";
return;
Expand Down
37 changes: 37 additions & 0 deletions Framework/Core/src/TMessageSerializer.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,44 @@
// granted to it by virtue of its status as an Intergovernmental Organization
// or submit itself to any jurisdiction.
#include <Framework/TMessageSerializer.h>
#include <FairMQTransportFactory.h>
#include <algorithm>
#include <memory>

using namespace o2::framework;

void* FairOutputTBuffer::embedInItself(fair::mq::Message& msg)
{
// The first bytes of the message are used to store the pointer to the message itself
// so that we can reallocate it if needed.
if (sizeof(char*) > msg.GetSize()) {
throw std::runtime_error("Message size too small to embed pointer");
}
char* data = reinterpret_cast<char*>(msg.GetData());
char* ptr = reinterpret_cast<char*>(&msg);
std::memcpy(data, &ptr, sizeof(char*));
return data + sizeof(char*);
}

// Reallocation function. Get the message pointer from the data and call Rebuild.
char* FairOutputTBuffer::fairMQrealloc(char* oldData, size_t newSize, size_t oldSize)
{
// Old data is the pointer at the beginning of the message, so the pointer
// to the message is **stored** in the 8 bytes before it.
auto* msg = *(fair::mq::Message**)(oldData - sizeof(char*));
if (newSize <= msg->GetSize()) {
// no need to reallocate, the message is already big enough
return oldData;
}
// Create a shallow copy of the message
fair::mq::MessagePtr oldMsg = msg->GetTransport()->CreateMessage();
oldMsg->Copy(*msg);
// Copy the old data while rebuilding. Reference counting should make
// sure the old message is not deleted until the new one is ready.
// We need 8 extra bytes for the pointer to the message itself (realloc does not know about it)
// and we need to copy 8 bytes more than the old size (again, the extra pointer).
msg->Rebuild(newSize + 8, fair::mq::Alignment{64});
memcpy(msg->GetData(), oldMsg->GetData(), oldSize + 8);

return reinterpret_cast<char*>(msg->GetData()) + sizeof(char*);
}
29 changes: 25 additions & 4 deletions Framework/Core/test/test_DataRefUtils.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,38 @@

using namespace o2::framework;

TEST_CASE("PureRootTest")
{
TBufferFile buffer(TBuffer::kWrite);
TObjString s("test");
buffer.WriteObject(&s);

TBufferFile buffer2(TBuffer::kRead, buffer.BufferSize(), buffer.Buffer(), false);
buffer2.SetReadMode();
buffer2.InitMap();
TClass* storedClass = buffer2.ReadClass();
// ReadClass advances the buffer, so we need to reset it.
buffer2.SetBufferOffset(0);
buffer2.ResetMap();
REQUIRE(storedClass != nullptr);
auto* outS = (TObjString*)buffer2.ReadObjectAny(storedClass);
REQUIRE(outS != nullptr);
REQUIRE(outS->GetString() == "test");
}

// Simple test to do root deserialization.
TEST_CASE("TestRootSerialization")
{
DataRef ref;
TMessage* tm = new TMessage(kMESS_OBJECT);
auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq");
auto msg = transport->CreateMessage(4096);
FairOutputTBuffer tm(*msg);
auto sOrig = std::make_unique<TObjString>("test");
tm->WriteObject(sOrig.get());
tm << sOrig.get();
o2::header::DataHeader dh;
dh.payloadSerializationMethod = o2::header::gSerializationMethodROOT;
ref.payload = tm->Buffer();
dh.payloadSize = tm->BufferSize();
ref.payload = (char*)msg->GetData();
dh.payloadSize = (size_t)msg->GetSize();
ref.header = reinterpret_cast<char const*>(&dh);

// Check by using the same type
Expand Down
Loading

0 comments on commit 03228df

Please sign in to comment.