Skip to content

Commit

Permalink
Add stack, hstack, vstack (#279)
Browse files Browse the repository at this point in the history
* update concatenate view to handle maybe type

* add stack

* add hstack

* add vstack

* add stack, vstack, hstack tests

* fix ci

* fix gcc werror
  • Loading branch information
alifahrri authored May 25, 2024
1 parent 5bba906 commit d0def64
Show file tree
Hide file tree
Showing 39 changed files with 3,456 additions and 25 deletions.
23 changes: 23 additions & 0 deletions include/nmtools/array/array/hstack.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef NMTOOLS_ARRAY_ARRAY_HSTACK_HPP
#define NMTOOLS_ARRAY_ARRAY_HSTACK_HPP

#include "nmtools/array/view/hstack.hpp"
#include "nmtools/array/eval.hpp"

namespace nmtools::array
{
template <typename output_t=none_t, typename context_t=none_t, typename resolver_t=eval_result_t<>,
typename lhs_array_t, typename rhs_array_t>
constexpr auto hstack(const lhs_array_t& lhs, const rhs_array_t& rhs,
context_t&& context=context_t{}, output_t&& output=output_t{},meta::as_value<resolver_t> resolver=meta::as_value_v<resolver_t>)
{
auto result = view::hstack(lhs,rhs);
return eval(result
, nmtools::forward<context_t>(context)
, nmtools::forward<output_t>(output)
, resolver
);
} // hstack
} // namespace nmtools::array

#endif // NMTOOLS_ARRAY_ARRAY_HSTACK_HPP
23 changes: 23 additions & 0 deletions include/nmtools/array/array/stack.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef NMTOOLS_ARRAY_ARRAY_STACK_HPP
#define NMTOOLS_ARRAY_ARRAY_STACK_HPP

#include "nmtools/array/view/stack.hpp"
#include "nmtools/array/eval.hpp"

namespace nmtools::array
{
template <typename output_t=none_t, typename context_t=none_t, typename resolver_t=eval_result_t<>,
typename lhs_array_t, typename rhs_array_t, typename axis_t=meta::ct<0>>
constexpr auto stack(const lhs_array_t& lhs, const rhs_array_t& rhs, axis_t axis=axis_t{}
, context_t&& context=context_t{}, output_t&& output=output_t{},meta::as_value<resolver_t> resolver=meta::as_value_v<resolver_t>)
{
auto result = view::stack(lhs,rhs,axis);
return eval(result
, nmtools::forward<context_t>(context)
, nmtools::forward<output_t>(output)
, resolver
);
} // stack
} // namespace nmtools::array

#endif // NMTOOLS_ARRAY_ARRAY_STACK_HPP
23 changes: 23 additions & 0 deletions include/nmtools/array/array/vstack.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef NMTOOLS_ARRAY_ARRAY_VSTACK_HPP
#define NMTOOLS_ARRAY_ARRAY_VSTACK_HPP

#include "nmtools/array/view/vstack.hpp"
#include "nmtools/array/eval.hpp"

namespace nmtools::array
{
template <typename output_t=none_t, typename context_t=none_t, typename resolver_t=eval_result_t<>,
typename lhs_array_t, typename rhs_array_t>
constexpr auto vstack(const lhs_array_t& lhs, const rhs_array_t& rhs,
context_t&& context=context_t{}, output_t&& output=output_t{},meta::as_value<resolver_t> resolver=meta::as_value_v<resolver_t>)
{
auto result = view::vstack(lhs,rhs);
return eval(result
, nmtools::forward<context_t>(context)
, nmtools::forward<output_t>(output)
, resolver
);
} // vstack
} // namespace nmtools::array

#endif // NMTOOLS_ARRAY_ARRAY_VSTACK_HPP
25 changes: 25 additions & 0 deletions include/nmtools/array/functional/hstack.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef NMTOOLS_ARRAY_FUNCTIONAL_HSTACK_HPP
#define NMTOOLS_ARRAY_FUNCTIONAL_HSTACK_HPP

#include "nmtools/array/functional/functor.hpp"
#include "nmtools/array/functional/indexing.hpp"
#include "nmtools/array/view/hstack.hpp"

namespace nmtools::functional
{
namespace fun
{
struct hstack_t
{
template <typename...args_t>
constexpr auto operator()(const args_t&...args) const
{
return view::hstack(args...);
}
};
}

constexpr inline auto hstack = functor_t{binary_fmap_t<fun::hstack_t>{}};
} // namespace nmtools::functional

#endif // NMTOOLS_ARRAY_FUNCTIONAL_HSTACK_HPP
25 changes: 25 additions & 0 deletions include/nmtools/array/functional/stack.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef NMTOOLS_ARRAY_FUNCTIONAL_STACK_HPP
#define NMTOOLS_ARRAY_FUNCTIONAL_STACK_HPP

#include "nmtools/array/functional/functor.hpp"
#include "nmtools/array/functional/indexing.hpp"
#include "nmtools/array/view/stack.hpp"

namespace nmtools::functional
{
namespace fun
{
struct stack_t
{
template <typename...args_t>
constexpr auto operator()(const args_t&...args) const
{
return view::stack(args...);
}
};
}

constexpr inline auto stack = functor_t{binary_fmap_t<fun::stack_t>{}};
} // namespace nmtools::functional

#endif // NMTOOLS_ARRAY_FUNCTIONAL_STACK_HPP
25 changes: 25 additions & 0 deletions include/nmtools/array/functional/vstack.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef NMTOOLS_ARRAY_FUNCTIONAL_VSTACK_HPP
#define NMTOOLS_ARRAY_FUNCTIONAL_VSTACK_HPP

#include "nmtools/array/functional/functor.hpp"
#include "nmtools/array/functional/indexing.hpp"
#include "nmtools/array/view/vstack.hpp"

namespace nmtools::functional
{
namespace fun
{
struct vstack_t
{
template <typename...args_t>
constexpr auto operator()(const args_t&...args) const
{
return view::vstack(args...);
}
};
}

constexpr inline auto vstack = functor_t{binary_fmap_t<fun::vstack_t>{}};
} // namespace nmtools::functional

#endif // NMTOOLS_ARRAY_FUNCTIONAL_VSTACK_HPP
3 changes: 2 additions & 1 deletion include/nmtools/array/index/broadcast_shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ namespace nmtools::meta
namespace error
{
// error type for in-compatible shapes
template <typename...>
struct BROADCAST_SHAPE_ERROR : detail::fail_t {};
// error type for unsupported ashape_t bshape_t
template <typename...>
Expand Down Expand Up @@ -346,7 +347,7 @@ namespace nmtools::meta
using type = nmtools_array<size_t,dim>;
return as_value_v<type>;
} else {
return as_value_v<error::BROADCAST_SHAPE_ERROR>;
return as_value_v<error::BROADCAST_SHAPE_ERROR<ashape_t,bshape_t>>;
}
} else if constexpr (is_none_v<ashape_t> && is_constant_index_array_v<bshape_t>) {
// broadcasting with none retain shape, just select the shape
Expand Down
9 changes: 6 additions & 3 deletions include/nmtools/array/index/moveaxis.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ namespace nmtools::index
} else {
return result_t{meta::Nothing};
}
} else if constexpr ((! meta::is_constant_index_array_v<result_t>) && (! meta::is_fail_v<result_t>)) {
using return_t = utl::maybe<result_t>;
} else if constexpr (!meta::is_fail_v<result_t>
&& !meta::is_constant_index_array_v<result_t>
) {
using return_t = nmtools_maybe<result_t>;

auto dim = [&](){
if constexpr (meta::is_constant_index_array_v<shape_t>) {
Expand Down Expand Up @@ -169,6 +171,7 @@ namespace nmtools::meta
{
namespace error
{
template <typename...>
struct MOVEAXIS_TO_TRANSPOSE_UNSUPPORTED : detail::fail_t {};

template <typename...>
Expand Down Expand Up @@ -243,7 +246,7 @@ namespace nmtools::meta
return as_value_v<type>;
}
} else {
return as_value_v<error::MOVEAXIS_TO_TRANSPOSE_UNSUPPORTED>;
return as_value_v<error::MOVEAXIS_TO_TRANSPOSE_UNSUPPORTED<shape_t,source_t,destination_t>>;
}
}();
using type = type_t<decltype(vtype)>;
Expand Down
34 changes: 26 additions & 8 deletions include/nmtools/array/view/concatenate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,32 @@ namespace nmtools::view
template <typename lhs_array_t, typename rhs_array_t, typename axis_t>
constexpr auto concatenate(const lhs_array_t& lhs, const rhs_array_t& rhs, axis_t axis)
{
auto ashape = shape(lhs);
auto bshape = shape(rhs);
[[maybe_unused]] const auto [success, shape] = index::shape_concatenate(ashape,bshape,axis);
// TODO: use nmtools_assert macro
nmtools_cassert (success
, "unsupported concatenate, mismatched shape"
);
return decorator_t<concatenate_t,lhs_array_t,rhs_array_t,axis_t>{{lhs,rhs,axis}};
if constexpr (meta::is_maybe_v<lhs_array_t>) {
using lhs_type = meta::get_maybe_type_t<lhs_array_t>;
using result_t = decltype(concatenate(meta::declval<lhs_type>(),rhs,axis));
using return_t = meta::conditional_t<meta::is_maybe_v<result_t>,result_t,nmtools_maybe<result_t>>;
return (lhs
? return_t{concatenate(*lhs,rhs,axis)}
: return_t{meta::Nothing}
);
} else if constexpr (meta::is_maybe_v<rhs_array_t>) {
using rhs_type = meta::get_maybe_type_t<rhs_array_t>;
using result_t = decltype(concatenate(lhs,meta::declval<rhs_type>(),axis));
using return_t = meta::conditional_t<meta::is_maybe_v<result_t>,result_t,nmtools_maybe<result_t>>;
return (rhs
? return_t{(concatenate(lhs,*rhs,axis))}
: return_t{meta::Nothing}
);
} else {
auto ashape = shape(lhs);
auto bshape = shape(rhs);
[[maybe_unused]] const auto [success, shape] = index::shape_concatenate(ashape,bshape,axis);
// TODO: use nmtools_assert macro
nmtools_cassert (success
, "unsupported concatenate, mismatched shape"
);
return decorator_t<concatenate_t,lhs_array_t,rhs_array_t,axis_t>{{lhs,rhs,axis}};
}
} // concatenate
} // namespace nmtools::view

Expand Down
79 changes: 79 additions & 0 deletions include/nmtools/array/view/hstack.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#ifndef NMTOOLS_ARRAY_INDEX_HSTACK_HPP
#define NMTOOLS_ARRAY_INDEX_HSTACK_HPP

#include "nmtools/meta.hpp"
#include "nmtools/array/shape.hpp"

namespace nmtools::index
{
struct hstack_axis_t {};

template <typename lhs_shape_t, typename rhs_shape_t>
constexpr auto hstack_axis([[maybe_unused]] const lhs_shape_t& lhs_shape, const rhs_shape_t&)
{
using result_t = meta::resolve_optype_t<hstack_axis_t,lhs_shape_t,rhs_shape_t>;

auto result = result_t {};

if constexpr (!meta::is_constant_index_v<result_t>) {
auto lhs_dim = len(lhs_shape);
result = ((lhs_dim == 1) ? 0 : 1);
}

return result;
}
} // namespace nmtools::index

namespace nmtools::meta
{
namespace error
{
template <typename...>
struct HSTACK_AXIS_UNSUPPORTED : detail::fail_t {};
}

template <typename lhs_shape_t, typename rhs_shape_t>
struct resolve_optype<void,index::hstack_axis_t,lhs_shape_t,rhs_shape_t>
{
static constexpr auto vtype = [](){
constexpr auto lhs_dim = len_v<lhs_shape_t>;
constexpr auto rhs_dim = len_v<rhs_shape_t>;
if constexpr ((lhs_dim > 0) && (lhs_dim == rhs_dim)) {
using type = ct<((lhs_dim == 1) ? 0 : 1)>;
return as_value_v<type>;
} else if constexpr (is_index_array_v<lhs_shape_t>
&& is_index_array_v<rhs_shape_t>
) {
using type = nm_size_t;
return as_value_v<type>;
} else {
using type = error::HSTACK_AXIS_UNSUPPORTED<lhs_shape_t,rhs_shape_t>;
return as_value_v<type>;
}
}();
using type = type_t<decltype(vtype)>;
};
} // namespace nmtools::meta

#endif // NMTOOLS_ARRAY_INDEX_HSTACK_HPP

#ifndef NMTOOLS_ARRAY_VIEW_HSTACK_HPP
#define NMTOOLS_ARRAY_VIEW_HSTACK_HPP

#include "nmtools/meta.hpp"
#include "nmtools/array/view/concatenate.hpp"

namespace nmtools::view
{
template <typename lhs_t, typename rhs_t>
constexpr auto hstack(const lhs_t& lhs, const rhs_t& rhs)
{
auto axis = index::hstack_axis(
shape<true>(lhs)
, shape<true>(rhs)
);
return concatenate(lhs,rhs,axis);
}
} // namespace nmtools::view

#endif // NMTOOLS_ARRAY_VIEW_HSTACK_HPP
4 changes: 3 additions & 1 deletion include/nmtools/array/view/indexing.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ namespace nmtools::view
template <typename array_t, typename indexer_t>
constexpr auto indexing(const array_t& array, const indexer_t& indexer)
{
if constexpr (meta::is_maybe_v<array_t>) {
if constexpr (meta::is_fail_v<indexer_t>) {
return indexer;
} else if constexpr (meta::is_maybe_v<array_t>) {
using array_type = meta::get_maybe_type_t<array_t>;
using result_t = decltype(indexing(meta::declval<array_type>(),indexer));
using return_t = meta::conditional_t<meta::is_maybe_v<result_t>,result_t,nmtools_maybe<result_t>>;
Expand Down
13 changes: 7 additions & 6 deletions include/nmtools/array/view/reshape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,15 @@ namespace nmtools::view
constexpr auto reshaper(const src_shape_t& src_shape, const dst_shape_t& dst_shape, const src_size_t& src_size)
{
auto m_dst_shape = index::shape_reshape(src_shape,dst_shape);
if constexpr (meta::is_maybe_v<decltype(m_dst_shape)>) {
if constexpr (meta::is_fail_v<decltype(m_dst_shape)>) {
return m_dst_shape;
} else if constexpr (meta::is_maybe_v<decltype(m_dst_shape)>) {
using result_t = decltype(reshape_t{unwrap(src_shape),dst_shape,unwrap(src_size)});
using return_t = nmtools_maybe<result_t>;
if (static_cast<bool>(m_dst_shape)) {
return return_t{result_t{unwrap(src_shape),dst_shape,unwrap(src_size)}};
} else {
return return_t{meta::Nothing};
}
return (m_dst_shape
? return_t{result_t{unwrap(src_shape),dst_shape,unwrap(src_size)}}
: return_t{meta::Nothing}
);
} else {
return reshape_t{unwrap(src_shape),dst_shape,unwrap(src_size)};
}
Expand Down
20 changes: 20 additions & 0 deletions include/nmtools/array/view/stack.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#ifndef NMTOOLS_ARRAY_VIEW_STACK_HPP
#define NMTOOLS_ARRAY_VIEW_STACK_HPP

#include "nmtools/array/view/concatenate.hpp"
#include "nmtools/array/view/expand_dims.hpp"

namespace nmtools::view
{
template <typename lhs_t, typename rhs_t, typename axis_t=meta::ct<0>>
constexpr auto stack(const lhs_t& lhs, const rhs_t& rhs, axis_t axis=axis_t{})
{
return concatenate(
expand_dims(lhs,axis)
, expand_dims(rhs,axis)
, axis
);
}
} // namespace nmtools::view

#endif // NMTOOLS_ARRAY_VIEW_STACK_HPP
Loading

0 comments on commit d0def64

Please sign in to comment.