Skip to content

Commit

Permalink
Generalize Containers.h (#862)
Browse files Browse the repository at this point in the history
* add changes

* Format

---------

Co-authored-by: Colin Unger <lockshaw@lockshaw.net>
  • Loading branch information
KateUnger and lockshaw authored Jul 18, 2023
1 parent 804707d commit fc8d32a
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 27 deletions.
31 changes: 11 additions & 20 deletions lib/utils/include/utils/containers.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "bidict.h"
#include "invoke.h"
#include "optional.h"
#include "type_traits.h"
#include <algorithm>
#include <cassert>
#include <functional>
Expand All @@ -12,7 +13,6 @@
#include <numeric>
#include <sstream>
#include <string>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>
#include <vector>
Expand Down Expand Up @@ -509,19 +509,21 @@ std::unordered_set<Out> flatmap_v2(std::unordered_set<In> const &v,
return result;
}

template <typename T, typename F>
std::vector<T> sorted_by(std::unordered_set<T> const &s, F const &f) {
std::vector<T> result(s.begin(), s.end());
inplace_sorted_by(s, f);
template <typename C, typename F, typename Elem = typename C::value_type>
std::vector<Elem> sorted_by(C const &c, F const &f) {
std::vector<Elem> result(c.begin(), c.end());
inplace_sorted_by(c, f);
return result;
}

template <typename T, typename F>
void inplace_sorted_by(std::vector<T> &v, F const &f) {
auto custom_comparator = [&](T const &lhs, T const &rhs) -> bool {
template <typename C, typename F, typename Elem = typename C::value_type>
void inplace_sorted_by(C &c, F const &f) {
CHECK_SUPPORTS_ITERATOR_TAG(std::random_access_iterator_tag, C);

auto custom_comparator = [&](C const &lhs, C const &rhs) -> bool {
return f(lhs, rhs);
};
std::sort(v.begin(), v.end(), custom_comparator);
std::sort(c.begin(), c.end(), custom_comparator);
}

template <typename C, typename F>
Expand All @@ -531,17 +533,6 @@ C filter(C const &v, F const &f) {
return result;
}

template <typename T, typename F>
std::unordered_set<T> filter(std::unordered_set<T> const &v, F const &f) {
std::unordered_set<T> result;
for (T const &t : v) {
if (f(t)) {
result.insert(t);
}
}
return result;
}

template <typename C, typename F, typename Elem = typename C::value_type>
void inplace_filter(C &v, F const &f) {
std::remove_if(v.begin(), v.end(), [&](Elem const &e) { return !f(e); });
Expand Down
33 changes: 33 additions & 0 deletions lib/utils/include/utils/stack_map.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef _FLEXFLOW_UTILS_STACK_MAP_H
#define _FLEXFLOW_UTILS_STACK_MAP_H

#include "containers.h"
#include "optional.h"
#include "stack_vector.h"

Expand Down Expand Up @@ -28,6 +29,28 @@ struct stack_map {
}
}

size_t size() const {
return this->contents.size();
}

bool empty() const {
return this->contents.empty();
}

friend bool operator==(stack_map const &lhs, stack_map const &rhs) {
if (lhs.size() != rhs.size()) {
return false;
}
return lhs.sorted() == rhs.sorted();
}

friend bool operator!=(stack_map const &lhs, stack_map const &rhs) {
if (lhs.size() != rhs.size()) {
return true;
}
return lhs.sorted() != rhs.sorted();
}

V &at(K const &k) {
return this->contents.at(get_idx(k).value()).second;
}
Expand Down Expand Up @@ -70,6 +93,15 @@ struct stack_map {
}

private:
std::vector<std::pair<K, V>> sorted() const {
auto comparator = [](std::pair<K, V> const &lhs,
std::pair<K, V> const &rhs) {
return lhs.first < rhs.first;
};

return sorted_by(this->contents, comparator);
}

optional<size_t> get_idx(K const &k) const {
for (std::size_t idx = 0; idx < contents.size(); idx++) {
if (contents.at(idx).first == k) {
Expand All @@ -82,6 +114,7 @@ struct stack_map {

stack_vector<std::pair<K, V>, MAXSIZE> contents;
};
CHECK_WELL_BEHAVED_VALUE_TYPE_NO_HASH(stack_map<int, int, 10>);

} // namespace FlexFlow

Expand Down
4 changes: 4 additions & 0 deletions lib/utils/include/utils/stack_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,10 @@ struct stack_vector {
return this->m_size;
}

bool empty() const {
return (this->m_size == 0);
}

private:
std::size_t m_size = 0;
std::array<optional<T>, MAXSIZE> contents;
Expand Down
17 changes: 10 additions & 7 deletions lib/utils/include/utils/type_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,6 @@ struct is_rc_copy_virtual_compliant
std::is_move_assignable<T>>>,
std::has_virtual_destructor<T>> {};

template <typename T, typename Enable = void>
struct is_clonable : std::false_type {};

template <typename T>
struct is_clonable<T, void_t<decltype(std::declval<T>().clone())>>
: std::true_type {};

template <typename T, typename Enable = void>
struct is_streamable : std::false_type {};

Expand Down Expand Up @@ -120,6 +113,16 @@ struct elements_satisfy<Cond, std::tuple<>> : std::true_type {};
static_assert(
elements_satisfy<is_equal_comparable, std::tuple<int, float>>::value, "");

template <typename C, typename Tag>
struct supports_iterator_tag
: std::is_base_of<Tag,
typename std::iterator_traits<
typename C::iterator>::iterator_category> {};

#define CHECK_SUPPORTS_ITERATOR_TAG(TAG, ...) \
static_assert(supports_iterator_tag<__VA_ARGS__, TAG>::value, \
#__VA_ARGS__ " does not support required iterator tag " #TAG);

template <typename T>
using is_default_constructible = std::is_default_constructible<T>;

Expand Down

0 comments on commit fc8d32a

Please sign in to comment.