diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index c75876ecf1..206a817ce8 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -4,6 +4,7 @@ #include "bidict.h" #include "invoke.h" #include "optional.h" +#include "type_traits.h" #include #include #include @@ -12,7 +13,6 @@ #include #include #include -#include #include #include #include @@ -509,19 +509,21 @@ std::unordered_set flatmap_v2(std::unordered_set const &v, return result; } -template -std::vector sorted_by(std::unordered_set const &s, F const &f) { - std::vector result(s.begin(), s.end()); - inplace_sorted_by(s, f); +template +std::vector sorted_by(C const &c, F const &f) { + std::vector result(c.begin(), c.end()); + inplace_sorted_by(c, f); return result; } -template -void inplace_sorted_by(std::vector &v, F const &f) { - auto custom_comparator = [&](T const &lhs, T const &rhs) -> bool { +template +void inplace_sorted_by(C &c, F const &f) { + CHECK_SUPPORTS_ITERATOR_TAG(std::random_access_iterator_tag, C); + + auto custom_comparator = [&](C const &lhs, C const &rhs) -> bool { return f(lhs, rhs); }; - std::sort(v.begin(), v.end(), custom_comparator); + std::sort(c.begin(), c.end(), custom_comparator); } template @@ -531,17 +533,6 @@ C filter(C const &v, F const &f) { return result; } -template -std::unordered_set filter(std::unordered_set const &v, F const &f) { - std::unordered_set result; - for (T const &t : v) { - if (f(t)) { - result.insert(t); - } - } - return result; -} - template void inplace_filter(C &v, F const &f) { std::remove_if(v.begin(), v.end(), [&](Elem const &e) { return !f(e); }); diff --git a/lib/utils/include/utils/stack_map.h b/lib/utils/include/utils/stack_map.h index 17e42ce755..bc08f44dc4 100644 --- a/lib/utils/include/utils/stack_map.h +++ b/lib/utils/include/utils/stack_map.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_UTILS_STACK_MAP_H #define _FLEXFLOW_UTILS_STACK_MAP_H +#include "containers.h" #include "optional.h" #include "stack_vector.h" @@ -28,6 +29,28 @@ struct stack_map { } } + size_t size() const { + return this->contents.size(); + } + + bool empty() const { + return this->contents.empty(); + } + + friend bool operator==(stack_map const &lhs, stack_map const &rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + return lhs.sorted() == rhs.sorted(); + } + + friend bool operator!=(stack_map const &lhs, stack_map const &rhs) { + if (lhs.size() != rhs.size()) { + return true; + } + return lhs.sorted() != rhs.sorted(); + } + V &at(K const &k) { return this->contents.at(get_idx(k).value()).second; } @@ -70,6 +93,15 @@ struct stack_map { } private: + std::vector> sorted() const { + auto comparator = [](std::pair const &lhs, + std::pair const &rhs) { + return lhs.first < rhs.first; + }; + + return sorted_by(this->contents, comparator); + } + optional get_idx(K const &k) const { for (std::size_t idx = 0; idx < contents.size(); idx++) { if (contents.at(idx).first == k) { @@ -82,6 +114,7 @@ struct stack_map { stack_vector, MAXSIZE> contents; }; +CHECK_WELL_BEHAVED_VALUE_TYPE_NO_HASH(stack_map); } // namespace FlexFlow diff --git a/lib/utils/include/utils/stack_vector.h b/lib/utils/include/utils/stack_vector.h index e74b9b6365..8564d8a4bf 100644 --- a/lib/utils/include/utils/stack_vector.h +++ b/lib/utils/include/utils/stack_vector.h @@ -259,6 +259,10 @@ struct stack_vector { return this->m_size; } + bool empty() const { + return (this->m_size == 0); + } + private: std::size_t m_size = 0; std::array, MAXSIZE> contents; diff --git a/lib/utils/include/utils/type_traits.h b/lib/utils/include/utils/type_traits.h index 299856f3fc..ee44f01983 100644 --- a/lib/utils/include/utils/type_traits.h +++ b/lib/utils/include/utils/type_traits.h @@ -44,13 +44,6 @@ struct is_rc_copy_virtual_compliant std::is_move_assignable>>, std::has_virtual_destructor> {}; -template -struct is_clonable : std::false_type {}; - -template -struct is_clonable().clone())>> - : std::true_type {}; - template struct is_streamable : std::false_type {}; @@ -120,6 +113,16 @@ struct elements_satisfy> : std::true_type {}; static_assert( elements_satisfy>::value, ""); +template +struct supports_iterator_tag + : std::is_base_of::iterator_category> {}; + +#define CHECK_SUPPORTS_ITERATOR_TAG(TAG, ...) \ + static_assert(supports_iterator_tag<__VA_ARGS__, TAG>::value, \ + #__VA_ARGS__ " does not support required iterator tag " #TAG); + template using is_default_constructible = std::is_default_constructible;