-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add kron, vecdot, and tensordot (#304)
* add kron * add max_value & min_value metafunctions * add tests * add vecdot * add tensordot * move out index::contains form expand_dims, add range index function * add is_clipped_index metafunction * update tests * skip kron compile-time shape inference when on gcc * update compiler-notes * fix for gcc werror * fix for gcc werror
- Loading branch information
Showing
31 changed files
with
7,546 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
#ifndef NMTOOLS_ARRAY_ARRAY_KRON_HPP | ||
#define NMTOOLS_ARRAY_ARRAY_KRON_HPP | ||
|
||
#include "nmtools/array/view/kron.hpp" | ||
#include "nmtools/array/eval.hpp" | ||
|
||
namespace nmtools::array | ||
{ | ||
template <typename output_t=none_t, typename context_t=none_t, typename resolver_t=eval_result_t<> | ||
, typename lhs_t, typename rhs_t> | ||
constexpr auto kron(const lhs_t& lhs, const rhs_t& rhs | ||
, context_t&& context=context_t{}, output_t&& output=output_t{}, meta::as_value<resolver_t> resolver=meta::as_value_v<resolver_t>) | ||
{ | ||
auto a = view::kron(lhs,rhs); | ||
return eval( | ||
a | ||
, nmtools::forward<context_t>(context) | ||
, nmtools::forward<output_t>(output) | ||
, resolver | ||
); | ||
} // kron | ||
} | ||
|
||
#endif // NMTOOLS_ARRAY_ARRAY_KRON_HPP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
#ifndef NMTOOLS_ARRAY_ARRAY_TENSORDOT_HPP | ||
#define NMTOOLS_ARRAY_ARRAY_TENSORDOT_HPP | ||
|
||
#include "nmtools/array/view/tensordot.hpp" | ||
#include "nmtools/array/eval.hpp" | ||
|
||
namespace nmtools::array | ||
{ | ||
template <typename output_t=none_t, typename context_t=none_t, typename resolver_t=eval_result_t<> | ||
, typename lhs_t, typename rhs_t, typename axes_t=meta::ct<2>> | ||
constexpr auto tensordot(const lhs_t& lhs, const rhs_t& rhs, axes_t axes=axes_t{} | ||
, context_t&& context=context_t{}, output_t&& output=output_t{},meta::as_value<resolver_t> resolver=meta::as_value_v<resolver_t>) | ||
{ | ||
auto a = view::tensordot(lhs,rhs,axes); | ||
return eval( | ||
a | ||
, nmtools::forward<context_t>(context) | ||
, nmtools::forward<output_t>(output) | ||
, resolver | ||
); | ||
} // tensordot | ||
} // nmtools::array | ||
|
||
#endif // NMTOOLS_ARRAY_ARRAY_TENSORDOT_HPP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
#ifndef NMTOOLS_ARRAY_ARRAY_VECDOT_HPP | ||
#define NMTOOLS_ARRAY_ARRAY_VECDOT_HPP | ||
|
||
#include "nmtools/array/view/vecdot.hpp" | ||
#include "nmtools/array/eval.hpp" | ||
|
||
namespace nmtools::array | ||
{ | ||
template <typename output_t=none_t, typename context_t=none_t, typename resolver_t=eval_result_t<> | ||
, typename lhs_t, typename rhs_t, typename dtype_t=none_t, typename keepdims_t=meta::false_type> | ||
constexpr auto vecdot(const lhs_t& lhs, const rhs_t& rhs, dtype_t dtype=dtype_t{}, keepdims_t keepdims=keepdims_t{} | ||
, context_t&& context=context_t{}, output_t&& output=output_t{},meta::as_value<resolver_t> resolver=meta::as_value_v<resolver_t>) | ||
{ | ||
auto a = view::vecdot(lhs,rhs,dtype,keepdims); | ||
return eval( | ||
a | ||
, nmtools::forward<context_t>(context) | ||
, nmtools::forward<output_t>(output) | ||
, resolver | ||
); | ||
} // vecdot | ||
} // nmtools::array | ||
|
||
#endif // NMTOOLS_ARRAY_ARRAY_VECDOT_HPP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
#ifndef NMTOOLS_ARRAY_INDEX_CONTAINS_HPP | ||
#define NMTOOLS_ARRAY_INDEX_CONTAINS_HPP | ||
|
||
#include "nmtools/meta.hpp" | ||
#include "nmtools/array/shape.hpp" | ||
#include "nmtools/utils/isequal.hpp" | ||
|
||
namespace nmtools::index | ||
{ | ||
template <typename array_t, typename value_t> | ||
constexpr auto contains(const array_t& array, const value_t& value) | ||
{ | ||
for (nm_size_t i=0; i<(nm_size_t)len(array); i++) { | ||
if (utils::isequal(at(array,i),value)) { | ||
return true; | ||
} | ||
} | ||
return false; | ||
} // contains | ||
} // nmtools::index | ||
|
||
#endif // NMTOOLS_ARRAY_INDEX_CONTAINS_HPP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
#ifndef NMTOOLS_ARRAY_INDEX_RANGE_HPP | ||
#define NMTOOLS_ARRAY_INDEX_RANGE_HPP | ||
|
||
#include "nmtools/meta.hpp" | ||
#include "nmtools/array/shape.hpp" | ||
|
||
namespace nmtools::index | ||
{ | ||
struct range_t {}; | ||
|
||
template <typename start_t, typename stop_t, typename step_t=meta::ct<1>> | ||
constexpr auto range([[maybe_unused]] start_t start | ||
, [[maybe_unused]] stop_t stop | ||
, [[maybe_unused]] step_t step=step_t{} | ||
) { | ||
using result_t = meta::resolve_optype_t<range_t,start_t,stop_t,step_t>; | ||
|
||
auto result = result_t {}; | ||
|
||
if constexpr (!meta::is_fail_v<result_t> | ||
&& !meta::is_constant_index_array_v<result_t> | ||
) { | ||
auto n = (stop - start) / step; | ||
if constexpr (meta::is_resizable_v<result_t>) { | ||
result.resize(n); | ||
} | ||
|
||
for (nm_size_t i=0; i<(nm_size_t)n; i++) { | ||
at(result,i) = i * step; | ||
} | ||
} | ||
|
||
return result; | ||
} // range | ||
|
||
template <typename stop_t> | ||
constexpr auto range(stop_t stop) | ||
{ | ||
return range(meta::ct_v<0>,stop,meta::ct_v<1>); | ||
} | ||
} // nmtools::index | ||
|
||
namespace nmtools::meta | ||
{ | ||
namespace error | ||
{ | ||
template <typename...> | ||
struct RANGE_UNSUPPORTED : detail::fail_t {}; | ||
} | ||
|
||
template <typename start_t, typename stop_t, typename step_t> | ||
struct resolve_optype< | ||
void, index::range_t, start_t, stop_t, step_t | ||
> { | ||
static constexpr auto vtype = [](){ | ||
if constexpr ( | ||
!is_index_v<start_t> | ||
|| !is_index_v<stop_t> | ||
|| !is_index_v<step_t> | ||
) { | ||
using type = error::RANGE_UNSUPPORTED<start_t,stop_t,step_t>; | ||
return as_value_v<type>; | ||
} else if constexpr (is_constant_index_v<start_t> | ||
&& is_constant_index_v<stop_t> | ||
&& is_constant_index_v<step_t> | ||
) { | ||
constexpr auto start = to_value_v<start_t>; | ||
constexpr auto stop = to_value_v<stop_t>; | ||
constexpr auto step = to_value_v<step_t>; | ||
constexpr auto start_cl = clipped_int64_t<int64_t(start > 0 ? start : 1)>(start); | ||
constexpr auto stop_cl = clipped_int64_t<(int64_t)stop>(stop); | ||
constexpr auto step_cl = clipped_int64_t<(int64_t)step>(step); | ||
constexpr auto result = index::range(start_cl,stop_cl,step_cl); | ||
using nmtools::at, nmtools::len; | ||
return template_reduce<len(result)>([&](auto init, auto I){ | ||
using init_t = type_t<decltype(init)>; | ||
using type = append_type_t<init_t,ct<at(result,I)>>; | ||
return as_value_v<type>; | ||
}, as_value_v<nmtools_tuple<>>); | ||
} else { | ||
constexpr auto max_dim = max_value_v<stop_t>; | ||
if constexpr (!is_fail_v<decltype(max_dim)>) { | ||
using type = nmtools_static_vector<nm_size_t,max_dim>; | ||
return as_value_v<type>; | ||
} else { | ||
// TODO: small vector optimization | ||
using type = nmtools_list<nm_size_t>; | ||
return as_value_v<type>; | ||
} | ||
} | ||
}(); | ||
using type = type_t<decltype(vtype)>; | ||
}; // index::range_t | ||
} // nmtools::meta | ||
|
||
#endif // NMTOOLS_ARRAY_INDEX_RANGE_HPP |
Oops, something went wrong.