Skip to content

Commit

Permalink
Generalize single axis simd reduction to handle ndarray (#256)
Browse files Browse the repository at this point in the history
* update simd reduction evaluator to handle ndarray

* update simd reduction indexing to handle ndarray

* make macro NMTOOLS_CHECK_MESSAGE override-able

* add avx & sse simd reduction tests for ndarray

* add gcc vector extension reduction tests for ndarray

* add simde avx512 reduction tests for ndarray

* skip mulltiply.reduce tests
  • Loading branch information
alifahrri authored Oct 29, 2023
1 parent 7ffe45d commit 32ca11d
Show file tree
Hide file tree
Showing 13 changed files with 19,591 additions and 15 deletions.
43 changes: 34 additions & 9 deletions include/nmtools/array/eval/simd/evaluator/ufunc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "nmtools/array/eval/simd/index.hpp"
#include "nmtools/array/eval/simd/ufunc.hpp"
#include "nmtools/array/eval/simd/bit_width.hpp"
#include "nmtools/array/index/insert_index.hpp"

namespace nmtools::array
{
Expand Down Expand Up @@ -229,10 +230,6 @@ namespace nmtools::array
}
if constexpr (meta::is_index_v<axis_type>) {
auto out_data_ptr = nmtools::data(output);
auto inp_dim = dim(*input_array_ptr);
if (inp_dim != 2) {
return false;
}
using index::ReductionKind, index::SIMD;
const auto n_elem_pack = meta::as_type_v<N>;
auto identity = [&]()->element_type{
Expand All @@ -247,10 +244,34 @@ namespace nmtools::array
out_data_ptr[i] = identity;
}
auto inp_shape = nmtools::shape(*input_array_ptr);
auto reduction_axis = view.axis;
auto reduction_kind = (reduction_axis == -1) || ((int)reduction_axis == (int)(len(inp_shape)-1)) ? ReductionKind::HORIZONTAL : ReductionKind::VERTICAL;
// "normalize" the out shape as if keepdims=True
auto out_shape_ = [&](){
using keepdims_type = decltype(view.keepdims);
if constexpr (is_none_v<keepdims_type>) {
return out_shape;
} else {
if constexpr (!keepdims_type::value) {
return index::insert_index(out_shape,1,reduction_axis);
} else {
return out_shape;
}
}
}();
auto out_shape = [out_shape_](){
// TODO: create "unwrap" function
if constexpr (meta::is_maybe_v<decltype(out_shape_)>) {
return *out_shape_; // assume success, TODO: error handling
} else {
return out_shape_;
}
}();
// vertical reduction
if (view.axis == 0) {
switch (reduction_kind) {
case ReductionKind::VERTICAL: {
const auto reduction_kind = meta::as_type_v<ReductionKind::VERTICAL>;
const auto enumerator = index::reduction_2d_enumerator(reduction_kind,n_elem_pack,out_shape,inp_shape);
const auto enumerator = index::reduction_2d_enumerator(reduction_kind,n_elem_pack,out_shape,inp_shape,reduction_axis);
for (size_t i=0; i<enumerator.size(); i++) {
auto [out_pack, inp_pack] = enumerator[i];
auto [out_tag,out_offset] = out_pack;
Expand All @@ -269,9 +290,10 @@ namespace nmtools::array
break;
}
}
} else if (view.axis == 1) {
} break;
case ReductionKind::HORIZONTAL: {
const auto reduction_kind = meta::as_type_v<ReductionKind::HORIZONTAL>;
const auto enumerator = index::reduction_2d_enumerator(reduction_kind,n_elem_pack,out_shape,inp_shape);
const auto enumerator = index::reduction_2d_enumerator(reduction_kind,n_elem_pack,out_shape,inp_shape,reduction_axis);
auto accum = op.set1(identity);
for (size_t i=0; i<enumerator.size(); i++) {
const auto [out_pack, inp_pack] = enumerator[i];
Expand Down Expand Up @@ -320,10 +342,13 @@ namespace nmtools::array
break;
}
}
} break;
default: {
return false;
} break;
}
return true;
}

return false;
}

Expand Down
67 changes: 61 additions & 6 deletions include/nmtools/array/eval/simd/index/ufunc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,46 @@ namespace nmtools::index
VERTICAL=1,
};

