Skip to content

Commit

Permalink
new aliasing logic using compile-time hashing
Browse files Browse the repository at this point in the history
  • Loading branch information
alifahrri committed Jun 29, 2024
1 parent 5dcfa50 commit a759ea5
Show file tree
Hide file tree
Showing 54 changed files with 810 additions and 155 deletions.
34 changes: 34 additions & 0 deletions include/nmtools/array/array/atleast_nd.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#ifndef NMTOOLS_ARRAY_ARRAY_ATLEAST_ND_HPP
#define NMTOOLS_ARRAY_ARRAY_ATLEAST_ND_HPP

#include "nmtools/array/view/atleast_nd.hpp"
#include "nmtools/array/eval.hpp"

namespace nmtools::array
{
/**
* @brief Eagerly compute atleast_nd.
*
* @tparam output_t
* @tparam context_t
* @tparam array_t
* @param array Input array
* @param context Evaluation context
* @param output
* @return constexpr auto
*/
template <typename output_t=none_t, typename context_t=none_t, typename resolver_t=eval_result_t<>,
typename array_t, typename nd_t>
constexpr auto atleast_nd(const array_t& array, nd_t nd
, context_t&& context=context_t{}, output_t&& output=output_t{},meta::as_value<resolver_t> resolver=meta::as_value_v<resolver_t>)
{
auto viewed = view::atleast_nd(array,nd);
return eval(viewed
,nmtools::forward<context_t>(context)
,nmtools::forward<output_t>(output)
,resolver
);
} // atleast_nd
} // namespace nmtools::array

#endif // NMTOOLS_ARRAY_ARRAY_ATLEAST_ND_HPP
4 changes: 2 additions & 2 deletions include/nmtools/array/eval.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ namespace nmtools::meta
// default impl of make_fixed_ndarray only support integral constant for now
using stype = ct<s>;
if constexpr (is_void_v<type>) {
using type = make_tuple_type_t<stype>;
using type = nmtools_tuple<stype>;
return as_value_v<type>;
} else {
using type = append_type_t<type,stype>;
Expand Down Expand Up @@ -496,7 +496,7 @@ namespace nmtools::meta
// default impl of make_fixed_ndarray only support integral constant for now
using stype = ct<s>;
if constexpr (is_void_v<type>) {
using type = make_tuple_type_t<stype>;
using type = nmtools_tuple<stype>;
return as_value_v<type>;
} else {
using type = append_type_t<type,stype>;
Expand Down
3 changes: 2 additions & 1 deletion include/nmtools/array/functional.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#ifndef NMTOOLS_ARRAY_FUNCTIONAL_HPP
#define NMTOOLS_ARRAY_FUNCTIONAL_HPP

#include "nmtools/array/functional/indexing.hpp"

#include "nmtools/array/functional/activations/celu.hpp"
#include "nmtools/array/functional/activations/elu.hpp"
#include "nmtools/array/functional/activations/hardshrink.hpp"
Expand Down Expand Up @@ -78,6 +80,5 @@
#include "nmtools/array/functional/squeeze.hpp"
#include "nmtools/array/functional/where.hpp"
#include "nmtools/array/functional/zeros.hpp"
#include "nmtools/array/functional/indexing.hpp"

#endif // NMTOOLS_ARRAY_FUNCTIONAL_HPP
15 changes: 11 additions & 4 deletions include/nmtools/array/functional/batch_norm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,21 @@
#define NMTOOLS_ARRAY_FUNCTIONAL_BATCH_NORM_HPP

#include "nmtools/array/functional/functor.hpp"
#include "nmtools/array/functional/moveaxis.hpp"
#include "nmtools/array/functional/ufuncs/add.hpp"
#include "nmtools/array/functional/ufuncs/multiply.hpp"
#include "nmtools/array/functional/ufuncs/subtract.hpp"
#include "nmtools/array/functional/ufuncs/divide.hpp"
#include "nmtools/array/functional/ufuncs/sqrt.hpp"
#include "nmtools/array/view/batch_norm.hpp"

namespace nmtools::functional
{
constexpr inline auto batch_norm = functor_t{quinary_fmap_t{
[](const auto&...args){
return view::batch_norm(args...);
}}};
constexpr inline auto batch_norm_fun = [](const auto&...args){
return view::batch_norm(args...);
};

constexpr inline auto batch_norm = functor_t{quinary_fmap_t<decltype(batch_norm_fun)>{batch_norm_fun}};
} // namespace nmtools::functional

#endif // NMTOOLS_ARRAY_FUNCTIONAL_BATCH_NORM_HPP
5 changes: 5 additions & 0 deletions include/nmtools/array/functional/softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
#define NMTOOLS_ARRAY_FUNCTIONAL_SOFTMAX_HPP

#include "nmtools/array/functional/functor.hpp"
#include "nmtools/array/functional/ufuncs/maximum.hpp"
#include "nmtools/array/functional/ufuncs/subtract.hpp"
#include "nmtools/array/functional/ufuncs/exp.hpp"
#include "nmtools/array/functional/ufuncs/add.hpp"
#include "nmtools/array/functional/ufuncs/divide.hpp"
#include "nmtools/array/view/softmax.hpp"

namespace nmtools::functional
Expand Down
2 changes: 2 additions & 0 deletions include/nmtools/array/functional/softmin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#define NMTOOLS_ARRAY_FUNCTIONAL_SOFTMIN_HPP

#include "nmtools/array/functional/functor.hpp"
#include "nmtools/array/functional/ufuncs/negative.hpp"
#include "nmtools/array/functional/softmax.hpp"
#include "nmtools/array/view/softmin.hpp"

namespace nmtools::functional
Expand Down
2 changes: 2 additions & 0 deletions include/nmtools/array/functional/var.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#define NMTOOLS_ARRAY_FUNCTIONAL_VAR_HPP

#include "nmtools/array/functional/functor.hpp"
#include "nmtools/array/functional/ufuncs/add.hpp"
#include "nmtools/array/functional/ufuncs/divide.hpp"
#include "nmtools/array/view/var.hpp"

namespace nmtools::functional
Expand Down
14 changes: 14 additions & 0 deletions include/nmtools/array/impl/utl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace nmtools::impl
}
};

