Skip to content

Commit

Permalink
Add swapaxes (#296)
Browse files Browse the repository at this point in the history
* add swapaxes

* add tests

* fix werror
  • Loading branch information
alifahrri authored Sep 22, 2024
1 parent e0c45fd commit 90593e4
Show file tree
Hide file tree
Showing 10 changed files with 580 additions and 0 deletions.
24 changes: 24 additions & 0 deletions include/nmtools/array/array/swapaxes.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef NMTOOLS_ARRAY_ARRAY_SWAPAXES_HPP
#define NMTOOLS_ARRAY_ARRAY_SWAPAXES_HPP

#include "nmtools/array/view/swapaxes.hpp"
#include "nmtools/array/eval.hpp"

namespace nmtools::array
{
template <typename output_t=none_t, typename context_t=none_t, typename resolver_t=eval_result_t<>
, typename array_t, typename axis1_t, typename axis2_t>
constexpr auto swapaxes(const array_t& array, axis1_t axis1, axis2_t axis2
, context_t&& context=context_t{}, output_t&& output=output_t{},meta::as_value<resolver_t> resolver=meta::as_value_v<resolver_t>)
{
auto a = view::swapaxes(array,axis1,axis2);
return eval(
a
, nmtools::forward<context_t>(context)
, nmtools::forward<output_t>(output)
, resolver
);
}
} // nmtools::array

#endif // NMTOOLS_ARRAY_ARRAY_SWAPAXES_HPP
108 changes: 108 additions & 0 deletions include/nmtools/array/view/swapaxes.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#ifndef NMTOOLS_ARRAY_VIEW_SWAPAXES_HPP
#define NMTOOLS_ARRAY_VIEW_SWAPAXES_HPP

#include "nmtools/meta.hpp"
#include "nmtools/array/shape.hpp"
#include "nmtools/array/index/normalize_axis.hpp"
#include "nmtools/utility/unwrap.hpp"

namespace nmtools::index
{
struct swapaxes_to_transpose_t {};

template <typename src_dim_t, typename axis1_t, typename axis2_t>
constexpr auto swapaxes_to_transpose([[maybe_unused]] const src_dim_t& src_dim, [[maybe_unused]] axis1_t axis1, [[maybe_unused]] axis2_t axis2)
{
using result_t = meta::resolve_optype_t<swapaxes_to_transpose_t,src_dim_t,axis1_t,axis2_t>;

auto result = result_t {};
if constexpr (!meta::is_constant_index_array_v<result_t>
&& !meta::is_fail_v<result_t>
) {
if constexpr (meta::is_resizable_v<result_t>) {
result.resize(src_dim);
}

for (nm_size_t i=0; i<(nm_size_t)src_dim; i++) {
at(result,i) = i;
}

// TODO: propagate error handling
auto m_axis1 = unwrap(normalize_axis(axis1,src_dim));
auto m_axis2 = unwrap(normalize_axis(axis2,src_dim));

auto tmp = at(result,m_axis1);
at(result,m_axis1) = at(result,m_axis2);
at(result,m_axis2) = tmp;
}

return result;
} // swapaxes_to_transpose

} // nmtools::index

namespace nmtools::meta
{
namespace error
{
template <typename...>
struct SWAPAXES_TO_TRANSPOSE_UNSUPPORTED : detail::fail_t {};
}

template <typename src_dim_t, typename axis1_t, typename axis2_t>
struct resolve_optype<
void, index::swapaxes_to_transpose_t, src_dim_t, axis1_t, axis2_t
> {
static constexpr auto vtype = [](){
if constexpr (!is_index_v<src_dim_t>
|| !is_index_v<axis1_t>
|| !is_index_v<axis2_t>
) {
using type = error::SWAPAXES_TO_TRANSPOSE_UNSUPPORTED<src_dim_t,axis1_t,axis2_t>;
return as_value_v<type>;
} else if constexpr (
is_constant_index_v<src_dim_t>
&& is_constant_index_v<axis1_t>
&& is_constant_index_v<axis2_t>
) {
constexpr auto axis1 = to_value_v<axis1_t>;
constexpr auto axis2 = to_value_v<axis2_t>;
constexpr auto result = index::swapaxes_to_transpose(src_dim_t{},axis1,axis2);
using nmtools::len, nmtools::at;
return template_reduce<len(result)>([&](auto init, auto I){
using init_t = type_t<decltype(init)>;
using type = append_type_t<init_t,ct<at(result,I)>>;
return as_value_v<type>;
}, as_value_v<nmtools_tuple<>>);
} else if constexpr (is_constant_index_v<src_dim_t>) {
constexpr auto N = to_value_v<src_dim_t>;
using type = nmtools_array<nm_size_t,N>;
return as_value_v<type>;
} else if constexpr (is_clipped_integer_v<src_dim_t>) {
constexpr auto N = to_value_v<src_dim_t>;
using type = nmtools_static_vector<nm_size_t,N>;
return as_value_v<type>;
} else {
// TODO: support small vector
using type = nmtools_list<nm_size_t>;
return as_value_v<type>;
}
}();
using type = type_t<decltype(vtype)>;
};
} // nmtools::meta

#include "nmtools/array/view/transpose.hpp"

namespace nmtools::view
{
template <typename array_t, typename axis1_t, typename axis2_t>
constexpr auto swapaxes(const array_t& array, axis1_t axis1, axis2_t axis2)
{
auto src_dim = dim<true>(array);
auto axes = index::swapaxes_to_transpose(src_dim,axis1,axis2);
return view::transpose(array,axes);
}
} // nmtools::view

#endif // NMTOOLS_ARRAY_VIEW_SWAPAXES_HPP
174 changes: 174 additions & 0 deletions include/nmtools/testing/data/array/swapaxes.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
#ifndef NMTOOLS_TESTING_DATA_ARRAY_SWAPAXES_HPP
#define NMTOOLS_TESTING_DATA_ARRAY_SWAPAXES_HPP

#include "nmtools/testing/testing.hpp"
#include "nmtools/testing/array_cast.hpp"

NMTOOLS_TESTING_DECLARE_CASE(array,swapaxes)
{
using namespace literals;

NMTOOLS_TESTING_DECLARE_ARGS(case1)
{
inline int a[1][3] = {
{1,2,3}
};
inline int axis1 = 0;
inline int axis2 = 1;

inline auto axis1_ct = 0_ct;
inline auto axis2_ct = 1_ct;
NMTOOLS_CAST_ARRAYS(a)
}
NMTOOLS_TESTING_DECLARE_EXPECT(case1)
{
inline int result[3][1] = {
{1},
{2},
{3},
};
}

NMTOOLS_TESTING_DECLARE_ARGS(case2)
{
inline int a[2][3][2] = {
{
{0,1},
{2,3},
{4,5},
},
{
{ 6, 7},
{ 8, 9},
{10,11},
}
};
inline int axis1 = 0;
inline int axis2 = 1;

inline auto axis1_ct = 0_ct;
inline auto axis2_ct = 1_ct;
NMTOOLS_CAST_ARRAYS(a)
}
NMTOOLS_TESTING_DECLARE_EXPECT(case2)
{
inline int result[3][2][2] = {
{
{0,1},
{6,7},
},
{
{2,3},
{8,9},
},
{
{ 4, 5},
{10,11},
}
};
}

NMTOOLS_TESTING_DECLARE_ARGS(case3)
{
inline int a[2][3][2] = {
{
{0,1},
{2,3},
{4,5},
},
{
{ 6, 7},
{ 8, 9},
{10,11},
}
};
inline int axis1 = 2;
inline int axis2 = 1;

inline int axis1_ct = 2_ct;
inline int axis2_ct = 1_ct;
NMTOOLS_CAST_ARRAYS(a)
}
NMTOOLS_TESTING_DECLARE_EXPECT(case3)
{
inline int result[2][2][3] = {
{
{0,2,4},
{1,3,5},
},
{
{6,8,10},
{7,9,11},
}
};
}

NMTOOLS_TESTING_DECLARE_ARGS(case4)
{
inline int a[2][3][1][2] = {
{
{
{0,1},
},
{
{2,3},
},
{
{4,5},
}
},
{
{
{ 6, 7},
},
{
{ 8, 9},
},
{
{10,11},
}
}
};
inline int axis1 = -2;
inline int axis2 = -1;

inline int axis1_ct = "-2"_ct;
inline int axis2_ct = "-1"_ct;
NMTOOLS_CAST_ARRAYS(a)
}
NMTOOLS_TESTING_DECLARE_EXPECT(case4)
{
inline int result[2][3][2][1] = {
{
{
{0},
{1},
},
{
{2},
{3},
},
{
{4},
{5},
},
},
{
{
{6},
{7},
},
{
{8},
{9},
},
{
{10},
{11},
},
}
};
}
}

#endif // NMTOOLS_TESTING_DATA_ARRAY_SWAPAXES_HPP
57 changes: 57 additions & 0 deletions include/nmtools/testing/data/index/swapaxes.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#ifndef NMTOOLS_TESTING_DATA_INDEX_SWAPAXES_HPP
#define NMTOOLS_TESTING_DATA_INDEX_SWAPAXES_HPP

#include "nmtools/testing/testing.hpp"
#include "nmtools/testing/array_cast.hpp"

NMTOOLS_TESTING_DECLARE_CASE(index,swapaxes_to_transpose)
{
using namespace literals;

NMTOOLS_TESTING_DECLARE_ARGS(case1)
{
inline int src_dim = 2;
inline int axis1 = 0;
inline int axis2 = 1;

inline auto src_dim_ct = 2_ct;
inline auto axis1_ct = 0_ct;
inline auto axis2_ct = 1_ct;
}
NMTOOLS_TESTING_DECLARE_EXPECT(case1)
{
inline int result[2] = {1,0};
}

NMTOOLS_TESTING_DECLARE_ARGS(case2)
{
inline int src_dim = 3;
inline int axis1 = 0;
inline int axis2 = 1;

inline auto src_dim_ct = 3_ct;
inline auto axis1_ct = 0_ct;
inline auto axis2_ct = 1_ct;
}
NMTOOLS_TESTING_DECLARE_EXPECT(case2)
{
inline int result[3] = {1,0,2};
}

NMTOOLS_TESTING_DECLARE_ARGS(case3)
{
inline int src_dim = 3;
inline int axis1 = 2;
inline int axis2 = 1;

inline auto src_dim_ct = 3_ct;
inline auto axis1_ct = 2_ct;
inline auto axis2_ct = 1_ct;
}
NMTOOLS_TESTING_DECLARE_EXPECT(case3)
{
inline int result[3] = {0,2,1};
}
}

#endif // NMTOOLS_TESTING_DATA_INDEX_SWAPAXES_HPP
1 change: 1 addition & 0 deletions tests/array/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ set(ARRAY_EVAL_TEST_SOURCES
array/stddev.cpp
array/sum.cpp
array/split.cpp
array/swapaxes.cpp
array/take.cpp
array/tile.cpp
array/tri.cpp
Expand Down
Loading

0 comments on commit 90593e4

Please sign in to comment.