template <typename index_t=size_t, ReductionKind reduction_kind, auto N_ELEM_PACK, typename inp_shape_t, typename out_shape_t, typename axis_t>
auto reduction_nd_reshape(meta::as_type<reduction_kind>, meta::as_type<N_ELEM_PACK>, const inp_shape_t& inp_shape, const out_shape_t&, [[maybe_unused]] axis_t axis)
{
using result_t = nmtools_array<index_t,2>;

auto result = result_t{};

auto dim = len(inp_shape);

if constexpr (meta::is_resizable_v<result_t>) {
result.resize(2); // strictly 2
}
at(result,0) = 1;
at(result,1) = 1;

if constexpr (reduction_kind == ReductionKind::HORIZONTAL) {
auto horizontal_axis = dim-1; // last
for (size_t i=0; i<(size_t)horizontal_axis; i++) {
at(result,0) *= at(inp_shape,i);
}
auto n_ops = at(inp_shape, horizontal_axis);
at(result,1) = n_ops;
} else if constexpr (reduction_kind == ReductionKind::VERTICAL) {
auto vertical_axis = dim;
int i = 0;
for (; i<=(int)axis; i++) {
at(result,0) *= at(inp_shape,i);
}
for (; i<(int)vertical_axis; i++) {
at(result,1) *= at(inp_shape,i);
}
}
if (dim == 1) {
at(result,0) = 1;
at(result,1) = at(inp_shape,0);
}

return result;
}

template <typename index_t=size_t, ReductionKind reduction_kind, auto N_ELEM_PACK, typename inp_shape_t, typename out_shape_t>
auto reduction_2d_shape(meta::as_type<reduction_kind>, meta::as_type<N_ELEM_PACK>, const inp_shape_t& inp_shape, const out_shape_t&)
{
Expand All @@ -156,8 +196,6 @@ namespace nmtools::index
result.resize(2); // strictly 2
}

// assume out (and result) is 2D

constexpr auto row_idx = meta::ct_v<0>;
constexpr auto col_idx = meta::ct_v<1>;

Expand All @@ -174,6 +212,14 @@ namespace nmtools::index

return result;
}

template <typename index_t=size_t, ReductionKind reduction_kind, auto N_ELEM_PACK, typename inp_shape_t, typename out_shape_t, typename axis_t=index_t>
auto reduction_2d_shape(meta::as_type<reduction_kind> kind, meta::as_type<N_ELEM_PACK> n_elem_pack, const inp_shape_t& inp_shape, const out_shape_t& out_shape, axis_t axis)
{
auto inp_reshaped = reduction_nd_reshape(kind,n_elem_pack,inp_shape,out_shape,axis);
auto result = reduction_2d_shape(kind,n_elem_pack,inp_reshaped,out_shape);
return result;
}

template <typename index_t=size_t, ReductionKind reduction_kind, auto N_ELEM_PACK, typename simd_index_t, typename simd_shape_t, typename out_shape_t, typename inp_shape_t>
auto reduction_2d(meta::as_type<reduction_kind>, meta::as_type<N_ELEM_PACK>, const simd_index_t& simd_index, const simd_shape_t&, const out_shape_t& out_shape, const inp_shape_t& inp_shape)
Expand All @@ -184,8 +230,8 @@ namespace nmtools::index
auto result = result_t {};

