Skip to content

Commit

Permalink
fixes for hip compilations
Browse files Browse the repository at this point in the history
  • Loading branch information
alifahrri committed Jun 6, 2024
1 parent 6908213 commit 542b05d
Show file tree
Hide file tree
Showing 31 changed files with 1,186 additions and 1,112 deletions.
5 changes: 3 additions & 2 deletions include/nmtools/array/as_static.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ namespace nmtools::array

auto operator()() const noexcept
{
if constexpr (meta::is_dynamic_index_array_v<attribute_type>) {
// TODO: return maybe when size > max_dim
if constexpr (meta::is_resizable_v<attribute_type>) {
using element_type = meta::get_element_type_t<attribute_t>;
using result_type = utl::static_vector<element_type,max_dim>;
auto result = result_type{};
result.resize(attribute.size());
result.resize(nmtools::size(attribute));
for (size_t i=0; i<len(result); i++) {
at(result,i) = at(attribute,i);
}
Expand Down
51 changes: 49 additions & 2 deletions include/nmtools/array/eval/hip/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,39 @@ namespace nmtools::array::hip
this->copy_buffer(output_buffer,output);
}

template <typename function_t, typename output_array_t, typename arg0_t, typename...args_t>
auto run(const function_t& f, output_array_t& output, const arg0_t& arg0, const args_t&...args)
template <typename F, typename operands_t, typename attributes_t>
auto map_to_device(const functional::functor_t<F,operands_t,attributes_t>& f)
{
static_assert( meta::len_v<operands_t> == 0 );
if constexpr (meta::is_same_v<attributes_t,meta::empty_attributes_t>) {
return f;
} else {
constexpr auto N = meta::len_v<attributes_t>;
auto attributes = meta::template_reduce<N>([&](auto init, auto I){
auto attribute = array::as_static(at(f.attributes,I));
return utility::tuple_append(init,attribute);
}, nmtools_tuple{});
return functional::functor_t<F,operands_t,decltype(attributes)>{
{f.fmap, f.operands, attributes}
};
}
} // map_to_device

template <template<typename...>typename tuple, typename...functors_t, typename operands_t>
auto map_to_device(const functional::functor_composition_t<tuple<functors_t...>,operands_t>& f)
{
static_assert( meta::len_v<operands_t> == 0 );
auto functors = meta::template_reduce<sizeof...(functors_t)>([&](auto init, auto I){
auto functor = map_to_device(at(f.functors,I));
return utility::tuple_append(init,functor);
}, nmtools_tuple{});
return functional::functor_composition_t<decltype(functors)>{functors};
} // map_to_device

template <typename function_t, typename output_array_t, template<typename...>typename tuple, typename...operands_t>
auto run(const function_t& f, output_array_t& output, const tuple<operands_t...>& operands)
{
#if 0
auto args_pack = [&](){
if constexpr (meta::is_tuple_v<arg0_t>) {
static_assert( sizeof...(args_t) == 0, "nmtools error" );
Expand All @@ -312,6 +342,23 @@ namespace nmtools::array::hip

using sequence_t = meta::make_index_sequence<meta::len_v<decltype(gpu_args_pack)>>;
this->run_(output,f,gpu_args_pack,sequence_t{});
#else
constexpr auto N = sizeof...(operands_t);
auto device_operands = meta::template_reduce<N>([&](auto init, auto index){
const auto& arg_i = nmtools::at(operands,index);
if constexpr (meta::is_num_v<decltype(arg_i)>) {
return utility::tuple_append(init,arg_i);
} else {
auto device_array = create_array(*arg_i);
return utility::tuple_append(init,device_array);
}
}, nmtools_tuple<>{});

// e.g. to convert dynamic allocation to static vector to run on device kernels
auto fn = map_to_device(f);
using sequence_t = meta::make_index_sequence<meta::len_v<decltype(device_operands)>>;
this->run_(output,fn,device_operands,sequence_t{});
#endif
}
};

Expand Down
19 changes: 12 additions & 7 deletions include/nmtools/array/index/flip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,22 +56,27 @@ namespace nmtools::meta
>
{
static constexpr auto vtype = [](){
using slice_t [[maybe_unused]] = nmtools_tuple<none_t,none_t,int>;
if constexpr (
!is_index_v<dim_t>
|| !(is_index_array_v<axes_t> || is_index_v<axes_t> || is_none_v<axes_t>)
) {
using type = error::FLIP_SLICES_UNSUPPORTED<dim_t,axes_t>;
return as_value_v<type>;
} else if constexpr (
is_constant_index_v<dim_t>
&& (is_index_array_v<axes_t> || is_index_v<axes_t> || is_none_v<axes_t>)
) {
using slice_t = nmtools_tuple<none_t,none_t,int>;
using type = nmtools_array<slice_t,dim_t::value>;
return as_value_v<type>;
} else if constexpr (
is_index_v<dim_t>
&& (is_index_array_v<axes_t> || is_index_v<axes_t> || is_none_v<axes_t>)
is_clipped_integer_v<dim_t>
) {
using slice_t = nmtools_tuple<none_t,none_t,int>;
using type = nmtools_static_vector<slice_t,dim_t::max_value>;
return as_value_v<type>;
} else {
// TODO: use small_vector
using type = nmtools_list<slice_t>;
return as_value_v<type>;
} else {
return as_value_v<error::FLIP_SLICES_UNSUPPORTED<dim_t,axes_t>>;
}
}();
using type = type_t<decltype(vtype)>;
Expand Down
144 changes: 144 additions & 0 deletions include/nmtools/array/index/reduce.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
#ifndef NMTOOLS_ARRAY_INDEX_REDUCE_HPP
#define NMTOOLS_ARRAY_INDEX_REDUCE_HPP

#include "nmtools/meta.hpp"
#include "nmtools/array/shape.hpp"

namespace nmtools::index
{
struct reduction_slices_t {};

template <typename indices_t, typename shape_type, typename axis_type, typename keepdims_type>
constexpr auto reduction_slices(const indices_t& indices_, const shape_type& src_shape, const axis_type& axis, keepdims_type keepdims)
{
using result_t = meta::resolve_optype_t<reduction_slices_t,indices_t,shape_type,axis_type,keepdims_type>;

auto slices = result_t {};
[[maybe_unused]] auto dim = len(src_shape);
if constexpr (meta::is_resizable_v<result_t>) {
slices.resize(dim);
}

// helper lambda to check if axis i is in the specified axis for reduction
auto in_axis = [&](auto i){
if constexpr (meta::is_index_v<axis_type> && meta::is_pointer_v<axis_type>) {
return i==*axis;
} else if constexpr (meta::is_index_v<axis_type>) {
using common_t = meta::promote_index_t<axis_type,decltype(i)>;
return (common_t)i==(common_t)axis;
} else {
auto f_predicate = [i](auto axis){
using common_t = meta::promote_index_t<decltype(i),decltype(axis)>;
return (common_t)i==(common_t)axis;
};
// axis is index array (reducing on multiple axes),
// axis may be pointer, but can't provide convenience function
// since may decay bounded array to pointer
if constexpr (meta::is_pointer_v<axis_type>) {
auto found = index::where(f_predicate, *axis);
return static_cast<bool>(len(found));
} else {
auto found = index::where(f_predicate, axis);
return static_cast<bool>(len(found));
}
}
};

// use the same type as axis_type for loop index
constexpr auto idx_vtype = [](){
if constexpr (meta::is_constant_index_array_v<axis_type>) {
// shortcut for now, just use int
return meta::as_value_v<int>;
} else if constexpr (meta::is_index_array_v<axis_type>) {
using type = meta::get_element_type_t<axis_type>;
return meta::as_value_v<type>;
} else if constexpr (meta::is_integer_v<axis_type>) {
return meta::as_value_v<axis_type>;
} else {
return meta::as_value_v<size_t>;
}
}();
using index_t = meta::get_index_element_type_t<shape_type>;
using idx_t [[maybe_unused]] = meta::type_t<meta::promote_index<index_t,meta::type_t<decltype(idx_vtype)>>>;

// indices and the referenced array may have different dim,
// this variable track index for indices_
auto ii = idx_t{0};
constexpr auto DIM = meta::len_v<shape_type>;
if constexpr (DIM > 0) {
// here, len(slices) already matched the dimension of source array
meta::template_for<DIM>([&](auto index){
constexpr auto i = decltype(index)::value;
// take all elements at given axis
if (in_axis(i)) {
// note that src_shape maybe constant index array
at(slices,i) = {
static_cast<nm_size_t>(0)
, static_cast<nm_size_t>(at(src_shape,meta::ct_v<i>))};
// if keepdims is true, also increment indices index
if (keepdims)
ii++;
}
// use indices otherwise, just slice with index:index+1
else {
auto s = at(indices_,ii++);
at(slices,i) = {
static_cast<nm_size_t>(s)
, static_cast<nm_size_t>(s+1)};
}
});
} else {
for (size_t i=0; i<dim; i++) {
// take all elements at given axis
if (in_axis(i)) {
// note that src_shape maybe constant index array
at(slices,i) = {
static_cast<nm_size_t>(0)
, static_cast<nm_size_t>(at(src_shape,i))};
// if keepdims is true, also increment indices index
if (keepdims)
ii++;
}
// use indices otherwise, just slice with index:index+1
else {
auto s = at(indices_,ii++);
at(slices,i) = {
static_cast<nm_size_t>(s)
, static_cast<nm_size_t>(s+1)};
}
}
}
return slices;
} // reduction_slices
} // namespace nmtools::index

namespace nmtools::meta
{
template <typename indices_t, typename shape_type, typename axis_type, typename keepdims_type>
struct resolve_optype<
void, index::reduction_slices_t, indices_t, shape_type, axis_type, keepdims_type
> {
static constexpr auto vtype = [](){
constexpr auto DIM = len_v<shape_type>;
[[maybe_unused]]
constexpr auto B_DIM = bounded_size_v<shape_type>;
using slice_type = nmtools_array<nm_size_t,2>;
// TODO: handle unsupported types
// TODO: compile-time inference
if constexpr (DIM > 0) {
using type = nmtools_array<slice_type,DIM>;
return as_value_v<type>;
} else if constexpr (!is_fail_v<decltype(B_DIM)>) {
using type = nmtools_static_vector<slice_type,B_DIM>;
return as_value_v<type>;
} else {
// TODO: support small_vector
using type = nmtools_list<slice_type>;
return as_value_v<type>;
}
}();
using type = type_t<decltype(vtype)>;
}; // reduction_slices_t
} // namespace nmtools::meta

#endif // NMTOOLS_ARRAY_INDEX_REDUCE_HPP
3 changes: 3 additions & 0 deletions include/nmtools/array/view/decorator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,15 +344,18 @@ namespace nmtools::view
* - has index member function that transform indices
*/

// TODO: remove
#ifdef NMTOOLS_NO_BASE_ACCESS
view_type view;
decorator_t(const view_type& view)
: view(view)
{}
#endif

#if 0
nmtools_func_attribute
~decorator_t() = default;
#endif

/**
* @brief return the shape of this array
Expand Down
5 changes: 5 additions & 0 deletions include/nmtools/array/view/indexing.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ namespace nmtools::view
, indexer(indexer)
{}

#if 0
nmtools_func_attribute
~indexing_t() = default;
#endif

constexpr auto operands() const noexcept
{
return nmtools_tuple<array_type>{array};
Expand Down
9 changes: 4 additions & 5 deletions include/nmtools/array/view/slice.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,10 @@ namespace nmtools::view
if constexpr (meta::is_maybe_v<decltype(m_dst_shape)>) {
using result_t = decltype(slice_t{unwrap(src_shape),unwrap(slices),unwrap(src_size)});
using return_t = nmtools_maybe<result_t>;
if (static_cast<bool>(m_dst_shape)) {
return return_t{result_t{unwrap(src_shape),unwrap(slices),unwrap(src_size)}};
} else {
return return_t{meta::Nothing};
}
return (has_value(m_dst_shape)
? return_t{result_t{unwrap(src_shape),unwrap(slices),unwrap(src_size)}}
: return_t{meta::Nothing}
);
} else {
return slice_t{unwrap(src_shape),unwrap(slices),unwrap(src_size)};
}
Expand Down
Loading

0 comments on commit 542b05d

Please sign in to comment.