Skip to content

Commit

Permalink
Add kron, vecdot, and tensordot (#304)
Browse files Browse the repository at this point in the history
* add kron

* add max_value & min_value metafunctions

* add tests

* add vecdot

* add tensordot

* move out index::contains form expand_dims, add range index function

* add is_clipped_index metafunction

* update tests

* skip kron compile-time shape inference when on gcc

* update compiler-notes

* fix for gcc werror

* fix for gcc werror
  • Loading branch information
alifahrri authored Oct 29, 2024
1 parent 64d1bb5 commit e1e7bf5
Show file tree
Hide file tree
Showing 31 changed files with 7,546 additions and 21 deletions.
5 changes: 4 additions & 1 deletion docs/compiler-notes.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,7 @@ Documenting various note on behaviour difference between clang & gcc (or with so

1. clang vs gcc disagree on capturing constexpr value in lambda expression
clang ok, gcc not ok
https://godbolt.org/z/a1o8P9957
https://godbolt.org/z/a1o8P9957

1. gcc `for` loop becomes goto in constexpr context (and breaks), works fine on clang
https://github.com/alifahrri/nmtools/issues/303
24 changes: 24 additions & 0 deletions include/nmtools/array/array/kron.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef NMTOOLS_ARRAY_ARRAY_KRON_HPP
#define NMTOOLS_ARRAY_ARRAY_KRON_HPP

#include "nmtools/array/view/kron.hpp"
#include "nmtools/array/eval.hpp"

namespace nmtools::array
{
template <typename output_t=none_t, typename context_t=none_t, typename resolver_t=eval_result_t<>
, typename lhs_t, typename rhs_t>
constexpr auto kron(const lhs_t& lhs, const rhs_t& rhs
, context_t&& context=context_t{}, output_t&& output=output_t{}, meta::as_value<resolver_t> resolver=meta::as_value_v<resolver_t>)
{
auto a = view::kron(lhs,rhs);
return eval(
a
, nmtools::forward<context_t>(context)
, nmtools::forward<output_t>(output)
, resolver
);
} // kron
}

#endif // NMTOOLS_ARRAY_ARRAY_KRON_HPP
24 changes: 24 additions & 0 deletions include/nmtools/array/array/tensordot.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef NMTOOLS_ARRAY_ARRAY_TENSORDOT_HPP
#define NMTOOLS_ARRAY_ARRAY_TENSORDOT_HPP

#include "nmtools/array/view/tensordot.hpp"
#include "nmtools/array/eval.hpp"

namespace nmtools::array
{
template <typename output_t=none_t, typename context_t=none_t, typename resolver_t=eval_result_t<>
, typename lhs_t, typename rhs_t, typename axes_t=meta::ct<2>>
constexpr auto tensordot(const lhs_t& lhs, const rhs_t& rhs, axes_t axes=axes_t{}
, context_t&& context=context_t{}, output_t&& output=output_t{},meta::as_value<resolver_t> resolver=meta::as_value_v<resolver_t>)
{
auto a = view::tensordot(lhs,rhs,axes);
return eval(
a
, nmtools::forward<context_t>(context)
, nmtools::forward<output_t>(output)
, resolver
);
} // tensordot
} // nmtools::array

#endif // NMTOOLS_ARRAY_ARRAY_TENSORDOT_HPP
24 changes: 24 additions & 0 deletions include/nmtools/array/array/vecdot.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef NMTOOLS_ARRAY_ARRAY_VECDOT_HPP
#define NMTOOLS_ARRAY_ARRAY_VECDOT_HPP

#include "nmtools/array/view/vecdot.hpp"
#include "nmtools/array/eval.hpp"

namespace nmtools::array
{
template <typename output_t=none_t, typename context_t=none_t, typename resolver_t=eval_result_t<>
, typename lhs_t, typename rhs_t, typename dtype_t=none_t, typename keepdims_t=meta::false_type>
constexpr auto vecdot(const lhs_t& lhs, const rhs_t& rhs, dtype_t dtype=dtype_t{}, keepdims_t keepdims=keepdims_t{}
, context_t&& context=context_t{}, output_t&& output=output_t{},meta::as_value<resolver_t> resolver=meta::as_value_v<resolver_t>)
{
auto a = view::vecdot(lhs,rhs,dtype,keepdims);
return eval(
a
, nmtools::forward<context_t>(context)
, nmtools::forward<output_t>(output)
, resolver
);
} // vecdot
} // nmtools::array

#endif // NMTOOLS_ARRAY_ARRAY_VECDOT_HPP
22 changes: 22 additions & 0 deletions include/nmtools/array/index/contains.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef NMTOOLS_ARRAY_INDEX_CONTAINS_HPP
#define NMTOOLS_ARRAY_INDEX_CONTAINS_HPP

#include "nmtools/meta.hpp"
#include "nmtools/array/shape.hpp"
#include "nmtools/utils/isequal.hpp"

namespace nmtools::index
{
template <typename array_t, typename value_t>
constexpr auto contains(const array_t& array, const value_t& value)
{
for (nm_size_t i=0; i<(nm_size_t)len(array); i++) {
if (utils::isequal(at(array,i),value)) {
return true;
}
}
return false;
} // contains
} // nmtools::index

#endif // NMTOOLS_ARRAY_INDEX_CONTAINS_HPP
21 changes: 1 addition & 20 deletions include/nmtools/array/index/expand_dims.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "nmtools/array/utility/at.hpp"
#include "nmtools/utils/isequal.hpp"
#include "nmtools/array/ndarray/hybrid.hpp"
#include "nmtools/array/index/contains.hpp"
#include "nmtools/array/index/normalize_axis.hpp"
#include "nmtools/utility/unwrap.hpp"

Expand All @@ -22,26 +23,6 @@ namespace nmtools::index
*/
struct shape_expand_dims_t {};

// TODO: remove
template <typename array_t, typename value_t>
constexpr auto contains(const array_t& array, const value_t& value)
{
if constexpr (meta::is_fixed_index_array_v<array_t>) {
bool contain = false;
meta::template_for<meta::len_v<array_t>>([&](auto i){
if (utils::isequal(at(array,i),value))
contain = true;
});
return contain;
}
else {
for (size_t i=0; i<len(array); i++)
if (utils::isequal(at(array,i),value))
return true;
return false;
}
} // contains

/**
* @brief extend the shape with value 1 for each given axis
*
Expand Down
96 changes: 96 additions & 0 deletions include/nmtools/array/index/range.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#ifndef NMTOOLS_ARRAY_INDEX_RANGE_HPP
#define NMTOOLS_ARRAY_INDEX_RANGE_HPP

#include "nmtools/meta.hpp"
#include "nmtools/array/shape.hpp"

namespace nmtools::index
{
struct range_t {};

template <typename start_t, typename stop_t, typename step_t=meta::ct<1>>
constexpr auto range([[maybe_unused]] start_t start
, [[maybe_unused]] stop_t stop
, [[maybe_unused]] step_t step=step_t{}
) {
using result_t = meta::resolve_optype_t<range_t,start_t,stop_t,step_t>;

auto result = result_t {};

if constexpr (!meta::is_fail_v<result_t>
&& !meta::is_constant_index_array_v<result_t>
) {
auto n = (stop - start) / step;
if constexpr (meta::is_resizable_v<result_t>) {
result.resize(n);
}

for (nm_size_t i=0; i<(nm_size_t)n; i++) {
at(result,i) = i * step;
}
}

return result;
} // range

template <typename stop_t>
constexpr auto range(stop_t stop)
{
return range(meta::ct_v<0>,stop,meta::ct_v<1>);
}
} // nmtools::index

namespace nmtools::meta
{
namespace error
{
template <typename...>
struct RANGE_UNSUPPORTED : detail::fail_t {};
}

template <typename start_t, typename stop_t, typename step_t>
struct resolve_optype<
void, index::range_t, start_t, stop_t, step_t
> {
static constexpr auto vtype = [](){
if constexpr (
!is_index_v<start_t>
|| !is_index_v<stop_t>
|| !is_index_v<step_t>
) {
using type = error::RANGE_UNSUPPORTED<start_t,stop_t,step_t>;
return as_value_v<type>;
} else if constexpr (is_constant_index_v<start_t>
&& is_constant_index_v<stop_t>
&& is_constant_index_v<step_t>
) {
constexpr auto start = to_value_v<start_t>;
constexpr auto stop = to_value_v<stop_t>;
constexpr auto step = to_value_v<step_t>;
constexpr auto start_cl = clipped_int64_t<int64_t(start > 0 ? start : 1)>(start);
constexpr auto stop_cl = clipped_int64_t<(int64_t)stop>(stop);
constexpr auto step_cl = clipped_int64_t<(int64_t)step>(step);
constexpr auto result = index::range(start_cl,stop_cl,step_cl);
using nmtools::at, nmtools::len;
return template_reduce<len(result)>([&](auto init, auto I){
using init_t = type_t<decltype(init)>;
using type = append_type_t<init_t,ct<at(result,I)>>;
return as_value_v<type>;
}, as_value_v<nmtools_tuple<>>);
} else {
constexpr auto max_dim = max_value_v<stop_t>;
if constexpr (!is_fail_v<decltype(max_dim)>) {
using type = nmtools_static_vector<nm_size_t,max_dim>;
return as_value_v<type>;
} else {
// TODO: small vector optimization
using type = nmtools_list<nm_size_t>;
return as_value_v<type>;
}
}
}();
using type = type_t<decltype(vtype)>;
}; // index::range_t
} // nmtools::meta

#endif // NMTOOLS_ARRAY_INDEX_RANGE_HPP
Loading

0 comments on commit e1e7bf5

Please sign in to comment.