diff --git a/tree/ntuple/v7/inc/ROOT/RNTupleSerialize.hxx b/tree/ntuple/v7/inc/ROOT/RNTupleSerialize.hxx index 67380b6d90c28..589f2e3b9e9df 100644 --- a/tree/ntuple/v7/inc/ROOT/RNTupleSerialize.hxx +++ b/tree/ntuple/v7/inc/ROOT/RNTupleSerialize.hxx @@ -87,6 +87,7 @@ public: struct RClusterSummary { std::uint64_t fFirstEntry = 0; std::uint64_t fNEntries = 0; + std::uint8_t fFlags = 0; /// -1 for "all columns" std::int32_t fColumnGroupID = -1; }; diff --git a/tree/ntuple/v7/src/RNTupleSerialize.cxx b/tree/ntuple/v7/src/RNTupleSerialize.cxx index 4ed2067b45a97..10c86128cd1c8 100644 --- a/tree/ntuple/v7/src/RNTupleSerialize.cxx +++ b/tree/ntuple/v7/src/RNTupleSerialize.cxx @@ -1096,6 +1096,10 @@ std::uint32_t ROOT::Experimental::Internal::RNTupleSerializer::SerializeClusterSummary(const RClusterSummary &clusterSummary, void *buffer) { + if (clusterSummary.fNEntries >= (static_cast(1) << 56)) { + throw RException(R__FAIL("number of entries in cluster exceeds maximum of 2^56")); + } + auto base = reinterpret_cast(buffer); auto pos = base; void **where = (buffer == nullptr) ? &buffer : reinterpret_cast(&pos); @@ -1103,12 +1107,10 @@ ROOT::Experimental::Internal::RNTupleSerializer::SerializeClusterSummary(const R auto frame = pos; pos += SerializeRecordFramePreamble(*where); pos += SerializeUInt64(clusterSummary.fFirstEntry, *where); - if (clusterSummary.fColumnGroupID >= 0) { - pos += SerializeInt64(-static_cast(clusterSummary.fNEntries), *where); - pos += SerializeUInt32(clusterSummary.fColumnGroupID, *where); - } else { - pos += SerializeInt64(static_cast(clusterSummary.fNEntries), *where); - } + const std::uint64_t nEntriesAndFlags = + (static_cast(clusterSummary.fFlags) << 56) | clusterSummary.fNEntries; + pos += SerializeUInt64(nEntriesAndFlags, *where); + auto size = pos - frame; pos += SerializeFramePostscript(frame, size); return size; @@ -1131,21 +1133,20 @@ ROOT::Experimental::Internal::RNTupleSerializer::DeserializeClusterSummary(const return R__FAIL("too short cluster summary"); bytes += DeserializeUInt64(bytes, clusterSummary.fFirstEntry); - std::int64_t nEntries; - bytes += DeserializeInt64(bytes, nEntries); + std::uint64_t nEntriesAndFlags; + bytes += DeserializeUInt64(bytes, nEntriesAndFlags); - if (nEntries < 0) { - if (fnFrameSizeLeft() < sizeof(std::uint32_t)) - return R__FAIL("too short cluster summary"); - clusterSummary.fNEntries = -nEntries; - std::uint32_t columnGroupID; - bytes += DeserializeUInt32(bytes, columnGroupID); - clusterSummary.fColumnGroupID = columnGroupID; - } else { - clusterSummary.fNEntries = nEntries; - clusterSummary.fColumnGroupID = -1; + const std::uint64_t nEntries = (nEntriesAndFlags << 8) >> 8; + const std::uint8_t flags = nEntriesAndFlags >> 56; + + if (flags & 0x01) { + return R__FAIL("sharded cluster flag set in cluster summary; sharded clusters are currently unsupported."); } + clusterSummary.fNEntries = nEntries; + clusterSummary.fFlags = flags; + clusterSummary.fColumnGroupID = -1; + return frameSize; } @@ -1511,7 +1512,7 @@ ROOT::Experimental::Internal::RNTupleSerializer::SerializePageList(void *buffer, pos += SerializeListFramePreamble(nClusters, *where); for (auto clusterId : physClusterIDs) { const auto &clusterDesc = desc.GetClusterDescriptor(context.GetMemClusterId(clusterId)); - RClusterSummary summary{clusterDesc.GetFirstEntryIndex(), clusterDesc.GetNEntries(), -1}; + RClusterSummary summary{clusterDesc.GetFirstEntryIndex(), clusterDesc.GetNEntries(), 0, -1}; pos += SerializeClusterSummary(summary, *where); } pos += SerializeFramePostscript(buffer ? clusterSummaryFrame : nullptr, pos - clusterSummaryFrame); diff --git a/tree/ntuple/v7/test/ntuple_serialize.cxx b/tree/ntuple/v7/test/ntuple_serialize.cxx index ab2b6cd654372..2d9f4a3f77523 100644 --- a/tree/ntuple/v7/test/ntuple_serialize.cxx +++ b/tree/ntuple/v7/test/ntuple_serialize.cxx @@ -439,9 +439,10 @@ TEST(RNTuple, SerializeClusterSummary) { RNTupleSerializer::RClusterSummary summary; summary.fFirstEntry = 42; - summary.fNEntries = 137; + summary.fNEntries = (static_cast(1) << 56) - 1; + summary.fFlags = 0x02; - unsigned char buffer[28]; + unsigned char buffer[24]; ASSERT_EQ(24u, RNTupleSerializer::SerializeClusterSummary(summary, nullptr)); EXPECT_EQ(24u, RNTupleSerializer::SerializeClusterSummary(summary, buffer)); RNTupleSerializer::RClusterSummary reco; @@ -454,21 +455,26 @@ TEST(RNTuple, SerializeClusterSummary) EXPECT_EQ(24u, RNTupleSerializer::DeserializeClusterSummary(buffer, 24, reco).Unwrap()); EXPECT_EQ(summary.fFirstEntry, reco.fFirstEntry); EXPECT_EQ(summary.fNEntries, reco.fNEntries); + EXPECT_EQ(summary.fFlags, reco.fFlags); EXPECT_EQ(summary.fColumnGroupID, reco.fColumnGroupID); - summary.fColumnGroupID = 13; - ASSERT_EQ(28u, RNTupleSerializer::SerializeClusterSummary(summary, nullptr)); - EXPECT_EQ(28u, RNTupleSerializer::SerializeClusterSummary(summary, buffer)); + summary.fFlags |= 0x01; + EXPECT_EQ(24u, RNTupleSerializer::SerializeClusterSummary(summary, buffer)); try { - RNTupleSerializer::DeserializeClusterSummary(buffer, 27, reco).Unwrap(); - FAIL() << "too short cluster summary should fail"; + RNTupleSerializer::DeserializeClusterSummary(buffer, 24, reco).Unwrap(); + FAIL() << "sharded cluster flag should fail"; } catch (const RException& err) { - EXPECT_THAT(err.what(), testing::HasSubstr("too short")); + EXPECT_THAT(err.what(), testing::HasSubstr("sharded")); + } + + summary.fFlags = 0; + summary.fNEntries++; + try { + RNTupleSerializer::SerializeClusterSummary(summary, buffer); + FAIL() << "overesized cluster should fail"; + } catch (const RException &err) { + EXPECT_THAT(err.what(), testing::HasSubstr("exceed")); } - EXPECT_EQ(28u, RNTupleSerializer::DeserializeClusterSummary(buffer, 28, reco).Unwrap()); - EXPECT_EQ(summary.fFirstEntry, reco.fFirstEntry); - EXPECT_EQ(summary.fNEntries, reco.fNEntries); - EXPECT_EQ(summary.fColumnGroupID, reco.fColumnGroupID); } TEST(RNTuple, SerializeClusterGroup)