-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add swapaxes * add tests * fix werror
- Loading branch information
Showing
10 changed files
with
580 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
#ifndef NMTOOLS_ARRAY_ARRAY_SWAPAXES_HPP | ||
#define NMTOOLS_ARRAY_ARRAY_SWAPAXES_HPP | ||
|
||
#include "nmtools/array/view/swapaxes.hpp" | ||
#include "nmtools/array/eval.hpp" | ||
|
||
namespace nmtools::array | ||
{ | ||
template <typename output_t=none_t, typename context_t=none_t, typename resolver_t=eval_result_t<> | ||
, typename array_t, typename axis1_t, typename axis2_t> | ||
constexpr auto swapaxes(const array_t& array, axis1_t axis1, axis2_t axis2 | ||
, context_t&& context=context_t{}, output_t&& output=output_t{},meta::as_value<resolver_t> resolver=meta::as_value_v<resolver_t>) | ||
{ | ||
auto a = view::swapaxes(array,axis1,axis2); | ||
return eval( | ||
a | ||
, nmtools::forward<context_t>(context) | ||
, nmtools::forward<output_t>(output) | ||
, resolver | ||
); | ||
} | ||
} // nmtools::array | ||
|
||
#endif // NMTOOLS_ARRAY_ARRAY_SWAPAXES_HPP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
#ifndef NMTOOLS_ARRAY_VIEW_SWAPAXES_HPP | ||
#define NMTOOLS_ARRAY_VIEW_SWAPAXES_HPP | ||
|
||
#include "nmtools/meta.hpp" | ||
#include "nmtools/array/shape.hpp" | ||
#include "nmtools/array/index/normalize_axis.hpp" | ||
#include "nmtools/utility/unwrap.hpp" | ||
|
||
namespace nmtools::index | ||
{ | ||
struct swapaxes_to_transpose_t {}; | ||
|
||
template <typename src_dim_t, typename axis1_t, typename axis2_t> | ||
constexpr auto swapaxes_to_transpose([[maybe_unused]] const src_dim_t& src_dim, [[maybe_unused]] axis1_t axis1, [[maybe_unused]] axis2_t axis2) | ||
{ | ||
using result_t = meta::resolve_optype_t<swapaxes_to_transpose_t,src_dim_t,axis1_t,axis2_t>; | ||
|
||
auto result = result_t {}; | ||
if constexpr (!meta::is_constant_index_array_v<result_t> | ||
&& !meta::is_fail_v<result_t> | ||
) { | ||
if constexpr (meta::is_resizable_v<result_t>) { | ||
result.resize(src_dim); | ||
} | ||
|
||
for (nm_size_t i=0; i<(nm_size_t)src_dim; i++) { | ||
at(result,i) = i; | ||
} | ||
|
||
// TODO: propagate error handling | ||
auto m_axis1 = unwrap(normalize_axis(axis1,src_dim)); | ||
auto m_axis2 = unwrap(normalize_axis(axis2,src_dim)); | ||
|
||
auto tmp = at(result,m_axis1); | ||
at(result,m_axis1) = at(result,m_axis2); | ||
at(result,m_axis2) = tmp; | ||
} | ||
|
||
return result; | ||
} // swapaxes_to_transpose | ||
|
||
} // nmtools::index | ||
|
||
namespace nmtools::meta | ||
{ | ||
namespace error | ||
{ | ||
template <typename...> | ||
struct SWAPAXES_TO_TRANSPOSE_UNSUPPORTED : detail::fail_t {}; | ||
} | ||
|
||
template <typename src_dim_t, typename axis1_t, typename axis2_t> | ||
struct resolve_optype< | ||
void, index::swapaxes_to_transpose_t, src_dim_t, axis1_t, axis2_t | ||
> { | ||
static constexpr auto vtype = [](){ | ||
if constexpr (!is_index_v<src_dim_t> | ||
|| !is_index_v<axis1_t> | ||
|| !is_index_v<axis2_t> | ||
) { | ||
using type = error::SWAPAXES_TO_TRANSPOSE_UNSUPPORTED<src_dim_t,axis1_t,axis2_t>; | ||
return as_value_v<type>; | ||
} else if constexpr ( | ||
is_constant_index_v<src_dim_t> | ||
&& is_constant_index_v<axis1_t> | ||
&& is_constant_index_v<axis2_t> | ||
) { | ||
constexpr auto axis1 = to_value_v<axis1_t>; | ||
constexpr auto axis2 = to_value_v<axis2_t>; | ||
constexpr auto result = index::swapaxes_to_transpose(src_dim_t{},axis1,axis2); | ||
using nmtools::len, nmtools::at; | ||
return template_reduce<len(result)>([&](auto init, auto I){ | ||
using init_t = type_t<decltype(init)>; | ||
using type = append_type_t<init_t,ct<at(result,I)>>; | ||
return as_value_v<type>; | ||
}, as_value_v<nmtools_tuple<>>); | ||
} else if constexpr (is_constant_index_v<src_dim_t>) { | ||
constexpr auto N = to_value_v<src_dim_t>; | ||
using type = nmtools_array<nm_size_t,N>; | ||
return as_value_v<type>; | ||
} else if constexpr (is_clipped_integer_v<src_dim_t>) { | ||
constexpr auto N = to_value_v<src_dim_t>; | ||
using type = nmtools_static_vector<nm_size_t,N>; | ||
return as_value_v<type>; | ||
} else { | ||
// TODO: support small vector | ||
using type = nmtools_list<nm_size_t>; | ||
return as_value_v<type>; | ||
} | ||
}(); | ||
using type = type_t<decltype(vtype)>; | ||
}; | ||
} // nmtools::meta | ||
|
||
#include "nmtools/array/view/transpose.hpp" | ||
|
||
namespace nmtools::view | ||
{ | ||
template <typename array_t, typename axis1_t, typename axis2_t> | ||
constexpr auto swapaxes(const array_t& array, axis1_t axis1, axis2_t axis2) | ||
{ | ||
auto src_dim = dim<true>(array); | ||
auto axes = index::swapaxes_to_transpose(src_dim,axis1,axis2); | ||
return view::transpose(array,axes); | ||
} | ||
} // nmtools::view | ||
|
||
#endif // NMTOOLS_ARRAY_VIEW_SWAPAXES_HPP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
#ifndef NMTOOLS_TESTING_DATA_ARRAY_SWAPAXES_HPP | ||
#define NMTOOLS_TESTING_DATA_ARRAY_SWAPAXES_HPP | ||
|
||
#include "nmtools/testing/testing.hpp" | ||
#include "nmtools/testing/array_cast.hpp" | ||
|
||
NMTOOLS_TESTING_DECLARE_CASE(array,swapaxes) | ||
{ | ||
using namespace literals; | ||
|
||
NMTOOLS_TESTING_DECLARE_ARGS(case1) | ||
{ | ||
inline int a[1][3] = { | ||
{1,2,3} | ||
}; | ||
inline int axis1 = 0; | ||
inline int axis2 = 1; | ||
|
||
inline auto axis1_ct = 0_ct; | ||
inline auto axis2_ct = 1_ct; | ||
NMTOOLS_CAST_ARRAYS(a) | ||
} | ||
NMTOOLS_TESTING_DECLARE_EXPECT(case1) | ||
{ | ||
inline int result[3][1] = { | ||
{1}, | ||
{2}, | ||
{3}, | ||
}; | ||
} | ||
|
||
NMTOOLS_TESTING_DECLARE_ARGS(case2) | ||
{ | ||
inline int a[2][3][2] = { | ||
{ | ||
{0,1}, | ||
{2,3}, | ||
{4,5}, | ||
}, | ||
{ | ||
{ 6, 7}, | ||
{ 8, 9}, | ||
{10,11}, | ||
} | ||
}; | ||
inline int axis1 = 0; | ||
inline int axis2 = 1; | ||
|
||
inline auto axis1_ct = 0_ct; | ||
inline auto axis2_ct = 1_ct; | ||
NMTOOLS_CAST_ARRAYS(a) | ||
} | ||
NMTOOLS_TESTING_DECLARE_EXPECT(case2) | ||
{ | ||
inline int result[3][2][2] = { | ||
{ | ||
{0,1}, | ||
{6,7}, | ||
}, | ||
{ | ||
{2,3}, | ||
{8,9}, | ||
}, | ||
{ | ||
{ 4, 5}, | ||
{10,11}, | ||
} | ||
}; | ||
} | ||
|
||
NMTOOLS_TESTING_DECLARE_ARGS(case3) | ||
{ | ||
inline int a[2][3][2] = { | ||
{ | ||
{0,1}, | ||
{2,3}, | ||
{4,5}, | ||
}, | ||
{ | ||
{ 6, 7}, | ||
{ 8, 9}, | ||
{10,11}, | ||
} | ||
}; | ||
inline int axis1 = 2; | ||
inline int axis2 = 1; | ||
|
||
inline int axis1_ct = 2_ct; | ||
inline int axis2_ct = 1_ct; | ||
NMTOOLS_CAST_ARRAYS(a) | ||
} | ||
NMTOOLS_TESTING_DECLARE_EXPECT(case3) | ||
{ | ||
inline int result[2][2][3] = { | ||
{ | ||
{0,2,4}, | ||
{1,3,5}, | ||
}, | ||
{ | ||
{6,8,10}, | ||
{7,9,11}, | ||
} | ||
}; | ||
} | ||
|
||
NMTOOLS_TESTING_DECLARE_ARGS(case4) | ||
{ | ||
inline int a[2][3][1][2] = { | ||
{ | ||
{ | ||
{0,1}, | ||
}, | ||
{ | ||
{2,3}, | ||
}, | ||
{ | ||
{4,5}, | ||
} | ||
}, | ||
{ | ||
{ | ||
{ 6, 7}, | ||
}, | ||
{ | ||
{ 8, 9}, | ||
}, | ||
{ | ||
{10,11}, | ||
} | ||
} | ||
}; | ||
inline int axis1 = -2; | ||
inline int axis2 = -1; | ||
|
||
inline int axis1_ct = "-2"_ct; | ||
inline int axis2_ct = "-1"_ct; | ||
NMTOOLS_CAST_ARRAYS(a) | ||
} | ||
NMTOOLS_TESTING_DECLARE_EXPECT(case4) | ||
{ | ||
inline int result[2][3][2][1] = { | ||
{ | ||
{ | ||
{0}, | ||
{1}, | ||
}, | ||
{ | ||
{2}, | ||
{3}, | ||
}, | ||
{ | ||
{4}, | ||
{5}, | ||
}, | ||
}, | ||
{ | ||
{ | ||
{6}, | ||
{7}, | ||
}, | ||
{ | ||
{8}, | ||
{9}, | ||
}, | ||
{ | ||
{10}, | ||
{11}, | ||
}, | ||
} | ||
}; | ||
} | ||
} | ||
|
||
#endif // NMTOOLS_TESTING_DATA_ARRAY_SWAPAXES_HPP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
#ifndef NMTOOLS_TESTING_DATA_INDEX_SWAPAXES_HPP | ||
#define NMTOOLS_TESTING_DATA_INDEX_SWAPAXES_HPP | ||
|
||
#include "nmtools/testing/testing.hpp" | ||
#include "nmtools/testing/array_cast.hpp" | ||
|
||
NMTOOLS_TESTING_DECLARE_CASE(index,swapaxes_to_transpose) | ||
{ | ||
using namespace literals; | ||
|
||
NMTOOLS_TESTING_DECLARE_ARGS(case1) | ||
{ | ||
inline int src_dim = 2; | ||
inline int axis1 = 0; | ||
inline int axis2 = 1; | ||
|
||
inline auto src_dim_ct = 2_ct; | ||
inline auto axis1_ct = 0_ct; | ||
inline auto axis2_ct = 1_ct; | ||
} | ||
NMTOOLS_TESTING_DECLARE_EXPECT(case1) | ||
{ | ||
inline int result[2] = {1,0}; | ||
} | ||
|
||
NMTOOLS_TESTING_DECLARE_ARGS(case2) | ||
{ | ||
inline int src_dim = 3; | ||
inline int axis1 = 0; | ||
inline int axis2 = 1; | ||
|
||
inline auto src_dim_ct = 3_ct; | ||
inline auto axis1_ct = 0_ct; | ||
inline auto axis2_ct = 1_ct; | ||
} | ||
NMTOOLS_TESTING_DECLARE_EXPECT(case2) | ||
{ | ||
inline int result[3] = {1,0,2}; | ||
} | ||
|
||
NMTOOLS_TESTING_DECLARE_ARGS(case3) | ||
{ | ||
inline int src_dim = 3; | ||
inline int axis1 = 2; | ||
inline int axis2 = 1; | ||
|
||
inline auto src_dim_ct = 3_ct; | ||
inline auto axis1_ct = 2_ct; | ||
inline auto axis2_ct = 1_ct; | ||
} | ||
NMTOOLS_TESTING_DECLARE_EXPECT(case3) | ||
{ | ||
inline int result[3] = {0,2,1}; | ||
} | ||
} | ||
|
||
#endif // NMTOOLS_TESTING_DATA_INDEX_SWAPAXES_HPP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.