From a759ea53ae7bd9f708b5fd3cd0bc303236ace870 Mon Sep 17 00:00:00 2001 From: Fahri Ali Rahman Date: Sat, 29 Jun 2024 18:03:56 +0700 Subject: [PATCH] new aliasing logic using compile-time hashing --- include/nmtools/array/array/atleast_nd.hpp | 34 ++++ include/nmtools/array/eval.hpp | 4 +- include/nmtools/array/functional.hpp | 3 +- .../nmtools/array/functional/batch_norm.hpp | 15 +- include/nmtools/array/functional/softmax.hpp | 5 + include/nmtools/array/functional/softmin.hpp | 2 + include/nmtools/array/functional/var.hpp | 2 + include/nmtools/array/impl/utl.hpp | 14 ++ include/nmtools/array/index/alias.hpp | 173 ++++++++++++++++++ include/nmtools/array/index/append.hpp | 104 +++++++++++ include/nmtools/array/index/argsort.hpp | 2 +- include/nmtools/array/index/as_tuple.hpp | 2 +- include/nmtools/array/index/atleast_nd.hpp | 12 +- include/nmtools/array/index/choose.hpp | 2 +- include/nmtools/array/index/concatenate.hpp | 4 +- include/nmtools/array/index/filter.hpp | 2 +- include/nmtools/array/index/gather.hpp | 2 +- include/nmtools/array/index/logical_not.hpp | 2 +- include/nmtools/array/index/matmul.hpp | 12 +- include/nmtools/array/index/moveaxis.hpp | 2 +- include/nmtools/array/index/ndenumerate.hpp | 2 +- .../nmtools/array/index/normalize_axis.hpp | 4 +- include/nmtools/array/index/outer.hpp | 2 +- include/nmtools/array/index/pack.hpp | 2 +- include/nmtools/array/index/pad.hpp | 2 +- .../array/index/remove_single_dims.hpp | 2 +- include/nmtools/array/index/repeat.hpp | 2 +- include/nmtools/array/index/reshape.hpp | 18 +- include/nmtools/array/index/reverse.hpp | 2 +- include/nmtools/array/index/scatter.hpp | 2 +- include/nmtools/array/index/slice.hpp | 8 +- include/nmtools/array/index/take.hpp | 2 +- include/nmtools/array/ndarray/dynamic.hpp | 6 +- include/nmtools/array/ndarray/fixed.hpp | 4 +- include/nmtools/array/ndarray/hybrid.hpp | 4 +- include/nmtools/array/shape.hpp | 4 +- include/nmtools/array/view/alias.hpp | 149 +++++++++++++-- include/nmtools/array/view/batch_norm.hpp | 19 +- .../nmtools/array/view/broadcast_arrays.hpp | 29 +-- include/nmtools/array/view/decorator.hpp | 102 +++++++++-- include/nmtools/array/view/mean.hpp | 3 +- include/nmtools/array/view/moveaxis.hpp | 7 +- include/nmtools/array/view/ref.hpp | 2 +- .../array/view/ref/initializer_list.hpp | 2 +- include/nmtools/array/view/softmax.hpp | 25 +-- include/nmtools/array/view/transpose.hpp | 36 +++- include/nmtools/array/view/ufunc.hpp | 3 +- include/nmtools/array/view/ufunc/detail.hpp | 4 +- include/nmtools/array/view/ufunc/ufunc.hpp | 2 +- include/nmtools/array/view/ufuncs/power.hpp | 44 ++++- include/nmtools/array/view/var.hpp | 21 +-- include/nmtools/utility/ct_digraph.hpp | 49 +++-- include/nmtools/utility/fwd.hpp | 2 +- include/nmtools/utility/unwrap.hpp | 2 +- 54 files changed, 810 insertions(+), 155 deletions(-) create mode 100644 include/nmtools/array/array/atleast_nd.hpp create mode 100644 include/nmtools/array/index/alias.hpp create mode 100644 include/nmtools/array/index/append.hpp diff --git a/include/nmtools/array/array/atleast_nd.hpp b/include/nmtools/array/array/atleast_nd.hpp new file mode 100644 index 000000000..e874c237d --- /dev/null +++ b/include/nmtools/array/array/atleast_nd.hpp @@ -0,0 +1,34 @@ +#ifndef NMTOOLS_ARRAY_ARRAY_ATLEAST_ND_HPP +#define NMTOOLS_ARRAY_ARRAY_ATLEAST_ND_HPP + +#include "nmtools/array/view/atleast_nd.hpp" +#include "nmtools/array/eval.hpp" + +namespace nmtools::array +{ + /** + * @brief Eagerly compute atleast_nd. + * + * @tparam output_t + * @tparam context_t + * @tparam array_t + * @param array Input array + * @param context Evaluation context + * @param output + * @return constexpr auto + */ + template , + typename array_t, typename nd_t> + constexpr auto atleast_nd(const array_t& array, nd_t nd + , context_t&& context=context_t{}, output_t&& output=output_t{},meta::as_value resolver=meta::as_value_v) + { + auto viewed = view::atleast_nd(array,nd); + return eval(viewed + ,nmtools::forward(context) + ,nmtools::forward(output) + ,resolver + ); + } // atleast_nd +} // namespace nmtools::array + +#endif // NMTOOLS_ARRAY_ARRAY_ATLEAST_ND_HPP \ No newline at end of file diff --git a/include/nmtools/array/eval.hpp b/include/nmtools/array/eval.hpp index d2619d68a..9205ae631 100644 --- a/include/nmtools/array/eval.hpp +++ b/include/nmtools/array/eval.hpp @@ -452,7 +452,7 @@ namespace nmtools::meta // default impl of make_fixed_ndarray only support integral constant for now using stype = ct; if constexpr (is_void_v) { - using type = make_tuple_type_t; + using type = nmtools_tuple; return as_value_v; } else { using type = append_type_t; @@ -496,7 +496,7 @@ namespace nmtools::meta // default impl of make_fixed_ndarray only support integral constant for now using stype = ct; if constexpr (is_void_v) { - using type = make_tuple_type_t; + using type = nmtools_tuple; return as_value_v; } else { using type = append_type_t; diff --git a/include/nmtools/array/functional.hpp b/include/nmtools/array/functional.hpp index 1e5610377..7a1785152 100644 --- a/include/nmtools/array/functional.hpp +++ b/include/nmtools/array/functional.hpp @@ -1,6 +1,8 @@ #ifndef NMTOOLS_ARRAY_FUNCTIONAL_HPP #define NMTOOLS_ARRAY_FUNCTIONAL_HPP +#include "nmtools/array/functional/indexing.hpp" + #include "nmtools/array/functional/activations/celu.hpp" #include "nmtools/array/functional/activations/elu.hpp" #include "nmtools/array/functional/activations/hardshrink.hpp" @@ -78,6 +80,5 @@ #include "nmtools/array/functional/squeeze.hpp" #include "nmtools/array/functional/where.hpp" #include "nmtools/array/functional/zeros.hpp" -#include "nmtools/array/functional/indexing.hpp" #endif // NMTOOLS_ARRAY_FUNCTIONAL_HPP \ No newline at end of file diff --git a/include/nmtools/array/functional/batch_norm.hpp b/include/nmtools/array/functional/batch_norm.hpp index 192874e98..759ba44f1 100644 --- a/include/nmtools/array/functional/batch_norm.hpp +++ b/include/nmtools/array/functional/batch_norm.hpp @@ -2,14 +2,21 @@ #define NMTOOLS_ARRAY_FUNCTIONAL_BATCH_NORM_HPP #include "nmtools/array/functional/functor.hpp" +#include "nmtools/array/functional/moveaxis.hpp" +#include "nmtools/array/functional/ufuncs/add.hpp" +#include "nmtools/array/functional/ufuncs/multiply.hpp" +#include "nmtools/array/functional/ufuncs/subtract.hpp" +#include "nmtools/array/functional/ufuncs/divide.hpp" +#include "nmtools/array/functional/ufuncs/sqrt.hpp" #include "nmtools/array/view/batch_norm.hpp" namespace nmtools::functional { - constexpr inline auto batch_norm = functor_t{quinary_fmap_t{ - [](const auto&...args){ - return view::batch_norm(args...); - }}}; + constexpr inline auto batch_norm_fun = [](const auto&...args){ + return view::batch_norm(args...); + }; + + constexpr inline auto batch_norm = functor_t{quinary_fmap_t{batch_norm_fun}}; } // namespace nmtools::functional #endif // NMTOOLS_ARRAY_FUNCTIONAL_BATCH_NORM_HPP \ No newline at end of file diff --git a/include/nmtools/array/functional/softmax.hpp b/include/nmtools/array/functional/softmax.hpp index 5f1a3d2ff..27286aa39 100644 --- a/include/nmtools/array/functional/softmax.hpp +++ b/include/nmtools/array/functional/softmax.hpp @@ -2,6 +2,11 @@ #define NMTOOLS_ARRAY_FUNCTIONAL_SOFTMAX_HPP #include "nmtools/array/functional/functor.hpp" +#include "nmtools/array/functional/ufuncs/maximum.hpp" +#include "nmtools/array/functional/ufuncs/subtract.hpp" +#include "nmtools/array/functional/ufuncs/exp.hpp" +#include "nmtools/array/functional/ufuncs/add.hpp" +#include "nmtools/array/functional/ufuncs/divide.hpp" #include "nmtools/array/view/softmax.hpp" namespace nmtools::functional diff --git a/include/nmtools/array/functional/softmin.hpp b/include/nmtools/array/functional/softmin.hpp index 13e6568eb..796c3c0ba 100644 --- a/include/nmtools/array/functional/softmin.hpp +++ b/include/nmtools/array/functional/softmin.hpp @@ -2,6 +2,8 @@ #define NMTOOLS_ARRAY_FUNCTIONAL_SOFTMIN_HPP #include "nmtools/array/functional/functor.hpp" +#include "nmtools/array/functional/ufuncs/negative.hpp" +#include "nmtools/array/functional/softmax.hpp" #include "nmtools/array/view/softmin.hpp" namespace nmtools::functional diff --git a/include/nmtools/array/functional/var.hpp b/include/nmtools/array/functional/var.hpp index e62074574..d2cd44b78 100644 --- a/include/nmtools/array/functional/var.hpp +++ b/include/nmtools/array/functional/var.hpp @@ -2,6 +2,8 @@ #define NMTOOLS_ARRAY_FUNCTIONAL_VAR_HPP #include "nmtools/array/functional/functor.hpp" +#include "nmtools/array/functional/ufuncs/add.hpp" +#include "nmtools/array/functional/ufuncs/divide.hpp" #include "nmtools/array/view/var.hpp" namespace nmtools::functional diff --git a/include/nmtools/array/impl/utl.hpp b/include/nmtools/array/impl/utl.hpp index 33aaf8671..622c1927b 100644 --- a/include/nmtools/array/impl/utl.hpp +++ b/include/nmtools/array/impl/utl.hpp @@ -21,6 +21,7 @@ namespace nmtools::impl } }; + // TODO: remove this specialization template struct len_t> { @@ -33,6 +34,19 @@ namespace nmtools::impl } }; + // TODO: remove this specialization + template + struct len_t> + { + using tuple = const utl::tuplev2&; + using type = size_t; + + constexpr auto operator()(tuple) const noexcept + { + return sizeof...(Ts); + } + }; + template struct len_t> { diff --git a/include/nmtools/array/index/alias.hpp b/include/nmtools/array/index/alias.hpp new file mode 100644 index 000000000..c5d7a1aba --- /dev/null +++ b/include/nmtools/array/index/alias.hpp @@ -0,0 +1,173 @@ +#ifndef NMTOOLS_ARRAY_INDEX_ALIAS_HPP +#define NMTOOLS_ARRAY_INDEX_ALIAS_HPP + +#include "nmtools/meta.hpp" +#include "nmtools/array/at.hpp" +#include "nmtools/array/shape.hpp" +#include "nmtools/array/index/max.hpp" + +namespace nmtools::index +{ + struct alias_t {}; + + template + constexpr auto alias(const operands_ids_t& operands_ids, const reserved_ids_t& reserved_ids) + { + using result_t = meta::resolve_optype_t; + + auto result = result_t{}; + + if constexpr (!meta::is_constant_index_array_v + && !meta::is_fail_v + ) { + auto size = len(operands_ids); + if constexpr (meta::is_resizable_v) { + result.resize(size); + } + nm_index_t max_reserved_id = [&]{ + if constexpr (is_none_v) { + return -1; + } else { + return max(reserved_ids); + } + }(); + nm_index_t max_operands_id = [&]{ + if constexpr (meta::is_constant_index_array_v) { + return max(operands_ids); + } else { + return -1; + } + }(); + nm_index_t max_id = max_reserved_id > max_operands_id ? max_reserved_id : max_operands_id; + nm_size_t tracked_id = max_id + 1; + for (nm_size_t i=0; i<(nm_size_t)size; i++) { + auto id = at(operands_ids,i); + if (id < 0) { + at(result,i) = tracked_id; + tracked_id++; + } else { + at(result,i) = id; + } + } + } + + return result; + } + + struct generate_alias_t {}; + + #ifndef NMTOOLS_ALIAS_DEFAULT_BASE + #define NMTOOLS_ALIAS_DEFAULT_BASE 512 + #endif + + #ifndef NMTOOLS_ALIAS_DEFAULT_PRIME + #define NMTOOLS_ALIAS_DEFAULT_PRIME 1033 + #endif + + // polynomial rolling hash + template , typename prime_t=meta::ct> + constexpr auto generate_alias(const aliases_t& aliases, base_t base=base_t{}, prime_t prime=prime_t{}) + { + using result_t = meta::resolve_optype_t; + + auto result = result_t {}; + + if constexpr (!meta::is_constant_index_v + && !meta::is_fail_v + ) { + result = 0; + auto N = len(aliases); + for (nm_size_t i=0; i<(nm_size_t)N; i++) { + result = (result * base + at(aliases,i)) % prime; + } + } + + return result; + } +} // namespace nmtools::index + +namespace nmtools::meta +{ + namespace error + { + template + struct ALIAS_UNSUPPORTED : detail::fail_t {}; + + template + struct GENERATE_ALIAS_UNSUPPORTED : detail::fail_t {}; + } + + template + struct resolve_optype< + void, index::alias_t, operands_ids_t, reserved_ids_t + > { + static constexpr auto vtype = [](){ + [[maybe_unused]] constexpr auto SIZE = len_v; + [[maybe_unused]] constexpr auto B_SIZE = bounded_size_v; + if constexpr ( + !is_index_array_v + || !(is_index_array_v || is_none_v) + ) { + using type = error::ALIAS_UNSUPPORTED; + return as_value_v; + } else if constexpr ( + is_constant_index_array_v + && (is_constant_index_array_v || is_none_v) + ) { + constexpr auto operands_ids = to_value_v; + constexpr auto reserved_ids = to_value_v; + constexpr auto result = index::alias(operands_ids,reserved_ids); + using nmtools::at, nmtools::len; + return template_reduce([&](auto init, auto index){ + using init_type = type_t; + using result_type = append_type_t>; + return as_value_v; + },as_value_v>); + } else if constexpr (SIZE > 0) { + using type = nmtools_array; + return as_value_v; + } else if constexpr (!is_fail_v) { + using type = nmtools_static_vector; + return as_value_v; + } else { + // TODO: support small_vector + using type = nmtools_list; + return as_value_v; + } + }(); + using type = type_t; + }; + + template + struct resolve_optype< + void, index::generate_alias_t, aliases_t, base_t, prime_t + > { + static constexpr auto vtype = [](){ + if constexpr ( + !is_index_array_v + || !is_index_v + || !is_index_v + ) { + using type = error::GENERATE_ALIAS_UNSUPPORTED; + return as_value_v; + } else if constexpr ( + is_constant_index_array_v + && is_constant_index_v + && is_constant_index_v + ) { + constexpr auto aliases = to_value_v; + constexpr auto base = to_value_v; + constexpr auto prime = to_value_v; + constexpr auto result = index::generate_alias(aliases,base,prime); + using type = meta::ct<(nm_index_t)result>; + return as_value_v; + } else { + using type = nm_index_t; + return as_value_v; + } + }(); + using type = type_t; + }; +} // namespace nmtools::meta + +#endif // NMTOOLS_ARRAY_INDEX_ALIAS_HPP \ No newline at end of file diff --git a/include/nmtools/array/index/append.hpp b/include/nmtools/array/index/append.hpp new file mode 100644 index 000000000..e79a9307f --- /dev/null +++ b/include/nmtools/array/index/append.hpp @@ -0,0 +1,104 @@ +#ifndef NMTOOLS_ARRAY_INDEX_APPEND_HPP +#define NMTOOLS_ARRAY_INDEX_APPEND_HPP + +#include "nmtools/meta.hpp" + +namespace nmtools::index +{ + struct append_t {}; + + template + constexpr auto append(const indices_t& indices, index_t index) + { + using result_t = meta::resolve_optype_t; + + auto result = result_t {}; + + if constexpr (!meta::is_constant_index_array_v + && !meta::is_fail_v + ) { + auto n = len(indices); + if constexpr (meta::is_resizable_v) { + result.resize(n+1); + } + [[maybe_unused]] + auto f = [&](const auto& indices){ + auto n = len(indices); + for (nm_size_t i=0; i<(nm_size_t)n; i++) { + at(result,i) = at(indices,i); + } + }; + constexpr auto B_SIZE = meta::bounded_size_v; + // avoid calling at(indices,i) when indices is tuple<> + if constexpr (meta::is_fail_v) { + f(indices); + } else if constexpr (B_SIZE > 0) { + f(indices); + } + at(result,n) = index; + } + + return result; + } +} // namespace nmtools::index + +namespace nmtools::meta +{ + namespace error + { + template + struct APPEND_UNSUPPORTED : detail::fail_t {}; + } + + template + struct resolve_optype< + void, index::append_t, indices_t, index_t + > { + static constexpr auto vtype = [](){ + [[maybe_unused]] + constexpr auto DIM = len_v; + [[maybe_unused]] + constexpr auto B_DIM = bounded_size_v; + if constexpr (!is_index_array_v + || !is_index_v + ) { + using type = error::APPEND_UNSUPPORTED; + return as_value_v; + } else if constexpr ( + is_constant_index_array_v + && is_constant_index_v + ) { + constexpr auto indices = to_value_v; + constexpr auto idx = to_value_v; + constexpr auto result = index::append(indices,idx); + using nmtools::at, nmtools::len; + return template_reduce([&](auto init, auto index){ + using init_type = type_t; + using result_type = append_type_t>; + return as_value_v; + }, as_value_v>); + } else if constexpr (DIM > 0) { + using element_t = get_index_element_type_t; + using type = nmtools_array; + return as_value_v; + } else if constexpr (!is_fail_v) { + if constexpr (B_DIM == 0) { + using type = nmtools_array; + return as_value_v; + } else { + using element_t = get_index_element_type_t; + using type = nmtools_static_vector; + return as_value_v; + } + } else { + using element_t = get_index_element_type_t; + // TODO: support small_vector + using type = nmtools_list; + return as_value_v; + } + }(); + using type = type_t; + }; +} // namespace nmtools::meta + +#endif // NMTOOLS_ARRAY_INDEX_APPEND_HPP \ No newline at end of file diff --git a/include/nmtools/array/index/argsort.hpp b/include/nmtools/array/index/argsort.hpp index 603651ede..9a62e33a9 100644 --- a/include/nmtools/array/index/argsort.hpp +++ b/include/nmtools/array/index/argsort.hpp @@ -98,7 +98,7 @@ namespace nmtools::meta constexpr auto array = to_value_v; constexpr auto args = index::argsort(array); // transform back to type - using init_type = make_tuple_type_t>; + using init_type = nmtools_tuple>; return template_reduce([&](auto init, auto index){ using init_t = type_t; using init_i = ct; diff --git a/include/nmtools/array/index/as_tuple.hpp b/include/nmtools/array/index/as_tuple.hpp index 6bf49373f..e7a9facfb 100644 --- a/include/nmtools/array/index/as_tuple.hpp +++ b/include/nmtools/array/index/as_tuple.hpp @@ -22,7 +22,7 @@ namespace nmtools::index template typename index_sequence, typename array_t, size_t...Is> constexpr decltype(auto) as_tuple(const array_t& array, index_sequence) { - using tuple_t = meta::make_tuple_type_t(array))>...>; + using tuple_t = nmtools_tuple(array))>...>; return tuple_t{nmtools::get(array)...}; } // as_tuple diff --git a/include/nmtools/array/index/atleast_nd.hpp b/include/nmtools/array/index/atleast_nd.hpp index 0a659597b..3a318fe17 100644 --- a/include/nmtools/array/index/atleast_nd.hpp +++ b/include/nmtools/array/index/atleast_nd.hpp @@ -17,9 +17,17 @@ namespace nmtools::index template constexpr auto shape_atleast_nd(const shape_t& shape, [[maybe_unused]] nd_t nd) { - using result_t = meta::resolve_optype_t; + using result_t [[maybe_unused]] = meta::resolve_optype_t; - if constexpr (meta::is_constant_index_array_v) { + if constexpr (meta::is_maybe_v) { + using shape_type = meta::get_maybe_type_t; + using result_t = meta::resolve_optype_t; + using return_t = meta::conditional_t,result_t,nmtools_maybe>; + return (static_cast(shape) + ? return_t{shape_atleast_nd(*shape,nd)} + : return_t{meta::Nothing} + ); + } else if constexpr (meta::is_constant_index_array_v) { return result_t {}; } else { auto result = result_t {}; diff --git a/include/nmtools/array/index/choose.hpp b/include/nmtools/array/index/choose.hpp index 8a557d42a..fd6a3e6d2 100644 --- a/include/nmtools/array/index/choose.hpp +++ b/include/nmtools/array/index/choose.hpp @@ -133,7 +133,7 @@ namespace nmtools::meta ) { // TODO: compute at compile-time here, then maps back to type constexpr auto N = fixed_index_array_size_v; - using type = make_fixed_ndarray_t>>; + using type = make_fixed_ndarray_t>>; return as_value_v; } else if constexpr ( is_hybrid_index_array_v diff --git a/include/nmtools/array/index/concatenate.hpp b/include/nmtools/array/index/concatenate.hpp index 4f73eab90..e3e8a77ed 100644 --- a/include/nmtools/array/index/concatenate.hpp +++ b/include/nmtools/array/index/concatenate.hpp @@ -129,7 +129,7 @@ namespace nmtools::meta if constexpr (is_fixed_index_array_v) { constexpr auto N = len_v; using elem_t = remove_cvref_t>; - return as_value_v>; + return as_value_v>; } else if constexpr (is_index_array_v) { return as_value_v; } else { @@ -223,7 +223,7 @@ namespace nmtools::index else success = false; // TODO: use optional instead - using return_t = meta::make_tuple_type_t; + using return_t = nmtools_tuple; return return_t{success,ret}; } } // shape_concatenate diff --git a/include/nmtools/array/index/filter.hpp b/include/nmtools/array/index/filter.hpp index ee1203fcf..d200ac514 100644 --- a/include/nmtools/array/index/filter.hpp +++ b/include/nmtools/array/index/filter.hpp @@ -25,7 +25,7 @@ namespace nmtools::index auto chosen = choose(indices,array); using indices_t = meta::remove_cvref_t; using chosen_t = meta::remove_cvref_t; - using return_t = meta::make_tuple_type_t; + using return_t = nmtools_tuple; return return_t{indices, chosen}; } // filter } // namespace nmtools::index diff --git a/include/nmtools/array/index/gather.hpp b/include/nmtools/array/index/gather.hpp index 9585632fd..9ea823f4e 100644 --- a/include/nmtools/array/index/gather.hpp +++ b/include/nmtools/array/index/gather.hpp @@ -103,7 +103,7 @@ namespace nmtools::meta constexpr auto result = index::gather(vector,indices); // assuming len(result) > 0 // transform back to type - using init_type = make_tuple_type_t>; + using init_type = nmtools_tuple>; return template_reduce([&](auto init, auto index){ using init_t = type_t; using result_i = ct; diff --git a/include/nmtools/array/index/logical_not.hpp b/include/nmtools/array/index/logical_not.hpp index dcd5f71f8..865627918 100644 --- a/include/nmtools/array/index/logical_not.hpp +++ b/include/nmtools/array/index/logical_not.hpp @@ -72,7 +72,7 @@ namespace nmtools::meta constexpr auto result = index::logical_not(array); // assuming len(result) > 0 // transform back to type - using init_type = make_tuple_type_t>; + using init_type = nmtools_tuple>; return template_reduce<::nmtools::len(result)-1>([&](auto init, auto index){ using init_t = type_t; using result_i = ct; diff --git a/include/nmtools/array/index/matmul.hpp b/include/nmtools/array/index/matmul.hpp index 43767ee95..5bcc35d24 100644 --- a/include/nmtools/array/index/matmul.hpp +++ b/include/nmtools/array/index/matmul.hpp @@ -97,7 +97,7 @@ namespace nmtools::meta using element_t = get_index_element_type_t; // NOTE: use make_hybrid_ndarray instead of make_hybrid_ndarray_t to avoid including ndarray.hpp using hybrid_t = typename make_hybrid_ndarray::type; - using type = make_tuple_type_t; + using type = nmtools_tuple; return as_value_v; } else if constexpr (is_fixed_index_array_v && is_constant_index_v) { // number of dimension of the array @@ -113,14 +113,14 @@ namespace nmtools::meta }(); using element_t = get_index_element_type_t; - using left_size_t = make_tuple_type_t>; - using right_size_t = make_tuple_type_t>; + using left_size_t = nmtools_tuple>; + using right_size_t = nmtools_tuple>; using left_t = typename make_fixed_ndarray::type; using right_t = typename make_fixed_ndarray::type; - using type = make_tuple_type_t; + using type = nmtools_tuple; return as_value_v; } else if constexpr (is_index_array_v) { - using type = make_tuple_type_t; + using type = nmtools_tuple; return as_value_v; } else { // unhandled type @@ -150,7 +150,7 @@ namespace nmtools::index constexpr auto shape_matmul(const lhs_shape_t& ashape, const rhs_shape_t& bshape) { using result_t = meta::resolve_optype_t; - using return_t = meta::make_maybe_type_t; + using return_t = nmtools_maybe; if constexpr (meta::is_constant_index_array_v) { // still use maybe for simplicity at caller site diff --git a/include/nmtools/array/index/moveaxis.hpp b/include/nmtools/array/index/moveaxis.hpp index a81d62334..0e6ba2a66 100644 --- a/include/nmtools/array/index/moveaxis.hpp +++ b/include/nmtools/array/index/moveaxis.hpp @@ -158,7 +158,7 @@ namespace nmtools::index // return ret; } else { // NOTE: quick hack so no need to modify caller to handle non maybe type - using return_t = meta::make_maybe_type_t; + using return_t = nmtools_maybe; static_assert( !meta::is_fail_v , "unsupported moveaxis_to_transpose" ); diff --git a/include/nmtools/array/index/ndenumerate.hpp b/include/nmtools/array/index/ndenumerate.hpp index 97b58f627..9012c2267 100644 --- a/include/nmtools/array/index/ndenumerate.hpp +++ b/include/nmtools/array/index/ndenumerate.hpp @@ -56,7 +56,7 @@ namespace nmtools::index } // size constexpr inline auto operator[](size_t i) const - -> meta::make_tuple_type_t + -> nmtools_tuple { auto idx = ndindex_[i]; const auto& val = apply_at(array,idx); diff --git a/include/nmtools/array/index/normalize_axis.hpp b/include/nmtools/array/index/normalize_axis.hpp index b21245237..eecea5006 100644 --- a/include/nmtools/array/index/normalize_axis.hpp +++ b/include/nmtools/array/index/normalize_axis.hpp @@ -35,7 +35,7 @@ namespace nmtools::index using ndim_t [[maybe_unused]] = decltype(ndim); if constexpr (! (meta::is_constant_index_v || meta::is_constant_index_array_v)) { - using return_t = meta::make_maybe_type_t; + using return_t = nmtools_maybe; // using return_t = utl::maybe; // auto ret = return_t {}; @@ -98,7 +98,7 @@ namespace nmtools::index } else { // NOTE: quick-hack // TODO: no need to return maybe type - using return_t = meta::make_maybe_type_t; + using return_t = nmtools_maybe; auto valid = [&](){ const auto result = result_t {}; diff --git a/include/nmtools/array/index/outer.hpp b/include/nmtools/array/index/outer.hpp index ad559d8dc..ced7b4dcd 100644 --- a/include/nmtools/array/index/outer.hpp +++ b/include/nmtools/array/index/outer.hpp @@ -99,7 +99,7 @@ namespace nmtools::index for (size_t i=0; i; + using return_t = nmtools_tuple; return return_t{aidx,bidx}; } } // namespace nmtools::index diff --git a/include/nmtools/array/index/pack.hpp b/include/nmtools/array/index/pack.hpp index b9ab004a8..cad83c88f 100644 --- a/include/nmtools/array/index/pack.hpp +++ b/include/nmtools/array/index/pack.hpp @@ -21,7 +21,7 @@ namespace nmtools::index // using common_t = meta::promote_index_t; using common_t = meta::type_t>; if constexpr (meta::is_integral_v) { - using array_t = meta::make_array_type_t; + using array_t = nmtools_array; return array_t{index_,indices...}; } else /* if constexpr (meta::is_index_array_v) */ { static_assert (sizeof...(indices)==0 diff --git a/include/nmtools/array/index/pad.hpp b/include/nmtools/array/index/pad.hpp index 6459fd70a..acf8360f4 100644 --- a/include/nmtools/array/index/pad.hpp +++ b/include/nmtools/array/index/pad.hpp @@ -102,7 +102,7 @@ namespace nmtools::index { using result_t = meta::resolve_optype_t; // use maybe type to indicate out of bound index (of src shape) - using return_t = meta::make_maybe_type_t; + using return_t = nmtools_maybe; using idx_t = meta::get_index_element_type_t; using s_idx_t = meta::make_signed_t; diff --git a/include/nmtools/array/index/remove_single_dims.hpp b/include/nmtools/array/index/remove_single_dims.hpp index 5615db9ab..146ae8c25 100644 --- a/include/nmtools/array/index/remove_single_dims.hpp +++ b/include/nmtools/array/index/remove_single_dims.hpp @@ -66,7 +66,7 @@ namespace nmtools::meta constexpr auto result = index::remove_single_dims(shape); // assuming len(result) > 0 // transform back to type - using init_type = make_tuple_type_t>; + using init_type = nmtools_tuple>; return template_reduce<::nmtools::len(result)-1>([&](auto init, auto index){ using init_t = type_t; using result_i = ct; diff --git a/include/nmtools/array/index/repeat.hpp b/include/nmtools/array/index/repeat.hpp index bc7c10949..d4e2a5f4a 100644 --- a/include/nmtools/array/index/repeat.hpp +++ b/include/nmtools/array/index/repeat.hpp @@ -141,7 +141,7 @@ namespace nmtools::meta } }, as_value_v>); } else if constexpr (is_none_v) { - using type = make_array_type_t; + using type = nmtools_array; return as_value_v; } else if constexpr ( is_index_array_v diff --git a/include/nmtools/array/index/reshape.hpp b/include/nmtools/array/index/reshape.hpp index 80c7356be..1215d49b7 100644 --- a/include/nmtools/array/index/reshape.hpp +++ b/include/nmtools/array/index/reshape.hpp @@ -45,7 +45,7 @@ namespace nmtools::index template constexpr auto shape_reshape(const src_shape_t& src_shape, const dst_shape_t& dst_shape) { - using result_t = meta::resolve_optype_t; + using result_t [[maybe_unused]] = meta::resolve_optype_t; using m_result_t [[maybe_unused]] = meta::get_maybe_type_t; // TODO: try to provide common function-lifting utility @@ -53,6 +53,7 @@ namespace nmtools::index // when src_shape is maybe, then assume the result_t is maybe if (static_cast(src_shape)) { auto result = shape_reshape(*src_shape,dst_shape); + #if 0 // even if we unwrap the src_shape, the result may be maybe type // since dst_shape may be different shape with src // because std::optional> is actually allowed @@ -67,9 +68,24 @@ namespace nmtools::index } else { return result_t{result}; } + #else + using return_t = result_t; + return (has_value(result) + ? return_t{unwrap(result)} + : return_t{meta::Nothing} + ); + #endif } else { return result_t{meta::Nothing}; } + } else if constexpr (meta::is_maybe_v) { + using dst_shape_type = meta::get_maybe_type_t; + using result_t = meta::resolve_optype_t; + using return_t = meta::conditional_t,result_t,nmtools_maybe>; + return (static_cast(dst_shape) + ? return_t{shape_reshape(src_shape,*dst_shape)} + : return_t{meta::Nothing} + ); } else if constexpr (meta::is_fail_v) { // let the caller decides what to do return result_t {}; diff --git a/include/nmtools/array/index/reverse.hpp b/include/nmtools/array/index/reverse.hpp index 27400db84..f3c005e88 100644 --- a/include/nmtools/array/index/reverse.hpp +++ b/include/nmtools/array/index/reverse.hpp @@ -62,7 +62,7 @@ namespace nmtools::meta if constexpr (is_constant_index_array_v) { constexpr auto indices = to_value_v; constexpr auto reversed = index::reverse(indices); - using init_type = make_tuple_type_t>; + using init_type = nmtools_tuple>; // convert back to type return template_reduce<::nmtools::len(reversed)-1>([&](auto init, auto index){ using init_t = type_t; diff --git a/include/nmtools/array/index/scatter.hpp b/include/nmtools/array/index/scatter.hpp index 5014e6f58..de6f491db 100644 --- a/include/nmtools/array/index/scatter.hpp +++ b/include/nmtools/array/index/scatter.hpp @@ -88,7 +88,7 @@ namespace nmtools::meta constexpr auto ind = to_value_v; constexpr auto res = index::scatter(vec, ind); // convert back to type - using init_type = make_tuple_type_t>; + using init_type = nmtools_tuple>; return template_reduce<::nmtools::len(res)-1>([&](auto init, auto index){ using init_t = type_t; using result_t = append_type_t>; diff --git a/include/nmtools/array/index/slice.hpp b/include/nmtools/array/index/slice.hpp index 037d13fdd..77b5e5c74 100644 --- a/include/nmtools/array/index/slice.hpp +++ b/include/nmtools/array/index/slice.hpp @@ -683,7 +683,7 @@ namespace nmtools::index constexpr auto NS = meta::len_v; if constexpr (NS==2) { const auto [start, stop] = slice; - using mresult_t = meta::make_tuple_type_t; + using mresult_t = nmtools_tuple; return mresult_t{start,stop,None}; } // return as it is to keep dtype @@ -883,7 +883,7 @@ namespace nmtools::index // TODO error handling // make sure sizeof...(slices) <= len(shape) - auto slices_pack = meta::make_tuple_type_t{slices...}; + auto slices_pack = nmtools_tuple{slices...}; // since res and shape may have different dim, // this var is to keep track of the active result index @@ -1039,7 +1039,7 @@ namespace nmtools::index if constexpr (meta::is_resizable_v) res.resize(dim); - auto slices_pack = meta::make_tuple_type_t{slices...}; + auto slices_pack = nmtools_tuple{slices...}; // since res and shape may have different dim, // also indices and shape may have different dim, @@ -1093,7 +1093,7 @@ namespace nmtools::index constexpr auto NS = meta::len_v; if constexpr (NS==2) { const auto [start, stop] = slice; - using mresult_t = meta::make_tuple_type_t; + using mresult_t = nmtools_tuple; return mresult_t{start,stop,None}; } // return as it is to keep dtype diff --git a/include/nmtools/array/index/take.hpp b/include/nmtools/array/index/take.hpp index 940e75bc4..fc31ec80a 100644 --- a/include/nmtools/array/index/take.hpp +++ b/include/nmtools/array/index/take.hpp @@ -111,7 +111,7 @@ namespace nmtools::meta > { // when slicing flattened array, the shape should be single element 1D array - using type = make_array_type_t; + using type = nmtools_array; }; // shape_take_t namespace error diff --git a/include/nmtools/array/ndarray/dynamic.hpp b/include/nmtools/array/ndarray/dynamic.hpp index e3c256baf..89a85cc54 100644 --- a/include/nmtools/array/ndarray/dynamic.hpp +++ b/include/nmtools/array/ndarray/dynamic.hpp @@ -181,7 +181,7 @@ namespace nmtools::array shape_.resize(sizeof...(shape)); meta::template_for([&](auto index){ constexpr auto i = decltype(index)::value; - using tuple_t = meta::make_tuple_type_t; + using tuple_t = nmtools_tuple; shape_.at(i) = nmtools::get(tuple_t{shape...}); }); strides_ = strides(); @@ -223,7 +223,7 @@ namespace nmtools::array constexpr decltype(auto) operator()(size_type n, size_types...ns) { using common_size_t = meta::promote_index_t; - auto indices = meta::make_array_type_t{ + auto indices = nmtools_array{ static_cast(n), static_cast(ns)... }; assert (dim()==indices.size()); @@ -244,7 +244,7 @@ namespace nmtools::array constexpr decltype(auto) operator()(size_type n, size_types...ns) const { using common_size_t = meta::promote_index_t; - auto indices = meta::make_array_type_t{ + auto indices = nmtools_array{ static_cast(n), static_cast(ns)... }; assert (dim()==indices.size()); diff --git a/include/nmtools/array/ndarray/fixed.hpp b/include/nmtools/array/ndarray/fixed.hpp index 0b978775a..09425cb3b 100644 --- a/include/nmtools/array/ndarray/fixed.hpp +++ b/include/nmtools/array/ndarray/fixed.hpp @@ -80,7 +80,7 @@ namespace nmtools::array { */ static constexpr auto strides() { - auto stride = meta::make_array_type_t{}; + auto stride = nmtools_array{}; for (size_t i=0; i> {}; #if 0 { - static inline constexpr auto value = meta::make_array_type_t{Shape1,ShapeN...}; + static inline constexpr auto value = nmtools_array{Shape1,ShapeN...}; using value_type = decltype(value); }; #endif diff --git a/include/nmtools/array/ndarray/hybrid.hpp b/include/nmtools/array/ndarray/hybrid.hpp index 223f6ede6..b9e9189db 100644 --- a/include/nmtools/array/ndarray/hybrid.hpp +++ b/include/nmtools/array/ndarray/hybrid.hpp @@ -231,7 +231,7 @@ namespace nmtools::array // using common_size_t = meta::promote_index_t; using promoted_t = meta::promote_index; using common_size_t = meta::type_t; - auto indices = meta::make_array_type_t{ + auto indices = nmtools_array{ static_cast(ns)... }; static_assert ( dimension == sizeof...(ns) @@ -252,7 +252,7 @@ namespace nmtools::array // using common_size_t = meta::promote_index_t; using promoted_t = meta::promote_index; using common_size_t = meta::type_t; - auto indices = meta::make_array_type_t{ + auto indices = nmtools_array{ static_cast(ns)... }; static_assert ( dimension == sizeof...(ns) diff --git a/include/nmtools/array/shape.hpp b/include/nmtools/array/shape.hpp index b5274940f..cd6258929 100644 --- a/include/nmtools/array/shape.hpp +++ b/include/nmtools/array/shape.hpp @@ -89,7 +89,7 @@ namespace nmtools::impl template constexpr auto repeat(const T& t) { - auto res = meta::make_array_type_t{}; + auto res = nmtools_array{}; meta::template_for([&](auto index){ constexpr auto i = decltype(index)::value; get(res) = t; @@ -140,7 +140,7 @@ namespace nmtools::impl // check for dynamic-shape array but fixed-dimension array else if constexpr (meta::nested_array_dim_v > 0) { constexpr auto N = meta::nested_array_dim_v; - auto shape_ = meta::make_array_type_t{}; + auto shape_ = nmtools_array{}; meta::template_for([&](auto index){ constexpr auto i = decltype(index)::value; // example for 3dim nested dynamic array diff --git a/include/nmtools/array/view/alias.hpp b/include/nmtools/array/view/alias.hpp index f44a56b9a..e41a94fc9 100644 --- a/include/nmtools/array/view/alias.hpp +++ b/include/nmtools/array/view/alias.hpp @@ -3,6 +3,9 @@ #include "nmtools/meta.hpp" #include "nmtools/array/utility/at.hpp" +#include "nmtools/utility/fwd.hpp" +#include "nmtools/utility/unwrap.hpp" +#include "nmtools/array/index/max.hpp" #include "nmtools/array/shape.hpp" #include "nmtools/array/view/decorator.hpp" @@ -17,6 +20,8 @@ namespace nmtools::view using const_reference = const value_type&; // array type as required by decorator using array_type = resolve_array_type_t; + + // TODO: assert id is constant index (or none) using id_type = id_t; static constexpr auto operands_ids = nmtools_tuple{id_type{}}; @@ -39,6 +44,38 @@ namespace nmtools::view return nmtools_tuple{id}; } + constexpr auto shape() const + { + if constexpr (meta::is_pointer_v>) { + return nmtools::shape(*array); + } else { + return nmtools::shape(array); + } + } + + constexpr auto dim() const + { + if constexpr (meta::is_pointer_v>) { + return nmtools::dim(*array); + } else { + return nmtools::dim(array); + } + } + + constexpr auto size() const + { + if constexpr (meta::is_pointer_v>) { + return nmtools::size(*array); + } else { + return nmtools::size(array); + } + } + + constexpr auto operands() const noexcept + { + return nmtools_tuple{array}; + } // operands + template constexpr auto index(size_types...indices) const { @@ -46,6 +83,41 @@ namespace nmtools::view } // index }; // alias_t + template + struct alias_t>> + { + using value_type = meta::get_element_type_t; + using const_reference = const value_type&; + // array type as required by decorator + using array_type = resolve_array_type_t; + using id_type = id_t; + + static constexpr auto operands_ids = nmtools_tuple{id_type{}}; + + array_type array; + id_type id; + + constexpr alias_t(const array_t& array, id_t id) + : array(array) + , id(id) + {} + + constexpr operator value_type() const + { + return array; + } + + constexpr auto attribute() const noexcept + { + return nmtools_tuple{id}; + } + + constexpr auto operands() const noexcept + { + return nmtools_tuple{array}; + } // operands + }; // alias + template struct alias_t>> { @@ -95,23 +167,68 @@ namespace nmtools::view } }; + // TODO: drop none default, make id mandatory template constexpr auto alias(const array_t& array, id_t id=id_t{}) { + // TODO: handle either type if constexpr (meta::is_maybe_v) { using array_type = meta::get_maybe_type_t; - using result_type = decorator_t; + using result_type = decltype(alias(meta::declval(),id)); using return_type = nmtools_maybe; - if (static_cast(array)) { - return return_type{decorator_t{{*array,id}}}; + return (static_cast(array) + ? return_type{alias(*array,id)} + : return_type{meta::Nothing} + ); + } else if constexpr (meta::is_same_view_v && !is_none_v) { + // Quick-hack: aliasing an alias will rename + const auto& m_array = array.array; + if constexpr (meta::is_pointer_v>) { + return alias(*array.array,id); } else { - return return_type{meta::Nothing}; + return alias(array.array,id); } + } else if constexpr (meta::is_view_v) { + // view is already aliased + static_assert( array_t::id_type::value == id_t::value ); + return array; } else { return decorator_t{{array,id}}; } } // alias + template + constexpr auto aliased(const arrays_t&...arrays) + { + auto array_pack = pack_operands(arrays...); + constexpr auto N = sizeof...(arrays); + constexpr auto initial_ids = meta::template_reduce([&](auto init, auto index){ + using array_type = decltype(unwrap(at(array_pack,index))); + constexpr auto id = get_id_v>; + return utility::tuple_append(init,meta::ct_v<(nm_index_t)id>); + },nmtools_tuple{}); + constexpr auto max_id = index::max(meta::to_value_v); + constexpr auto offset_id = max_id + 1; // if max_id: -1, then offset 0 + constexpr auto final_ids = meta::template_reduce([&](auto init, auto index){ + constexpr auto id = at(initial_ids,index); + constexpr auto final_id = ((id < 0) ? (offset_id + index) : id); + return utility::tuple_append(init,meta::ct_v); + },nmtools_tuple{}); + auto aliased = meta::template_reduce([&](auto init, auto index){ + const auto& array = at(array_pack,index); + if constexpr (meta::is_pointer_v>) { + return append_operands(init,alias(*array,at(final_ids,index))); + } else { + return append_operands(init,alias(array,at(final_ids,index))); + } + },nmtools_tuple{}); + if constexpr (N == 1) { + return nmtools::get<0>(aliased); + } else { + return aliased; + } + } // aliased + template constexpr auto alias(const T* ptr, size_t numel, id_t id=id_t{}) { @@ -121,28 +238,36 @@ namespace nmtools::view namespace nmtools::meta { - template - struct is_ndarray< view::decorator_t > + template + struct is_ndarray< view::decorator_t > { static constexpr auto value = is_ndarray_v; }; - template - struct is_num< view::decorator_t > + template + struct is_num< view::decorator_t > { static constexpr auto value = is_num_v; }; // specialization for ptr - template - struct is_ndarray< view::decorator_t > + template + struct is_ndarray< view::decorator_t > { static constexpr auto value = is_num_v; }; - template + template + struct get_element_type< + view::decorator_t + > + { + using type = get_element_type_t; + }; + + template struct get_element_type< - view::decorator_t + view::decorator_t > { using type = meta::remove_address_space_t; diff --git a/include/nmtools/array/view/batch_norm.hpp b/include/nmtools/array/view/batch_norm.hpp index a804fc48d..cae2a5328 100644 --- a/include/nmtools/array/view/batch_norm.hpp +++ b/include/nmtools/array/view/batch_norm.hpp @@ -1,7 +1,7 @@ #ifndef NMTOOLS_ARRAY_VIEW_BATCH_NORM_HPP #define NMTOOLS_ARRAY_VIEW_BATCH_NORM_HPP -#include "nmtools/array/view/atleast_3d.hpp" +#include "nmtools/array/view/atleast_nd.hpp" #include "nmtools/array/view/moveaxis.hpp" #include "nmtools/array/view/ufuncs/add.hpp" #include "nmtools/array/view/ufuncs/multiply.hpp" @@ -43,13 +43,20 @@ namespace nmtools::view auto src_axis = meta::ct_v<-1>; auto dst_axis = meta::ct_v<-3>; - auto weight_ = view::moveaxis(view::atleast_3d(weight),src_axis,dst_axis); - auto bias_ = view::moveaxis(view::atleast_3d(bias),src_axis,dst_axis); - auto mean_ = view::moveaxis(view::atleast_3d(mean),src_axis,dst_axis); - auto var_ = view::moveaxis(view::atleast_3d(var),src_axis,dst_axis); + auto aliased = view::aliased(input,mean,var,weight,bias); + auto a_input = nmtools::get<0>(aliased); + auto a_mean = nmtools::get<1>(aliased); + auto a_var = nmtools::get<2>(aliased); + auto a_weight = nmtools::get<3>(aliased); + auto a_bias = nmtools::get<4>(aliased); + + auto weight_ = view::moveaxis(view::atleast_nd(a_weight,meta::ct_v<3>),src_axis,dst_axis); + auto bias_ = view::moveaxis(view::atleast_nd(a_bias,meta::ct_v<3>),src_axis,dst_axis); + auto mean_ = view::moveaxis(view::atleast_nd(a_mean,meta::ct_v<3>),src_axis,dst_axis); + auto var_ = view::moveaxis(view::atleast_nd(a_var,meta::ct_v<3>),src_axis,dst_axis); auto stddev_ = view::sqrt(view::add(var_,eps)); - auto subtracted = view::subtract(input,mean_); + auto subtracted = view::subtract(a_input,mean_); auto divided = view::divide(subtracted,stddev_); auto multiplied = view::multiply(divided,weight_); return view::add(multiplied,bias_); diff --git a/include/nmtools/array/view/broadcast_arrays.hpp b/include/nmtools/array/view/broadcast_arrays.hpp index 9ecbf9bd6..838bd85c3 100644 --- a/include/nmtools/array/view/broadcast_arrays.hpp +++ b/include/nmtools/array/view/broadcast_arrays.hpp @@ -6,6 +6,7 @@ #include "nmtools/array/shape.hpp" #include "nmtools/meta.hpp" #include "nmtools/assert.hpp" +#include "nmtools/array/view/alias.hpp" namespace nmtools::view { @@ -16,30 +17,36 @@ namespace nmtools::view * @param arrays * @return constexpr auto */ - template - constexpr auto broadcast_arrays(const arrays_t&...arrays) + template typename tuple, typename...arrays_t, auto...Is> + constexpr auto aliased_broadcast_arrays(const tuple& arrays, meta::index_sequence) { static_assert( sizeof...(arrays_t) >= 2 , "please provide at least two arrays for broadcast_arrays"); - auto bcast_shape = index::broadcast_shape(shape(arrays)...); - auto bcast_size = index::broadcast_size(bcast_shape,size(arrays)...); + auto bcast_shape = index::broadcast_shape(shape(nmtools::get(arrays))...); + auto bcast_size = index::broadcast_size(bcast_shape,size(nmtools::get(arrays))...); if constexpr (meta::is_maybe_v) { // avoid tuple...> because not usable in constexpr evaluation using bcast_shape_t = meta::get_maybe_type_t; using bcast_size_t = meta::get_maybe_type_t; - using result_t = nmtools_tuple(),meta::declval())))...>; + using result_t = nmtools_tuple(arrays),meta::declval(),meta::declval())))...>; using return_t = nmtools_maybe; - if (static_cast(bcast_shape)) { - return return_t{nmtools_tuple{unwrap(view::broadcast_to(arrays,*bcast_shape,*bcast_size))...}}; - } else { - return return_t{meta::Nothing}; - } + return (static_cast(bcast_shape) + ? return_t{nmtools_tuple{unwrap(view::broadcast_to(nmtools::get(arrays),*bcast_shape,*bcast_size))...}} + : return_t{meta::Nothing} + ); } else { // unwrap to avoid construction of tuple...> at compile-time - return nmtools_tuple{unwrap(view::broadcast_to(arrays,bcast_shape,bcast_size))...}; + return nmtools_tuple{unwrap(view::broadcast_to(nmtools::get(arrays),bcast_shape,bcast_size))...}; } } // broadcast_arrays + + template + constexpr auto broadcast_arrays(const arrays_t&...arrays) + { + auto aliased_pack = view::aliased(arrays...); + return aliased_broadcast_arrays(aliased_pack,meta::make_index_sequence_v); + } } // namespace nmtools::view #endif // NMTOOLS_ARRAY_VIEW_BROADCAST_ARRAYS_HPP \ No newline at end of file diff --git a/include/nmtools/array/view/decorator.hpp b/include/nmtools/array/view/decorator.hpp index 8ba0df6c2..c318c1378 100644 --- a/include/nmtools/array/view/decorator.hpp +++ b/include/nmtools/array/view/decorator.hpp @@ -10,6 +10,8 @@ #include "nmtools/array/utility/apply_at.hpp" #include "nmtools/array/index/ref.hpp" #include "nmtools/array/index/product.hpp" +#include "nmtools/array/index/alias.hpp" +#include "nmtools/array/index/append.hpp" // TODO: move to shape.hpp #ifdef NMTOOLS_ENABLE_BOOST @@ -24,6 +26,7 @@ #define NMTOOLS_NO_BASE_ACCESS #endif +// TODO: remove namespace nmtools::view::detail { using utility::tuple_append, utility::tuple_cat; @@ -122,7 +125,35 @@ namespace nmtools::meta template typename lhs_t, typename rhs_t> constexpr inline auto is_same_view_v = is_same_view::value; -} + + template + struct get_type_id + { + static constexpr auto value = index::generate_alias(meta::type_name_v); + }; + + template + constexpr inline auto get_type_id_v = get_type_id::value; + + template + constexpr auto generate_view_id(const operands_ids_t&, as_value) + { + // avoid calling to_value_v> + constexpr auto id_sequence = [](){ + constexpr auto type_id = get_type_id_v; + constexpr auto B_SIZE = bounded_size_v; + static_assert( !is_fail_v ); + if constexpr (B_SIZE == 0) { + return nmtools_array{type_id}; + } else { + constexpr auto operands_ids = to_value_v; + return index::append(operands_ids,type_id); + } + }(); + constexpr auto node_id = index::generate_alias(id_sequence); + return meta::ct_v; + } +} // namespace nmtools::meta namespace nmtools::view { @@ -254,6 +285,7 @@ namespace nmtools::view */ template typename view_t, typename...Ts> struct decorator_t + // TODO: remove conditional macro, c++ for opencl temporarily dropped #ifndef NMTOOLS_NO_BASE_ACCESS : view_t #endif // NMTOOLS_NO_BASE_ACCESS @@ -282,6 +314,20 @@ namespace nmtools::view static constexpr auto arity = meta::ct_v>; + static constexpr auto operands_ids = [](){ + // TODO: check for unique-ness of operands ids + return meta::template_reduce([&](auto init, auto index){ + constexpr auto I = decltype(index)::value; + using operand_type = meta::remove_cvref_t>; + if constexpr (meta::has_id_type_v) { + return utility::tuple_append(init,typename operand_type::id_type{}); + } else { + return utility::tuple_append(init,index); + } + }, nmtools_tuple{}); + }(); + + #if 0 // check for "free"/"unnamed" operands, which we can NOT assign the id freely static constexpr auto unnamed_operands = [](){ return meta::template_reduce([&](auto init, auto index){ @@ -320,20 +366,24 @@ namespace nmtools::view return meta::as_value_v>; } }(); - using id_type = meta::type_t; - static constexpr auto operands_ids = [](){ - // TODO: check for unique-ness of operands ids - return meta::template_reduce([&](auto init, auto index){ - constexpr auto I = decltype(index)::value; - using operand_type = meta::remove_cvref_t>; - if constexpr (meta::has_id_type_v) { - return utility::tuple_append(init,typename operand_type::id_type{}); - } else { - return utility::tuple_append(init,index); - } - }, nmtools_tuple{}); + #else + static constexpr auto id_vtype = [](){ + if constexpr (meta::has_id_type_v) { + using type = typename view_type::id_type; + return meta::as_value_v; + } else { + // constexpr auto type_id = meta::get_type_id_v; + // constexpr auto id_sequence = index::append(operands_ids,type_id); + // constexpr auto id = index::generate_alias(id_sequence); + constexpr auto id = meta::generate_view_id(operands_ids,meta::as_value_v); + using type = meta::ct::value>; + return meta::as_value_v; + } }(); + #endif + + using id_type = meta::type_t; static_assert( meta::is_constant_index_v , "invalid identifier for view type" ); @@ -551,6 +601,28 @@ namespace nmtools::view }; // decorator_t + template + struct get_id + { + static constexpr auto value = -1; + }; + + template typename view_t, typename...Ts> + struct get_id< + decorator_t + > { + static constexpr auto value = decorator_t::id_type::value; + }; + + template + struct get_id : get_id {}; + + template + struct get_id : get_id {}; + + template + constexpr inline auto get_id_v = get_id::value; + // TODO: remove /** * @brief make view given parameters arrays @@ -639,7 +711,7 @@ namespace nmtools::view if constexpr (is_none_v && !meta::is_tuple_v) { // init typelist, use tuple for now, // TODO: deduce template template of nocv_array_t if possible - using type = meta::make_tuple_type_t; + using type = nmtools_tuple; return meta::as_value_v; } else if constexpr (is_none_v && meta::is_tuple_v) { return meta::as_value_v; @@ -805,7 +877,7 @@ namespace nmtools::view // convert to array type that has value semantics constexpr auto N = meta::len_v; using elem_t = meta::get_element_type_t; - using type = meta::make_array_type_t; + using type = nmtools_array; return meta::as_value_v; } else /* if constexpr ( is_none_v diff --git a/include/nmtools/array/view/mean.hpp b/include/nmtools/array/view/mean.hpp index c860ab450..40f35b106 100644 --- a/include/nmtools/array/view/mean.hpp +++ b/include/nmtools/array/view/mean.hpp @@ -122,7 +122,8 @@ namespace nmtools::view // but by composing two view (add.reduce + divide) instead auto shape = ::nmtools::shape(array); - auto divisor = detail::mean_divisor(shape,axis); + // TODO: error handling + auto divisor = detail::mean_divisor(unwrap(shape),axis); using divisor_t = decltype(divisor); using element_t = meta::get_element_type_t; auto dtype_ = [&](){ diff --git a/include/nmtools/array/view/moveaxis.hpp b/include/nmtools/array/view/moveaxis.hpp index 39901ccef..2225a9539 100644 --- a/include/nmtools/array/view/moveaxis.hpp +++ b/include/nmtools/array/view/moveaxis.hpp @@ -33,12 +33,7 @@ namespace nmtools::view { auto shape_ = shape(array); auto order = index::moveaxis_to_transpose(shape_,source,destination); - // order should be maybe type - using result_t = meta::remove_cvref_t; - nmtools_assert_prepare_type( return_t, result_t ); - nmtools_assert( order, "unsupported moveaxis arguments", return_t ); - - return return_t{view::transpose(array,*order)}; + return view::transpose(array,order); } // moveaxis } // namespace nmtools::view diff --git a/include/nmtools/array/view/ref.hpp b/include/nmtools/array/view/ref.hpp index 67f72c68a..96ed1357f 100644 --- a/include/nmtools/array/view/ref.hpp +++ b/include/nmtools/array/view/ref.hpp @@ -32,7 +32,7 @@ namespace nmtools::view constexpr auto identity(size_types...indices) { using common_size_t = meta::type_t>; - using indices_t = meta::make_array_type_t; + using indices_t = nmtools_array; auto ndindex = indices_t{static_cast(indices)...}; return ndindex; } // identity diff --git a/include/nmtools/array/view/ref/initializer_list.hpp b/include/nmtools/array/view/ref/initializer_list.hpp index c58597086..f101fa236 100644 --- a/include/nmtools/array/view/ref/initializer_list.hpp +++ b/include/nmtools/array/view/ref/initializer_list.hpp @@ -115,7 +115,7 @@ namespace nmtools */ constexpr auto shape() const { - auto shape_ = meta::make_array_type_t{}; + auto shape_ = nmtools_array{}; // get the size for each 'nested' list, dimension known at compile time // assuming each element on each 'axis' has the same size // this techniques is similar with nmtools::shape, but since diff --git a/include/nmtools/array/view/softmax.hpp b/include/nmtools/array/view/softmax.hpp index e3f19c931..5b75cde38 100644 --- a/include/nmtools/array/view/softmax.hpp +++ b/include/nmtools/array/view/softmax.hpp @@ -1,6 +1,7 @@ #ifndef NMTOOLS_ARRAY_VIEW_SOFTMAX_HPP #define NMTOOLS_ARRAY_VIEW_SOFTMAX_HPP +#include "nmtools/array/view/alias.hpp" #include "nmtools/array/view/ufuncs/exp.hpp" #include "nmtools/array/view/ufuncs/add.hpp" #include "nmtools/array/view/ufuncs/maximum.hpp" @@ -13,23 +14,25 @@ namespace nmtools::view * @brief Applies softmax function along specified axis * * @tparam input_t - * @tparam dim_t + * @tparam axis_t * @param input input array - * @param dim axis where softmax to be applied + * @param axis axis where softmax to be applied * @return constexpr auto */ - template - constexpr auto softmax(const input_t& input, dim_t dim) + template + constexpr auto softmax(const input_t& input, axis_t axis) { - // following pytorch, only allow index dim (index array dim not allowed) - static_assert( meta::is_index_v - , "unsupported softmax, expect dim to be index" + auto a_input = view::aliased(input); + // following pytorch, only allow index axis (index array axis not allowed) + static_assert( meta::is_index_v + , "unsupported softmax, expect axis to be index" ); // NOTE: this follow https://cs231n.github.io/linear-classify/#softmax for numerical stability - auto input_ = view::subtract(input,view::reduce_maximum(input,/*axis=*/dim,/*dtype=*/None,/*initial=*/None,/*keepdims=*/True)); - auto input_exp = view::exp(input_); - auto reduced = view::reduce_add(input_exp,/*axis=*/dim,/*dtype=*/None,/*initial=*/None,/*keepdims=*/True); - return view::divide(input_exp,reduced); + auto a = view::reduce_maximum(a_input,axis,/*dtype=*/None,/*initial=*/None,/*keepdims=*/True); + auto b = view::subtract(a_input,a); + auto c = view::exp(b); + auto d = view::reduce_add(c,axis,/*dtype=*/None,/*initial=*/None,/*keepdims=*/True); + return view::divide(c,d); } // softmax } // namespace nmtools::view diff --git a/include/nmtools/array/view/transpose.hpp b/include/nmtools/array/view/transpose.hpp index dfeabada9..a58c5ce66 100644 --- a/include/nmtools/array/view/transpose.hpp +++ b/include/nmtools/array/view/transpose.hpp @@ -66,20 +66,38 @@ namespace nmtools::view } }; // transpose_t - template - constexpr auto make_transpose(const array_t& array, const axes_t& axes=axes_t{}) + template + constexpr auto transposer(const shape_t& src_shape, size_t src_size, const axes_t& axes=axes_t{}) { - auto src_shape = shape(array); - auto src_size = size(array); - auto indexer = transpose_t{src_shape,axes,src_size}; - return indexing(array,indexer); + if constexpr (meta::is_maybe_v) { + // assume when shape is maybe, size is also maybe + using result_t = decltype(transposer(unwrap(src_shape),unwrap(src_size),axes)); + using return_t = meta::conditional_t,result_t,nmtools_maybe>; + return (static_cast(src_shape) + ? return_t{transposer(unwrap(src_shape),unwrap(src_size),axes)} + : return_t{meta::Nothing} + ); + } else if constexpr (meta::is_maybe_v) { + using result_t = decltype(transposer(src_shape,src_size,unwrap(axes))); + using return_t = nmtools_maybe; + return (static_cast(axes) + ? return_t{transposer(src_shape,src_size,*axes)} + : return_t{meta::Nothing} + ); + } else { + auto indexer = transpose_t{src_shape,axes,src_size}; + return indexer; + } } template constexpr auto transpose(const array_t& array, const axes_t& axes=axes_t{}) { - auto f = [](const auto&...args){ - return make_transpose(args...); + auto f = [](const auto& array, const auto& axes){ + auto src_shape = shape(array); + auto src_size = size(array); + auto indexer = transposer(src_shape,src_size,axes); + return indexing(array,indexer); }; return lift_indexing(f,array,axes); } @@ -100,7 +118,7 @@ namespace nmtools::array auto src_shape = as_static(attribute.src_shape); auto axes = as_static(attribute.axes); auto src_size = as_static(attribute.src_size); - return view::transpose_t{src_shape,axes,src_size}; + return view::transposer(src_shape,axes,src_size); } }; } // namespace nmtools::array diff --git a/include/nmtools/array/view/ufunc.hpp b/include/nmtools/array/view/ufunc.hpp index 0bd99d369..e638ac6fd 100644 --- a/include/nmtools/array/view/ufunc.hpp +++ b/include/nmtools/array/view/ufunc.hpp @@ -5,6 +5,7 @@ #include "nmtools/array/view/ufunc/reduce.hpp" #include "nmtools/array/view/ufunc/accumulate.hpp" #include "nmtools/array/view/ufunc/outer.hpp" +#include "nmtools/array/view/alias.hpp" namespace nmtools::view { @@ -380,7 +381,7 @@ namespace nmtools::view else if constexpr (meta::is_boolean_v) { using left_t = decltype(reduce(op,array,axis,dtype,initial,True)); using right_t = decltype(reduce(op,array,axis,dtype,initial,False)); - using either_t = meta::make_either_type_t; + using either_t = nmtools_either; return ( keepdims ? either_t{reduce(op,array,axis,dtype,initial,True)} diff --git a/include/nmtools/array/view/ufunc/detail.hpp b/include/nmtools/array/view/ufunc/detail.hpp index cc17a38f7..da642245b 100644 --- a/include/nmtools/array/view/ufunc/detail.hpp +++ b/include/nmtools/array/view/ufunc/detail.hpp @@ -42,7 +42,7 @@ namespace nmtools::view::detail static_assert( sizeof...(arrays_t) > 0, "nmtools internal error" ); static constexpr auto vtype = [](){ // TODO: no need to create full tuple here, use meta::type_list - using arrays_type = meta::make_tuple_type_t; + using arrays_type = nmtools_tuple; // for view / num type: simply copy (take value), // otherwise take reference return meta::template_reduce([](auto init, auto index){ @@ -51,7 +51,7 @@ namespace nmtools::view::detail using operand_t = view::resolve_array_type_t; using init_t = meta::type_t; if constexpr (meta::is_void_v) { - return meta::as_value_v>; + return meta::as_value_v>; } else { return meta::as_value_v>; } diff --git a/include/nmtools/array/view/ufunc/ufunc.hpp b/include/nmtools/array/view/ufunc/ufunc.hpp index 374ab5da4..21e0eee3c 100644 --- a/include/nmtools/array/view/ufunc/ufunc.hpp +++ b/include/nmtools/array/view/ufunc/ufunc.hpp @@ -90,7 +90,7 @@ namespace nmtools::view { // dont take reference for the array, a Num type should be copied // and view type should be cheap to copy - using operands_type = meta::make_tuple_type_t; + using operands_type = nmtools_tuple; using array_type = operands_type; using op_type = op_t; using result_type = detail::get_ufunc_result_type_t...>; diff --git a/include/nmtools/array/view/ufuncs/power.hpp b/include/nmtools/array/view/ufuncs/power.hpp index 2ecd8bee3..1f97c4ce4 100644 --- a/include/nmtools/array/view/ufuncs/power.hpp +++ b/include/nmtools/array/view/ufuncs/power.hpp @@ -6,6 +6,7 @@ namespace nmtools::view { + #if 0 template < typename lhs_t=none_t, typename rhs_t=none_t, typename res_t=none_t, typename=void> @@ -15,7 +16,47 @@ namespace nmtools::view NMTOOLS_UFUNC_CONSTEXPR auto operator()(const T& t, const U& u) const { - return math::pow(t,u); + // TODO: formalize casting rules + using common_t [[maybe_unused]] = meta::common_type_t; + using result_t [[maybe_unused]] = meta::conditional_t,common_t,res_t>; + auto lhs = [&](){ + if constexpr (!is_none_v) { + return static_cast(t); + } else { + return static_cast(t); + } + }(); + auto rhs = [&](){ + if constexpr (!is_none_v) { + return static_cast(u); + } else { + return static_cast(u); + } + }(); + auto result = math::pow(lhs,rhs); + if constexpr (!is_none_v) { + return static_cast(result); + } else { + return result; + } + } // operator() + }; // power_t + #else + template < + typename lhs_t=none_t, typename rhs_t=none_t, + typename res_t=none_t, typename=void> + struct power_t + { + template + NMTOOLS_UFUNC_CONSTEXPR + auto operator()(const T& t, const U& u) const + { + if constexpr (meta::is_view_v || meta::is_view_v) { + using common_t [[maybe_unused]] = meta::common_type_t; + return math::pow(static_cast(t),static_cast(u)); + } else { + return math::pow(t,u); + } } // operator() }; // power_t @@ -49,6 +90,7 @@ namespace nmtools::view return math::pow(t,u); } // operator() }; // power_t + #endif template NMTOOLS_UFUNC_CONSTEXPR diff --git a/include/nmtools/array/view/var.hpp b/include/nmtools/array/view/var.hpp index 403c0f08e..669e55d6f 100644 --- a/include/nmtools/array/view/var.hpp +++ b/include/nmtools/array/view/var.hpp @@ -1,6 +1,7 @@ #ifndef NMTOOLS_ARRAY_VIEW_VAR_HPP #define NMTOOLS_ARRAY_VIEW_VAR_HPP +#include "nmtools/array/view/alias.hpp" #include "nmtools/array/view/ufuncs/fabs.hpp" #include "nmtools/array/view/ufuncs/square.hpp" #include "nmtools/array/view/ufuncs/subtract.hpp" @@ -27,21 +28,17 @@ namespace nmtools::view template constexpr auto var(const array_t& array, const axis_t& axis, dtype_t dtype=dtype_t{}, ddof_t ddof=ddof_t{0}, keepdims_t keepdims=keepdims_t{}) { + auto input = view::aliased(array); // must keep dimension to properly subtract - auto a = mean(array,axis,dtype,/*keepdims=*/True); - auto b = subtract(array, a); - auto c = fabs(b); - auto d = square(c); + auto a = view::mean(input,axis,dtype,/*keepdims=*/True); + auto b = view::subtract(input, a); + auto c = view::fabs(b); + auto d = view::square(c); // no reason to start from other initial value - auto e = sum(d,axis,dtype,/*initial=*/None,keepdims); - // TODO: fix unwrap to handle bounded array - #if 0 + auto e = view::sum(d,axis,dtype,/*initial=*/None,keepdims); // TODO: error handling - auto N = detail::mean_divisor(::nmtools::shape(unwrap(array)),axis); - #else - auto N = detail::mean_divisor(::nmtools::shape(array),axis); - #endif - return divide(e,N-ddof); + auto N = detail::mean_divisor(::nmtools::shape(unwrap(input)),axis); + return view::divide(e,N-ddof); } // var } // namespace nmtools::view diff --git a/include/nmtools/utility/ct_digraph.hpp b/include/nmtools/utility/ct_digraph.hpp index b03f88d62..50c6151ab 100644 --- a/include/nmtools/utility/ct_digraph.hpp +++ b/include/nmtools/utility/ct_digraph.hpp @@ -2,6 +2,7 @@ #define NMTOOLS_UTILITY_CT_DIGRAPH_HPP #include "nmtools/meta.hpp" +#include "nmtools/array/at.hpp" #include "nmtools/utility/ct_map.hpp" namespace nmtools::utility @@ -54,19 +55,39 @@ namespace nmtools::utility constexpr auto add_edge(from_t from, to_t to) const noexcept { auto edges = digraph.at(from); - auto new_edges = tuple_append(edges,to); - auto new_digraph = digraph.update(from,/*edges*/new_edges); - using new_digraph_type = decltype(new_digraph); - using new_node_data_type = decltype(node_data); - using new_ct_digraph_type = ct_digraph< - meta::remove_cvref_t - , meta::remove_cvref_t - , meta::remove_cvref_t - >; - return new_ct_digraph_type( - new_digraph - , node_data - ); + // avoid duplication + auto contains = [](auto tuple, auto key){ + constexpr auto N = meta::len_v; + constexpr auto KEY = meta::to_value_v; + return meta::template_reduce([&](auto init, auto index){ + auto element = nmtools::at(tuple,index); + constexpr auto ELEMENT = meta::to_value_v; + if constexpr (ELEMENT == KEY) { + return meta::true_type{}; + } else { + return init; + } + }, meta::false_type{}); + }; + auto has_key = contains(edges,to); + constexpr auto HAS_KEY = meta::to_value_v; + if constexpr (HAS_KEY) { + return *this; + } else { + auto new_edges = tuple_append(edges,to); + auto new_digraph = digraph.update(from,/*edges*/new_edges); + using new_digraph_type = decltype(new_digraph); + using new_node_data_type = decltype(node_data); + using new_ct_digraph_type = ct_digraph< + meta::remove_cvref_t + , meta::remove_cvref_t + , meta::remove_cvref_t + >; + return new_ct_digraph_type( + new_digraph + , node_data + ); + } } constexpr auto nodes() const noexcept @@ -103,7 +124,7 @@ namespace nmtools::utility { return digraph.at(key); } - }; + }; // ct_digraph template ct_digraph(const ct_map&, const ct_map&) -> ct_digraph; diff --git a/include/nmtools/utility/fwd.hpp b/include/nmtools/utility/fwd.hpp index 243a1eb9f..cbd780365 100644 --- a/include/nmtools/utility/fwd.hpp +++ b/include/nmtools/utility/fwd.hpp @@ -222,7 +222,7 @@ namespace nmtools } template typename tuple, typename...Ts, auto...Is> - constexpr auto append_operands(const tuple& ts, const T& t) + constexpr auto append_operands(const tuple& ts, const T& t, meta::index_sequence) { using result_t = tuple< meta::fwd_operand_t>... diff --git a/include/nmtools/utility/unwrap.hpp b/include/nmtools/utility/unwrap.hpp index 2b9762779..96bb7e4dc 100644 --- a/include/nmtools/utility/unwrap.hpp +++ b/include/nmtools/utility/unwrap.hpp @@ -29,7 +29,7 @@ namespace nmtools } } - #if 0 + #if 1 template constexpr auto unwrap(const T(&t)[N]) -> const T(&)[N]