// TODO: remove this specialization
template <typename...Ts>
struct len_t<utl::tuple<Ts...>>
{
Expand All @@ -33,6 +34,19 @@ namespace nmtools::impl
}
};

// TODO: remove this specialization
template <typename...Ts>
struct len_t<utl::tuplev2<Ts...>>
{
using tuple = const utl::tuplev2<Ts...>&;
using type = size_t;

constexpr auto operator()(tuple) const noexcept
{
return sizeof...(Ts);
}
};

template <typename T, typename allocator>
struct len_t<utl::vector<T,allocator>>
{
Expand Down
173 changes: 173 additions & 0 deletions include/nmtools/array/index/alias.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
#ifndef NMTOOLS_ARRAY_INDEX_ALIAS_HPP
#define NMTOOLS_ARRAY_INDEX_ALIAS_HPP

#include "nmtools/meta.hpp"
#include "nmtools/array/at.hpp"
#include "nmtools/array/shape.hpp"
#include "nmtools/array/index/max.hpp"

namespace nmtools::index
{
struct alias_t {};

template <typename operands_ids_t, typename reserved_ids_t>
constexpr auto alias(const operands_ids_t& operands_ids, const reserved_ids_t& reserved_ids)
{
using result_t = meta::resolve_optype_t<alias_t,operands_ids_t,reserved_ids_t>;

auto result = result_t{};

if constexpr (!meta::is_constant_index_array_v<result_t>
&& !meta::is_fail_v<result_t>
) {
auto size = len(operands_ids);
if constexpr (meta::is_resizable_v<result_t>) {
result.resize(size);
}
nm_index_t max_reserved_id = [&]{
if constexpr (is_none_v<reserved_ids_t>) {
return -1;
} else {
return max(reserved_ids);
}
}();
nm_index_t max_operands_id = [&]{
if constexpr (meta::is_constant_index_array_v<operands_ids_t>) {
return max(operands_ids);
} else {
return -1;
}
}();
nm_index_t max_id = max_reserved_id > max_operands_id ? max_reserved_id : max_operands_id;
nm_size_t tracked_id = max_id + 1;
for (nm_size_t i=0; i<(nm_size_t)size; i++) {
auto id = at(operands_ids,i);
if (id < 0) {
at(result,i) = tracked_id;
tracked_id++;
} else {
at(result,i) = id;
}
}
}

return result;
}

struct generate_alias_t {};

#ifndef NMTOOLS_ALIAS_DEFAULT_BASE
#define NMTOOLS_ALIAS_DEFAULT_BASE 512
#endif

#ifndef NMTOOLS_ALIAS_DEFAULT_PRIME
#define NMTOOLS_ALIAS_DEFAULT_PRIME 1033
#endif

// polynomial rolling hash
template <typename aliases_t, typename base_t=meta::ct<NMTOOLS_ALIAS_DEFAULT_BASE>, typename prime_t=meta::ct<NMTOOLS_ALIAS_DEFAULT_PRIME>>
constexpr auto generate_alias(const aliases_t& aliases, base_t base=base_t{}, prime_t prime=prime_t{})
{
using result_t = meta::resolve_optype_t<generate_alias_t,aliases_t,base_t,prime_t>;

auto result = result_t {};

if constexpr (!meta::is_constant_index_v<result_t>
&& !meta::is_fail_v<result_t>
) {
result = 0;
auto N = len(aliases);
for (nm_size_t i=0; i<(nm_size_t)N; i++) {
result = (result * base + at(aliases,i)) % prime;
}
}

return result;
}
} // namespace nmtools::index

namespace nmtools::meta
{
namespace error
{
template <typename...>
struct ALIAS_UNSUPPORTED : detail::fail_t {};

template <typename...>
struct GENERATE_ALIAS_UNSUPPORTED : detail::fail_t {};
}

template <typename operands_ids_t, typename reserved_ids_t>
struct resolve_optype<
void, index::alias_t, operands_ids_t, reserved_ids_t
> {
static constexpr auto vtype = [](){
[[maybe_unused]] constexpr auto SIZE = len_v<operands_ids_t>;
[[maybe_unused]] constexpr auto B_SIZE = bounded_size_v<operands_ids_t>;
if constexpr (
!is_index_array_v<operands_ids_t>
|| !(is_index_array_v<reserved_ids_t> || is_none_v<reserved_ids_t>)
) {
using type = error::ALIAS_UNSUPPORTED<operands_ids_t,reserved_ids_t>;
return as_value_v<type>;
} else if constexpr (
is_constant_index_array_v<operands_ids_t>
&& (is_constant_index_array_v<reserved_ids_t> || is_none_v<reserved_ids_t>)
) {
constexpr auto operands_ids = to_value_v<operands_ids_t>;
constexpr auto reserved_ids = to_value_v<reserved_ids_t>;
constexpr auto result = index::alias(operands_ids,reserved_ids);
using nmtools::at, nmtools::len;
return template_reduce<len(result)>([&](auto init, auto index){
using init_type = type_t<decltype(init)>;
using result_type = append_type_t<init_type,ct<at(result,index)>>;
return as_value_v<result_type>;
},as_value_v<nmtools_tuple<>>);
} else if constexpr (SIZE > 0) {
using type = nmtools_array<nm_size_t,SIZE>;
return as_value_v<type>;
} else if constexpr (!is_fail_v<decltype(B_SIZE)>) {
using type = nmtools_static_vector<nm_size_t,B_SIZE>;
return as_value_v<type>;
} else {
// TODO: support small_vector
using type = nmtools_list<nm_size_t>;
return as_value_v<type>;
}
}();
using type = type_t<decltype(vtype)>;
};

template <typename aliases_t, typename base_t, typename prime_t>
struct resolve_optype<
void, index::generate_alias_t, aliases_t, base_t, prime_t
> {
static constexpr auto vtype = [](){
if constexpr (
!is_index_array_v<aliases_t>
|| !is_index_v<base_t>
|| !is_index_v<prime_t>
) {
using type = error::GENERATE_ALIAS_UNSUPPORTED<aliases_t,base_t,prime_t>;
return as_value_v<type>;
} else if constexpr (
is_constant_index_array_v<aliases_t>
&& is_constant_index_v<base_t>
&& is_constant_index_v<prime_t>
) {
constexpr auto aliases = to_value_v<aliases_t>;
constexpr auto base = to_value_v<base_t>;
constexpr auto prime = to_value_v<prime_t>;
constexpr auto result = index::generate_alias(aliases,base,prime);
using type = meta::ct<(nm_index_t)result>;
return as_value_v<type>;
} else {
using type = nm_index_t;
return as_value_v<type>;
}
}();
using type = type_t<decltype(vtype)>;
};
} // namespace nmtools::meta

#endif // NMTOOLS_ARRAY_INDEX_ALIAS_HPP
Loading

0 comments on commit a759ea5

Please sign in to comment.