Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize Containers.h #862

Merged
merged 2 commits into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 11 additions & 20 deletions lib/utils/include/utils/containers.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "bidict.h"
#include "invoke.h"
#include "optional.h"
#include "type_traits.h"
#include <algorithm>
#include <cassert>
#include <functional>
Expand All @@ -12,7 +13,6 @@
#include <numeric>
#include <sstream>
#include <string>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>
#include <vector>
Expand Down Expand Up @@ -509,19 +509,21 @@ std::unordered_set<Out> flatmap_v2(std::unordered_set<In> const &v,
return result;
}

template <typename T, typename F>
std::vector<T> sorted_by(std::unordered_set<T> const &s, F const &f) {
std::vector<T> result(s.begin(), s.end());
inplace_sorted_by(s, f);
template <typename C, typename F, typename Elem = typename C::value_type>
std::vector<Elem> sorted_by(C const &c, F const &f) {
std::vector<Elem> result(c.begin(), c.end());
inplace_sorted_by(c, f);
return result;
}

template <typename T, typename F>
void inplace_sorted_by(std::vector<T> &v, F const &f) {
auto custom_comparator = [&](T const &lhs, T const &rhs) -> bool {
template <typename C, typename F, typename Elem = typename C::value_type>
void inplace_sorted_by(C &c, F const &f) {
CHECK_SUPPORTS_ITERATOR_TAG(std::random_access_iterator_tag, C);

auto custom_comparator = [&](C const &lhs, C const &rhs) -> bool {
return f(lhs, rhs);
};
std::sort(v.begin(), v.end(), custom_comparator);
std::sort(c.begin(), c.end(), custom_comparator);
}

template <typename C, typename F>
Expand All @@ -531,17 +533,6 @@ C filter(C const &v, F const &f) {
return result;
}

template <typename T, typename F>
std::unordered_set<T> filter(std::unordered_set<T> const &v, F const &f) {
std::unordered_set<T> result;
for (T const &t : v) {
if (f(t)) {
result.insert(t);
}
}
return result;
}

template <typename C, typename F, typename Elem = typename C::value_type>
void inplace_filter(C &v, F const &f) {
std::remove_if(v.begin(), v.end(), [&](Elem const &e) { return !f(e); });
Expand Down
33 changes: 33 additions & 0 deletions lib/utils/include/utils/stack_map.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef _FLEXFLOW_UTILS_STACK_MAP_H
#define _FLEXFLOW_UTILS_STACK_MAP_H

#include "containers.h"
#include "optional.h"
#include "stack_vector.h"

Expand Down Expand Up @@ -28,6 +29,28 @@ struct stack_map {
}
}

size_t size() const {
return this->contents.size();
}

bool empty() const {
return this->contents.empty();
}

friend bool operator==(stack_map const &lhs, stack_map const &rhs) {
if (lhs.size() != rhs.size()) {
return false;
}
return lhs.sorted() == rhs.sorted();
}

friend bool operator!=(stack_map const &lhs, stack_map const &rhs) {
if (lhs.size() != rhs.size()) {
return true;
}
return lhs.sorted() != rhs.sorted();
}

V &at(K const &k) {
return this->contents.at(get_idx(k).value()).second;
}
Expand Down Expand Up @@ -70,6 +93,15 @@ struct stack_map {
}

private:
std::vector<std::pair<K, V>> sorted() const {
auto comparator = [](std::pair<K, V> const &lhs,
std::pair<K, V> const &rhs) {
return lhs.first < rhs.first;
};

return sorted_by(this->contents, comparator);
}

optional<size_t> get_idx(K const &k) const {
for (std::size_t idx = 0; idx < contents.size(); idx++) {
if (contents.at(idx).first == k) {
Expand All @@ -82,6 +114,7 @@ struct stack_map {

stack_vector<std::pair<K, V>, MAXSIZE> contents;
};
CHECK_WELL_BEHAVED_VALUE_TYPE_NO_HASH(stack_map<int, int, 10>);

} // namespace FlexFlow

Expand Down
4 changes: 4 additions & 0 deletions lib/utils/include/utils/stack_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,10 @@ struct stack_vector {
return this->m_size;
}

bool empty() const {
return (this->m_size == 0);
}

private:
std::size_t m_size = 0;
std::array<optional<T>, MAXSIZE> contents;
Expand Down
17 changes: 10 additions & 7 deletions lib/utils/include/utils/type_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,6 @@ struct is_rc_copy_virtual_compliant
std::is_move_assignable<T>>>,
std::has_virtual_destructor<T>> {};

template <typename T, typename Enable = void>
struct is_clonable : std::false_type {};

template <typename T>
struct is_clonable<T, void_t<decltype(std::declval<T>().clone())>>
: std::true_type {};

template <typename T, typename Enable = void>
struct is_streamable : std::false_type {};

Expand Down Expand Up @@ -120,6 +113,16 @@ struct elements_satisfy<Cond, std::tuple<>> : std::true_type {};
static_assert(
elements_satisfy<is_equal_comparable, std::tuple<int, float>>::value, "");

template <typename C, typename Tag>
struct supports_iterator_tag
: std::is_base_of<Tag,
typename std::iterator_traits<
typename C::iterator>::iterator_category> {};

#define CHECK_SUPPORTS_ITERATOR_TAG(TAG, ...) \
static_assert(supports_iterator_tag<__VA_ARGS__, TAG>::value, \
#__VA_ARGS__ " does not support required iterator tag " #TAG);

template <typename T>
using is_default_constructible = std::is_default_constructible<T>;

Expand Down
Loading