const auto n_ops = [&](){
if (reduction_kind == ReductionKind::VERTICAL) {
return at(out_shape,meta::ct_v<-1>);
if constexpr (reduction_kind == ReductionKind::VERTICAL) {
return at(inp_shape,meta::ct_v<-1>);
} else {
return at(inp_shape,meta::ct_v<-1>);
}
Expand All @@ -206,11 +252,12 @@ namespace nmtools::index
at(result,meta::ct_v<1>) = tagged_index_t{inp_tag,inp_index};
} else if constexpr (reduction_kind == ReductionKind::VERTICAL) {
auto inp_offset = at(simd_index,meta::ct_v<0>) * at(inp_shape,meta::ct_v<1>);
auto out_offset = len(out_shape) > 1 ? (at(simd_index,meta::ct_v<0>) / (at(inp_shape,meta::ct_v<0>) / at(out_shape,meta::ct_v<0>))) * at(out_shape,meta::ct_v<-1>) : 0;
auto out_index = at(simd_index,meta::ct_v<1>) * N_ELEM_PACK;
auto inp_index = at(simd_index,meta::ct_v<1>) * N_ELEM_PACK;

auto rel_scalar_index = n_simd * N_ELEM_PACK + (at(simd_index,meta::ct_v<1>) - n_simd);
out_index = out_index > n_ops ? rel_scalar_index : out_index;
out_index = out_index > n_ops ? (out_offset + rel_scalar_index) : (out_offset + out_index);
inp_index = inp_index > n_ops ? (inp_offset + rel_scalar_index) : (inp_offset + inp_index);

// prefer scalar instead of padding because scalar store for output
Expand All @@ -228,8 +275,8 @@ namespace nmtools::index
{
using inp_shape_type = const inp_shape_t;
using out_shape_type = const out_shape_t;
using simd_shape_type = const nmtools_array<index_t,2>;

using simd_shape_type = meta::remove_cvref_t<decltype(reduction_2d_shape(meta::as_type_v<reduction_kind>,meta::as_type_v<N_ELEM_PACK>,meta::declval<inp_shape_t>(),meta::declval<out_shape_t>()))>;
using simd_index_type = nmtools_array<index_t,2>;
using index_type = index_t;
using size_type = index_t;
Expand Down Expand Up @@ -268,6 +315,14 @@ namespace nmtools::index
return enumerator_t{n_elem_pack,kind,out_shape,inp_shape};
}

template <typename index_t=size_t, ReductionKind reduction_kind, auto N_ELEM_PACK, typename out_shape_t, typename inp_shape_t, typename axis_t>
constexpr auto reduction_2d_enumerator(meta::as_type<reduction_kind> kind, meta::as_type<N_ELEM_PACK> n_elem_pack, const out_shape_t& out_shape, const inp_shape_t& inp_shape, axis_t axis)
{
auto inp_reshaped = reduction_nd_reshape(kind,n_elem_pack,inp_shape,out_shape,axis);
auto out_reshaped = reduction_nd_reshape(kind,n_elem_pack,out_shape,out_shape,axis);
return reduction_2d_enumerator(kind,n_elem_pack,out_reshaped,inp_reshaped);
}

template <auto N_ELEM_PACK, typename out_shape_t, typename lhs_shape_t, typename rhs_shape_t>
constexpr auto outer_simd_shape(meta::as_type<N_ELEM_PACK>, const out_shape_t& out_shape, const lhs_shape_t& lhs_shape, const rhs_shape_t& rhs_shape)
{
Expand Down
2 changes: 2 additions & 0 deletions include/nmtools/testing/testing.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ nmtools::utils::to_string(array)
*/
#define NMTOOLS_TESTING_LOG_TYPEINFO_IMPL_DOCTEST INFO

#ifndef NMTOOLS_CHECK_MESSAGE
#if !defined(__EMSCRIPTEN__) && !defined(__ANDROID__) && !defined(__arm__) && !defined(__MINGW32__)
#define NMTOOLS_CHECK_MESSAGE(result, message) \
{ \
Expand All @@ -53,6 +54,7 @@ nmtools::utils::to_string(array)
CHECK(result); \
}
#endif // (__EMSCRIPTEN__ || __ANDROID__ || __arm__ || __MINGW32__)
#endif // NMTOOLS_CHECK_MESSAGE

/**
* @brief implementation of doctest assert macro with message
Expand Down
2 changes: 2 additions & 0 deletions tests/simd/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ if (NMTOOLS_SIMD_TEST_SSE)
set(NMTOOLS_SIMD_TEST_SOURCES ${NMTOOLS_SIMD_TEST_SOURCES}
x86/reduction_sse.cpp
x86/reduction_2d_sse.cpp
x86/reduction_nd_sse.cpp
)
endif (NMTOOLS_SIMD_TEST_REDUCTION)
endif (NMTOOLS_SIMD_TEST_SSE)
Expand All @@ -63,6 +64,7 @@ if (NMTOOLS_SIMD_TEST_AVX)
set(NMTOOLS_SIMD_TEST_SOURCES ${NMTOOLS_SIMD_TEST_SOURCES}
x86/reduction_avx.cpp
x86/reduction_2d_avx.cpp
x86/reduction_nd_avx.cpp
)
endif (NMTOOLS_SIMD_TEST_REDUCTION)
endif (NMTOOLS_SIMD_TEST_AVX)
Expand Down
Loading

0 comments on commit 32ca11d

Please sign in to comment.