Skip to content

Commit

Permalink
[ntuple] fix up serialization of cluster summary
Browse files Browse the repository at this point in the history
  • Loading branch information
jblomer committed Sep 24, 2024
1 parent bec1bd6 commit 5cfe091
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 31 deletions.
1 change: 1 addition & 0 deletions tree/ntuple/v7/inc/ROOT/RNTupleSerialize.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ public:
struct RClusterSummary {
std::uint64_t fFirstEntry = 0;
std::uint64_t fNEntries = 0;
std::uint8_t fFlags = 0;
/// -1 for "all columns"
std::int32_t fColumnGroupID = -1;
};
Expand Down
39 changes: 20 additions & 19 deletions tree/ntuple/v7/src/RNTupleSerialize.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -1096,19 +1096,21 @@ std::uint32_t
ROOT::Experimental::Internal::RNTupleSerializer::SerializeClusterSummary(const RClusterSummary &clusterSummary,
void *buffer)
{
if (clusterSummary.fNEntries >= (static_cast<std::uint64_t>(1) << 56)) {
throw RException(R__FAIL("number of entries in cluster exceeds maximum of 2^56"));
}

auto base = reinterpret_cast<unsigned char *>(buffer);
auto pos = base;
void **where = (buffer == nullptr) ? &buffer : reinterpret_cast<void **>(&pos);

auto frame = pos;
pos += SerializeRecordFramePreamble(*where);
pos += SerializeUInt64(clusterSummary.fFirstEntry, *where);
if (clusterSummary.fColumnGroupID >= 0) {
pos += SerializeInt64(-static_cast<int64_t>(clusterSummary.fNEntries), *where);
pos += SerializeUInt32(clusterSummary.fColumnGroupID, *where);
} else {
pos += SerializeInt64(static_cast<int64_t>(clusterSummary.fNEntries), *where);
}
const std::uint64_t nEntriesAndFlags =
(static_cast<std::uint64_t>(clusterSummary.fFlags) << 56) | clusterSummary.fNEntries;
pos += SerializeUInt64(nEntriesAndFlags, *where);

auto size = pos - frame;
pos += SerializeFramePostscript(frame, size);
return size;
Expand All @@ -1131,21 +1133,20 @@ ROOT::Experimental::Internal::RNTupleSerializer::DeserializeClusterSummary(const
return R__FAIL("too short cluster summary");

bytes += DeserializeUInt64(bytes, clusterSummary.fFirstEntry);
std::int64_t nEntries;
bytes += DeserializeInt64(bytes, nEntries);
std::uint64_t nEntriesAndFlags;
bytes += DeserializeUInt64(bytes, nEntriesAndFlags);

if (nEntries < 0) {
if (fnFrameSizeLeft() < sizeof(std::uint32_t))
return R__FAIL("too short cluster summary");
clusterSummary.fNEntries = -nEntries;
std::uint32_t columnGroupID;
bytes += DeserializeUInt32(bytes, columnGroupID);
clusterSummary.fColumnGroupID = columnGroupID;
} else {
clusterSummary.fNEntries = nEntries;
clusterSummary.fColumnGroupID = -1;
const std::uint64_t nEntries = (nEntriesAndFlags << 8) >> 8;
const std::uint8_t flags = nEntriesAndFlags >> 56;

if (flags & 0x01) {
return R__FAIL("sharded cluster flag set in cluster summary; sharded clusters are currently unsupported.");
}

clusterSummary.fNEntries = nEntries;
clusterSummary.fFlags = flags;
clusterSummary.fColumnGroupID = -1;

return frameSize;
}

Expand Down Expand Up @@ -1511,7 +1512,7 @@ ROOT::Experimental::Internal::RNTupleSerializer::SerializePageList(void *buffer,
pos += SerializeListFramePreamble(nClusters, *where);
for (auto clusterId : physClusterIDs) {
const auto &clusterDesc = desc.GetClusterDescriptor(context.GetMemClusterId(clusterId));
RClusterSummary summary{clusterDesc.GetFirstEntryIndex(), clusterDesc.GetNEntries(), -1};
RClusterSummary summary{clusterDesc.GetFirstEntryIndex(), clusterDesc.GetNEntries(), 0, -1};
pos += SerializeClusterSummary(summary, *where);
}
pos += SerializeFramePostscript(buffer ? clusterSummaryFrame : nullptr, pos - clusterSummaryFrame);
Expand Down
30 changes: 18 additions & 12 deletions tree/ntuple/v7/test/ntuple_serialize.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -439,9 +439,10 @@ TEST(RNTuple, SerializeClusterSummary)
{
RNTupleSerializer::RClusterSummary summary;
summary.fFirstEntry = 42;
summary.fNEntries = 137;
summary.fNEntries = (static_cast<std::uint64_t>(1) << 56) - 1;
summary.fFlags = 0x02;

unsigned char buffer[28];
unsigned char buffer[24];
ASSERT_EQ(24u, RNTupleSerializer::SerializeClusterSummary(summary, nullptr));
EXPECT_EQ(24u, RNTupleSerializer::SerializeClusterSummary(summary, buffer));
RNTupleSerializer::RClusterSummary reco;
Expand All @@ -454,21 +455,26 @@ TEST(RNTuple, SerializeClusterSummary)
EXPECT_EQ(24u, RNTupleSerializer::DeserializeClusterSummary(buffer, 24, reco).Unwrap());
EXPECT_EQ(summary.fFirstEntry, reco.fFirstEntry);
EXPECT_EQ(summary.fNEntries, reco.fNEntries);
EXPECT_EQ(summary.fFlags, reco.fFlags);
EXPECT_EQ(summary.fColumnGroupID, reco.fColumnGroupID);

summary.fColumnGroupID = 13;
ASSERT_EQ(28u, RNTupleSerializer::SerializeClusterSummary(summary, nullptr));
EXPECT_EQ(28u, RNTupleSerializer::SerializeClusterSummary(summary, buffer));
summary.fFlags |= 0x01;
EXPECT_EQ(24u, RNTupleSerializer::SerializeClusterSummary(summary, buffer));
try {
RNTupleSerializer::DeserializeClusterSummary(buffer, 27, reco).Unwrap();
FAIL() << "too short cluster summary should fail";
RNTupleSerializer::DeserializeClusterSummary(buffer, 24, reco).Unwrap();
FAIL() << "sharded cluster flag should fail";
} catch (const RException& err) {
EXPECT_THAT(err.what(), testing::HasSubstr("too short"));
EXPECT_THAT(err.what(), testing::HasSubstr("sharded"));
}

summary.fFlags = 0;
summary.fNEntries++;
try {
RNTupleSerializer::SerializeClusterSummary(summary, buffer);
FAIL() << "overesized cluster should fail";
} catch (const RException &err) {
EXPECT_THAT(err.what(), testing::HasSubstr("exceed"));
}
EXPECT_EQ(28u, RNTupleSerializer::DeserializeClusterSummary(buffer, 28, reco).Unwrap());
EXPECT_EQ(summary.fFirstEntry, reco.fFirstEntry);
EXPECT_EQ(summary.fNEntries, reco.fNEntries);
EXPECT_EQ(summary.fColumnGroupID, reco.fColumnGroupID);
}

TEST(RNTuple, SerializeClusterGroup)
Expand Down

0 comments on commit 5cfe091

Please sign in to comment.