diff --git a/Framework/Core/include/Framework/ASoA.h b/Framework/Core/include/Framework/ASoA.h index fb0e282508e2c..7ede44819c02c 100644 --- a/Framework/Core/include/Framework/ASoA.h +++ b/Framework/Core/include/Framework/ASoA.h @@ -1166,6 +1166,14 @@ using is_soa_table_t = typename framework::is_specialization; template inline constexpr bool is_soa_table_like_v = framework::is_base_of_template_v; +template +class FilteredBase; +template +class Filtered; + +template +inline constexpr bool is_soa_filtered_v = framework::is_base_of_template_v; + /// Helper function to extract bound indices template static constexpr auto extractBindings(framework::pack) @@ -1173,10 +1181,126 @@ static constexpr auto extractBindings(framework::pack) return framework::pack{}; } +SelectionVector selectionToVector(gandiva::Selection const& sel); + +template +auto doSliceBy(T const* table, o2::framework::Preslice const& container, int value) +{ + if constexpr (o2::soa::is_binding_compatible_v()) { + std::shared_ptr out; + uint64_t offset = 0; + auto status = container.getSliceFor(value, table->asArrowTable(), out, offset); + auto t = typename T::self_t({out}, offset); + table->copyIndexBindings(t); + t.bindInternalIndicesTo(table); + return t; + } else { + static_assert(o2::framework::always_static_assert_v, "Wrong Preslice<> entry used: incompatible type"); + } +} + +template +auto doSliceBy(T const* table, o2::framework::PresliceUnsorted const& container, int value) +{ + if constexpr (o2::soa::is_binding_compatible_v()) { + auto selection = container.getSliceFor(value); + if constexpr (soa::is_soa_filtered_v) { + auto t = soa::Filtered({table->asArrowTable()}, selection); + table->copyIndexBindings(t); + t.bindInternalIndicesTo(table); + return t; + } else { + auto t = soa::Filtered({table->asArrowTable()}, selection); + table->copyIndexBindings(t); + t.bindInternalIndicesTo(table); + return t; + } + } else { + static_assert(o2::framework::always_static_assert_v, "Wrong Preslice<> entry used: incompatible type"); + } +} + template -class Filtered; +auto prepareFilteredSlice(T const* table, std::shared_ptr slice, uint64_t offset) +{ + if (offset >= table->tableSize()) { + if constexpr (soa::is_soa_filtered_v) { + Filtered fresult{{{slice}}, SelectionVector{}, 0}; + table->copyIndexBindings(fresult); + return fresult; + } else { + typename T::self_t fresult{{{slice}}, SelectionVector{}, 0}; + table->copyIndexBindings(fresult); + return fresult; + } + } + auto start = offset; + auto end = start + slice->num_rows(); + auto mSelectedRows = table->getSelectedRows(); + auto start_iterator = std::lower_bound(mSelectedRows.begin(), mSelectedRows.end(), start); + auto stop_iterator = std::lower_bound(start_iterator, mSelectedRows.end(), end); + SelectionVector slicedSelection{start_iterator, stop_iterator}; + std::transform(slicedSelection.begin(), slicedSelection.end(), slicedSelection.begin(), + [&start](int64_t idx) { + return idx - static_cast(start); + }); + if constexpr (soa::is_soa_filtered_v) { + Filtered fresult{{{slice}}, std::move(slicedSelection), start}; + table->copyIndexBindings(fresult); + return fresult; + } else { + typename T::self_t fresult{{{slice}}, std::move(slicedSelection), start}; + table->copyIndexBindings(fresult); + return fresult; + } +} -SelectionVector selectionToVector(gandiva::Selection const& sel); +template +auto doFilteredSliceBy(T const* table, o2::framework::Preslice const& container, int value) +{ + if constexpr (o2::soa::is_binding_compatible_v()) { + uint64_t offset = 0; + std::shared_ptr slice = nullptr; + auto status = container.getSliceFor(value, table->asArrowTable(), slice, offset); + return prepareFilteredSlice(table, slice, offset); + } else { + static_assert(o2::framework::always_static_assert_v, "Wrong Preslice<> entry used: incompatible type"); + } +} + +template +auto doSliceByCached(T const* table, framework::expressions::BindingNode const& node, int value, o2::framework::SliceCache& cache) +{ + auto localCache = cache.ptr->getCacheFor({o2::soa::getLabelFromTypeForKey(node.name), node.name}); + auto [offset, count] = localCache.getSliceFor(value); + auto t = typename T::self_t({table->asArrowTable()->Slice(static_cast(offset), count)}, static_cast(offset)); + table->copyIndexBindings(t); + return t; +} + +template +auto doFilteredSliceByCached(T const* table, framework::expressions::BindingNode const& node, int value, o2::framework::SliceCache& cache) +{ + auto localCache = cache.ptr->getCacheFor({o2::soa::getLabelFromTypeForKey(node.name), node.name}); + auto [offset, count] = localCache.getSliceFor(value); + auto slice = table->asArrowTable()->Slice(static_cast(offset), count); + return prepareFilteredSlice(table, slice, offset); +} + +template +auto doSliceByCachedUnsorted(T const* table, framework::expressions::BindingNode const& node, int value, o2::framework::SliceCache& cache) +{ + auto localCache = cache.ptr->getCacheUnsortedFor({o2::soa::getLabelFromTypeForKey(node.name), node.name}); + if constexpr (soa::is_soa_filtered_v) { + auto t = typename T::self_t({table->asArrowTable()}, localCache.getSliceFor(value)); + table->copyIndexBindings(t); + return t; + } else { + auto t = Filtered({table->asArrowTable()}, localCache.getSliceFor(value)); + table->copyIndexBindings(t); + return t; + } +} template auto select(T const& t, framework::expressions::Filter const& f) @@ -1192,6 +1316,7 @@ template class Table { public: + using self_t = Table; using table_t = Table; using columns = framework::pack; using column_types = framework::pack; @@ -1480,36 +1605,13 @@ class Table template auto sliceBy(o2::framework::Preslice const& container, int value) const { - if constexpr (o2::soa::is_binding_compatible_v>()) { - std::shared_ptr out; - uint64_t offset = 0; - auto status = container.getSliceFor(value, mTable, out, offset); - auto t = table_t({out}, offset); - copyIndexBindings(t); - t.bindInternalIndicesTo(this); - return t; - } else { - static_assert(o2::framework::always_static_assert_v, "Wrong Preslice<> entry used: incompatible type"); - } + return doSliceBy(this, container, value); } template auto sliceBy(o2::framework::PresliceUnsorted const& container, int value) const { - if constexpr (o2::soa::is_binding_compatible_v>()) { - auto selection = container.getSliceFor(value); - auto t = soa::Filtered({this->asArrowTable()}, selection); - copyIndexBindings(t); - t.bindInternalIndicesTo(this); - return t; - } else { - static_assert(o2::framework::always_static_assert_v, "Wrong Preslice<> entry used: incompatible type"); - } - } - - auto slice(uint64_t start, uint64_t end) const - { - return rawSlice(start, end); + return doSliceBy(this, container, value); } auto rawSlice(uint64_t start, uint64_t end) const @@ -2483,6 +2585,7 @@ struct Join : JoinBase { using base::bindExternalIndices; using base::bindInternalIndicesTo; + using self_t = Join; using table_t = base; using persistent_columns_t = typename table_t::persistent_columns_t; using iterator = typename table_t::template RowView, Ts...>; @@ -2492,49 +2595,24 @@ struct Join : JoinBase { auto sliceByCached(framework::expressions::BindingNode const& node, int value, o2::framework::SliceCache& cache) const { - auto localCache = cache.ptr->getCacheFor({o2::soa::getLabelFromTypeForKey>(node.name), node.name}); - auto [offset, count] = localCache.getSliceFor(value); - auto t = Join({this->asArrowTable()->Slice(static_cast(offset), count)}, static_cast(offset)); - this->copyIndexBindings(t); - return t; + return doSliceByCached(this, node, value, cache); } auto sliceByCachedUnsorted(framework::expressions::BindingNode const& node, int value, o2::framework::SliceCache& cache) const { - auto localCache = cache.ptr->getCacheUnsortedFor({o2::soa::getLabelFromTypeForKey>(node.name), node.name}); - auto t = Filtered>({this->asArrowTable()}, localCache.getSliceFor(value)); - this->copyIndexBindings(t); - return t; + return doSliceByCachedUnsorted(this, node, value, cache); } template auto sliceBy(o2::framework::Preslice const& container, int value) const { - if constexpr (o2::soa::is_binding_compatible_v>()) { - std::shared_ptr out; - uint64_t offset = 0; - auto status = container.getSliceFor(value, this->asArrowTable(), out, offset); - auto t = table_t({out}, offset); - this->copyIndexBindings(t); - t.bindInternalIndicesTo(this); - return t; - } else { - static_assert(o2::framework::always_static_assert_v, "Wrong Preslice<> entry used: incompatible type"); - } + return doSliceBy(this, container, value); } template auto sliceBy(o2::framework::PresliceUnsorted const& container, int value) const { - if constexpr (o2::soa::is_binding_compatible_v>()) { - auto selection = container.getSliceFor(value); - auto t = soa::Filtered>({this->asArrowTable()}, selection); - this->copyIndexBindings(t); - t.bindInternalIndicesTo(this); - return t; - } else { - static_assert(o2::framework::always_static_assert_v, "Wrong Preslice<> entry used: incompatible type"); - } + return doSliceBy(this, container, value); } template @@ -2753,78 +2831,28 @@ class FilteredBase : public T auto sliceByCached(framework::expressions::BindingNode const& node, int value, o2::framework::SliceCache& cache) const { - auto localCache = cache.ptr->getCacheFor({o2::soa::getLabelFromTypeForKey>(node.name), node.name}); - auto [offset, count] = localCache.getSliceFor(value); - auto slice = this->asArrowTable()->Slice(static_cast(offset), count); - if (offset >= this->tableSize()) { - self_t fresult{{slice}, SelectionVector{}, 0}; // empty slice - this->copyIndexBindings(fresult); - return fresult; - } - auto start = static_cast(offset); - auto end = start + slice->num_rows(); - auto start_iterator = std::lower_bound(mSelectedRows.begin(), mSelectedRows.end(), start); - auto stop_iterator = std::lower_bound(start_iterator, mSelectedRows.end(), end); - SelectionVector slicedSelection{start_iterator, stop_iterator}; - std::transform(slicedSelection.begin(), slicedSelection.end(), slicedSelection.begin(), - [&start](int64_t idx) { - return idx - static_cast(start); - }); - self_t fresult{{slice}, std::move(slicedSelection), start}; - this->copyIndexBindings(fresult); - return fresult; + return doFilteredSliceByCached(this, node, value, cache); } auto sliceByCachedUnsorted(framework::expressions::BindingNode const& node, int value, o2::framework::SliceCache& cache) const { - auto localCache = cache.ptr->getCacheUnsortedFor({o2::soa::getLabelFromTypeForKey>(node.name), node.name}); - auto t = std::decay_t{{this->asArrowTable()}, localCache.getSliceFor(value)}; + auto t = doSliceByCachedUnsorted(this, node, value, cache); t.intersectWithSelection(this->getSelectedRows()); - this->copyIndexBindings(t); return t; } template auto sliceBy(o2::framework::Preslice const& container, int value) const { - if constexpr (o2::soa::is_binding_compatible_v>()) { - uint64_t offset = 0; - std::shared_ptr result = nullptr; - auto status = container.getSliceFor(value, this->asArrowTable(), result, offset); - if (offset >= static_cast(this->tableSize())) { - self_t fresult{{result}, SelectionVector{}, 0}; // empty slice - this->copyIndexBindings(fresult); - return fresult; - } - auto start = offset; - auto end = start + result->num_rows(); - auto start_iterator = std::lower_bound(mSelectedRows.begin(), mSelectedRows.end(), start); - auto stop_iterator = std::lower_bound(start_iterator, mSelectedRows.end(), end); - SelectionVector slicedSelection{start_iterator, stop_iterator}; - std::transform(slicedSelection.begin(), slicedSelection.end(), slicedSelection.begin(), - [&start](int64_t idx) { - return idx - static_cast(start); - }); - self_t fresult{{result}, std::move(slicedSelection), start}; - this->copyIndexBindings(fresult); - return fresult; - } else { - static_assert(o2::framework::always_static_assert_v, "Wrong Preslice<> entry used: incompatible type"); - } + return doFilteredSliceBy(this, container, value); } template auto sliceBy(o2::framework::PresliceUnsorted const& container, int value) const { - if constexpr (o2::soa::is_binding_compatible_v>()) { - auto selection = container.getSliceFor(value); - auto t = std::decay_t{{this->asArrowTable()}, selection}; - t.intersectWithSelection(this->getSelectedRows()); - this->copyIndexBindings(t); - return t; - } else { - static_assert(o2::framework::always_static_assert_v, "Wrong Preslice<> entry used: incompatible type"); - } + auto t = doSliceBy(this, container, value); + t.intersectWithSelection(this->getSelectedRows()); + return t; } auto select(framework::expressions::Filter const& f) const @@ -2844,18 +2872,6 @@ class FilteredBase : public T } protected: - auto slice(uint64_t start, uint64_t end) - { - auto start_iterator = std::lower_bound(mSelectedRows.begin(), mSelectedRows.end(), start); - auto stop_iterator = std::lower_bound(start_iterator, mSelectedRows.end(), end); - SelectionVector slicedSelection{start_iterator, stop_iterator}; - std::transform(slicedSelection.begin(), slicedSelection.end(), slicedSelection.begin(), - [&](int64_t idx) { - return idx - static_cast(start); - }); - return self_t{{this->asArrowTable()->Slice(start, end - start + 1)}, std::move(slicedSelection), start}; - } - void sumWithSelection(SelectionVector const& selection) { mCached = true; @@ -2921,6 +2937,7 @@ template class Filtered : public FilteredBase { public: + using base_t = T; using self_t = Filtered; using table_t = typename FilteredBase::table_t; using originals = originals_pack_t; @@ -3006,36 +3023,42 @@ class Filtered : public FilteredBase return operator*=(other.getSelectedRows()); } + template + auto rawSliceBy(o2::framework::Preslice const& container, int value) const + { + return (table_t)this->sliceBy(container, value); + } + auto sliceByCached(framework::expressions::BindingNode const& node, int value, o2::framework::SliceCache& cache) const { - auto localCache = cache.ptr->getCacheFor({o2::soa::getLabelFromTypeForKey>(node.name), node.name}); - auto [offset, count] = localCache.getSliceFor(value); - auto slice = this->asArrowTable()->Slice(offset, count); - if (offset >= this->tableSize()) { - self_t fresult{{slice}, SelectionVector{}, 0}; // empty slice - this->copyIndexBindings(fresult); - return fresult; - } - auto start = offset; - auto end = start + slice->num_rows(); - auto start_iterator = std::lower_bound(this->getSelectedRows().begin(), this->getSelectedRows().end(), start); - auto stop_iterator = std::lower_bound(start_iterator, this->getSelectedRows().end(), end); - SelectionVector slicedSelection{start_iterator, stop_iterator}; - std::transform(slicedSelection.begin(), slicedSelection.end(), slicedSelection.begin(), - [&start](int64_t idx) { - return idx - static_cast(start); - }); - self_t fresult{{slice}, std::move(slicedSelection), static_cast(start)}; - this->copyIndexBindings(fresult); - return fresult; + return doFilteredSliceByCached(this, node, value, cache); } auto sliceByCachedUnsorted(framework::expressions::BindingNode const& node, int value, o2::framework::SliceCache& cache) const { - auto localCache = cache.ptr->getCacheUnsortedFor({o2::soa::getLabelFromTypeForKey>(node.name), node.name}); - auto t = std::decay_t{{this->asArrowTable()}, localCache.getSliceFor(value)}; + auto t = doSliceByCachedUnsorted(this, node, value, cache); + t.intersectWithSelection(this->getSelectedRows()); + return t; + } + + template + auto sliceBy(o2::framework::Preslice const& container, int value) const + { + return doFilteredSliceBy(this, container, value); + } + + template + auto sliceBy(o2::framework::PresliceUnsorted const& container, int value) const + { + auto t = doSliceBy(this, container, value); t.intersectWithSelection(this->getSelectedRows()); - this->copyIndexBindings(t); + return t; + } + + auto select(framework::expressions::Filter const& f) const + { + auto t = o2::soa::select(*this, f); + copyIndexBindings(t); return t; } }; @@ -3045,6 +3068,7 @@ class Filtered> : public FilteredBase { public: using self_t = Filtered>; + using base_t = T; using table_t = typename FilteredBase::table_t; using originals = originals_pack_t; @@ -3146,33 +3170,27 @@ class Filtered> : public FilteredBase auto sliceByCached(framework::expressions::BindingNode const& node, int value, o2::framework::SliceCache& cache) const { - auto localCache = cache.ptr->getCacheFor({o2::soa::getLabelFromTypeForKey>(node.name), node.name}); - auto [offset, count] = localCache.getSliceFor(value); - auto start = static_cast(offset); - auto slice = this->asArrowTable()->Slice(start, count); - auto end = start + slice->num_rows(); - auto start_iterator = std::lower_bound(this->getSelectedRows().begin(), this->getSelectedRows().end(), start); - auto stop_iterator = std::lower_bound(start_iterator, this->getSelectedRows().end(), end); - SelectionVector slicedSelection{start_iterator, stop_iterator}; - std::transform(slicedSelection.begin(), slicedSelection.end(), slicedSelection.begin(), - [&start](int64_t idx) { - return idx - static_cast(start); - }); - SelectionVector copy = slicedSelection; - Filtered filteredTable{{slice}, std::move(slicedSelection), start}; - std::vector> filtered{filteredTable}; - self_t fresult{std::move(filtered), std::move(copy), start}; - this->copyIndexBindings(fresult); - return fresult; + return doFilteredSliceByCached(this, node, value, cache); } auto sliceByCachedUnsorted(framework::expressions::BindingNode const& node, int value, o2::framework::SliceCache& cache) const { - auto localCache = cache.ptr->getCacheUnsortedFor({o2::soa::getLabelFromTypeForKey>(node.name), node.name}); - std::vector> filtered{Filtered{{this->asArrowTable()}, localCache.getSliceFor(value)}}; - auto t = std::decay_t{std::move(filtered), localCache.getSliceFor(value)}; + auto t = doSliceByCachedUnsorted(this, node, value, cache); + t.intersectWithSelection(this->getSelectedRows()); + return t; + } + + template + auto sliceBy(o2::framework::Preslice const& container, int value) const + { + return doFilteredSliceBy(this, container, value); + } + + template + auto sliceBy(o2::framework::PresliceUnsorted const& container, int value) const + { + auto t = doSliceBy(this, container, value); t.intersectWithSelection(this->getSelectedRows()); - this->copyIndexBindings(t); return t; } @@ -3187,9 +3205,6 @@ class Filtered> : public FilteredBase } }; -template -inline constexpr bool is_soa_filtered_v = framework::is_base_of_template_v; - /// Template for building an index table to access matching rows from non- /// joinable, but compatible tables, e.g. Collisions and ZDCs. /// First argument is the key table (BCs for the Collisions+ZDCs case), the rest