From e1e7bf53f0f975894bbb9da1563f41b6b7f690a3 Mon Sep 17 00:00:00 2001 From: Fahri Ali Rahman Date: Tue, 29 Oct 2024 14:25:24 +0700 Subject: [PATCH] Add kron, vecdot, and tensordot (#304) * add kron * add max_value & min_value metafunctions * add tests * add vecdot * add tensordot * move out index::contains form expand_dims, add range index function * add is_clipped_index metafunction * update tests * skip kron compile-time shape inference when on gcc * update compiler-notes * fix for gcc werror * fix for gcc werror --- docs/compiler-notes.txt | 5 +- include/nmtools/array/array/kron.hpp | 24 + include/nmtools/array/array/tensordot.hpp | 24 + include/nmtools/array/array/vecdot.hpp | 24 + include/nmtools/array/index/contains.hpp | 22 + include/nmtools/array/index/expand_dims.hpp | 21 +- include/nmtools/array/index/range.hpp | 96 + include/nmtools/array/view/kron.hpp | 385 ++++ include/nmtools/array/view/tensordot.hpp | 412 ++++ include/nmtools/array/view/vecdot.hpp | 32 + .../meta/bits/traits/is_clipped_index.hpp | 15 + .../nmtools/meta/bits/transform/max_value.hpp | 100 + .../nmtools/meta/bits/transform/min_value.hpp | 100 + include/nmtools/meta/traits.hpp | 1 + include/nmtools/meta/transform.hpp | 2 + include/nmtools/testing/data/array/kron.hpp | 2053 +++++++++++++++++ .../nmtools/testing/data/array/tensordot.hpp | 833 +++++++ include/nmtools/testing/data/array/vecdot.hpp | 427 ++++ include/nmtools/testing/data/index/kron.hpp | 587 +++++ .../nmtools/testing/data/index/tensordot.hpp | 575 +++++ tests/array/CMakeLists.txt | 3 + tests/array/array/kron.cpp | 246 ++ tests/array/array/tensordot.cpp | 267 +++ tests/array/array/vecdot.cpp | 149 ++ tests/index/CMakeLists.txt | 2 + tests/index/src/kron.cpp | 158 ++ tests/index/src/tensordot.cpp | 339 +++ tests/view/CMakeLists.txt | 3 + tests/view/src/kron.cpp | 246 ++ tests/view/src/tensordot.cpp | 267 +++ tests/view/src/vecdot.cpp | 149 ++ 31 files changed, 7546 insertions(+), 21 deletions(-) create mode 100644 include/nmtools/array/array/kron.hpp create mode 100644 include/nmtools/array/array/tensordot.hpp create mode 100644 include/nmtools/array/array/vecdot.hpp create mode 100644 include/nmtools/array/index/contains.hpp create mode 100644 include/nmtools/array/index/range.hpp create mode 100644 include/nmtools/array/view/kron.hpp create mode 100644 include/nmtools/array/view/tensordot.hpp create mode 100644 include/nmtools/array/view/vecdot.hpp create mode 100644 include/nmtools/meta/bits/traits/is_clipped_index.hpp create mode 100644 include/nmtools/meta/bits/transform/max_value.hpp create mode 100644 include/nmtools/meta/bits/transform/min_value.hpp create mode 100644 include/nmtools/testing/data/array/kron.hpp create mode 100644 include/nmtools/testing/data/array/tensordot.hpp create mode 100644 include/nmtools/testing/data/array/vecdot.hpp create mode 100644 include/nmtools/testing/data/index/kron.hpp create mode 100644 include/nmtools/testing/data/index/tensordot.hpp create mode 100644 tests/array/array/kron.cpp create mode 100644 tests/array/array/tensordot.cpp create mode 100644 tests/array/array/vecdot.cpp create mode 100644 tests/index/src/kron.cpp create mode 100644 tests/index/src/tensordot.cpp create mode 100644 tests/view/src/kron.cpp create mode 100644 tests/view/src/tensordot.cpp create mode 100644 tests/view/src/vecdot.cpp diff --git a/docs/compiler-notes.txt b/docs/compiler-notes.txt index 29a4a62b0..1de0108a6 100644 --- a/docs/compiler-notes.txt +++ b/docs/compiler-notes.txt @@ -101,4 +101,7 @@ Documenting various note on behaviour difference between clang & gcc (or with so 1. clang vs gcc disagree on capturing constexpr value in lambda expression clang ok, gcc not ok - https://godbolt.org/z/a1o8P9957 \ No newline at end of file + https://godbolt.org/z/a1o8P9957 + +1. gcc `for` loop becomes goto in constexpr context (and breaks), works fine on clang + https://github.com/alifahrri/nmtools/issues/303 \ No newline at end of file diff --git a/include/nmtools/array/array/kron.hpp b/include/nmtools/array/array/kron.hpp new file mode 100644 index 000000000..3fe5e0b03 --- /dev/null +++ b/include/nmtools/array/array/kron.hpp @@ -0,0 +1,24 @@ +#ifndef NMTOOLS_ARRAY_ARRAY_KRON_HPP +#define NMTOOLS_ARRAY_ARRAY_KRON_HPP + +#include "nmtools/array/view/kron.hpp" +#include "nmtools/array/eval.hpp" + +namespace nmtools::array +{ + template + , typename lhs_t, typename rhs_t> + constexpr auto kron(const lhs_t& lhs, const rhs_t& rhs + , context_t&& context=context_t{}, output_t&& output=output_t{}, meta::as_value resolver=meta::as_value_v) + { + auto a = view::kron(lhs,rhs); + return eval( + a + , nmtools::forward(context) + , nmtools::forward(output) + , resolver + ); + } // kron +} + +#endif // NMTOOLS_ARRAY_ARRAY_KRON_HPP \ No newline at end of file diff --git a/include/nmtools/array/array/tensordot.hpp b/include/nmtools/array/array/tensordot.hpp new file mode 100644 index 000000000..329b45b1a --- /dev/null +++ b/include/nmtools/array/array/tensordot.hpp @@ -0,0 +1,24 @@ +#ifndef NMTOOLS_ARRAY_ARRAY_TENSORDOT_HPP +#define NMTOOLS_ARRAY_ARRAY_TENSORDOT_HPP + +#include "nmtools/array/view/tensordot.hpp" +#include "nmtools/array/eval.hpp" + +namespace nmtools::array +{ + template + , typename lhs_t, typename rhs_t, typename axes_t=meta::ct<2>> + constexpr auto tensordot(const lhs_t& lhs, const rhs_t& rhs, axes_t axes=axes_t{} + , context_t&& context=context_t{}, output_t&& output=output_t{},meta::as_value resolver=meta::as_value_v) + { + auto a = view::tensordot(lhs,rhs,axes); + return eval( + a + , nmtools::forward(context) + , nmtools::forward(output) + , resolver + ); + } // tensordot +} // nmtools::array + +#endif // NMTOOLS_ARRAY_ARRAY_TENSORDOT_HPP \ No newline at end of file diff --git a/include/nmtools/array/array/vecdot.hpp b/include/nmtools/array/array/vecdot.hpp new file mode 100644 index 000000000..89953bbf2 --- /dev/null +++ b/include/nmtools/array/array/vecdot.hpp @@ -0,0 +1,24 @@ +#ifndef NMTOOLS_ARRAY_ARRAY_VECDOT_HPP +#define NMTOOLS_ARRAY_ARRAY_VECDOT_HPP + +#include "nmtools/array/view/vecdot.hpp" +#include "nmtools/array/eval.hpp" + +namespace nmtools::array +{ + template + , typename lhs_t, typename rhs_t, typename dtype_t=none_t, typename keepdims_t=meta::false_type> + constexpr auto vecdot(const lhs_t& lhs, const rhs_t& rhs, dtype_t dtype=dtype_t{}, keepdims_t keepdims=keepdims_t{} + , context_t&& context=context_t{}, output_t&& output=output_t{},meta::as_value resolver=meta::as_value_v) + { + auto a = view::vecdot(lhs,rhs,dtype,keepdims); + return eval( + a + , nmtools::forward(context) + , nmtools::forward(output) + , resolver + ); + } // vecdot +} // nmtools::array + +#endif // NMTOOLS_ARRAY_ARRAY_VECDOT_HPP \ No newline at end of file diff --git a/include/nmtools/array/index/contains.hpp b/include/nmtools/array/index/contains.hpp new file mode 100644 index 000000000..80c8beac4 --- /dev/null +++ b/include/nmtools/array/index/contains.hpp @@ -0,0 +1,22 @@ +#ifndef NMTOOLS_ARRAY_INDEX_CONTAINS_HPP +#define NMTOOLS_ARRAY_INDEX_CONTAINS_HPP + +#include "nmtools/meta.hpp" +#include "nmtools/array/shape.hpp" +#include "nmtools/utils/isequal.hpp" + +namespace nmtools::index +{ + template + constexpr auto contains(const array_t& array, const value_t& value) + { + for (nm_size_t i=0; i<(nm_size_t)len(array); i++) { + if (utils::isequal(at(array,i),value)) { + return true; + } + } + return false; + } // contains +} // nmtools::index + +#endif // NMTOOLS_ARRAY_INDEX_CONTAINS_HPP \ No newline at end of file diff --git a/include/nmtools/array/index/expand_dims.hpp b/include/nmtools/array/index/expand_dims.hpp index 81492110a..7686ab892 100644 --- a/include/nmtools/array/index/expand_dims.hpp +++ b/include/nmtools/array/index/expand_dims.hpp @@ -6,6 +6,7 @@ #include "nmtools/array/utility/at.hpp" #include "nmtools/utils/isequal.hpp" #include "nmtools/array/ndarray/hybrid.hpp" +#include "nmtools/array/index/contains.hpp" #include "nmtools/array/index/normalize_axis.hpp" #include "nmtools/utility/unwrap.hpp" @@ -22,26 +23,6 @@ namespace nmtools::index */ struct shape_expand_dims_t {}; - // TODO: remove - template - constexpr auto contains(const array_t& array, const value_t& value) - { - if constexpr (meta::is_fixed_index_array_v) { - bool contain = false; - meta::template_for>([&](auto i){ - if (utils::isequal(at(array,i),value)) - contain = true; - }); - return contain; - } - else { - for (size_t i=0; i> + constexpr auto range([[maybe_unused]] start_t start + , [[maybe_unused]] stop_t stop + , [[maybe_unused]] step_t step=step_t{} + ) { + using result_t = meta::resolve_optype_t; + + auto result = result_t {}; + + if constexpr (!meta::is_fail_v + && !meta::is_constant_index_array_v + ) { + auto n = (stop - start) / step; + if constexpr (meta::is_resizable_v) { + result.resize(n); + } + + for (nm_size_t i=0; i<(nm_size_t)n; i++) { + at(result,i) = i * step; + } + } + + return result; + } // range + + template + constexpr auto range(stop_t stop) + { + return range(meta::ct_v<0>,stop,meta::ct_v<1>); + } +} // nmtools::index + +namespace nmtools::meta +{ + namespace error + { + template + struct RANGE_UNSUPPORTED : detail::fail_t {}; + } + + template + struct resolve_optype< + void, index::range_t, start_t, stop_t, step_t + > { + static constexpr auto vtype = [](){ + if constexpr ( + !is_index_v + || !is_index_v + || !is_index_v + ) { + using type = error::RANGE_UNSUPPORTED; + return as_value_v; + } else if constexpr (is_constant_index_v + && is_constant_index_v + && is_constant_index_v + ) { + constexpr auto start = to_value_v; + constexpr auto stop = to_value_v; + constexpr auto step = to_value_v; + constexpr auto start_cl = clipped_int64_t 0 ? start : 1)>(start); + constexpr auto stop_cl = clipped_int64_t<(int64_t)stop>(stop); + constexpr auto step_cl = clipped_int64_t<(int64_t)step>(step); + constexpr auto result = index::range(start_cl,stop_cl,step_cl); + using nmtools::at, nmtools::len; + return template_reduce([&](auto init, auto I){ + using init_t = type_t; + using type = append_type_t>; + return as_value_v; + }, as_value_v>); + } else { + constexpr auto max_dim = max_value_v; + if constexpr (!is_fail_v) { + using type = nmtools_static_vector; + return as_value_v; + } else { + // TODO: small vector optimization + using type = nmtools_list; + return as_value_v; + } + } + }(); + using type = type_t; + }; // index::range_t +} // nmtools::meta + +#endif // NMTOOLS_ARRAY_INDEX_RANGE_HPP \ No newline at end of file diff --git a/include/nmtools/array/view/kron.hpp b/include/nmtools/array/view/kron.hpp new file mode 100644 index 000000000..ac6b233ff --- /dev/null +++ b/include/nmtools/array/view/kron.hpp @@ -0,0 +1,385 @@ +#ifndef NTMOOLS_ARRAY_VIEW_KRON_HPP +#define NTMOOLS_ARRAY_VIEW_KRON_HPP + +#include "nmtools/meta.hpp" +#include "nmtools/array/shape.hpp" + +namespace nmtools::index +{ + struct kron_dst_transpose_t {}; + + template + constexpr auto kron_dst_transpose(const lhs_dim_t& lhs_dim, const rhs_dim_t& rhs_dim) + -> meta::resolve_optype_t + { + using result_t = meta::resolve_optype_t; + + auto result = result_t {}; + + if constexpr (!meta::is_fail_v + && !meta::is_constant_index_array_v + ) { + auto dst_dim = (nm_size_t)lhs_dim + (nm_size_t)rhs_dim; + + if constexpr (meta::is_resizable_v) { + result.resize(dst_dim); + } + + for (nm_size_t i=0; i<(nm_size_t)dst_dim; i++) { + at(result,i) = i; + } + + if ((nm_size_t)lhs_dim == (nm_size_t)rhs_dim) { + for (nm_size_t i=0; i<(nm_size_t)dst_dim/2; i++) { + at(result,i*2) = i; + } + for (nm_size_t i=0; i<(nm_size_t)dst_dim/2; i++) { + at(result,i*2+1) = i + (dst_dim/2); + } + } else if ((nm_size_t)lhs_dim < (nm_size_t)rhs_dim) { + #ifdef __clang__ + auto rhs_dim_i = [&](){ + if constexpr (meta::is_constant_index_v) { + return meta::ct_v; + } else if constexpr (meta::is_clipped_integer_v) { + constexpr auto max_i = meta::max_value_v; + return clipped_size_t{rhs_dim-1}; + } else { + return rhs_dim-1; + } + }(); + #else + auto rhs_dim_i = nm_index_t(rhs_dim) - 1; + #endif + auto initial_axes = kron_dst_transpose(lhs_dim,rhs_dim_i); + for (nm_size_t i=0; i<(nm_size_t)(dst_dim-1); i++) { + at(result,i) = at(initial_axes,i); + } + for (nm_size_t i=0; i<(nm_size_t)lhs_dim; i++) { + nm_index_t idx = -(i+1) * 2; + auto tmp = at(result,idx); + at(result,idx) = at(result,idx-1); + at(result,idx-1) = tmp; + } + } else { + #ifdef __clang__ + auto rhs_dim_i = [&](){ + if constexpr (meta::is_constant_index_v) { + return meta::ct_v; + } else if constexpr (meta::is_clipped_integer_v) { + constexpr auto max_rdim = meta::max_value_v; + constexpr auto max_ldim = meta::max_value_v; + + // avoid calling the function recursively at compile-time + if constexpr (!meta::is_fail_v) { + if constexpr (max_rdim <= max_ldim) { + constexpr auto max_i = max_rdim + 1; + return clipped_size_t(rhs_dim+1); + } else { + // unreachable + return rhs_dim + 1; + } + } else { + return rhs_dim + 1; + } + } else { + return rhs_dim+1; + } + }(); + #else + auto rhs_dim_i = (nm_index_t)rhs_dim+1; + #endif + auto initial_axes = kron_dst_transpose(lhs_dim,rhs_dim_i); + for (nm_size_t i=0; i<(nm_size_t)dst_dim; i++) { + at(result,i) = at(initial_axes,i); + } + for (nm_size_t i=0; i<(nm_size_t)rhs_dim; i++) { + nm_index_t idx = -(i*2+1); + auto tmp = at(result,idx); + at(result,idx) = at(result,idx-1); + at(result,idx-1) = tmp; + } + } + } + + return result; + } // kron_dst_transpose + + struct kron_lhs_reshape_t {}; + + template + constexpr auto kron_lhs_reshape([[maybe_unused]] const lhs_shape_t& lhs_shape, [[maybe_unused]] rhs_dim_t rhs_dim) + { + using result_t = meta::resolve_optype_t; + + auto result = result_t {}; + + if constexpr (!meta::is_fail_v + && !meta::is_constant_index_array_v + ) { + auto lhs_dim = len(lhs_shape); + auto dst_dim = lhs_dim + rhs_dim; + + if constexpr (meta::is_resizable_v) { + result.resize(dst_dim); + } + + for (nm_size_t i=0; i<(nm_size_t)dst_dim; i++) { + at(result,i) = 1; + } + for (nm_size_t i=0; i<(nm_size_t)lhs_dim; i++) { + at(result,i) = at(lhs_shape,i); + } + } + + return result; + } // kron_lhs_reshape + + struct kron_dst_reshape_t {}; + + template + constexpr auto kron_dst_reshape(const lhs_shape_t& lhs_shape, const rhs_shape_t& rhs_shape) + { + using result_t = meta::resolve_optype_t; + + auto result = result_t {}; + + if constexpr (!meta::is_fail_v + && !meta::is_constant_index_array_v + ) { + auto lhs_dim = len(lhs_shape); + auto rhs_dim = len(rhs_shape); + + auto dst_dim = (lhs_dim > rhs_dim ? lhs_dim : rhs_dim); + + if constexpr (meta::is_resizable_v) { + result.resize(dst_dim); + } + + for (nm_index_t i=1; i<=(nm_index_t)dst_dim; i++) { + auto n_i_lhs = -i + lhs_dim; + auto n_i_rhs = -i + rhs_dim; + if ( (n_i_lhs < lhs_dim) + && (n_i_rhs < rhs_dim) + && (n_i_lhs >= 0) + && (n_i_rhs >= 0) + ) { + at(result,-i) = at(lhs_shape,-i) * at(rhs_shape,-i); + } else if ( + (n_i_lhs < lhs_dim) + && (n_i_lhs >= 0) + ) { + at(result,-i) = at(lhs_shape,-i); + } else { + at(result,-i) = at(rhs_shape,-i); + } + } + } + + return result; + } +} // nmtools::index + +namespace nmtools::meta +{ + namespace error + { + template + struct KRON_DST_TRANSPOSE_UNSUPPORTED : detail::fail_t {}; + + template + struct KRON_LHS_RESHAPE_UNSUPPORTED : detail::fail_t {}; + + template + struct KRON_DST_RESHAPE_UNSUPPORTED : detail::fail_t {}; + } + + template + struct resolve_optype< + void, index::kron_dst_transpose_t, lhs_dim_t, rhs_dim_t + > { + static constexpr auto vtype = [](){ + if constexpr (!is_index_v + || !is_index_v + ) { + using type = error::KRON_DST_TRANSPOSE_UNSUPPORTED; + return as_value_v; + } else if constexpr ( + is_constant_index_v + && is_constant_index_v + ) { + // TODO: fix compile-time shape inference on gcc + // quick workaround, skip on gcc (https://github.com/alifahrri/nmtools/issues/303) + #ifdef __clang__ + constexpr auto lhs_dim = to_value_v; + constexpr auto rhs_dim = to_value_v; + constexpr auto result = index::kron_dst_transpose(clipped_size_t(lhs_dim),clipped_size_t(rhs_dim)); + using nmtools::at, nmtools::len; + return template_reduce([&](auto init, auto I){ + using init_t = type_t; + using type = append_type_t>; + return as_value_v; + }, as_value_v>); + #else + constexpr auto DST_DIM = lhs_dim_t::value + rhs_dim_t::value; + using type = nmtools_array; + return as_value_v; + #endif + } else { + [[maybe_unused]] constexpr auto LHS_MAX = max_value_v; + [[maybe_unused]] constexpr auto RHS_MAX = max_value_v; + if constexpr (is_constant_index_v && is_constant_index_v) { + constexpr auto DST_DIM = lhs_dim_t::value + rhs_dim_t::value; + using type = nmtools_array; + return as_value_v; + } else if constexpr (!is_fail_v + && !is_fail_v + ) { + constexpr auto DST_B_DIM = LHS_MAX + RHS_MAX; + using type = nmtools_static_vector; + return as_value_v; + } else { + // TODO: support small vector + using type = nmtools_list; + return as_value_v; + } + } + }(); + using type = type_t; + }; // index::kron_dst_transpose_t + + template + struct resolve_optype< + void, index::kron_lhs_reshape_t, lhs_shape_t, rhs_dim_t + > { + static constexpr auto vtype = [](){ + if constexpr (!is_index_v + || !is_index_array_v + ) { + using type = error::KRON_LHS_RESHAPE_UNSUPPORTED; + return as_value_v; + } else if constexpr (is_constant_index_v + && is_constant_index_array_v + ) { + constexpr auto lhs_shape = to_value_v; + constexpr auto result = index::kron_lhs_reshape(lhs_shape,rhs_dim_t{}); + using nmtools::at, nmtools::len; + return template_reduce([&](auto init, auto I){ + using init_t = type_t; + using type = append_type_t>; + return as_value_v; + }, as_value_v>); + } else { + [[maybe_unused]] + constexpr auto LHS_B_DIM = bounded_size_v; + [[maybe_unused]] + constexpr auto RHS_MAX = max_value_v; + constexpr auto LHS_DIM = len_v; + if constexpr ((LHS_DIM > 0) && is_constant_index_v) { + constexpr auto DST_DIM = LHS_DIM + rhs_dim_t::value; + using type = nmtools_array; + return as_value_v; + } else if constexpr (!is_fail_v + && !is_fail_v + ) { + constexpr auto DST_B_DIM = LHS_B_DIM + RHS_MAX; + using type = nmtools_static_vector; + return as_value_v; + } else { + // TODO: small vector optimization + using type = nmtools_list; + return as_value_v; + } + } + }(); + using type = type_t; + }; // index::kron_lhs_reshape_t + + template + struct resolve_optype< + void, index::kron_dst_reshape_t, lhs_shape_t, rhs_shape_t + > { + static constexpr auto vtype = [](){ + if constexpr ( + !is_index_array_v + || !is_index_array_v + ) { + using type = error::KRON_DST_RESHAPE_UNSUPPORTED; + return as_value_v; + } else if constexpr ( + is_constant_index_array_v + && is_constant_index_array_v + ) { + constexpr auto lhs_shape = to_value_v; + constexpr auto rhs_shape = to_value_v; + constexpr auto result = index::kron_dst_reshape(lhs_shape,rhs_shape); + using nmtools::at, nmtools::len; + return template_reduce([&](auto init, auto I){ + using init_t = type_t; + using type = append_type_t>; + return as_value_v; + }, as_value_v>); + } else { + [[maybe_unused]] constexpr auto LHS_B_DIM = bounded_size_v; + [[maybe_unused]] constexpr auto RHS_B_DIM = bounded_size_v; + constexpr auto LHS_DIM = len_v; + constexpr auto RHS_DIM = len_v; + if constexpr ((LHS_DIM > 0) && (RHS_DIM > 0)) { + constexpr auto DST_DIM = (LHS_DIM > RHS_DIM ? LHS_DIM : RHS_DIM); + using type = nmtools_array; + return as_value_v; + } else if constexpr (!is_fail_v + && !is_fail_v + ) { + constexpr auto DST_B_DIM = (LHS_B_DIM > RHS_B_DIM ? LHS_B_DIM : RHS_B_DIM); + using type = nmtools_static_vector; + return as_value_v; + } else { + // TODO: small vector optimization + using type = nmtools_list; + return as_value_v; + } + } + }(); + using type = type_t; + }; // index::kron_dst_reshape_t +} // nmtools::meta + +/*******************************************************************************/ + +#include "nmtools/array/view/decorator.hpp" +#include "nmtools/array/view/alias.hpp" +#include "nmtools/array/view/reshape.hpp" +#include "nmtools/array/view/tile.hpp" +#include "nmtools/array/view/transpose.hpp" +#include "nmtools/array/view/ufuncs/multiply.hpp" + +namespace nmtools::view +{ + template + constexpr auto kron(const lhs_t& lhs, const rhs_t& rhs) + { + auto lhs_shape = shape(lhs); + auto rhs_shape = shape(rhs); + auto lhs_dim = dim(lhs); + auto rhs_dim = dim(rhs); + + auto aliased = view::aliased(lhs,rhs); + + auto a_lhs = nmtools::get<0>(aliased); + auto a_rhs = nmtools::get<1>(aliased); + + auto lhs_dst_shape = index::kron_lhs_reshape(lhs_shape,rhs_dim); + auto dst_axes = index::kron_dst_transpose(lhs_dim,rhs_dim); + auto dst_shape = index::kron_dst_reshape(lhs_shape,rhs_shape); + + auto a = view::reshape(a_lhs,lhs_dst_shape); + auto b = view::tile(a,rhs_shape); + auto c = view::multiply(b,a_rhs); + auto d = view::transpose(c,dst_axes); + auto e = view::reshape(d,dst_shape); + return e; + } // kron +} // nmtools::view + +#endif // NTMOOLS_ARRAY_VIEW_KRON_HPP \ No newline at end of file diff --git a/include/nmtools/array/view/tensordot.hpp b/include/nmtools/array/view/tensordot.hpp new file mode 100644 index 000000000..8d82683eb --- /dev/null +++ b/include/nmtools/array/view/tensordot.hpp @@ -0,0 +1,412 @@ +#ifndef NMTOOLS_ARRAY_VIEW_TENSORDOT_HPP +#define NMTOOLS_ARRAY_VIEW_TENSORDOT_HPP + +#include "nmtools/meta.hpp" +#include "nmtools/array/shape.hpp" +#include "nmtools/array/index/range.hpp" +#include "nmtools/array/index/normalize_axis.hpp" +#include "nmtools/utility/unwrap.hpp" +#include "nmtools/array/index/contains.hpp" + +namespace nmtools::index +{ + struct tensordot_lhs_transpose_t {}; + + template + constexpr auto tensordot_lhs_transpose([[maybe_unused]] lhs_dim_t lhs_dim, [[maybe_unused]] const axes_t& axes) + { + using result_t = meta::resolve_optype_t; + + auto result = result_t {}; + + if constexpr (!meta::is_fail_v + && !meta::is_constant_index_array_v + ) { + if constexpr (meta::is_resizable_v) { + result.resize(lhs_dim); + } + + for (nm_size_t i=0; i<(nm_size_t)lhs_dim; i++) { + at(result,i) = i; + } + + if constexpr (!meta::is_index_v) { + // TODO: propagate error + auto m_axes = unwrap(normalize_axis(axes,lhs_dim)); + nm_index_t index = 0; + for (nm_size_t i=0; i<(nm_size_t)lhs_dim; i++) { + if (contains(m_axes,i)) { + continue; + } + at(result,index++) = i; + } + for (nm_size_t i=0; i<(nm_size_t)len(m_axes); i++) { + at(result,index++) = at(m_axes,i); + } + } + } + + return result; + } // tensordot_lhs_transpose + + struct tensordot_rhs_transpose_t {}; + + template + constexpr auto tensordot_rhs_transpose([[maybe_unused]] rhs_dim_t rhs_dim, [[maybe_unused]] const axes_t& axes) + { + using result_t = meta::resolve_optype_t; + + auto result = result_t {}; + + if constexpr (!meta::is_fail_v + && !meta::is_constant_index_array_v + ) { + if constexpr (meta::is_resizable_v) { + result.resize(rhs_dim); + } + + for (nm_size_t i=0; i<(nm_size_t)rhs_dim; i++) { + at(result,i) = i; + } + + nm_size_t index = 0; + + const auto m_axes = [&](){ + if constexpr (meta::is_index_v) { + return range(meta::ct_v<0>,axes); + } else { + // TODO: propagate error + return unwrap(normalize_axis(axes,rhs_dim)); + } + }(); + + for (nm_size_t i=0; i<(nm_size_t)rhs_dim; i++) { + if (contains(m_axes,i)) { + continue; + } + at(result,index) = i; + index++; + } + + for (nm_size_t i=0; i<(nm_size_t)len(m_axes); i++) { + at(result,index) = at(m_axes,i); + index++; + } + } + + return result; + } // tensordot_rhs_transpose + + struct tensordot_lhs_reshape_t {}; + + template + constexpr auto tensordot_lhs_reshape([[maybe_unused]] const lhs_shape_t& lhs_shape + , [[maybe_unused]] const rhs_shape_t& rhs_shape + , [[maybe_unused]] const lhs_axes_t& lhs_axes + ) { + using result_t = meta::resolve_optype_t; + + auto result = result_t {}; + + if constexpr (!meta::is_fail_v + && !meta::is_constant_index_array_v + ) { + auto lhs_dim = len(lhs_shape); + auto rhs_dim = len(rhs_shape); + auto sum_dim = [&](){ + if constexpr (meta::is_index_v) { + return lhs_axes; + } else { + return len(lhs_axes); + } + }(); + auto dst_dim = lhs_dim + rhs_dim - sum_dim; + + if constexpr (meta::is_resizable_v) { + result.resize(dst_dim); + } + + for (nm_size_t i=0; i<(nm_size_t)dst_dim; i++) { + at(result,i) = 1; + } + + auto non_contracted = lhs_dim - sum_dim; + for (nm_size_t i=0; i<(non_contracted); i++) { + at(result,i) = at(lhs_shape,i); + } + for (nm_size_t i=0; i<(sum_dim); i++) { + nm_index_t index = -(i+1); + at(result,index) = at(lhs_shape,index); + } + } + + return result; + } // tensordot_lhs_reshape +} // namespace nmtools::index + +namespace nmtools::meta +{ + namespace error + { + template + struct TENSORDOT_LHS_TRANSPOSE_UNSUPPORTED : detail::fail_t {}; + + template + struct TENSORDOT_RHS_TRANSPOSE_UNSUPPORTED : detail::fail_t {}; + + template + struct TENSORDOT_LHS_RESHAPE_UNSUPPORTED : detail::fail_t {}; + } + + template + struct resolve_optype< + void, index::tensordot_lhs_transpose_t, lhs_dim_t, axes_t + > { + static constexpr auto vtype = [](){ + if constexpr (!is_index_v + || !(is_index_array_v || is_index_v) + ) { + using type = error::TENSORDOT_LHS_TRANSPOSE_UNSUPPORTED; + return as_value_v; + } else if constexpr ( + is_constant_index_v + && (is_constant_index_array_v || is_constant_index_v) + ) { + constexpr auto lhs_dim = to_value_v; + constexpr auto lhs_dim_cl = clipped_size_t(lhs_dim); + constexpr auto axes = [&](){ + if constexpr (is_constant_index_v) { + constexpr auto axes = to_value_v; + constexpr auto axes_cl = clipped_size_t<(nm_size_t)axes>(axes); + return axes_cl; + } else { + return to_value_v; + } + }(); + constexpr auto result = index::tensordot_lhs_transpose(lhs_dim_cl,axes); + using nmtools::at, nmtools::len; + return template_reduce([&](auto init, auto I){ + using init_t = type_t; + using type = append_type_t>; + return as_value_v; + }, as_value_v>); + } else { + if constexpr (is_constant_index_v) { + constexpr auto DIM = lhs_dim_t::value; + using type = nmtools_array; + return as_value_v; + } else if constexpr (is_clipped_index_v) { + constexpr auto MAX_DIM = max_value_v; + using type = nmtools_static_vector; + return as_value_v; + } else { + // TODO: small vector optimization + using type = nmtools_list; + return as_value_v; + } + } + }(); + using type = type_t; + }; // index::tensordot_lhs_transpose_t + + template + struct resolve_optype< + void, index::tensordot_rhs_transpose_t, rhs_dim_t, axes_t + > { + static constexpr auto vtype = [](){ + if constexpr (!is_index_v + || !(is_index_v || is_index_array_v) + ) { + using type = error::TENSORDOT_RHS_TRANSPOSE_UNSUPPORTED; + return as_value_v; + } else if constexpr (is_constant_index_v + && (is_constant_index_array_v || is_constant_index_v) + ) { + constexpr auto rhs_dim = to_value_v; + constexpr auto rhs_dim_cl = clipped_size_t(rhs_dim); + constexpr auto axes = [&](){ + if constexpr (is_constant_index_v) { + constexpr auto axes = to_value_v; + constexpr auto axes_cl = clipped_size_t<(nm_size_t)axes>(axes); + return axes_cl; + } else { + return to_value_v; + } + }(); + constexpr auto result = index::tensordot_rhs_transpose(rhs_dim_cl,axes); + using nmtools::at, nmtools::len; + return template_reduce([&](auto init, auto I){ + using init_t = type_t; + using type = append_type_t>; + return as_value_v; + }, as_value_v>); + } else { + if constexpr (is_constant_index_v) { + constexpr auto DIM = rhs_dim_t::value; + using type = nmtools_array; + return as_value_v; + } else if constexpr (is_clipped_index_v) { + constexpr auto MAX_DIM = max_value_v; + using type = nmtools_static_vector; + return as_value_v; + } else { + // TODO: small vector optimization + using type = nmtools_list; + return as_value_v; + } + } + }(); + using type = type_t; + }; + + template + struct resolve_optype< + void, index::tensordot_lhs_reshape_t, lhs_shape_t, rhs_shape_t, lhs_axes_t + > { + static constexpr auto vtype = [](){ + if constexpr (!is_index_array_v + || !is_index_array_v + || !(is_index_array_v || is_index_v) + ) { + using type = error::TENSORDOT_LHS_RESHAPE_UNSUPPORTED; + return as_value_v; + } else if constexpr (is_constant_index_array_v + && is_constant_index_array_v + && (is_constant_index_array_v || is_constant_index_v) + ) { + constexpr auto lhs_shape = to_value_v; + constexpr auto rhs_shape = to_value_v; + constexpr auto lhs_axes = to_value_v; + constexpr auto result = index::tensordot_lhs_reshape(lhs_shape,rhs_shape,lhs_axes); + using nmtools::at, nmtools::len; + return template_reduce([&](auto init, auto I){ + using init_t = type_t; + using type = append_type_t>; + return as_value_v; + }, as_value_v>); + } else { + constexpr auto LHS_DIM = len_v; + constexpr auto RHS_DIM = len_v; + + [[maybe_unused]] constexpr auto LHS_B_DIM = bounded_size_v; + [[maybe_unused]] constexpr auto RHS_B_DIM = bounded_size_v; + + [[maybe_unused]] constexpr auto AXIS_DIM = len_v; + + if constexpr ((LHS_DIM > 0) && (RHS_DIM > 0) + && (is_constant_index_v || (AXIS_DIM > 0)) + ) { + constexpr auto SUM_DIM = [&](){ + if constexpr (AXIS_DIM > 0) { + return AXIS_DIM; + } else { + return lhs_axes_t::value; + } + }(); + constexpr auto DST_DIM = LHS_DIM + RHS_DIM - SUM_DIM; + using type = nmtools_array; + return as_value_v; + } else if constexpr (!is_fail_v + && !is_fail_v + ) { + // lhs_axes used as subtractor, contribute inversely to max dim + constexpr auto DST_B_DIM = LHS_B_DIM + RHS_B_DIM - 1; + using type = nmtools_static_vector; + return as_value_v; + } else { + // TODO: support small vector + using type = nmtools_list; + return as_value_v; + } + } + }(); + using type = type_t; + }; +} // nmtools::meta + +/*******************************************************************************/ + +#include "nmtools/array/view/decorator.hpp" +#include "nmtools/array/view/alias.hpp" +#include "nmtools/array/view/transpose.hpp" +#include "nmtools/array/view/reshape.hpp" +#include "nmtools/array/view/sum.hpp" +#include "nmtools/array/view/ufuncs/multiply.hpp" + +namespace nmtools::view +{ + template > + constexpr auto tensordot(const lhs_t& lhs, const rhs_t& rhs, const axes_t& axes=axes_t{}) + { + auto lhs_axes = [&](){ + if constexpr (meta::is_tuple_v) { + return nmtools::get<0>(axes); + } else { + return axes; + } + }(); + auto rhs_axes = [&](){ + if constexpr (meta::is_tuple_v) { + return nmtools::get<1>(axes); + } else { + return axes; + } + }(); + auto sum_dim = [&](){ + if constexpr (meta::is_tuple_v) { + return size(lhs_axes); + } else { + return axes; + } + }(); + auto sum_axis = [&](){ + if constexpr (meta::is_constant_index_v) { + constexpr auto N = meta::to_value_v; + return meta::template_reduce([&](auto init, auto I){ + constexpr auto i = decltype(I)::value; + using init_t = meta::remove_cvref_t; + using type = meta::append_type_t>; + return type{}; + }, nmtools_tuple{}); + } else { + // TODO: support small vector + using result_t = nmtools_list; + auto result = result_t {}; + result.resize(sum_dim); + for (nm_size_t i=0; i<(nm_size_t)len(result); i++) { + at(result,i) = -(i+1); + } + return result; + } + }(); + + // auto lhs_shape = shape(lhs); + auto lhs_dim = dim(lhs); + auto lhs_transpose_axes = index::tensordot_lhs_transpose(lhs_dim,lhs_axes); + + auto rhs_shape = shape(rhs); + auto rhs_dim = dim(rhs); + auto rhs_transpose_axes = index::tensordot_rhs_transpose(rhs_dim,rhs_axes); + + auto aliased = view::aliased(lhs,rhs); + auto a_lhs = nmtools::get<0>(aliased); + auto a_rhs = nmtools::get<1>(aliased); + + auto a = view::transpose(a_lhs,lhs_transpose_axes); + + auto lhs_dst_shape = index::tensordot_lhs_reshape(shape(a),rhs_shape,sum_axis); + + auto b = view::reshape(a,lhs_dst_shape); + auto c = view::transpose(a_rhs,rhs_transpose_axes); + auto d = view::multiply(b,c); + + auto dtype = None; + auto initial = None; + auto keepdims = False; + auto e = view::sum(d,sum_axis,dtype,initial,keepdims); + + return e; + } // tensordot +} // nmtools::view + +#endif // NMTOOLS_ARRAY_VIEW_TENSORDOT_HPP \ No newline at end of file diff --git a/include/nmtools/array/view/vecdot.hpp b/include/nmtools/array/view/vecdot.hpp new file mode 100644 index 000000000..3e9a850c8 --- /dev/null +++ b/include/nmtools/array/view/vecdot.hpp @@ -0,0 +1,32 @@ +#ifndef NMTOOLS_ARRAY_VIEW_VECDOT_HPP +#define NMTOOLS_ARRAY_VIEW_VECDOT_HPP + +#include "nmtools/array/view/decorator.hpp" +#include "nmtools/array/view/alias.hpp" +#include "nmtools/array/view/ufuncs/multiply.hpp" +#include "nmtools/array/view/sum.hpp" + +namespace nmtools::view +{ + template + constexpr auto vecdot(const lhs_t& lhs, const rhs_t& rhs, dtype_t dtype=dtype_t{}, keepdims_t keepdims=keepdims_t{}) + { + auto aliased = view::aliased(lhs,rhs); + + auto a_lhs = nmtools::get<0>(aliased); + auto a_rhs = nmtools::get<1>(aliased); + + auto axis = meta::ct_v<-1>; + auto initial = None; + + return view::sum( + view::multiply(a_lhs,a_rhs) + , axis + , dtype + , initial + , keepdims + ); + } // vecdot +} // nmtools::view + +#endif // NMTOOLS_ARRAY_VIEW_VECDOT_HPP \ No newline at end of file diff --git a/include/nmtools/meta/bits/traits/is_clipped_index.hpp b/include/nmtools/meta/bits/traits/is_clipped_index.hpp new file mode 100644 index 000000000..0a38c74af --- /dev/null +++ b/include/nmtools/meta/bits/traits/is_clipped_index.hpp @@ -0,0 +1,15 @@ +#ifndef NMTOOLS_META_BITS_TRAITS_IS_CLIPPED_INDEX_HPP +#define NMTOOLS_META_BITS_TRAITS_IS_CLIPPED_INDEX_HPP + +#include "nmtools/meta/bits/traits/is_clipped_integer.hpp" + +namespace nmtools::meta +{ + template + struct is_clipped_index : is_clipped_integer {}; + + template + constexpr inline auto is_clipped_index_v = is_clipped_index::value; +} + +#endif // NMTOOLS_META_BITS_TRAITS_IS_CLIPPED_INDEX_HPP \ No newline at end of file diff --git a/include/nmtools/meta/bits/transform/max_value.hpp b/include/nmtools/meta/bits/transform/max_value.hpp new file mode 100644 index 000000000..b52a2bd3b --- /dev/null +++ b/include/nmtools/meta/bits/transform/max_value.hpp @@ -0,0 +1,100 @@ +#ifndef NMTOOLS_META_BITS_TRANSFORM_MAX_VALUE_HPP +#define NMTOOLS_META_BITS_TRANSFORM_MAX_VALUE_HPP + +#include "nmtools/meta/common.hpp" +#include "nmtools/meta/bits/traits/is_tuple.hpp" +#include "nmtools/meta/bits/transform/promote_index.hpp" +#include "nmtools/meta/bits/transform/clipped_max.hpp" + +namespace nmtools::meta +{ + namespace error + { + template + struct MAX_VALUE_UNSUPPORTED : detail::fail_t {}; + } + + /** + * @brief Convert constant index to value + * + * If T is tuple, it is expected to transform to array + * for easy handling as value. + * + * @tparam T + */ + template + struct max_value + { + static inline constexpr auto value = error::MAX_VALUE_UNSUPPORTED{}; + }; // max_value + + template + struct max_value : max_value {}; + + template + struct max_value : max_value {}; + + template + struct max_value> + { + static constexpr auto value = v; + }; + + // nmtools' true_type & false_type is not an alias to integral_constant + template <> struct max_value + { static constexpr auto value = true; }; + + template <> struct max_value + { static constexpr auto value = false; }; + + // converting nmtools' clipped_integer using meta::max_value means + // converting the max value to value array + // note that currently the min value is ignored. + // this (using max value) is mostly useful for indexing function (where min=0, max=N) + // to carry the information about maximum number of elements per axis + // hence we can deduce the maximum number of elements to deduce the type of buffer at compile time + + template typename Tuple, typename...T, auto...Min, auto...Max> + struct max_value< + Tuple...>, + enable_if_t< is_tuple_v...>> && sizeof...(T)> + > + { + using index_t = promote_index_t; + static constexpr auto value = nmtools_array{index_t(Max)...}; + }; // max_value + + template typename tuple, typename...Ts> + struct max_value< + tuple + , enable_if_t< is_tuple_v> && (is_constant_index_v && ...) && sizeof...(Ts)> + > { + using index_t = promote_index_t; + static constexpr auto value = nmtools_array{index_t(Ts::value)...}; + }; // max_value + + template typename Array, typename T, auto Min, auto Max, auto N> + struct max_value< + Array,N> + > + { + static constexpr auto value = [](){ + using type = nmtools_array; + auto result = type{}; + for (size_t i=0; i + struct max_value> + : clipped_max> + {}; + + template + constexpr inline auto max_value_v = max_value::value; +} // namespace nmtools::meta + +#endif // NMTOOLS_META_BITS_TRANSFORM_MAX_VALUE_HPP \ No newline at end of file diff --git a/include/nmtools/meta/bits/transform/min_value.hpp b/include/nmtools/meta/bits/transform/min_value.hpp new file mode 100644 index 000000000..124fab4ed --- /dev/null +++ b/include/nmtools/meta/bits/transform/min_value.hpp @@ -0,0 +1,100 @@ +#ifndef NMTOOLS_META_BITS_TRANSFORM_MIN_VALUE_HPP +#define NMTOOLS_META_BITS_TRANSFORM_MIN_VALUE_HPP + +#include "nmtools/meta/common.hpp" +#include "nmtools/meta/bits/traits/is_tuple.hpp" +#include "nmtools/meta/bits/transform/promote_index.hpp" +#include "nmtools/meta/bits/transform/clipped_max.hpp" + +namespace nmtools::meta +{ + namespace error + { + template + struct MIN_VALUE_UNSUPPORTED : detail::fail_t {}; + } + + /** + * @brief Convert constant index to value + * + * If T is tuple, it is expected to transform to array + * for easy handling as value. + * + * @tparam T + */ + template + struct min_value + { + static inline constexpr auto value = error::MIN_VALUE_UNSUPPORTED{}; + }; // min_value + + template + struct min_value : min_value {}; + + template + struct min_value : min_value {}; + + template + struct min_value> + { + static constexpr auto value = v; + }; + + // nmtools' true_type & false_type is not an alias to integral_constant + template <> struct min_value + { static constexpr auto value = true; }; + + template <> struct min_value + { static constexpr auto value = false; }; + + // converting nmtools' clipped_integer using meta::min_value means + // converting the max value to value array + // note that currently the min value is ignored. + // this (using max value) is mostly useful for indexing function (where min=0, max=N) + // to carry the information about maximum number of elements per axis + // hence we can deduce the maximum number of elements to deduce the type of buffer at compile time + + template typename Tuple, typename...T, auto...Min, auto...Max> + struct min_value< + Tuple...>, + enable_if_t< is_tuple_v...>> && sizeof...(T)> + > + { + using index_t = promote_index_t; + static constexpr auto value = nmtools_array{index_t(Min)...}; + }; // min_value + + template typename tuple, typename...Ts> + struct min_value< + tuple + , enable_if_t< is_tuple_v> && (is_constant_index_v && ...) && sizeof...(Ts)> + > { + using index_t = promote_index_t; + static constexpr auto value = nmtools_array{index_t(Ts::value)...}; + }; // min_value + + template typename Array, typename T, auto Min, auto Max, auto N> + struct min_value< + Array,N> + > + { + static constexpr auto value = [](){ + using type = nmtools_array; + auto result = type{}; + for (size_t i=0; i + struct min_value> + : clipped_max> + {}; + + template + constexpr inline auto min_value_v = min_value::value; +} // namespace nmtools::meta + +#endif // NMTOOLS_META_BITS_TRANSFORM_MIN_VALUE_HPP \ No newline at end of file diff --git a/include/nmtools/meta/traits.hpp b/include/nmtools/meta/traits.hpp index 09998ab4b..62a7a461f 100644 --- a/include/nmtools/meta/traits.hpp +++ b/include/nmtools/meta/traits.hpp @@ -38,6 +38,7 @@ #include "nmtools/meta/bits/traits/is_copy_constructible.hpp" #include "nmtools/meta/bits/traits/is_default_constructible.hpp" #include "nmtools/meta/bits/traits/is_clipped_index_array.hpp" +#include "nmtools/meta/bits/traits/is_clipped_index.hpp" #include "nmtools/meta/bits/traits/is_clipped_integer.hpp" #include "nmtools/meta/bits/traits/is_const.hpp" #include "nmtools/meta/bits/traits/is_constant_index.hpp" diff --git a/include/nmtools/meta/transform.hpp b/include/nmtools/meta/transform.hpp index 7de661063..90e609d0d 100644 --- a/include/nmtools/meta/transform.hpp +++ b/include/nmtools/meta/transform.hpp @@ -23,6 +23,8 @@ #include "nmtools/meta/bits/transform/len.hpp" #include "nmtools/meta/bits/transform/make_signed.hpp" #include "nmtools/meta/bits/transform/make_unsigned.hpp" +#include "nmtools/meta/bits/transform/max_value.hpp" +#include "nmtools/meta/bits/transform/min_value.hpp" #include "nmtools/meta/bits/transform/promote_index.hpp" #include "nmtools/meta/bits/transform/promote_types.hpp" #include "nmtools/meta/bits/transform/remove_address_space.hpp" diff --git a/include/nmtools/testing/data/array/kron.hpp b/include/nmtools/testing/data/array/kron.hpp new file mode 100644 index 000000000..7d94b526d --- /dev/null +++ b/include/nmtools/testing/data/array/kron.hpp @@ -0,0 +1,2053 @@ +#ifndef NMTOOLS_TESTING_DATA_ARRAY_KRON_HPP +#define NMTOOLS_TESTING_DATA_ARRAY_KRON_HPP + +#include "nmtools/testing/testing.hpp" +#include "nmtools/testing/array_cast.hpp" + +NMTOOLS_TESTING_DECLARE_CASE(array,kron) +{ + NMTOOLS_TESTING_DECLARE_ARGS(case1a) + { + inline int a[6] = {0,1,2,3,4,5}; + inline int b[3] = {0,1,2}; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case1a) + { + inline int result[18] = {0,0,0,0,1,2,0,2,4,0,3,6,0,4,8,0,5,10}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case1b) + { + inline int a[6] = {0,1,2,3,4,5}; + inline int b[3][4] = { + {0,1, 2, 3}, + {4,5, 6, 7}, + {8,9,10,11}, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case1b) + { + inline int result[3][24] = {{ 0, 0, 0, 0, 0, 1, 2, 3, 0, 2, 4, 6, 0, 3, 6, 9, + 0, 4, 8, 12, 0, 5, 10, 15}, + { 0, 0, 0, 0, 4, 5, 6, 7, 8, 10, 12, 14, 12, 15, 18, 21, + 16, 20, 24, 28, 20, 25, 30, 35}, + { 0, 0, 0, 0, 8, 9, 10, 11, 16, 18, 20, 22, 24, 27, 30, 33, + 32, 36, 40, 44, 40, 45, 50, 55}}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case1c) + { + inline int a[6] = {0,1,2,3,4,5}; + inline int b[2][3][2] = { + { + {0,1}, + {2,3}, + {4,5}, + }, + { + { 6, 7}, + { 8, 9}, + {10,11}, + }, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case1c) + { + inline int result[2][3][12] = + {{{ 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5}, + { 0, 0, 2, 3, 4, 6, 6, 9, 8, 12, 10, 15}, + { 0, 0, 4, 5, 8, 10, 12, 15, 16, 20, 20, 25}}, + + {{ 0, 0, 6, 7, 12, 14, 18, 21, 24, 28, 30, 35}, + { 0, 0, 8, 9, 16, 18, 24, 27, 32, 36, 40, 45}, + { 0, 0, 10, 11, 20, 22, 30, 33, 40, 44, 50, 55}}}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case1d) + { + inline int a[6] = {0,1,2,3,4,5}; + inline int b[2][1][3][2] = { + { + { + {0,1}, + {2,3}, + {4,5}, + } + }, + { + { + { 6, 7}, + { 8, 9}, + {10,11}, + } + }, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case1d) + { + inline int result[2][1][3][12] = + {{{{ 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5}, + { 0, 0, 2, 3, 4, 6, 6, 9, 8, 12, 10, 15}, + { 0, 0, 4, 5, 8, 10, 12, 15, 16, 20, 20, 25}}}, + + + {{{ 0, 0, 6, 7, 12, 14, 18, 21, 24, 28, 30, 35}, + { 0, 0, 8, 9, 16, 18, 24, 27, 32, 36, 40, 45}, + { 0, 0, 10, 11, 20, 22, 30, 33, 40, 44, 50, 55}}}}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case1e) + { + inline int a[6] = {0,1,2,3,4,5}; + inline int b[2][1][3][1][2] = { + { + { + { + {0,1}, + }, + { + {2,3}, + }, + { + {4,5}, + }, + } + }, + { + { + { + {6,7}, + }, + { + {8,9}, + }, + { + {10,11}, + }, + } + }, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case1e) + { + inline int result[2][1][3][1][12] = + {{{{{ 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5}}, + + {{ 0, 0, 2, 3, 4, 6, 6, 9, 8, 12, 10, 15}}, + + {{ 0, 0, 4, 5, 8, 10, 12, 15, 16, 20, 20, 25}}}}, + + + + {{{{ 0, 0, 6, 7, 12, 14, 18, 21, 24, 28, 30, 35}}, + + {{ 0, 0, 8, 9, 16, 18, 24, 27, 32, 36, 40, 45}}, + + {{ 0, 0, 10, 11, 20, 22, 30, 33, 40, 44, 50, 55}}}}} + ; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2a) + { + inline int a[3][2] = { + {0,1}, + {2,3}, + {4,5}, + }; + inline int b[2][3] = { + {0,1,2}, + {3,4,5}, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2a) + { + inline int result[6][6] = + {{ 0, 0, 0, 0, 1, 2}, + { 0, 0, 0, 3, 4, 5}, + { 0, 2, 4, 0, 3, 6}, + { 6, 8, 10, 9, 12, 15}, + { 0, 4, 8, 0, 5, 10}, + {12, 16, 20, 15, 20, 25}}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2b) + { + inline int a[2][3] = { + {0,1,2}, + {3,4,5}, + }; + inline int b[3][2] = { + {0,1}, + {2,3}, + {4,5}, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2b) + { + inline int result[6][6] = + {{ 0, 0, 0, 1, 0, 2}, + { 0, 0, 2, 3, 4, 6}, + { 0, 0, 4, 5, 8, 10}, + { 0, 3, 0, 4, 0, 5}, + { 6, 9, 8, 12, 10, 15}, + {12, 15, 16, 20, 20, 25}}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2c) + { + inline int a[3][2] = { + {0,1}, + {2,3}, + {4,5}, + }; + inline int b[6] = {0,1,2,3,4,5}; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2c) + { + inline int result[3][12] = + {{ 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5}, + { 0, 2, 4, 6, 8, 10, 0, 3, 6, 9, 12, 15}, + { 0, 4, 8, 12, 16, 20, 0, 5, 10, 15, 20, 25}}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2d) + { + inline int a[3][4] = { + {0,1, 2, 3}, + {4,5, 6, 7}, + {8,9,10,11}, + }; + inline int b[2][3][2] = { + { + {0,1}, + {2,3}, + {4,5}, + }, + { + { 6, 7}, + { 8, 9}, + {10,11}, + }, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2d) + { + inline int result[2][9][8] = + {{{ 0, 0, 0, 1, 0, 2, 0, 3}, + { 0, 0, 2, 3, 4, 6, 6, 9}, + { 0, 0, 4, 5, 8, 10, 12, 15}, + { 0, 4, 0, 5, 0, 6, 0, 7}, + { 8, 12, 10, 15, 12, 18, 14, 21}, + { 16, 20, 20, 25, 24, 30, 28, 35}, + { 0, 8, 0, 9, 0, 10, 0, 11}, + { 16, 24, 18, 27, 20, 30, 22, 33}, + { 32, 40, 36, 45, 40, 50, 44, 55}}, + + {{ 0, 0, 6, 7, 12, 14, 18, 21}, + { 0, 0, 8, 9, 16, 18, 24, 27}, + { 0, 0, 10, 11, 20, 22, 30, 33}, + { 24, 28, 30, 35, 36, 42, 42, 49}, + { 32, 36, 40, 45, 48, 54, 56, 63}, + { 40, 44, 50, 55, 60, 66, 70, 77}, + { 48, 56, 54, 63, 60, 70, 66, 77}, + { 64, 72, 72, 81, 80, 90, 88, 99}, + { 80, 88, 90, 99, 100, 110, 110, 121}}}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2e) + { + inline int a[3][4] = { + {0,1, 2, 3}, + {4,5, 6, 7}, + {8,9,10,11}, + }; + inline int b[2][1][3][2] = { + { + { + {0,1}, + {2,3}, + {4,5}, + } + }, + { + { + { 6, 7}, + { 8, 9}, + {10,11}, + } + }, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2e) + { + inline int result[2][1][9][8] = + {{{{ 0, 0, 0, 1, 0, 2, 0, 3}, + { 0, 0, 2, 3, 4, 6, 6, 9}, + { 0, 0, 4, 5, 8, 10, 12, 15}, + { 0, 4, 0, 5, 0, 6, 0, 7}, + { 8, 12, 10, 15, 12, 18, 14, 21}, + { 16, 20, 20, 25, 24, 30, 28, 35}, + { 0, 8, 0, 9, 0, 10, 0, 11}, + { 16, 24, 18, 27, 20, 30, 22, 33}, + { 32, 40, 36, 45, 40, 50, 44, 55}}}, + + + {{{ 0, 0, 6, 7, 12, 14, 18, 21}, + { 0, 0, 8, 9, 16, 18, 24, 27}, + { 0, 0, 10, 11, 20, 22, 30, 33}, + { 24, 28, 30, 35, 36, 42, 42, 49}, + { 32, 36, 40, 45, 48, 54, 56, 63}, + { 40, 44, 50, 55, 60, 66, 70, 77}, + { 48, 56, 54, 63, 60, 70, 66, 77}, + { 64, 72, 72, 81, 80, 90, 88, 99}, + { 80, 88, 90, 99, 100, 110, 110, 121}}}}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2f) + { + inline int a[3][4] = { + {0,1, 2, 3}, + {4,5, 6, 7}, + {8,9,10,11}, + }; + inline int b[2][1][3][1][2] = { + { + { + { + {0,1}, + }, + { + {2,3}, + }, + { + {4,5}, + }, + } + }, + { + { + { + {6,7}, + }, + { + {8,9}, + }, + { + {10,11}, + }, + } + }, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2f) + { + inline int result[2][1][3][3][8] = + {{{{{ 0, 0, 0, 1, 0, 2, 0, 3}, + { 0, 4, 0, 5, 0, 6, 0, 7}, + { 0, 8, 0, 9, 0, 10, 0, 11}}, + + {{ 0, 0, 2, 3, 4, 6, 6, 9}, + { 8, 12, 10, 15, 12, 18, 14, 21}, + { 16, 24, 18, 27, 20, 30, 22, 33}}, + + {{ 0, 0, 4, 5, 8, 10, 12, 15}, + { 16, 20, 20, 25, 24, 30, 28, 35}, + { 32, 40, 36, 45, 40, 50, 44, 55}}}}, + + + + {{{{ 0, 0, 6, 7, 12, 14, 18, 21}, + { 24, 28, 30, 35, 36, 42, 42, 49}, + { 48, 56, 54, 63, 60, 70, 66, 77}}, + + {{ 0, 0, 8, 9, 16, 18, 24, 27}, + { 32, 36, 40, 45, 48, 54, 56, 63}, + { 64, 72, 72, 81, 80, 90, 88, 99}}, + + {{ 0, 0, 10, 11, 20, 22, 30, 33}, + { 40, 44, 50, 55, 60, 66, 70, 77}, + { 80, 88, 90, 99, 100, 110, 110, 121}}}}}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2g) + { + inline int a[3][4] = { + {0,1, 2, 3}, + {4,5, 6, 7}, + {8,9,10,11}, + }; + inline int b[2][1][3][1][2][1] = { + { + { + { + { + {0}, + {1}, + } + }, + { + { + {2}, + {3}, + } + }, + { + { + {4}, + {5}, + } + }, + } + }, + { + { + { + { + {6}, + {7}, + } + }, + { + { + {8}, + {9}, + } + }, + { + { + {10}, + {11}, + } + }, + } + }, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2g) + { + inline int result[2][1][3][1][6][4] = + {{{{{{ 0, 0, 0, 0}, + { 0, 1, 2, 3}, + { 0, 0, 0, 0}, + { 4, 5, 6, 7}, + { 0, 0, 0, 0}, + { 8, 9, 10, 11}}}, + + + {{{ 0, 2, 4, 6}, + { 0, 3, 6, 9}, + { 8, 10, 12, 14}, + { 12, 15, 18, 21}, + { 16, 18, 20, 22}, + { 24, 27, 30, 33}}}, + + + {{{ 0, 4, 8, 12}, + { 0, 5, 10, 15}, + { 16, 20, 24, 28}, + { 20, 25, 30, 35}, + { 32, 36, 40, 44}, + { 40, 45, 50, 55}}}}}, + + + + + {{{{{ 0, 6, 12, 18}, + { 0, 7, 14, 21}, + { 24, 30, 36, 42}, + { 28, 35, 42, 49}, + { 48, 54, 60, 66}, + { 56, 63, 70, 77}}}, + + + {{{ 0, 8, 16, 24}, + { 0, 9, 18, 27}, + { 32, 40, 48, 56}, + { 36, 45, 54, 63}, + { 64, 72, 80, 88}, + { 72, 81, 90, 99}}}, + + + {{{ 0, 10, 20, 30}, + { 0, 11, 22, 33}, + { 40, 50, 60, 70}, + { 44, 55, 66, 77}, + { 80, 90, 100, 110}, + { 88, 99, 110, 121}}}}}}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2h) + { + inline int a[3][4] = { + {0,1, 2, 3}, + {4,5, 6, 7}, + {8,9,10,11}, + }; + inline int b[2][1][1][3][1][2][1] = { + { + { + { + { + { + {0}, + {1}, + } + }, + { + { + {2}, + {3}, + } + }, + { + { + {4}, + {5}, + } + }, + } + } + }, + { + { + { + { + { + {6}, + {7}, + } + }, + { + { + {8}, + {9}, + } + }, + { + { + {10}, + {11}, + } + }, + } + } + }, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2h) + { + inline int result[2][1][1][3][1][6][4] = + {{{{{{{ 0, 0, 0, 0}, + { 0, 1, 2, 3}, + { 0, 0, 0, 0}, + { 4, 5, 6, 7}, + { 0, 0, 0, 0}, + { 8, 9, 10, 11}}}, + + + {{{ 0, 2, 4, 6}, + { 0, 3, 6, 9}, + { 8, 10, 12, 14}, + { 12, 15, 18, 21}, + { 16, 18, 20, 22}, + { 24, 27, 30, 33}}}, + + + {{{ 0, 4, 8, 12}, + { 0, 5, 10, 15}, + { 16, 20, 24, 28}, + { 20, 25, 30, 35}, + { 32, 36, 40, 44}, + { 40, 45, 50, 55}}}}}}, + + + + + + {{{{{{ 0, 6, 12, 18}, + { 0, 7, 14, 21}, + { 24, 30, 36, 42}, + { 28, 35, 42, 49}, + { 48, 54, 60, 66}, + { 56, 63, 70, 77}}}, + + + {{{ 0, 8, 16, 24}, + { 0, 9, 18, 27}, + { 32, 40, 48, 56}, + { 36, 45, 54, 63}, + { 64, 72, 80, 88}, + { 72, 81, 90, 99}}}, + + + {{{ 0, 10, 20, 30}, + { 0, 11, 22, 33}, + { 40, 50, 60, 70}, + { 44, 55, 66, 77}, + { 80, 90, 100, 110}, + { 88, 99, 110, 121}}}}}}}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3a) + { + inline int a[2][3][2] = { + { + {0,1}, + {2,3}, + {4,5}, + }, + { + { 6, 7}, + { 8, 9}, + {10,11}, + }, + }; + inline int b[2][3][2] = { + { + {0,1}, + {2,3}, + {4,5}, + }, + { + { 6, 7}, + { 8, 9}, + {10,11}, + }, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3a) + { + inline int result[4][9][4] = + {{{ 0, 0, 0, 1}, + { 0, 0, 2, 3}, + { 0, 0, 4, 5}, + { 0, 2, 0, 3}, + { 4, 6, 6, 9}, + { 8, 10, 12, 15}, + { 0, 4, 0, 5}, + { 8, 12, 10, 15}, + { 16, 20, 20, 25}}, + + {{ 0, 0, 6, 7}, + { 0, 0, 8, 9}, + { 0, 0, 10, 11}, + { 12, 14, 18, 21}, + { 16, 18, 24, 27}, + { 20, 22, 30, 33}, + { 24, 28, 30, 35}, + { 32, 36, 40, 45}, + { 40, 44, 50, 55}}, + + {{ 0, 6, 0, 7}, + { 12, 18, 14, 21}, + { 24, 30, 28, 35}, + { 0, 8, 0, 9}, + { 16, 24, 18, 27}, + { 32, 40, 36, 45}, + { 0, 10, 0, 11}, + { 20, 30, 22, 33}, + { 40, 50, 44, 55}}, + + {{ 36, 42, 42, 49}, + { 48, 54, 56, 63}, + { 60, 66, 70, 77}, + { 48, 56, 54, 63}, + { 64, 72, 72, 81}, + { 80, 88, 90, 99}, + { 60, 70, 66, 77}, + { 80, 90, 88, 99}, + {100, 110, 110, 121}}}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3b) + { + inline int a[2][3][2] = { + { + {0,1}, + {2,3}, + {4,5}, + }, + { + { 6, 7}, + { 8, 9}, + {10,11}, + }, + }; + inline int b[2][1][3] = { + { + {0,1,2}, + }, + { + + {3,4,5}, + }, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3b) + { + inline int result[4][3][6] = + {{{ 0, 0, 0, 0, 1, 2}, + { 0, 2, 4, 0, 3, 6}, + { 0, 4, 8, 0, 5, 10}}, + + {{ 0, 0, 0, 3, 4, 5}, + { 6, 8, 10, 9, 12, 15}, + {12, 16, 20, 15, 20, 25}}, + + {{ 0, 6, 12, 0, 7, 14}, + { 0, 8, 16, 0, 9, 18}, + { 0, 10, 20, 0, 11, 22}}, + + {{18, 24, 30, 21, 28, 35}, + {24, 32, 40, 27, 36, 45}, + {30, 40, 50, 33, 44, 55}}}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3c) + { + inline int a[2][3][2] = { + { + {0,1}, + {2,3}, + {4,5}, + }, + { + { 6, 7}, + { 8, 9}, + {10,11}, + }, + }; + inline int b[3][4] = { + {0,1, 2, 3}, + {4,5, 6, 7}, + {8,9,10,11}, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3c) + { + inline int result[2][9][8] = + {{{ 0, 0, 0, 0, 0, 1, 2, 3}, + { 0, 0, 0, 0, 4, 5, 6, 7}, + { 0, 0, 0, 0, 8, 9, 10, 11}, + { 0, 2, 4, 6, 0, 3, 6, 9}, + { 8, 10, 12, 14, 12, 15, 18, 21}, + { 16, 18, 20, 22, 24, 27, 30, 33}, + { 0, 4, 8, 12, 0, 5, 10, 15}, + { 16, 20, 24, 28, 20, 25, 30, 35}, + { 32, 36, 40, 44, 40, 45, 50, 55}}, + + {{ 0, 6, 12, 18, 0, 7, 14, 21}, + { 24, 30, 36, 42, 28, 35, 42, 49}, + { 48, 54, 60, 66, 56, 63, 70, 77}, + { 0, 8, 16, 24, 0, 9, 18, 27}, + { 32, 40, 48, 56, 36, 45, 54, 63}, + { 64, 72, 80, 88, 72, 81, 90, 99}, + { 0, 10, 20, 30, 0, 11, 22, 33}, + { 40, 50, 60, 70, 44, 55, 66, 77}, + { 80, 90, 100, 110, 88, 99, 110, 121}}}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3d) + { + inline int a[2][3][2] = { + { + {0,1}, + {2,3}, + {4,5}, + }, + { + { 6, 7}, + { 8, 9}, + {10,11}, + } + }; + inline int b[6] = {0,1,2,3,4,5}; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3d) + { + inline int result[2][3][12] = + {{{ 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5}, + { 0, 2, 4, 6, 8, 10, 0, 3, 6, 9, 12, 15}, + { 0, 4, 8, 12, 16, 20, 0, 5, 10, 15, 20, 25}}, + + {{ 0, 6, 12, 18, 24, 30, 0, 7, 14, 21, 28, 35}, + { 0, 8, 16, 24, 32, 40, 0, 9, 18, 27, 36, 45}, + { 0, 10, 20, 30, 40, 50, 0, 11, 22, 33, 44, 55}}}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3e) + { + inline int a[2][3][2] = { + { + {0,1}, + {2,3}, + {4,5}, + }, + { + { 6, 7}, + { 8, 9}, + {10,11}, + }, + }; + inline int b[2][1][3][1] = { + { + { + {0}, + {1}, + {2} + } + }, + { + { + {3}, + {4}, + {5}, + } + }, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3e) + { + inline int result[2][2][9][2] = + {{{{ 0, 0}, + { 0, 1}, + { 0, 2}, + { 0, 0}, + { 2, 3}, + { 4, 6}, + { 0, 0}, + { 4, 5}, + { 8, 10}}, + + {{ 0, 0}, + { 6, 7}, + {12, 14}, + { 0, 0}, + { 8, 9}, + {16, 18}, + { 0, 0}, + {10, 11}, + {20, 22}}}, + + + {{{ 0, 3}, + { 0, 4}, + { 0, 5}, + { 6, 9}, + { 8, 12}, + {10, 15}, + {12, 15}, + {16, 20}, + {20, 25}}, + + {{18, 21}, + {24, 28}, + {30, 35}, + {24, 27}, + {32, 36}, + {40, 45}, + {30, 33}, + {40, 44}, + {50, 55}}}}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3f) + { + inline int a[2][3][2] = { + { + {0,1}, + {2,3}, + {4,5}, + }, + { + { 6, 7}, + { 8, 9}, + {10,11}, + } + }; + inline int b[2][1][2][3][1] = { + { + { + { + {0}, + {1}, + {2}, + }, + { + {3}, + {4}, + {5}, + }, + } + }, + { + { + { + {6}, + {7}, + {8}, + }, + { + { 9}, + {10}, + {11}, + }, + } + }, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3f) + { + inline int result[2][1][4][9][2] = + {{{{{ 0, 0}, + { 0, 1}, + { 0, 2}, + { 0, 0}, + { 2, 3}, + { 4, 6}, + { 0, 0}, + { 4, 5}, + { 8, 10}}, + + {{ 0, 3}, + { 0, 4}, + { 0, 5}, + { 6, 9}, + { 8, 12}, + { 10, 15}, + { 12, 15}, + { 16, 20}, + { 20, 25}}, + + {{ 0, 0}, + { 6, 7}, + { 12, 14}, + { 0, 0}, + { 8, 9}, + { 16, 18}, + { 0, 0}, + { 10, 11}, + { 20, 22}}, + + {{ 18, 21}, + { 24, 28}, + { 30, 35}, + { 24, 27}, + { 32, 36}, + { 40, 45}, + { 30, 33}, + { 40, 44}, + { 50, 55}}}}, + + + + {{{{ 0, 6}, + { 0, 7}, + { 0, 8}, + { 12, 18}, + { 14, 21}, + { 16, 24}, + { 24, 30}, + { 28, 35}, + { 32, 40}}, + + {{ 0, 9}, + { 0, 10}, + { 0, 11}, + { 18, 27}, + { 20, 30}, + { 22, 33}, + { 36, 45}, + { 40, 50}, + { 44, 55}}, + + {{ 36, 42}, + { 42, 49}, + { 48, 56}, + { 48, 54}, + { 56, 63}, + { 64, 72}, + { 60, 66}, + { 70, 77}, + { 80, 88}}, + + {{ 54, 63}, + { 60, 70}, + { 66, 77}, + { 72, 81}, + { 80, 90}, + { 88, 99}, + { 90, 99}, + {100, 110}, + {110, 121}}}}}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3g) + { + inline int a[2][3][2] = { + { + {0,1}, + {2,3}, + {4,5}, + }, + { + { 6, 7}, + { 8, 9}, + {10,11}, + }, + }; + inline int b[2][1][2][1][3][1] = { + { + { + { + { + {0}, + {1}, + {2}, + } + }, + { + { + {3}, + {4}, + {5}, + } + }, + } + }, + { + { + { + { + {6}, + {7}, + {8}, + } + }, + { + { + { 9}, + {10}, + {11}, + } + }, + } + }, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3g) + { + inline int result[2][1][2][2][9][2] = + {{{{{{ 0, 0}, + { 0, 1}, + { 0, 2}, + { 0, 0}, + { 2, 3}, + { 4, 6}, + { 0, 0}, + { 4, 5}, + { 8, 10}}, + + {{ 0, 0}, + { 6, 7}, + { 12, 14}, + { 0, 0}, + { 8, 9}, + { 16, 18}, + { 0, 0}, + { 10, 11}, + { 20, 22}}}, + + + {{{ 0, 3}, + { 0, 4}, + { 0, 5}, + { 6, 9}, + { 8, 12}, + { 10, 15}, + { 12, 15}, + { 16, 20}, + { 20, 25}}, + + {{ 18, 21}, + { 24, 28}, + { 30, 35}, + { 24, 27}, + { 32, 36}, + { 40, 45}, + { 30, 33}, + { 40, 44}, + { 50, 55}}}}}, + + + + + {{{{{ 0, 6}, + { 0, 7}, + { 0, 8}, + { 12, 18}, + { 14, 21}, + { 16, 24}, + { 24, 30}, + { 28, 35}, + { 32, 40}}, + + {{ 36, 42}, + { 42, 49}, + { 48, 56}, + { 48, 54}, + { 56, 63}, + { 64, 72}, + { 60, 66}, + { 70, 77}, + { 80, 88}}}, + + + {{{ 0, 9}, + { 0, 10}, + { 0, 11}, + { 18, 27}, + { 20, 30}, + { 22, 33}, + { 36, 45}, + { 40, 50}, + { 44, 55}}, + + {{ 54, 63}, + { 60, 70}, + { 66, 77}, + { 72, 81}, + { 80, 90}, + { 88, 99}, + { 90, 99}, + {100, 110}, + {110, 121}}}}}}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3h) + { + inline int a[2][3][2] = { + { + {0,1}, + {2,3}, + {4,5}, + }, + { + { 6, 7}, + { 8, 9}, + {10,11}, + } + }; + inline int b[2][1][1][3][1][2][1] = { + { + { + { + { + { + {0}, + {1}, + } + }, + { + { + {2}, + {3}, + } + }, + { + { + {4}, + {5}, + } + }, + } + } + }, + { + { + { + { + { + {6}, + {7}, + } + }, + { + { + {8}, + {9}, + } + }, + { + { + {10}, + {11}, + } + }, + } + } + }, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3h) + { + inline int result[2][1][1][3][2][6][2] = + {{{{{{{ 0, 0}, + { 0, 1}, + { 0, 0}, + { 2, 3}, + { 0, 0}, + { 4, 5}}, + + {{ 0, 0}, + { 6, 7}, + { 0, 0}, + { 8, 9}, + { 0, 0}, + { 10, 11}}}, + + + {{{ 0, 2}, + { 0, 3}, + { 4, 6}, + { 6, 9}, + { 8, 10}, + { 12, 15}}, + + {{ 12, 14}, + { 18, 21}, + { 16, 18}, + { 24, 27}, + { 20, 22}, + { 30, 33}}}, + + + {{{ 0, 4}, + { 0, 5}, + { 8, 12}, + { 10, 15}, + { 16, 20}, + { 20, 25}}, + + {{ 24, 28}, + { 30, 35}, + { 32, 36}, + { 40, 45}, + { 40, 44}, + { 50, 55}}}}}}, + + + + + + {{{{{{ 0, 6}, + { 0, 7}, + { 12, 18}, + { 14, 21}, + { 24, 30}, + { 28, 35}}, + + {{ 36, 42}, + { 42, 49}, + { 48, 54}, + { 56, 63}, + { 60, 66}, + { 70, 77}}}, + + + {{{ 0, 8}, + { 0, 9}, + { 16, 24}, + { 18, 27}, + { 32, 40}, + { 36, 45}}, + + {{ 48, 56}, + { 54, 63}, + { 64, 72}, + { 72, 81}, + { 80, 88}, + { 90, 99}}}, + + + {{{ 0, 10}, + { 0, 11}, + { 20, 30}, + { 22, 33}, + { 40, 50}, + { 44, 55}}, + + {{ 60, 70}, + { 66, 77}, + { 80, 90}, + { 88, 99}, + {100, 110}, + {110, 121}}}}}}}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case4a) + { + inline int a[2][1][3][2] = { + { + { + {0,1}, + {2,3}, + {4,5}, + } + }, + { + { + { 6, 7}, + { 8, 9}, + {10,11}, + } + }, + }; + inline int b[2][2][3][1] = { + { + { + {0}, + {1}, + {2}, + }, + { + {3}, + {4}, + {5}, + }, + }, + { + { + {6}, + {7}, + {8}, + }, + { + { 9}, + {10}, + {11}, + }, + }, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case4a) + { + inline int result[4][2][9][2] = + {{{{ 0, 0}, + { 0, 1}, + { 0, 2}, + { 0, 0}, + { 2, 3}, + { 4, 6}, + { 0, 0}, + { 4, 5}, + { 8, 10}}, + + {{ 0, 3}, + { 0, 4}, + { 0, 5}, + { 6, 9}, + { 8, 12}, + { 10, 15}, + { 12, 15}, + { 16, 20}, + { 20, 25}}}, + + + {{{ 0, 6}, + { 0, 7}, + { 0, 8}, + { 12, 18}, + { 14, 21}, + { 16, 24}, + { 24, 30}, + { 28, 35}, + { 32, 40}}, + + {{ 0, 9}, + { 0, 10}, + { 0, 11}, + { 18, 27}, + { 20, 30}, + { 22, 33}, + { 36, 45}, + { 40, 50}, + { 44, 55}}}, + + + {{{ 0, 0}, + { 6, 7}, + { 12, 14}, + { 0, 0}, + { 8, 9}, + { 16, 18}, + { 0, 0}, + { 10, 11}, + { 20, 22}}, + + {{ 18, 21}, + { 24, 28}, + { 30, 35}, + { 24, 27}, + { 32, 36}, + { 40, 45}, + { 30, 33}, + { 40, 44}, + { 50, 55}}}, + + + {{{ 36, 42}, + { 42, 49}, + { 48, 56}, + { 48, 54}, + { 56, 63}, + { 64, 72}, + { 60, 66}, + { 70, 77}, + { 80, 88}}, + + {{ 54, 63}, + { 60, 70}, + { 66, 77}, + { 72, 81}, + { 80, 90}, + { 88, 99}, + { 90, 99}, + {100, 110}, + {110, 121}}}}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case4b) + { + inline int a[2][1][3][2] = { + { + { + {0,1}, + {2,3}, + {4,5}, + } + }, + { + { + { 6, 7}, + { 8, 9}, + {10,11}, + } + }, + }; + inline int b[2][3][2] = { + { + {0,1}, + {2,3}, + {4,5}, + }, + { + { 6, 7}, + { 8, 9}, + {10,11}, + } + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case4b) + { + inline int result[2][2][9][4] = + {{{{ 0, 0, 0, 1}, + { 0, 0, 2, 3}, + { 0, 0, 4, 5}, + { 0, 2, 0, 3}, + { 4, 6, 6, 9}, + { 8, 10, 12, 15}, + { 0, 4, 0, 5}, + { 8, 12, 10, 15}, + { 16, 20, 20, 25}}, + + {{ 0, 0, 6, 7}, + { 0, 0, 8, 9}, + { 0, 0, 10, 11}, + { 12, 14, 18, 21}, + { 16, 18, 24, 27}, + { 20, 22, 30, 33}, + { 24, 28, 30, 35}, + { 32, 36, 40, 45}, + { 40, 44, 50, 55}}}, + + + {{{ 0, 6, 0, 7}, + { 12, 18, 14, 21}, + { 24, 30, 28, 35}, + { 0, 8, 0, 9}, + { 16, 24, 18, 27}, + { 32, 40, 36, 45}, + { 0, 10, 0, 11}, + { 20, 30, 22, 33}, + { 40, 50, 44, 55}}, + + {{ 36, 42, 42, 49}, + { 48, 54, 56, 63}, + { 60, 66, 70, 77}, + { 48, 56, 54, 63}, + { 64, 72, 72, 81}, + { 80, 88, 90, 99}, + { 60, 70, 66, 77}, + { 80, 90, 88, 99}, + {100, 110, 110, 121}}}}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case4c) + { + inline int a[2][1][3][2] = { + { + { + {0,1}, + {2,3}, + {4,5}, + } + }, + { + { + { 6, 7}, + { 8, 9}, + {10,11}, + } + }, + }; + inline int b[3][2] = { + {0,1}, + {2,3}, + {4,5}, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case4c) + { + inline int result[2][1][9][4] = + {{{{ 0, 0, 0, 1}, + { 0, 0, 2, 3}, + { 0, 0, 4, 5}, + { 0, 2, 0, 3}, + { 4, 6, 6, 9}, + { 8, 10, 12, 15}, + { 0, 4, 0, 5}, + { 8, 12, 10, 15}, + {16, 20, 20, 25}}}, + + + {{{ 0, 6, 0, 7}, + {12, 18, 14, 21}, + {24, 30, 28, 35}, + { 0, 8, 0, 9}, + {16, 24, 18, 27}, + {32, 40, 36, 45}, + { 0, 10, 0, 11}, + {20, 30, 22, 33}, + {40, 50, 44, 55}}}}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case4d) + { + inline int a[2][1][3][2] = { + { + { + {0,1}, + {2,3}, + {4,5}, + } + }, + { + { + { 6, 7}, + { 8, 9}, + {10,11}, + } + }, + }; + inline int b[6] = {0,1,2,3,4,5}; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case4d) + { + inline int result[2][1][3][12] = + {{{{ 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5}, + { 0, 2, 4, 6, 8, 10, 0, 3, 6, 9, 12, 15}, + { 0, 4, 8, 12, 16, 20, 0, 5, 10, 15, 20, 25}}}, + + + {{{ 0, 6, 12, 18, 24, 30, 0, 7, 14, 21, 28, 35}, + { 0, 8, 16, 24, 32, 40, 0, 9, 18, 27, 36, 45}, + { 0, 10, 20, 30, 40, 50, 0, 11, 22, 33, 44, 55}}}}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case5a) + { + inline int a[2][1][3][1][2] = { + { + { + { + {0,1} + }, + { + {2,3}, + }, + { + {4,5}, + }, + } + }, + { + { + { + {6,7}, + }, + { + {8,9}, + }, + { + {10,11}, + }, + } + }, + }; + inline int b[2][2][1][3][1] = { + { + { + { + {0}, + {1}, + {2}, + } + }, + { + { + {3}, + {4}, + {5}, + } + }, + }, + { + { + { + {6}, + {7}, + {8}, + } + }, + { + { + { 9}, + {10}, + {11}, + } + }, + }, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case5a) + { + inline int result[4][2][3][3][2] = + {{{{{ 0, 0}, + { 0, 1}, + { 0, 2}}, + + {{ 0, 0}, + { 2, 3}, + { 4, 6}}, + + {{ 0, 0}, + { 4, 5}, + { 8, 10}}}, + + + {{{ 0, 3}, + { 0, 4}, + { 0, 5}}, + + {{ 6, 9}, + { 8, 12}, + { 10, 15}}, + + {{ 12, 15}, + { 16, 20}, + { 20, 25}}}}, + + + + {{{{ 0, 6}, + { 0, 7}, + { 0, 8}}, + + {{ 12, 18}, + { 14, 21}, + { 16, 24}}, + + {{ 24, 30}, + { 28, 35}, + { 32, 40}}}, + + + {{{ 0, 9}, + { 0, 10}, + { 0, 11}}, + + {{ 18, 27}, + { 20, 30}, + { 22, 33}}, + + {{ 36, 45}, + { 40, 50}, + { 44, 55}}}}, + + + + {{{{ 0, 0}, + { 6, 7}, + { 12, 14}}, + + {{ 0, 0}, + { 8, 9}, + { 16, 18}}, + + {{ 0, 0}, + { 10, 11}, + { 20, 22}}}, + + + {{{ 18, 21}, + { 24, 28}, + { 30, 35}}, + + {{ 24, 27}, + { 32, 36}, + { 40, 45}}, + + {{ 30, 33}, + { 40, 44}, + { 50, 55}}}}, + + + + {{{{ 36, 42}, + { 42, 49}, + { 48, 56}}, + + {{ 48, 54}, + { 56, 63}, + { 64, 72}}, + + {{ 60, 66}, + { 70, 77}, + { 80, 88}}}, + + + {{{ 54, 63}, + { 60, 70}, + { 66, 77}}, + + {{ 72, 81}, + { 80, 90}, + { 88, 99}}, + + {{ 90, 99}, + {100, 110}, + {110, 121}}}}}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case6a) + { + inline int a[2][1][3][1][2][1] = { + { + { + { + { + {0}, + {1}, + } + }, + { + { + {2}, + {3}, + } + }, + { + { + {4}, + {5}, + } + }, + } + }, + { + { + { + { + {6}, + {7}, + } + }, + { + { + {8}, + {9}, + } + }, + { + { + {10}, + {11}, + } + }, + } + }, + }; + inline int b[1][3][1][2][1][2] = { + { + { + { + { + {0,1}, + }, + { + {2,3}, + }, + } + }, + { + { + { + {4,5}, + }, + { + {6,7}, + } + } + }, + { + { + { + {8,9}, + }, + { + {10,11}, + } + } + }, + } + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case6a) + { + inline int result[2][3][3][2][2][2] = + {{{{{{ 0, 0}, + { 0, 1}}, + + {{ 0, 0}, + { 2, 3}}}, + + + {{{ 0, 2}, + { 0, 3}}, + + {{ 4, 6}, + { 6, 9}}}, + + + {{{ 0, 4}, + { 0, 5}}, + + {{ 8, 12}, + { 10, 15}}}}, + + + + {{{{ 0, 0}, + { 4, 5}}, + + {{ 0, 0}, + { 6, 7}}}, + + + {{{ 8, 10}, + { 12, 15}}, + + {{ 12, 14}, + { 18, 21}}}, + + + {{{ 16, 20}, + { 20, 25}}, + + {{ 24, 28}, + { 30, 35}}}}, + + + + {{{{ 0, 0}, + { 8, 9}}, + + {{ 0, 0}, + { 10, 11}}}, + + + {{{ 16, 18}, + { 24, 27}}, + + {{ 20, 22}, + { 30, 33}}}, + + + {{{ 32, 36}, + { 40, 45}}, + + {{ 40, 44}, + { 50, 55}}}}}, + + + + + {{{{{ 0, 6}, + { 0, 7}}, + + {{ 12, 18}, + { 14, 21}}}, + + + {{{ 0, 8}, + { 0, 9}}, + + {{ 16, 24}, + { 18, 27}}}, + + + {{{ 0, 10}, + { 0, 11}}, + + {{ 20, 30}, + { 22, 33}}}}, + + + + {{{{ 24, 30}, + { 28, 35}}, + + {{ 36, 42}, + { 42, 49}}}, + + + {{{ 32, 40}, + { 36, 45}}, + + {{ 48, 56}, + { 54, 63}}}, + + + {{{ 40, 50}, + { 44, 55}}, + + {{ 60, 70}, + { 66, 77}}}}, + + + + {{{{ 48, 54}, + { 56, 63}}, + + {{ 60, 66}, + { 70, 77}}}, + + + {{{ 64, 72}, + { 72, 81}}, + + {{ 80, 88}, + { 90, 99}}}, + + + {{{ 80, 90}, + { 88, 99}}, + + {{100, 110}, + {110, 121}}}}}}; + } +} + +#endif // NMTOOLS_TESTING_DATA_ARRAY_KRON_HPP \ No newline at end of file diff --git a/include/nmtools/testing/data/array/tensordot.hpp b/include/nmtools/testing/data/array/tensordot.hpp new file mode 100644 index 000000000..dc05324ed --- /dev/null +++ b/include/nmtools/testing/data/array/tensordot.hpp @@ -0,0 +1,833 @@ +#ifndef NMTOOLS_TESTING_DATA_ARRAY_TENSORDOT_HPP +#define NMTOOLS_TESTING_DATA_ARRAY_TENSORDOT_HPP + +#include "nmtools/testing/testing.hpp" +#include "nmtools/testing/array_cast.hpp" + +NMTOOLS_TESTING_DECLARE_CASE(array,tensordot) +{ + using namespace literals; + + NMTOOLS_TESTING_DECLARE_ARGS(case1a) + { + inline int a[6] = {0,1,2,3,4,5}; + inline int b[6] = {0,1,2,3,4,5}; + + inline int axes = 1; + inline auto axes_ct = 1_ct; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case1a) + { + inline int result = 55; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2a) + { + inline int a[3][2] = { + {0,1}, + {2,3}, + {4,5}, + }; + inline int b[3][2] = { + {0,1}, + {2,3}, + {4,5}, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2a) + { + inline int result = 55; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2b) + { + inline int a[3][2] = { + {0,1}, + {2,3}, + {4,5}, + }; + inline int b[3][2][1] = { + { + {0}, + {1}, + }, + { + {2}, + {3}, + }, + { + {4}, + {5}, + }, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2b) + { + inline int result[1] = {55}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2c) + { + inline int a[3][2] = { + {0,1}, + {2,3}, + {4,5}, + }; + inline int b[3][2][2] = { + { + {0,1}, + {2,3}, + }, + { + {4,5}, + {6,7}, + }, + { + { 8, 9}, + {10,11}, + }, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2c) + { + inline int result[2] = {110,125}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2d) + { + inline int a[3][2] = { + {0,1}, + {2,3}, + {4,5}, + }; + inline int b[3][2][1][2] = { + { + { + {0,1}, + }, + { + {2,3}, + }, + }, + { + { + {4,5}, + }, + { + {6,7}, + } + }, + { + { + {8,9}, + }, + { + {10,11}, + } + }, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2d) + { + inline int result[1][2] = { + {110,125} + }; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2e) + { + inline int a[3][2] = { + {0,1}, + {2,3}, + {4,5}, + }; + inline int b[1][3][2] = { + { + {0,1}, + {2,3}, + {4,5}, + } + }; + + inline auto axes = nmtools_tuple{ + nmtools_array{0,1}, + nmtools_array{1,2}, + }; + inline auto axes_ct = nmtools_tuple{ + nmtools_tuple{0_ct,1_ct}, + nmtools_tuple{1_ct,2_ct}, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2e) + { + inline int result[1] = {55}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2f) + { + inline int a[3][2] = { + {0,1}, + {2,3}, + {4,5}, + }; + inline int b[2][3][2] = { + { + {0,1}, + {2,3}, + {4,5}, + }, + { + {6,7}, + {8,9}, + {10,11}, + }, + }; + + inline auto axes = nmtools_tuple{ + nmtools_array{0,1}, + nmtools_array{1,2}, + }; + inline auto axes_ct = nmtools_tuple{ + nmtools_tuple{0_ct,1_ct}, + nmtools_tuple{1_ct,2_ct}, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2f) + { + inline int result[2] = {55,145}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2g) + { + inline int a[3][2] = { + {0,1}, + {2,3}, + {4,5}, + }; + inline int b[2][3][2] = { + { + {0,1}, + {2,3}, + {4,5}, + }, + { + {6,7}, + {8,9}, + {10,11}, + }, + }; + + inline auto axes = nmtools_tuple{ + nmtools_array{0,1}, + nmtools_array{1,0}, + }; + inline auto axes_ct = nmtools_tuple{ + nmtools_tuple{0_ct,1_ct}, + nmtools_tuple{1_ct,0_ct}, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2g) + { + inline int result[2] = {100,115}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2h) + { + inline int a[3][2] = { + {0,1}, + {2,3}, + {4,5}, + }; + inline int b[2][3][2] = { + { + {0,1}, + {2,3}, + {4,5}, + }, + { + {6,7}, + {8,9}, + {10,11}, + }, + }; + + inline auto axes = nmtools_tuple{ + nmtools_array{1,0}, + nmtools_array{0,1}, + }; + inline auto axes_ct = nmtools_tuple{ + nmtools_tuple{1_ct,0_ct}, + nmtools_tuple{0_ct,1_ct}, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2h) + { + inline int result[2] = {100,115}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2i) + { + inline int a[3][2] = { + {0,1}, + {2,3}, + {4,5}, + }; + inline int b[2][3][2] = { + { + {0,1}, + {2,3}, + {4,5}, + }, + { + {6,7}, + {8,9}, + {10,11}, + }, + }; + + inline auto axes = nmtools_tuple{ + nmtools_array{1,0}, + nmtools_array{2,1}, + }; + inline auto axes_ct = nmtools_tuple{ + nmtools_tuple{1_ct,0_ct}, + nmtools_tuple{2_ct,1_ct}, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2i) + { + inline int result[2] = {55,145}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2j) + { + inline int a[3][2] = { + {0,1}, + {2,3}, + {4,5}, + }; + inline int b[2][3] = { + {0,1,2}, + {3,4,5}, + }; + inline auto axes = nmtools_tuple{ + nmtools_array{0}, + nmtools_array{1}, + }; + inline auto axes_ct = nmtools_tuple{ + nmtools_tuple{0_ct}, + nmtools_tuple{1_ct}, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2j) + { + inline int result[2][2] = { + {10,28}, + {13,40}, + }; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3a) + { + inline int a[2][3][2] = { + { + {0,1}, + {2,3}, + {4,5}, + }, + { + { 6, 7}, + { 8, 9}, + {10,11}, + }, + }; + inline int b[3][2] = { + {0,1}, + {2,3}, + {4,5}, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3a) + { + inline int result[2] = {55,145}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3b) + { + inline int a[2][3][2] = { + { + {0,1}, + {2,3}, + {4,5}, + }, + { + { 6, 7}, + { 8, 9}, + {10,11}, + }, + }; + inline int b[2][3] = { + {0,1,2}, + {3,4,5}, + }; + + inline auto axes = nmtools_tuple{ + nmtools_array{0,1}, + nmtools_array{0,1}, + }; + inline auto axes_ct = nmtools_tuple{ + nmtools_tuple{0_ct,1_ct}, + nmtools_tuple{0_ct,1_ct}, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3b) + { + inline int result[2] = {110,125}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3c) + { + inline int a[2][3][2] = { + { + {0,1}, + {2,3}, + {4,5}, + }, + { + { 6, 7}, + { 8, 9}, + {10,11}, + }, + }; + inline int b[1][3][2] = { + { + {0,1}, + {2,3}, + {4,5}, + } + }; + inline auto axes = nmtools_tuple{ + nmtools_array{0,1}, + nmtools_array{-1,-2}, + }; + inline auto axes_ct = nmtools_tuple{ + nmtools_tuple{0_ct,1_ct}, + nmtools_tuple{"-1"_ct,"-2"_ct}, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3c) + { + inline int result[2][1] = { + {100}, + {115}, + }; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3d) + { + inline int a[2][3][2] = { + { + {0,1}, + {2,3}, + {4,5}, + }, + { + { 6, 7}, + { 8, 9}, + {10,11}, + }, + }; + inline int b[1][3][2] = { + { + {0,1}, + {2,3}, + {4,5}, + } + }; + inline auto axes = nmtools_tuple{ + nmtools_array{0}, + nmtools_array{-1}, + }; + inline auto axes_ct = nmtools_tuple{ + nmtools_tuple{0_ct}, + nmtools_tuple{"-1"_ct}, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3d) + { + inline int result[3][2][1][3] = + {{{{ 6, 18, 30}}, + + {{ 7, 23, 39}}}, + + + {{{ 8, 28, 48}}, + + {{ 9, 33, 57}}}, + + + {{{10, 38, 66}}, + + {{11, 43, 75}}}}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3e) + { + inline int a[2][3][1] = { + { + {0}, + {1}, + {2}, + }, + { + {3}, + {4}, + {5}, + }, + }; + inline int b[2][3][1] = { + { + {0}, + {1}, + {2}, + }, + { + {3}, + {4}, + {5}, + }, + }; + inline int axes = 3; + inline auto axes_ct = 3_ct; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3e) + { + inline int result = 55; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case4a) + { + inline int a[2][3][4][5] = + {{{{ 0, 1, 2, 3, 4}, + { 5, 6, 7, 8, 9}, + { 10, 11, 12, 13, 14}, + { 15, 16, 17, 18, 19}}, + + {{ 20, 21, 22, 23, 24}, + { 25, 26, 27, 28, 29}, + { 30, 31, 32, 33, 34}, + { 35, 36, 37, 38, 39}}, + + {{ 40, 41, 42, 43, 44}, + { 45, 46, 47, 48, 49}, + { 50, 51, 52, 53, 54}, + { 55, 56, 57, 58, 59}}}, + + + {{{ 60, 61, 62, 63, 64}, + { 65, 66, 67, 68, 69}, + { 70, 71, 72, 73, 74}, + { 75, 76, 77, 78, 79}}, + + {{ 80, 81, 82, 83, 84}, + { 85, 86, 87, 88, 89}, + { 90, 91, 92, 93, 94}, + { 95, 96, 97, 98, 99}}, + + {{100, 101, 102, 103, 104}, + {105, 106, 107, 108, 109}, + {110, 111, 112, 113, 114}, + {115, 116, 117, 118, 119}}}}; + + inline int b[3][4][5][2] = + {{{{ 0, 1}, + { 2, 3}, + { 4, 5}, + { 6, 7}, + { 8, 9}}, + + {{ 10, 11}, + { 12, 13}, + { 14, 15}, + { 16, 17}, + { 18, 19}}, + + {{ 20, 21}, + { 22, 23}, + { 24, 25}, + { 26, 27}, + { 28, 29}}, + + {{ 30, 31}, + { 32, 33}, + { 34, 35}, + { 36, 37}, + { 38, 39}}}, + + + {{{ 40, 41}, + { 42, 43}, + { 44, 45}, + { 46, 47}, + { 48, 49}}, + + {{ 50, 51}, + { 52, 53}, + { 54, 55}, + { 56, 57}, + { 58, 59}}, + + {{ 60, 61}, + { 62, 63}, + { 64, 65}, + { 66, 67}, + { 68, 69}}, + + {{ 70, 71}, + { 72, 73}, + { 74, 75}, + { 76, 77}, + { 78, 79}}}, + + + {{{ 80, 81}, + { 82, 83}, + { 84, 85}, + { 86, 87}, + { 88, 89}}, + + {{ 90, 91}, + { 92, 93}, + { 94, 95}, + { 96, 97}, + { 98, 99}}, + + {{100, 101}, + {102, 103}, + {104, 105}, + {106, 107}, + {108, 109}}, + + {{110, 111}, + {112, 113}, + {114, 115}, + {116, 117}, + {118, 119}}}}; + + inline auto axes = 3; + inline auto axes_ct = 3_ct; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case4a) + { + inline int result[2][2] = { + {140420, 142190}, + {352820, 358190}, + }; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case4b) + { + inline int a[2][3][4][5] = + {{{{ 0, 1, 2, 3, 4}, + { 5, 6, 7, 8, 9}, + { 10, 11, 12, 13, 14}, + { 15, 16, 17, 18, 19}}, + + {{ 20, 21, 22, 23, 24}, + { 25, 26, 27, 28, 29}, + { 30, 31, 32, 33, 34}, + { 35, 36, 37, 38, 39}}, + + {{ 40, 41, 42, 43, 44}, + { 45, 46, 47, 48, 49}, + { 50, 51, 52, 53, 54}, + { 55, 56, 57, 58, 59}}}, + + + {{{ 60, 61, 62, 63, 64}, + { 65, 66, 67, 68, 69}, + { 70, 71, 72, 73, 74}, + { 75, 76, 77, 78, 79}}, + + {{ 80, 81, 82, 83, 84}, + { 85, 86, 87, 88, 89}, + { 90, 91, 92, 93, 94}, + { 95, 96, 97, 98, 99}}, + + {{100, 101, 102, 103, 104}, + {105, 106, 107, 108, 109}, + {110, 111, 112, 113, 114}, + {115, 116, 117, 118, 119}}}}; + + inline int b[3][5][4][2] = + {{{{ 0, 1}, + { 2, 3}, + { 4, 5}, + { 6, 7}}, + + {{ 8, 9}, + { 10, 11}, + { 12, 13}, + { 14, 15}}, + + {{ 16, 17}, + { 18, 19}, + { 20, 21}, + { 22, 23}}, + + {{ 24, 25}, + { 26, 27}, + { 28, 29}, + { 30, 31}}, + + {{ 32, 33}, + { 34, 35}, + { 36, 37}, + { 38, 39}}}, + + + {{{ 40, 41}, + { 42, 43}, + { 44, 45}, + { 46, 47}}, + + {{ 48, 49}, + { 50, 51}, + { 52, 53}, + { 54, 55}}, + + {{ 56, 57}, + { 58, 59}, + { 60, 61}, + { 62, 63}}, + + {{ 64, 65}, + { 66, 67}, + { 68, 69}, + { 70, 71}}, + + {{ 72, 73}, + { 74, 75}, + { 76, 77}, + { 78, 79}}}, + + + {{{ 80, 81}, + { 82, 83}, + { 84, 85}, + { 86, 87}}, + + {{ 88, 89}, + { 90, 91}, + { 92, 93}, + { 94, 95}}, + + {{ 96, 97}, + { 98, 99}, + {100, 101}, + {102, 103}}, + + {{104, 105}, + {106, 107}, + {108, 109}, + {110, 111}}, + + {{112, 113}, + {114, 115}, + {116, 117}, + {118, 119}}}}; + + inline auto axes = nmtools_tuple{ + nmtools_array{2,3}, + nmtools_array{-2,-3}, + }; + inline auto axes_ct = nmtools_tuple{ + nmtools_tuple{2_ct,3_ct}, + nmtools_tuple{"-2"_ct,"-3"_ct}, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case4b) + { + inline int result[2][3][3][2] = + {{{{ 4180, 4370}, + { 11780, 11970}, + { 19380, 19570}}, + + {{ 11780, 12370}, + { 35380, 35970}, + { 58980, 59570}}, + + {{ 19380, 20370}, + { 58980, 59970}, + { 98580, 99570}}}, + + + {{{ 26980, 28370}, + { 82580, 83970}, + {138180, 139570}}, + + {{ 34580, 36370}, + {106180, 107970}, + {177780, 179570}}, + + {{ 42180, 44370}, + {129780, 131970}, + {217380, 219570}}}}; + } +} + +#endif // NMTOOLS_TESTING_DATA_ARRAY_TENSORDOT_HPP \ No newline at end of file diff --git a/include/nmtools/testing/data/array/vecdot.hpp b/include/nmtools/testing/data/array/vecdot.hpp new file mode 100644 index 000000000..cc47b9d09 --- /dev/null +++ b/include/nmtools/testing/data/array/vecdot.hpp @@ -0,0 +1,427 @@ +#ifndef NMTOOLS_TESTING_DATA_ARRAY_VECDOT_HPP +#define NMTOOLS_TESTING_DATA_ARRAY_VECDOT_HPP + +#include "nmtools/testing/testing.hpp" +#include "nmtools/testing/array_cast.hpp" + +NMTOOLS_TESTING_DECLARE_CASE(array,vecdot) +{ + NMTOOLS_TESTING_DECLARE_ARGS(case1a) + { + inline int a[6] = {0,1,2,3,4,5}; + inline int b[6] = {6,7,8,9,10,11}; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case1a) + { + inline int result = 145; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case1b) + { + inline int a[6] = {0,1,2,3,4,5}; + inline int b[6] = {6,7,8,9,10,11}; + + inline auto keepdims = True; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case1b) + { + inline int result[1] = {145}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case1c) + { + inline int a[6] = {0,1,2,3,4,5}; + inline int b[2][6] = { + { 6, 7, 8, 9,10,11}, + {12,13,14,15,16,17}, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case1c) + { + inline int result[2] = {145,235}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case1d) + { + inline int a[6] = {0,1,2,3,4,5}; + inline int b[2][6] = { + { 6, 7, 8, 9,10,11}, + {12,13,14,15,16,17}, + }; + + inline auto keepdims = True; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case1d) + { + inline int result[2][1] = { + {145}, + {235}, + }; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case1e) + { + inline int a[3] = {0,1,2}; + inline int b[2][2][3] = { + { + {6, 7, 8}, + {9,10,11}, + }, + { + {12,13,14}, + {15,16,17}, + }, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case1e) + { + inline int result[2][2] = { + {23,32}, + {41,50}, + }; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case1f) + { + inline int a[3] = {0,1,2}; + inline int b[2][2][3] = { + { + {6, 7, 8}, + {9,10,11}, + }, + { + {12,13,14}, + {15,16,17}, + }, + }; + inline auto keepdims = True; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case1f) + { + inline int result[2][2][1] = { + { + {23}, + {32}, + }, + { + {41}, + {50}, + } + }; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2a) + { + inline int a[2][3] = { + {0,1,2}, + {3,4,5}, + }; + inline int b[2][3] = { + {6, 7, 8}, + {9,10,11}, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2a) + { + inline int result[2] = {23,122}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2b) + { + inline int a[2][3] = { + {0,1,2}, + {3,4,5}, + }; + inline int b[2][3] = { + {6, 7, 8}, + {9,10,11}, + }; + + inline auto keepdims = True; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2b) + { + inline int result[2][1] = { + { 23}, + {122}, + }; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2c) + { + inline int a[2][3] = { + {0,1,2}, + {3,4,5}, + }; + inline int b[3] = {6,7,8}; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2c) + { + inline int result[2] = {23,86}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2d) + { + inline int a[2][3] = { + {0,1,2}, + {3,4,5}, + }; + inline int b[3] = {6,7,8}; + + inline auto keepdims = True; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2d) + { + inline int result[2][1] = { + {23}, + {86}, + }; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2e) + { + inline int a[3][2] = { + {0,1}, + {2,3}, + {4,5}, + }; + inline int b[2][3][2] = { + { + {0,1}, + {2,3}, + {4,5}, + }, + { + { 6, 7}, + { 8, 9}, + {10,11}, + }, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2e) + { + inline int result[2][3] = { + {1,13,41}, + {7,43,95}, + }; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2f) + { + inline int a[3][2] = { + {0,1}, + {2,3}, + {4,5}, + }; + inline int b[2][3][2] = { + { + {0,1}, + {2,3}, + {4,5}, + }, + { + { 6, 7}, + { 8, 9}, + {10,11}, + }, + }; + + inline auto keepdims = True; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2f) + { + inline int result[2][3][1] = { + { + { 1}, + {13}, + {41}, + }, + { + { 7}, + {43}, + {95}, + }, + }; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3a) + { + inline int a[2][3][2] = { + { + {0,1}, + {2,3}, + {4,5}, + }, + { + { 6, 7}, + { 8, 9}, + {10,11}, + }, + }; + inline int b[2][3][2] = { + { + {0,1}, + {2,3}, + {4,5}, + }, + { + { 6, 7}, + { 8, 9}, + {10,11}, + }, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3a) + { + inline int result[2][3] = { + { 1, 13, 41}, + {85,145,221}, + }; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3b) + { + inline int a[2][3][2] = { + { + {0,1}, + {2,3}, + {4,5}, + }, + { + { 6, 7}, + { 8, 9}, + {10,11}, + }, + }; + inline int b[1][3][2] = { + { + {0,1}, + {2,3}, + {4,5}, + } + }; + inline auto keepdims = True; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3b) + { + inline int result[2][3][1] = { + { + { 1}, + {13}, + {41}, + }, + { + { 7}, + {43}, + {95}, + } + }; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3c) + { + inline int a[2][3][2] = { + { + {0,1}, + {2,3}, + {4,5}, + }, + { + { 6, 7}, + { 8, 9}, + {10,11}, + }, + }; + inline int b[3][2][1][2] = { + { + { + {0,1}, + }, + { + {2,3}, + }, + }, + { + { + {4,5}, + }, + { + {6,7}, + } + }, + { + { + {8,9}, + }, + { + {10,11}, + } + }, + }; + + NMTOOLS_CAST_ARRAYS(a) + NMTOOLS_CAST_ARRAYS(b) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3c) + { + inline int result[3][2][3] = { + { + { 1, 3, 5}, + {33,43,53}, + }, + { + { 5, 23, 41}, + {85,111,137}, + }, + { + { 9, 43, 77}, + {137,179,221}, + } + }; + } + +} + +#endif // NMTOOLS_TESTING_DATA_ARRAY_VECDOT_HPP \ No newline at end of file diff --git a/include/nmtools/testing/data/index/kron.hpp b/include/nmtools/testing/data/index/kron.hpp new file mode 100644 index 000000000..f5a189dce --- /dev/null +++ b/include/nmtools/testing/data/index/kron.hpp @@ -0,0 +1,587 @@ +#ifndef NMTOOLS_TESTING_DATA_INDEX_KRON_HPP +#define NMTOOLS_TESTING_DATA_INDEX_KRON_HPP + +#include "nmtools/testing/testing.hpp" +#include "nmtools/testing/array_cast.hpp" + +NMTOOLS_TESTING_DECLARE_CASE(index,kron_dst_transpose) +{ + using namespace literals; + + NMTOOLS_TESTING_DECLARE_ARGS(case1a) + { + inline int lhs_shape[1] = {6}; + inline int rhs_shape[1] = {3}; + + inline auto lhs_shape_ct = nmtools_tuple{6_ct}; + inline auto rhs_shape_ct = nmtools_tuple{3_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + + inline int lhs_dim = 1; + inline int rhs_dim = 1; + + inline auto lhs_dim_ct = 1_ct; + inline auto rhs_dim_ct = 1_ct; + + inline auto lhs_dim_cl = "1:[1]"_ct; + inline auto rhs_dim_cl = "1:[1]"_ct; + } + NMTOOLS_TESTING_DECLARE_EXPECT(case1a) + { + inline int result[2] = {0,1}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case1b) + { + inline int lhs_shape[1] = {6}; + inline int rhs_shape[2] = {3,4}; + + inline auto lhs_shape_ct = nmtools_tuple{6_ct}; + inline auto rhs_shape_ct = nmtools_tuple{3_ct,4_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + + inline int lhs_dim = 1; + inline int rhs_dim = 2; + + inline auto lhs_dim_ct = 1_ct; + inline auto rhs_dim_ct = 2_ct; + + inline auto lhs_dim_cl = "1:[1]"_ct; + inline auto rhs_dim_cl = "2:[2]"_ct; + } + NMTOOLS_TESTING_DECLARE_EXPECT(case1b) + { + inline int result[3] = {1,0,2}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case1c) + { + inline int lhs_shape[1] = {6}; + inline int rhs_shape[3] = {2,3,2}; + + inline auto lhs_shape_ct = nmtools_tuple{6_ct}; + inline auto rhs_shape_ct = nmtools_tuple{2_ct,3_ct,2_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + + inline int lhs_dim = 1; + inline int rhs_dim = 3; + + inline auto lhs_dim_ct = 1_ct; + inline auto rhs_dim_ct = 3_ct; + + inline auto lhs_dim_cl = "1:[1]"_ct; + inline auto rhs_dim_cl = "3:[3]"_ct; + } + NMTOOLS_TESTING_DECLARE_EXPECT(case1c) + { + inline int result[4] = {1,2,0,3}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case1d) + { + inline int lhs_shape[1] = {6}; + inline int rhs_shape[4] = {2,1,3,2}; + + inline auto lhs_shape_ct = nmtools_tuple{6_ct}; + inline auto rhs_shape_ct = nmtools_tuple{2_ct,1_ct,3_ct,2_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + + inline int lhs_dim = 1; + inline int rhs_dim = 4; + + inline auto lhs_dim_ct = 1_ct; + inline auto rhs_dim_ct = 4_ct; + + inline auto lhs_dim_cl = "1:[1]"_ct; + inline auto rhs_dim_cl = "4:[4]"_ct; + } + NMTOOLS_TESTING_DECLARE_EXPECT(case1d) + { + inline int result[5] = {1,2,3,0,4}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case1e) + { + inline int lhs_shape[1] = {6}; + inline int rhs_shape[5] = {2,1,3,1,2}; + + inline auto lhs_shape_ct = nmtools_tuple{6_ct}; + inline auto rhs_shape_ct = nmtools_tuple{2_ct,1_ct,3_ct,1_ct,2_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + + inline int lhs_dim = 1; + inline int rhs_dim = 5; + + inline auto lhs_dim_ct = 1_ct; + inline auto rhs_dim_ct = 5_ct; + + inline auto lhs_dim_cl = "1:[1]"_ct; + inline auto rhs_dim_cl = "5:[5]"_ct; + } + NMTOOLS_TESTING_DECLARE_EXPECT(case1e) + { + inline int result[6] = {1,2,3,4,0,5}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2a) + { + inline int lhs_shape[2] = {2,3}; + inline int rhs_shape[2] = {3,4}; + + inline auto lhs_shape_ct = nmtools_tuple{2_ct,3_ct}; + inline auto rhs_shape_ct = nmtools_tuple{3_ct,4_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + + inline int lhs_dim = 2; + inline int rhs_dim = 2; + + inline auto lhs_dim_ct = 2_ct; + inline auto rhs_dim_ct = 2_ct; + + inline auto lhs_dim_cl = "2:[2]"_ct; + inline auto rhs_dim_cl = "2:[2]"_ct; + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2a) + { + inline int result[4] = {0,2,1,3}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2b) + { + inline int lhs_shape[2] = {3,2}; + inline int rhs_shape[2] = {2,3}; + + inline auto lhs_shape_ct = nmtools_tuple{3_ct,2_ct}; + inline auto rhs_shape_ct = nmtools_tuple{2_ct,3_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + + inline int lhs_dim = 2; + inline int rhs_dim = 2; + + inline auto lhs_dim_ct = 2_ct; + inline auto rhs_dim_ct = 2_ct; + + inline auto lhs_dim_cl = "2:[2]"_ct; + inline auto rhs_dim_cl = "2:[2]"_ct; + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2b) + { + inline int result[4] = {0,2,1,3}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2c) + { + inline int lhs_shape[2] = {3,2}; + inline int rhs_shape[1] = {6}; + + inline auto lhs_shape_ct = nmtools_tuple{3_ct,2_ct}; + inline auto rhs_shape_ct = nmtools_tuple{6_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + + inline int lhs_dim = 2; + inline int rhs_dim = 1; + + inline auto lhs_dim_ct = 2_ct; + inline auto rhs_dim_ct = 1_ct; + + inline auto lhs_dim_cl = "2:[2]"_ct; + inline auto rhs_dim_cl = "1:[1]"_ct; + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2c) + { + inline int result[3] = {0,1,2}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2d) + { + inline int lhs_shape[2] = {3,4}; + inline int rhs_shape[3] = {2,3,2}; + + inline auto lhs_shape_ct = nmtools_tuple{3_ct,4_ct}; + inline auto rhs_shape_ct = nmtools_tuple{2_ct,3_ct,2_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + + inline int lhs_dim = 2; + inline int rhs_dim = 3; + + inline auto lhs_dim_ct = 2_ct; + inline auto rhs_dim_ct = 3_ct; + + inline auto lhs_dim_cl = "2:[2]"_ct; + inline auto rhs_dim_cl = "3:[3]"_ct; + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2d) + { + inline int result[5] = {2,0,3,1,4}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2e) + { + inline int lhs_shape[2] = {3,4}; + inline int rhs_shape[4] = {2,1,3,2}; + + inline auto lhs_shape_ct = nmtools_tuple{3_ct,4_ct}; + inline auto rhs_shape_ct = nmtools_tuple{2_ct,1_ct,3_ct,2_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + + inline int lhs_dim = 2; + inline int rhs_dim = 4; + + inline auto lhs_dim_ct = 2_ct; + inline auto rhs_dim_ct = 4_ct; + + inline auto lhs_dim_cl = "2:[2]"_ct; + inline auto rhs_dim_cl = "4:[4]"_ct; + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2e) + { + inline int result[6] = {2,3,0,4,1,5}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2f) + { + inline int lhs_shape[2] = {3,4}; + inline int rhs_shape[5] = {2,1,3,1,2}; + + inline auto lhs_shape_ct = nmtools_tuple{3_ct,4_ct}; + inline auto rhs_shape_ct = nmtools_tuple{2_ct,1_ct,3_ct,1_ct,2_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + + inline int lhs_dim = 2; + inline int rhs_dim = 5; + + inline auto lhs_dim_ct = 2_ct; + inline auto rhs_dim_ct = 5_ct; + + inline auto lhs_dim_cl = "2:[2]"_ct; + inline auto rhs_dim_cl = "5:[5]"_ct; + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2f) + { + inline int result[7] = {2,3,4,0,5,1,6}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3a) + { + inline int lhs_shape[3] = {2,3,2}; + inline int rhs_shape[3] = {2,3,2}; + + inline auto lhs_shape_ct = nmtools_tuple{2_ct,3_ct,2_ct}; + inline auto rhs_shape_ct = nmtools_tuple{2_ct,3_ct,2_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + + inline int lhs_dim = 3; + inline int rhs_dim = 3; + + inline auto lhs_dim_ct = 3_ct; + inline auto rhs_dim_ct = 3_ct; + + inline auto lhs_dim_cl = "3:[3]"_ct; + inline auto rhs_dim_cl = "3:[3]"_ct; + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3a) + { + inline int result[6] = {0,3,1,4,2,5}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3b) + { + inline int lhs_shape[3] = {2,3,2}; + inline int rhs_shape[3] = {2,1,3}; + + inline auto lhs_shape_ct = nmtools_tuple{2_ct,3_ct,2_ct}; + inline auto rhs_shape_ct = nmtools_tuple{2_ct,1_ct,3_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + + inline int lhs_dim = 3; + inline int rhs_dim = 3; + + inline auto lhs_dim_ct = 3_ct; + inline auto rhs_dim_ct = 3_ct; + + inline auto lhs_dim_cl = "3:[3]"_ct; + inline auto rhs_dim_cl = "3:[3]"_ct; + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3b) + { + inline int result[6] = {0,3,1,4,2,5}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3c) + { + inline int lhs_shape[3] = {2,3,2}; + inline int rhs_shape[2] = {3,4}; + + inline auto lhs_shape_ct = nmtools_tuple{2_ct,3_ct,2_ct}; + inline auto rhs_shape_ct = nmtools_tuple{3_ct,4_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + + inline int lhs_dim = 3; + inline int rhs_dim = 2; + + inline auto lhs_dim_ct = 3_ct; + inline auto rhs_dim_ct = 2_ct; + + inline auto lhs_dim_cl = "3:[3]"_ct; + inline auto rhs_dim_cl = "2:[2]"_ct; + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3c) + { + inline int result[5] = {0,1,3,2,4}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3d) + { + inline int lhs_shape[3] = {2,3,2}; + inline int rhs_shape[1] = {6}; + + inline auto lhs_shape_ct = nmtools_tuple{2_ct,3_ct,2_ct}; + inline auto rhs_shape_ct = nmtools_tuple{6_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + + inline int lhs_dim = 3; + inline int rhs_dim = 1; + + inline auto lhs_dim_ct = 3_ct; + inline auto rhs_dim_ct = 1_ct; + + inline auto lhs_dim_cl = "3:[3]"_ct; + inline auto rhs_dim_cl = "1:[1]"_ct; + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3d) + { + inline int result[4] = {0,1,2,3}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3e) + { + inline int lhs_shape[3] = {2,3,2}; + inline int rhs_shape[4] = {2,1,3,1}; + + inline auto lhs_shape_ct = nmtools_tuple{2_ct,3_ct,2_ct}; + inline auto rhs_shape_ct = nmtools_tuple{2_ct,1_ct,3_ct,1_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + + inline int lhs_dim = 3; + inline int rhs_dim = 4; + + inline auto lhs_dim_ct = 3_ct; + inline auto rhs_dim_ct = 4_ct; + + inline auto lhs_dim_cl = "3:[3]"_ct; + inline auto rhs_dim_cl = "4:[4]"_ct; + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3e) + { + inline int result[7] = {3,0,4,1,5,2,6}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3f) + { + inline int lhs_shape[3] = {2,3,2}; + inline int rhs_shape[5] = {2,1,2,3,1}; + + inline auto lhs_shape_ct = nmtools_tuple{2_ct,3_ct,2_ct}; + inline auto rhs_shape_ct = nmtools_tuple{2_ct,1_ct,2_ct,3_ct,1_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + + inline int lhs_dim = 3; + inline int rhs_dim = 5; + + inline auto lhs_dim_ct = 3_ct; + inline auto rhs_dim_ct = 5_ct; + + inline auto lhs_dim_cl = "3:[3]"_ct; + inline auto rhs_dim_cl = "5:[5]"_ct; + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3f) + { + inline int result[8] = {3,4,0,5,1,6,2,7}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case4a) + { + inline int lhs_shape[4] = {2,1,3,2}; + inline int rhs_shape[4] = {2,2,3,1}; + + inline auto lhs_shape_ct = nmtools_tuple{2_ct,1_ct,3_ct,2_ct}; + inline auto rhs_shape_ct = nmtools_tuple{2_ct,2_ct,3_ct,1_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + + inline int lhs_dim = 4; + inline int rhs_dim = 4; + + inline auto lhs_dim_ct = 4_ct; + inline auto rhs_dim_ct = 4_ct; + + inline auto lhs_dim_cl = "4:[4]"_ct; + inline auto rhs_dim_cl = "4:[4]"_ct; + } + NMTOOLS_TESTING_DECLARE_EXPECT(case4a) + { + inline int result[8] = {0,4,1,5,2,6,3,7}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case4b) + { + inline int lhs_shape[4] = {2,1,3,2}; + inline int rhs_shape[3] = {2,3,2}; + + inline auto lhs_shape_ct = nmtools_tuple{2_ct,1_ct,3_ct,2_ct}; + inline auto rhs_shape_ct = nmtools_tuple{2_ct,3_ct,2_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + + inline int lhs_dim = 4; + inline int rhs_dim = 3; + + inline auto lhs_dim_ct = 4_ct; + inline auto rhs_dim_ct = 3_ct; + + inline auto lhs_dim_cl = "4:[4]"_ct; + inline auto rhs_dim_cl = "3:[3]"_ct; + } + NMTOOLS_TESTING_DECLARE_EXPECT(case4b) + { + inline int result[7] = {0,1,4,2,5,3,6}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case4c) + { + inline int lhs_shape[4] = {2,1,3,2}; + inline int rhs_shape[2] = {3,2}; + + inline auto lhs_shape_ct = nmtools_tuple{2_ct,1_ct,3_ct,2_ct}; + inline auto rhs_shape_ct = nmtools_tuple{3_ct,2_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + + inline int lhs_dim = 4; + inline int rhs_dim = 2; + + inline auto lhs_dim_ct = 4_ct; + inline auto rhs_dim_ct = 2_ct; + + inline auto lhs_dim_cl = "4:[4]"_ct; + inline auto rhs_dim_cl = "2:[2]"_ct; + } + NMTOOLS_TESTING_DECLARE_EXPECT(case4c) + { + inline int result[6] = {0,1,2,4,3,5}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case4d) + { + inline int lhs_shape[4] = {2,1,3,2}; + inline int rhs_shape[1] = {6}; + + inline auto lhs_shape_ct = nmtools_tuple{2_ct,1_ct,3_ct,2_ct}; + inline auto rhs_shape_ct = nmtools_tuple{6_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + + inline int lhs_dim = 4; + inline int rhs_dim = 1; + + inline auto lhs_dim_ct = 4_ct; + inline auto rhs_dim_ct = 1_ct; + + inline auto lhs_dim_cl = "4:[4]"_ct; + inline auto rhs_dim_cl = "1:[1]"_ct; + } + NMTOOLS_TESTING_DECLARE_EXPECT(case4d) + { + inline int result[5] = {0,1,2,3,4}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case5a) + { + inline int lhs_shape[5] = {2,1,3,1,2}; + inline int rhs_shape[5] = {2,2,1,3,1}; + + inline auto lhs_shape_ct = nmtools_tuple{2_ct,1_ct,3_ct,1_ct,2_ct}; + inline auto rhs_shape_ct = nmtools_tuple{2_ct,2_ct,1_ct,3_ct,1_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + + inline int lhs_dim = 5; + inline int rhs_dim = 5; + + inline auto lhs_dim_ct = 5_ct; + inline auto rhs_dim_ct = 5_ct; + + inline auto lhs_dim_cl = "5:[5]"_ct; + inline auto rhs_dim_cl = "5:[5]"_ct; + } + NMTOOLS_TESTING_DECLARE_EXPECT(case5a) + { + inline int result[10] = {0,5,1,6,2,7,3,8,4,9}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case6a) + { + inline int lhs_shape[6] = {2,1,3,1,2,1}; + inline int rhs_shape[6] = {1,3,1,2,1,2}; + + inline auto lhs_shape_ct = nmtools_tuple{2_ct,1_ct,3_ct,1_ct,2_ct,1_ct}; + inline auto rhs_shape_ct = nmtools_tuple{1_ct,3_ct,1_ct,2_ct,1_ct,2_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + + inline int lhs_dim = 6; + inline int rhs_dim = 6; + + inline auto lhs_dim_ct = 6_ct; + inline auto rhs_dim_ct = 6_ct; + + inline auto lhs_dim_cl = "6:[6]"_ct; + inline auto rhs_dim_cl = "6:[6]"_ct; + } + NMTOOLS_TESTING_DECLARE_EXPECT(case6a) + { + inline int result[12] = {0,6,1,7,2,8,3,9,4,10,5,11}; + } +} + +#endif // NMTOOLS_TESTING_DATA_INDEX_KRON_HPP \ No newline at end of file diff --git a/include/nmtools/testing/data/index/tensordot.hpp b/include/nmtools/testing/data/index/tensordot.hpp new file mode 100644 index 000000000..566c4bd89 --- /dev/null +++ b/include/nmtools/testing/data/index/tensordot.hpp @@ -0,0 +1,575 @@ +#ifndef NMTOOLS_TESTING_DATA_INDEX_TENSORDOT_HPP +#define NMTOOLS_TESTING_DATA_INDEX_TENSORDOT_HPP + +#include "nmtools/testing/testing.hpp" +#include "nmtools/testing/array_cast.hpp" + +NMTOOLS_TESTING_DECLARE_CASE(index,tensordot_lhs_transpose) +{ + using namespace literals; + NMTOOLS_TESTING_DECLARE_ARGS(case1a) + { + inline int dim = 1; + inline int axes = 1; + + inline auto dim_ct = 1_ct; + inline auto axes_ct = 1_ct; + } + NMTOOLS_TESTING_DECLARE_EXPECT(case1a) + { + inline int result[1] = {0}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2a) + { + inline int dim = 2; + inline int axes = 2; + + inline auto dim_ct = 2_ct; + inline auto axes_ct = 2_ct; + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2a) + { + inline int result[2] = {0,1}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2b) + { + inline int dim = 2; + inline int axes = 2; + + inline auto dim_ct = 2_ct; + inline auto axes_ct = 2_ct; + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2b) + { + inline int result[2] = {0,1}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2e) + { + inline int dim = 2; + inline int axes[2] = {0,1}; + + inline auto dim_ct = 2_ct; + inline auto axes_ct = nmtools_tuple{0_ct,1_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(axes) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2e) + { + inline int result[2] = {0,1}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2h) + { + inline int dim = 2; + inline int axes[2] = {1,0}; + + inline auto dim_ct = 2_ct; + inline auto axes_ct = nmtools_tuple{1_ct,0_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(axes) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2h) + { + inline int result[2] = {1,0}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3a) + { + inline int dim = 3; + inline int axes = 2; + + inline auto dim_ct = 3_ct; + inline auto axes_ct = 2_ct; + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3a) + { + inline int result[3] = {0,1,2}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3b) + { + inline int dim = 3; + inline int axes[2] = {0,1}; + + inline auto dim_ct = 3_ct; + inline auto axes_ct = nmtools_tuple{0_ct,1_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(axes) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3b) + { + inline int result[3] = {2,0,1}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3d) + { + inline int dim = 3; + inline int axes[1] = {0}; + + inline auto dim_ct = 3_ct; + inline auto axes_ct = nmtools_tuple{0_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(axes) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3d) + { + inline int result[3] = {1,2,0}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3e) + { + inline int dim = 3; + inline int axes = 3; + + inline auto dim_ct = 3_ct; + inline auto axes_ct = 3_ct; + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3e) + { + inline int result[3] = {0,1,2}; + } +} + +NMTOOLS_TESTING_DECLARE_CASE(index,tensordot_rhs_transpose) +{ + using namespace literals; + NMTOOLS_TESTING_DECLARE_ARGS(case1a) + { + inline int dim = 1; + inline int axes[1] = {0}; + + inline auto dim_ct = 1_ct; + inline auto axes_ct = nmtools_tuple{0_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(axes); + } + NMTOOLS_TESTING_DECLARE_EXPECT(case1a) + { + inline int result[1] = {0}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2a) + { + inline int dim = 2; + inline int axes[2] = {0,1}; + + inline auto dim_ct = 2_ct; + inline auto axes_ct = nmtools_tuple{0_ct,1_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(axes) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2a) + { + inline int result[2] = {0,1}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2b) + { + inline int dim = 3; + inline int axes[2] = {0,1}; + + inline auto dim_ct = 3_ct; + inline auto axes_ct = nmtools_tuple{0_ct,1_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(axes) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2b) + { + inline int result[3] = {2,0,1}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2d) + { + inline int dim = 4; + inline int axes[2] = {0,1}; + + inline auto dim_ct = 4_ct; + inline auto axes_ct = nmtools_tuple{0_ct,1_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(axes) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2d) + { + inline int result[4] = {2,3,0,1}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2e) + { + inline int dim = 3; + inline int axes[2] = {1,2}; + + inline auto dim_ct = 3_ct; + inline auto axes_ct = nmtools_tuple{1_ct,2_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(axes) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2e) + { + inline int result[3] = {0,1,2}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2g) + { + inline int dim = 3; + inline int axes[2] = {1,0}; + + inline auto dim_ct = 3_ct; + inline auto axes_ct = nmtools_tuple{1_ct,0_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(axes) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2g) + { + inline int result[3] = {2,1,0}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2i) + { + inline int dim = 3; + inline int axes[2] = {2,1}; + + inline auto dim_ct = 3_ct; + inline auto axes_ct = nmtools_tuple{2_ct,1_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(axes) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2i) + { + inline int result[3] = {0,2,1}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2j) + { + inline int dim = 2; + inline int axes[1] = {1}; + + inline auto dim_ct = 2_ct; + inline auto axes_ct = nmtools_tuple{1_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(axes) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2j) + { + inline int result[2] = {0,1}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3a) + { + inline int dim = 2; + inline int axes[2] = {0,1}; + + inline auto dim_ct = 2_ct; + inline auto axes_ct = nmtools_tuple{0_ct,1_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(axes) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3a) + { + inline int result[2] = {0,1}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3c) + { + inline int dim = 3; + inline int axes[2] = {2,1}; + + inline auto dim_ct = 3_ct; + inline auto axes_ct = nmtools_tuple{2_ct,1_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(axes) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3c) + { + inline int result[3] = {0,2,1}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3d) + { + inline int dim = 3; + inline int axes[1] = {2}; + + inline auto dim_ct = 3_ct; + inline auto axes_ct = nmtools_tuple{2_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(axes) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3d) + { + inline int result[3] = {0,1,2}; + } +} + +NMTOOLS_TESTING_DECLARE_CASE(index,tensordot_lhs_reshape) +{ + using namespace literals; + NMTOOLS_TESTING_DECLARE_ARGS(case1a) + { + inline int lhs_shape[1] = {6}; + inline int rhs_shape[1] = {6}; + inline int lhs_axes[1] = {-1}; + + inline auto lhs_shape_ct = nmtools_tuple{6_ct}; + inline auto rhs_shape_ct = nmtools_tuple{6_ct}; + inline auto lhs_axes_ct = nmtools_tuple{"-1"_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(lhs_axes) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case1a) + { + inline int result[1] = {6}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2a) + { + inline int lhs_shape[2] = {3,2}; + inline int rhs_shape[2] = {3,2}; + inline int lhs_axes[2] = {-1,-2}; + + inline auto lhs_shape_ct = nmtools_tuple{3_ct,2_ct}; + inline auto rhs_shape_ct = nmtools_tuple{3_ct,2_ct}; + inline auto lhs_axes_ct = nmtools_tuple{"-1"_ct,"-2"_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(lhs_axes) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2a) + { + inline int result[2] = {3,2}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2b) + { + inline int lhs_shape[2] = {3,2}; + inline int rhs_shape[3] = {3,2,1}; + inline int lhs_axes[2] = {-1,-2}; + + inline auto lhs_shape_ct = nmtools_tuple{3_ct,2_ct}; + inline auto rhs_shape_ct = nmtools_tuple{3_ct,2_ct,1_ct}; + inline auto lhs_axes_ct = nmtools_tuple{"-1"_ct,"-2"_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(lhs_axes) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2b) + { + inline int result[3] = {1,3,2}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2c) + { + inline int lhs_shape[2] = {3,2}; + inline int rhs_shape[3] = {3,2,2}; + inline int lhs_axes[2] = {-1,-2}; + + inline auto lhs_shape_ct = nmtools_tuple{3_ct,2_ct}; + inline auto rhs_shape_ct = nmtools_tuple{3_ct,2_ct,2_ct}; + inline auto lhs_axes_ct = nmtools_tuple{"-1"_ct,"-2"_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(lhs_axes) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2c) + { + inline int result[3] = {1,3,2}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2d) + { + inline int lhs_shape[2] = {3,2}; + inline int rhs_shape[4] = {3,2,1,2}; + inline int lhs_axes[2] = {-1,-2}; + + inline auto lhs_shape_ct = nmtools_tuple{3_ct,2_ct}; + inline auto rhs_shape_ct = nmtools_tuple{3_ct,2_ct,1_ct,2_ct}; + inline auto lhs_axes_ct = nmtools_tuple{"-1"_ct,"-2"_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(lhs_axes) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2d) + { + inline int result[4] = {1,1,3,2}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2e) + { + inline int lhs_shape[2] = {3,2}; + inline int rhs_shape[3] = {1,3,2}; + inline int lhs_axes[2] = {-1,-2}; + + inline auto lhs_shape_ct = nmtools_tuple{3_ct,2_ct}; + inline auto rhs_shape_ct = nmtools_tuple{1_ct,3_ct,2_ct}; + inline auto lhs_axes_ct = nmtools_tuple{"-1"_ct,"-2"_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(lhs_axes) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2e) + { + inline int result[3] = {1,3,2}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2f) + { + inline int lhs_shape[2] = {3,2}; + inline int rhs_shape[3] = {2,3,2}; + inline int lhs_axes[2] = {-1,-2}; + + inline auto lhs_shape_ct = nmtools_tuple{3_ct,2_ct}; + inline auto rhs_shape_ct = nmtools_tuple{2_ct,3_ct,2_ct}; + inline auto lhs_axes_ct = nmtools_tuple{"-1"_ct,"-2"_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(lhs_axes) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2f) + { + inline int result[3] = {1,3,2}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2h) + { + inline int lhs_shape[2] = {2,3}; + inline int rhs_shape[3] = {2,3,2}; + inline int lhs_axes[2] = {-1,-2}; + + inline auto lhs_shape_ct = nmtools_tuple{2_ct,3_ct}; + inline auto rhs_shape_ct = nmtools_tuple{2_ct,3_ct,2_ct}; + inline auto lhs_axes_ct = nmtools_tuple{"-1"_ct,"-2"_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(lhs_axes) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2h) + { + inline int result[3] = {1,2,3}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case2j) + { + inline int lhs_shape[2] = {2,3}; + inline int rhs_shape[2] = {2,3}; + inline int lhs_axes[1] = {-1}; + + inline auto lhs_shape_ct = nmtools_tuple{2_ct,3_ct}; + inline auto rhs_shape_ct = nmtools_tuple{2_ct,3_ct}; + inline auto lhs_axes_ct = nmtools_tuple{"-1"_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(lhs_axes) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case2j) + { + inline int result[3] = {2,1,3}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3a) + { + inline int lhs_shape[3] = {2,3,2}; + inline int rhs_shape[2] = {3,2}; + inline int lhs_axes[2] = {-1,-2}; + + inline auto lhs_shape_ct = nmtools_tuple{2_ct,3_ct,2_ct}; + inline auto rhs_shape_ct = nmtools_tuple{3_ct,2_ct}; + inline auto lhs_axes_ct = nmtools_tuple{"-1"_ct,"-2"_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(lhs_axes) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3a) + { + inline int result[3] = {2,3,2}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3b) + { + inline int lhs_shape[3] = {2,2,3}; + inline int rhs_shape[2] = {2,3}; + inline int lhs_axes[2] = {-1,-2}; + + inline auto lhs_shape_ct = nmtools_tuple{2_ct,2_ct,3_ct}; + inline auto rhs_shape_ct = nmtools_tuple{2_ct,3_ct}; + inline auto lhs_axes_ct = nmtools_tuple{"-1"_ct,"-2"_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(lhs_axes) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3b) + { + inline int result[3] = {2,2,3}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3c) + { + inline int lhs_shape[3] = {2,2,3}; + inline int rhs_shape[3] = {1,3,2}; + inline int lhs_axes[2] = {-1,-2}; + + inline auto lhs_shape_ct = nmtools_tuple{2_ct,2_ct,3_ct}; + inline auto rhs_shape_ct = nmtools_tuple{1_ct,3_ct,2_ct}; + inline auto lhs_axes_ct = nmtools_tuple{"-1"_ct,"-2"_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(lhs_axes) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3c) + { + inline int result[4] = {2,1,2,3}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3d) + { + inline int lhs_shape[3] = {3,2,2}; + inline int rhs_shape[3] = {1,3,2}; + inline int lhs_axes[1] = {-1}; + + inline auto lhs_shape_ct = nmtools_tuple{3_ct,2_ct,2_ct}; + inline auto rhs_shape_ct = nmtools_tuple{1_ct,3_ct,2_ct}; + inline auto lhs_axes_ct = nmtools_tuple{"-1"_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(lhs_axes) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3d) + { + inline int result[5] = {3,2,1,1,2}; + } + + NMTOOLS_TESTING_DECLARE_ARGS(case3e) + { + inline int lhs_shape[3] = {2,3,1}; + inline int rhs_shape[3] = {2,3,1}; + inline int lhs_axes[3] = {-1,-2,-3}; + + inline auto lhs_shape_ct = nmtools_tuple{2_ct,3_ct,1_ct}; + inline auto rhs_shape_ct = nmtools_tuple{2_ct,3_ct,1_ct}; + inline auto lhs_axes_ct = nmtools_tuple{"-1"_ct,"-2"_ct,"-3"_ct}; + + NMTOOLS_CAST_INDEX_ARRAYS(lhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(rhs_shape) + NMTOOLS_CAST_INDEX_ARRAYS(lhs_axes) + } + NMTOOLS_TESTING_DECLARE_EXPECT(case3e) + { + inline int result[3] = {2,3,1}; + } +} + +#endif // NMTOOLS_TESTING_DATA_INDEX_TENSORDOT_HPP \ No newline at end of file diff --git a/tests/array/CMakeLists.txt b/tests/array/CMakeLists.txt index b85b68b11..00e65f88c 100644 --- a/tests/array/CMakeLists.txt +++ b/tests/array/CMakeLists.txt @@ -119,6 +119,7 @@ set(ARRAY_EVAL_TEST_SOURCES # break matmul to multiple files to avoid high memory peak # array/matmul.cpp array/instance_norm.cpp + array/kron.cpp array/linspace.cpp array/mean.cpp array/moveaxis.cpp @@ -142,6 +143,7 @@ set(ARRAY_EVAL_TEST_SOURCES array/split.cpp array/swapaxes.cpp array/take.cpp + array/tensordot.cpp array/tile.cpp array/trace.cpp array/tri.cpp @@ -149,6 +151,7 @@ set(ARRAY_EVAL_TEST_SOURCES array/triu.cpp array/transpose.cpp array/var.cpp + array/vecdot.cpp array/vector_norm.cpp array/vstack.cpp array/where.cpp diff --git a/tests/array/array/kron.cpp b/tests/array/array/kron.cpp new file mode 100644 index 000000000..da7ca6da7 --- /dev/null +++ b/tests/array/array/kron.cpp @@ -0,0 +1,246 @@ +#include "nmtools/array/array/kron.hpp" +#include "nmtools/testing/data/array/kron.hpp" +#include "nmtools/testing/doctest.hpp" + +#define KRON_SUBCASE(case_name,...) \ +SUBCASE(#case_name) \ +{ \ + NMTOOLS_TESTING_USE_CASE(array,kron,case_name); \ + auto result = nmtools::array::kron(__VA_ARGS__); \ + NMTOOLS_ASSERT_EQUAL( nmtools::shape(result), nmtools::shape(expect::result) ); \ + NMTOOLS_ASSERT_CLOSE( result, expect::result ); \ +} + +TEST_CASE("kron(case1a)" * doctest::test_suite("array::kron")) +{ + KRON_SUBCASE( case1a, a, b ); + KRON_SUBCASE( case1a, a_a, b_a ); + KRON_SUBCASE( case1a, a_f, b_f ); + KRON_SUBCASE( case1a, a_h, b_h ); + KRON_SUBCASE( case1a, a_d, b_d ); +} + +TEST_CASE("kron(case1b)" * doctest::test_suite("array::kron")) +{ + KRON_SUBCASE( case1b, a, b ); + KRON_SUBCASE( case1b, a_a, b_a ); + KRON_SUBCASE( case1b, a_f, b_f ); + KRON_SUBCASE( case1b, a_h, b_h ); + KRON_SUBCASE( case1b, a_d, b_d ); +} + +TEST_CASE("kron(case1c)" * doctest::test_suite("array::kron")) +{ + KRON_SUBCASE( case1c, a, b ); + KRON_SUBCASE( case1c, a_a, b_a ); + KRON_SUBCASE( case1c, a_f, b_f ); + KRON_SUBCASE( case1c, a_h, b_h ); + KRON_SUBCASE( case1c, a_d, b_d ); +} + +TEST_CASE("kron(case1d)" * doctest::test_suite("array::kron")) +{ + KRON_SUBCASE( case1d, a, b ); + KRON_SUBCASE( case1d, a_a, b_a ); + KRON_SUBCASE( case1d, a_f, b_f ); + KRON_SUBCASE( case1d, a_h, b_h ); + KRON_SUBCASE( case1d, a_d, b_d ); +} + +TEST_CASE("kron(case1e)" * doctest::test_suite("array::kron")) +{ + KRON_SUBCASE( case1e, a, b ); + KRON_SUBCASE( case1e, a_a, b_a ); + KRON_SUBCASE( case1e, a_f, b_f ); + KRON_SUBCASE( case1e, a_h, b_h ); + KRON_SUBCASE( case1e, a_d, b_d ); +} + +TEST_CASE("kron(case2a)" * doctest::test_suite("array::kron")) +{ + KRON_SUBCASE( case2a, a, b ); + KRON_SUBCASE( case2a, a_a, b_a ); + KRON_SUBCASE( case2a, a_f, b_f ); + KRON_SUBCASE( case2a, a_h, b_h ); + KRON_SUBCASE( case2a, a_d, b_d ); +} + +TEST_CASE("kron(case2b)" * doctest::test_suite("array::kron")) +{ + KRON_SUBCASE( case2b, a, b ); + KRON_SUBCASE( case2b, a_a, b_a ); + KRON_SUBCASE( case2b, a_f, b_f ); + KRON_SUBCASE( case2b, a_h, b_h ); + KRON_SUBCASE( case2b, a_d, b_d ); +} + +TEST_CASE("kron(case2c)" * doctest::test_suite("array::kron")) +{ + KRON_SUBCASE( case2c, a, b ); + KRON_SUBCASE( case2c, a_a, b_a ); + KRON_SUBCASE( case2c, a_f, b_f ); + KRON_SUBCASE( case2c, a_h, b_h ); + KRON_SUBCASE( case2c, a_d, b_d ); +} + +TEST_CASE("kron(case2d)" * doctest::test_suite("array::kron")) +{ + KRON_SUBCASE( case2d, a, b ); + KRON_SUBCASE( case2d, a_a, b_a ); + KRON_SUBCASE( case2d, a_f, b_f ); + KRON_SUBCASE( case2d, a_h, b_h ); + KRON_SUBCASE( case2d, a_d, b_d ); +} + +TEST_CASE("kron(case2e)" * doctest::test_suite("array::kron")) +{ + KRON_SUBCASE( case2e, a, b ); + KRON_SUBCASE( case2e, a_a, b_a ); + KRON_SUBCASE( case2e, a_f, b_f ); + KRON_SUBCASE( case2e, a_h, b_h ); + KRON_SUBCASE( case2e, a_d, b_d ); +} + +TEST_CASE("kron(case2f)" * doctest::test_suite("array::kron")) +{ + KRON_SUBCASE( case2f, a, b ); + KRON_SUBCASE( case2f, a_a, b_a ); + KRON_SUBCASE( case2f, a_f, b_f ); + KRON_SUBCASE( case2f, a_h, b_h ); + KRON_SUBCASE( case2f, a_d, b_d ); +} + +TEST_CASE("kron(case2g)" * doctest::test_suite("array::kron")) +{ + KRON_SUBCASE( case2g, a, b ); + KRON_SUBCASE( case2g, a_a, b_a ); + KRON_SUBCASE( case2g, a_f, b_f ); + KRON_SUBCASE( case2g, a_h, b_h ); + KRON_SUBCASE( case2g, a_d, b_d ); +} + +TEST_CASE("kron(case2h)" * doctest::test_suite("array::kron")) +{ + KRON_SUBCASE( case2h, a, b ); + KRON_SUBCASE( case2h, a_a, b_a ); + KRON_SUBCASE( case2h, a_f, b_f ); + KRON_SUBCASE( case2h, a_h, b_h ); + KRON_SUBCASE( case2h, a_d, b_d ); +} + +TEST_CASE("kron(case3a)" * doctest::test_suite("array::kron")) +{ + KRON_SUBCASE( case3a, a, b ); + KRON_SUBCASE( case3a, a_a, b_a ); + KRON_SUBCASE( case3a, a_f, b_f ); + KRON_SUBCASE( case3a, a_h, b_h ); + KRON_SUBCASE( case3a, a_d, b_d ); +} + +TEST_CASE("kron(case3b)" * doctest::test_suite("array::kron")) +{ + KRON_SUBCASE( case3b, a, b ); + KRON_SUBCASE( case3b, a_a, b_a ); + KRON_SUBCASE( case3b, a_f, b_f ); + KRON_SUBCASE( case3b, a_h, b_h ); + KRON_SUBCASE( case3b, a_d, b_d ); +} + +TEST_CASE("kron(case3d)" * doctest::test_suite("array::kron")) +{ + KRON_SUBCASE( case3d, a, b ); + KRON_SUBCASE( case3d, a_a, b_a ); + KRON_SUBCASE( case3d, a_f, b_f ); + KRON_SUBCASE( case3d, a_h, b_h ); + KRON_SUBCASE( case3d, a_d, b_d ); +} + +TEST_CASE("kron(case3e)" * doctest::test_suite("array::kron")) +{ + KRON_SUBCASE( case3e, a, b ); + KRON_SUBCASE( case3e, a_a, b_a ); + KRON_SUBCASE( case3e, a_f, b_f ); + KRON_SUBCASE( case3e, a_h, b_h ); + KRON_SUBCASE( case3e, a_d, b_d ); +} + +TEST_CASE("kron(case3f)" * doctest::test_suite("array::kron")) +{ + KRON_SUBCASE( case3f, a, b ); + KRON_SUBCASE( case3f, a_a, b_a ); + KRON_SUBCASE( case3f, a_f, b_f ); + KRON_SUBCASE( case3f, a_h, b_h ); + KRON_SUBCASE( case3f, a_d, b_d ); +} + +TEST_CASE("kron(case3g)" * doctest::test_suite("array::kron")) +{ + KRON_SUBCASE( case3g, a, b ); + KRON_SUBCASE( case3g, a_a, b_a ); + KRON_SUBCASE( case3g, a_f, b_f ); + KRON_SUBCASE( case3g, a_h, b_h ); + KRON_SUBCASE( case3g, a_d, b_d ); +} + +TEST_CASE("kron(case3h)" * doctest::test_suite("array::kron")) +{ + KRON_SUBCASE( case3h, a, b ); + KRON_SUBCASE( case3h, a_a, b_a ); + KRON_SUBCASE( case3h, a_f, b_f ); + KRON_SUBCASE( case3h, a_h, b_h ); + KRON_SUBCASE( case3h, a_d, b_d ); +} + +TEST_CASE("kron(case4a)" * doctest::test_suite("array::kron")) +{ + KRON_SUBCASE( case4a, a, b ); + KRON_SUBCASE( case4a, a_a, b_a ); + KRON_SUBCASE( case4a, a_f, b_f ); + KRON_SUBCASE( case4a, a_h, b_h ); + KRON_SUBCASE( case4a, a_d, b_d ); +} + +TEST_CASE("kron(case4b)" * doctest::test_suite("array::kron")) +{ + KRON_SUBCASE( case4b, a, b ); + KRON_SUBCASE( case4b, a_a, b_a ); + KRON_SUBCASE( case4b, a_f, b_f ); + KRON_SUBCASE( case4b, a_h, b_h ); + KRON_SUBCASE( case4b, a_d, b_d ); +} + +TEST_CASE("kron(case4c)" * doctest::test_suite("array::kron")) +{ + KRON_SUBCASE( case4c, a, b ); + KRON_SUBCASE( case4c, a_a, b_a ); + KRON_SUBCASE( case4c, a_f, b_f ); + KRON_SUBCASE( case4c, a_h, b_h ); + KRON_SUBCASE( case4c, a_d, b_d ); +} + +TEST_CASE("kron(case4d)" * doctest::test_suite("array::kron")) +{ + KRON_SUBCASE( case4d, a, b ); + KRON_SUBCASE( case4d, a_a, b_a ); + KRON_SUBCASE( case4d, a_f, b_f ); + KRON_SUBCASE( case4d, a_h, b_h ); + KRON_SUBCASE( case4d, a_d, b_d ); +} + +TEST_CASE("kron(case5a)" * doctest::test_suite("array::kron")) +{ + KRON_SUBCASE( case5a, a, b ); + KRON_SUBCASE( case5a, a_a, b_a ); + KRON_SUBCASE( case5a, a_f, b_f ); + KRON_SUBCASE( case5a, a_h, b_h ); + KRON_SUBCASE( case5a, a_d, b_d ); +} + +TEST_CASE("kron(case6a)" * doctest::test_suite("array::kron")) +{ + KRON_SUBCASE( case6a, a, b ); + KRON_SUBCASE( case6a, a_a, b_a ); + KRON_SUBCASE( case6a, a_f, b_f ); + KRON_SUBCASE( case6a, a_h, b_h ); + KRON_SUBCASE( case6a, a_d, b_d ); +} \ No newline at end of file diff --git a/tests/array/array/tensordot.cpp b/tests/array/array/tensordot.cpp new file mode 100644 index 000000000..83a3a53a6 --- /dev/null +++ b/tests/array/array/tensordot.cpp @@ -0,0 +1,267 @@ +#include "nmtools/array/array/tensordot.hpp" +#include "nmtools/testing/data/array/tensordot.hpp" +#include "nmtools/testing/doctest.hpp" + +#define TENSORDOT_SUBCASE(case_name, ...) \ +SUBCASE(#case_name) \ +{ \ + NMTOOLS_TESTING_USE_CASE(array,tensordot,case_name); \ + auto result = nmtools::array::tensordot(__VA_ARGS__); \ + NMTOOLS_ASSERT_EQUAL( nmtools::shape(result), nmtools::shape(expect::result) ); \ + NMTOOLS_ASSERT_CLOSE( result, expect::result ); \ +} + +TEST_CASE("tensordot(case1a)" * doctest::test_suite("array::tensordot")) +{ + // TODO: fix + // TENSORDOT_SUBCASE( case1a, a, b, axes ); + TENSORDOT_SUBCASE( case1a, a, b, axes_ct ); + TENSORDOT_SUBCASE( case1a, a_a, b_a, axes_ct ); + TENSORDOT_SUBCASE( case1a, a_f, b_f, axes_ct ); + TENSORDOT_SUBCASE( case1a, a_h, b_h, axes_ct ); + // TENSORDOT_SUBCASE( case1a, a_d, b_d, axes_ct ); +} + +TEST_CASE("tensordot(case2a)" * doctest::test_suite("array::tensordot")) +{ + TENSORDOT_SUBCASE( case2a, a, b ); + TENSORDOT_SUBCASE( case2a, a_a, b_a ); + TENSORDOT_SUBCASE( case2a, a_f, b_f ); + TENSORDOT_SUBCASE( case2a, a_h, b_h ); + // TODO: fix + // TENSORDOT_SUBCASE( case2a, a_d, b_d ); +} + +TEST_CASE("tensordot(case2b)" * doctest::test_suite("array::tensordot")) +{ + TENSORDOT_SUBCASE( case2b, a, b ); + TENSORDOT_SUBCASE( case2b, a_a, b_a ); + TENSORDOT_SUBCASE( case2b, a_f, b_f ); + TENSORDOT_SUBCASE( case2b, a_h, b_h ); + TENSORDOT_SUBCASE( case2b, a_d, b_d ); +} + +TEST_CASE("tensordot(case2c)" * doctest::test_suite("array::tensordot")) +{ + TENSORDOT_SUBCASE( case2c, a, b ); + TENSORDOT_SUBCASE( case2c, a_a, b_a ); + TENSORDOT_SUBCASE( case2c, a_f, b_f ); + TENSORDOT_SUBCASE( case2c, a_h, b_h ); + TENSORDOT_SUBCASE( case2c, a_d, b_d ); +} + +TEST_CASE("tensordot(case2d)" * doctest::test_suite("array::tensordot")) +{ + TENSORDOT_SUBCASE( case2d, a, b ); + TENSORDOT_SUBCASE( case2d, a_a, b_a ); + TENSORDOT_SUBCASE( case2d, a_f, b_f ); + TENSORDOT_SUBCASE( case2d, a_h, b_h ); + TENSORDOT_SUBCASE( case2d, a_d, b_d ); +} + +TEST_CASE("tensordot(case2e)" * doctest::test_suite("array::tensordot")) +{ + TENSORDOT_SUBCASE( case2e, a, b, axes ); + TENSORDOT_SUBCASE( case2e, a_a, b_a, axes ); + TENSORDOT_SUBCASE( case2e, a_f, b_f, axes ); + TENSORDOT_SUBCASE( case2e, a_h, b_h, axes ); + TENSORDOT_SUBCASE( case2e, a_d, b_d, axes ); + + TENSORDOT_SUBCASE( case2e, a, b, axes_ct ); + TENSORDOT_SUBCASE( case2e, a_a, b_a, axes_ct ); + TENSORDOT_SUBCASE( case2e, a_f, b_f, axes_ct ); + TENSORDOT_SUBCASE( case2e, a_h, b_h, axes_ct ); + TENSORDOT_SUBCASE( case2e, a_d, b_d, axes_ct ); +} + +TEST_CASE("tensordot(case2f)" * doctest::test_suite("array::tensordot")) +{ + TENSORDOT_SUBCASE( case2f, a, b, axes ); + TENSORDOT_SUBCASE( case2f, a_a, b_a, axes ); + TENSORDOT_SUBCASE( case2f, a_f, b_f, axes ); + TENSORDOT_SUBCASE( case2f, a_h, b_h, axes ); + TENSORDOT_SUBCASE( case2f, a_d, b_d, axes ); + + TENSORDOT_SUBCASE( case2f, a, b, axes_ct ); + TENSORDOT_SUBCASE( case2f, a_a, b_a, axes_ct ); + TENSORDOT_SUBCASE( case2f, a_f, b_f, axes_ct ); + TENSORDOT_SUBCASE( case2f, a_h, b_h, axes_ct ); + TENSORDOT_SUBCASE( case2f, a_d, b_d, axes_ct ); +} + +TEST_CASE("tensordot(case2f)" * doctest::test_suite("array::tensordot")) +{ + TENSORDOT_SUBCASE( case2f, a, b, axes ); + TENSORDOT_SUBCASE( case2f, a_a, b_a, axes ); + TENSORDOT_SUBCASE( case2f, a_f, b_f, axes ); + TENSORDOT_SUBCASE( case2f, a_h, b_h, axes ); + TENSORDOT_SUBCASE( case2f, a_d, b_d, axes ); + + TENSORDOT_SUBCASE( case2f, a, b, axes_ct ); + TENSORDOT_SUBCASE( case2f, a_a, b_a, axes_ct ); + TENSORDOT_SUBCASE( case2f, a_f, b_f, axes_ct ); + TENSORDOT_SUBCASE( case2f, a_h, b_h, axes_ct ); + TENSORDOT_SUBCASE( case2f, a_d, b_d, axes_ct ); +} + +TEST_CASE("tensordot(case2g)" * doctest::test_suite("array::tensordot")) +{ + TENSORDOT_SUBCASE( case2g, a, b, axes ); + TENSORDOT_SUBCASE( case2g, a_a, b_a, axes ); + TENSORDOT_SUBCASE( case2g, a_f, b_f, axes ); + TENSORDOT_SUBCASE( case2g, a_h, b_h, axes ); + TENSORDOT_SUBCASE( case2g, a_d, b_d, axes ); + + TENSORDOT_SUBCASE( case2g, a, b, axes_ct ); + TENSORDOT_SUBCASE( case2g, a_a, b_a, axes_ct ); + TENSORDOT_SUBCASE( case2g, a_f, b_f, axes_ct ); + TENSORDOT_SUBCASE( case2g, a_h, b_h, axes_ct ); + TENSORDOT_SUBCASE( case2g, a_d, b_d, axes_ct ); + +} + +TEST_CASE("tensordot(case2h)" * doctest::test_suite("array::tensordot")) +{ + TENSORDOT_SUBCASE( case2h, a, b, axes ); + TENSORDOT_SUBCASE( case2h, a_a, b_a, axes ); + TENSORDOT_SUBCASE( case2h, a_f, b_f, axes ); + TENSORDOT_SUBCASE( case2h, a_h, b_h, axes ); + TENSORDOT_SUBCASE( case2h, a_d, b_d, axes ); + + TENSORDOT_SUBCASE( case2h, a, b, axes_ct ); + TENSORDOT_SUBCASE( case2h, a_a, b_a, axes_ct ); + TENSORDOT_SUBCASE( case2h, a_f, b_f, axes_ct ); + TENSORDOT_SUBCASE( case2h, a_h, b_h, axes_ct ); + TENSORDOT_SUBCASE( case2h, a_d, b_d, axes_ct ); +} + +TEST_CASE("tensordot(case2i)" * doctest::test_suite("array::tensordot")) +{ + TENSORDOT_SUBCASE( case2i, a, b, axes ); + TENSORDOT_SUBCASE( case2i, a_a, b_a, axes ); + TENSORDOT_SUBCASE( case2i, a_f, b_f, axes ); + TENSORDOT_SUBCASE( case2i, a_h, b_h, axes ); + TENSORDOT_SUBCASE( case2i, a_d, b_d, axes ); + + TENSORDOT_SUBCASE( case2i, a, b, axes_ct ); + TENSORDOT_SUBCASE( case2i, a_a, b_a, axes_ct ); + TENSORDOT_SUBCASE( case2i, a_f, b_f, axes_ct ); + TENSORDOT_SUBCASE( case2i, a_h, b_h, axes_ct ); + TENSORDOT_SUBCASE( case2i, a_d, b_d, axes_ct ); +} + +TEST_CASE("tensordot(case2j)" * doctest::test_suite("array::tensordot")) +{ + TENSORDOT_SUBCASE( case2j, a, b, axes ); + TENSORDOT_SUBCASE( case2j, a_a, b_a, axes ); + TENSORDOT_SUBCASE( case2j, a_f, b_f, axes ); + TENSORDOT_SUBCASE( case2j, a_h, b_h, axes ); + TENSORDOT_SUBCASE( case2j, a_d, b_d, axes ); + + TENSORDOT_SUBCASE( case2j, a, b, axes_ct ); + TENSORDOT_SUBCASE( case2j, a_a, b_a, axes_ct ); + TENSORDOT_SUBCASE( case2j, a_f, b_f, axes_ct ); + TENSORDOT_SUBCASE( case2j, a_h, b_h, axes_ct ); + TENSORDOT_SUBCASE( case2j, a_d, b_d, axes_ct ); +} + +TEST_CASE("tensordot(case3a)" * doctest::test_suite("array::tensordot")) +{ + TENSORDOT_SUBCASE( case3a, a, b ); + TENSORDOT_SUBCASE( case3a, a_a, b_a ); + TENSORDOT_SUBCASE( case3a, a_f, b_f ); + TENSORDOT_SUBCASE( case3a, a_h, b_h ); + TENSORDOT_SUBCASE( case3a, a_d, b_d ); +} + +TEST_CASE("tensordot(case3b)" * doctest::test_suite("array::tensordot")) +{ + TENSORDOT_SUBCASE( case3b, a, b, axes ); + TENSORDOT_SUBCASE( case3b, a_a, b_a, axes ); + TENSORDOT_SUBCASE( case3b, a_f, b_f, axes ); + TENSORDOT_SUBCASE( case3b, a_h, b_h, axes ); + TENSORDOT_SUBCASE( case3b, a_d, b_d, axes ); + + TENSORDOT_SUBCASE( case3b, a, b, axes_ct ); + TENSORDOT_SUBCASE( case3b, a_a, b_a, axes_ct ); + TENSORDOT_SUBCASE( case3b, a_f, b_f, axes_ct ); + TENSORDOT_SUBCASE( case3b, a_h, b_h, axes_ct ); + TENSORDOT_SUBCASE( case3b, a_d, b_d, axes_ct ); +} + +TEST_CASE("tensordot(case3c)" * doctest::test_suite("array::tensordot")) +{ + TENSORDOT_SUBCASE( case3c, a, b, axes ); + TENSORDOT_SUBCASE( case3c, a_a, b_a, axes ); + TENSORDOT_SUBCASE( case3c, a_f, b_f, axes ); + TENSORDOT_SUBCASE( case3c, a_h, b_h, axes ); + TENSORDOT_SUBCASE( case3c, a_d, b_d, axes ); + + TENSORDOT_SUBCASE( case3c, a, b, axes_ct ); + TENSORDOT_SUBCASE( case3c, a_a, b_a, axes_ct ); + TENSORDOT_SUBCASE( case3c, a_f, b_f, axes_ct ); + TENSORDOT_SUBCASE( case3c, a_h, b_h, axes_ct ); + TENSORDOT_SUBCASE( case3c, a_d, b_d, axes_ct ); +} + + +TEST_CASE("tensordot(case3d)" * doctest::test_suite("array::tensordot")) +{ + TENSORDOT_SUBCASE( case3d, a, b, axes ); + TENSORDOT_SUBCASE( case3d, a_a, b_a, axes ); + TENSORDOT_SUBCASE( case3d, a_f, b_f, axes ); + TENSORDOT_SUBCASE( case3d, a_h, b_h, axes ); + TENSORDOT_SUBCASE( case3d, a_d, b_d, axes ); + + TENSORDOT_SUBCASE( case3d, a, b, axes_ct ); + TENSORDOT_SUBCASE( case3d, a_a, b_a, axes_ct ); + TENSORDOT_SUBCASE( case3d, a_f, b_f, axes_ct ); + TENSORDOT_SUBCASE( case3d, a_h, b_h, axes_ct ); + TENSORDOT_SUBCASE( case3d, a_d, b_d, axes_ct ); +} + +TEST_CASE("tensordot(case3e)" * doctest::test_suite("array::tensordot")) +{ + // TODO: fix compilation + // TENSORDOT_SUBCASE( case3e, a, b, axes ); + // TENSORDOT_SUBCASE( case3e, a_a, b_a, axes ); + // TENSORDOT_SUBCASE( case3e, a_f, b_f, axes ); + // TENSORDOT_SUBCASE( case3e, a_h, b_h, axes ); + // TENSORDOT_SUBCASE( case3e, a_d, b_d, axes ); + + TENSORDOT_SUBCASE( case3e, a, b, axes_ct ); + TENSORDOT_SUBCASE( case3e, a_a, b_a, axes_ct ); + TENSORDOT_SUBCASE( case3e, a_f, b_f, axes_ct ); + TENSORDOT_SUBCASE( case3e, a_h, b_h, axes_ct ); + // TENSORDOT_SUBCASE( case3e, a_d, b_d, axes_ct ); +} + +TEST_CASE("tensordot(case4a)" * doctest::test_suite("array::tensordot")) +{ + TENSORDOT_SUBCASE( case4a, a, b, axes ); + TENSORDOT_SUBCASE( case4a, a_a, b_a, axes ); + TENSORDOT_SUBCASE( case4a, a_f, b_f, axes ); + TENSORDOT_SUBCASE( case4a, a_h, b_h, axes ); + TENSORDOT_SUBCASE( case4a, a_d, b_d, axes ); + + TENSORDOT_SUBCASE( case4a, a, b, axes_ct ); + TENSORDOT_SUBCASE( case4a, a_a, b_a, axes_ct ); + TENSORDOT_SUBCASE( case4a, a_f, b_f, axes_ct ); + TENSORDOT_SUBCASE( case4a, a_h, b_h, axes_ct ); + TENSORDOT_SUBCASE( case4a, a_d, b_d, axes_ct ); +} + +TEST_CASE("tensordot(case4b)" * doctest::test_suite("array::tensordot")) +{ + TENSORDOT_SUBCASE( case4b, a, b, axes ); + TENSORDOT_SUBCASE( case4b, a_a, b_a, axes ); + TENSORDOT_SUBCASE( case4b, a_f, b_f, axes ); + TENSORDOT_SUBCASE( case4b, a_h, b_h, axes ); + TENSORDOT_SUBCASE( case4b, a_d, b_d, axes ); + + TENSORDOT_SUBCASE( case4b, a, b, axes_ct ); + TENSORDOT_SUBCASE( case4b, a_a, b_a, axes_ct ); + TENSORDOT_SUBCASE( case4b, a_f, b_f, axes_ct ); + TENSORDOT_SUBCASE( case4b, a_h, b_h, axes_ct ); + TENSORDOT_SUBCASE( case4b, a_d, b_d, axes_ct ); +} \ No newline at end of file diff --git a/tests/array/array/vecdot.cpp b/tests/array/array/vecdot.cpp new file mode 100644 index 000000000..2a4b8513c --- /dev/null +++ b/tests/array/array/vecdot.cpp @@ -0,0 +1,149 @@ +#include "nmtools/array/array/vecdot.hpp" +#include "nmtools/testing/data/array/vecdot.hpp" +#include "nmtools/testing/doctest.hpp" + +using nmtools::None; + +#define VECDOT_SUBCASE(case_name,...) \ +SUBCASE(#case_name) \ +{ \ + NMTOOLS_TESTING_USE_CASE(array,vecdot,case_name); \ + auto result = nmtools::array::vecdot(__VA_ARGS__); \ + NMTOOLS_ASSERT_EQUAL( nmtools::shape(result), nmtools::shape(expect::result) ); \ + NMTOOLS_ASSERT_CLOSE( result, expect::result ); \ +} + +TEST_CASE("vecdot(case1a)" * doctest::test_suite("array::vecdot")) +{ + VECDOT_SUBCASE( case1a, a, b ); + VECDOT_SUBCASE( case1a, a_a, b_a ); + VECDOT_SUBCASE( case1a, a_f, b_f ); + VECDOT_SUBCASE( case1a, a_h, b_h ); + // VECDOT_SUBCASE( case1a, a_d, b_d ); +} + +TEST_CASE("vecdot(case1b)" * doctest::test_suite("array::vecdot")) +{ + VECDOT_SUBCASE( case1b, a, b, None, keepdims ); + VECDOT_SUBCASE( case1b, a_a, b_a, None, keepdims ); + VECDOT_SUBCASE( case1b, a_f, b_f, None, keepdims ); + VECDOT_SUBCASE( case1b, a_h, b_h, None, keepdims ); + VECDOT_SUBCASE( case1b, a_d, b_d, None, keepdims ); +} + +TEST_CASE("vecdot(case1c)" * doctest::test_suite("array::vecdot")) +{ + VECDOT_SUBCASE( case1c, a, b ); + VECDOT_SUBCASE( case1c, a_a, b_a ); + VECDOT_SUBCASE( case1c, a_f, b_f ); + VECDOT_SUBCASE( case1c, a_h, b_h ); + // VECDOT_SUBCASE( case1c, a_d, b_d ); +} + +TEST_CASE("vecdot(case1d)" * doctest::test_suite("array::vecdot")) +{ + VECDOT_SUBCASE( case1d, a, b, None, keepdims ); + VECDOT_SUBCASE( case1d, a_a, b_a, None, keepdims ); + VECDOT_SUBCASE( case1d, a_f, b_f, None, keepdims ); + VECDOT_SUBCASE( case1d, a_h, b_h, None, keepdims ); + VECDOT_SUBCASE( case1d, a_d, b_d, None, keepdims ); +} + +TEST_CASE("vecdot(case1e)" * doctest::test_suite("array::vecdot")) +{ + VECDOT_SUBCASE( case1e, a, b ); + VECDOT_SUBCASE( case1e, a_a, b_a ); + VECDOT_SUBCASE( case1e, a_f, b_f ); + VECDOT_SUBCASE( case1e, a_h, b_h ); + // VECDOT_SUBCASE( case1e, a_d, b_d ); +} + +TEST_CASE("vecdot(case1f)" * doctest::test_suite("array::vecdot")) +{ + VECDOT_SUBCASE( case1f, a, b, None, keepdims ); + VECDOT_SUBCASE( case1f, a_a, b_a, None, keepdims ); + VECDOT_SUBCASE( case1f, a_f, b_f, None, keepdims ); + VECDOT_SUBCASE( case1f, a_h, b_h, None, keepdims ); + VECDOT_SUBCASE( case1f, a_d, b_d, None, keepdims ); +} + +TEST_CASE("vecdot(case2a)" * doctest::test_suite("array::vecdot")) +{ + VECDOT_SUBCASE( case2a, a, b ); + VECDOT_SUBCASE( case2a, a_a, b_a ); + VECDOT_SUBCASE( case2a, a_f, b_f ); + VECDOT_SUBCASE( case2a, a_h, b_h ); + VECDOT_SUBCASE( case2a, a_d, b_d ); +} + +TEST_CASE("vecdot(case2b)" * doctest::test_suite("array::vecdot")) +{ + VECDOT_SUBCASE( case2b, a, b, None, keepdims ); + VECDOT_SUBCASE( case2b, a_a, b_a, None, keepdims ); + VECDOT_SUBCASE( case2b, a_f, b_f, None, keepdims ); + VECDOT_SUBCASE( case2b, a_h, b_h, None, keepdims ); + VECDOT_SUBCASE( case2b, a_d, b_d, None, keepdims ); +} + +TEST_CASE("vecdot(case2c)" * doctest::test_suite("array::vecdot")) +{ + VECDOT_SUBCASE( case2c, a, b ); + VECDOT_SUBCASE( case2c, a_a, b_a ); + VECDOT_SUBCASE( case2c, a_f, b_f ); + VECDOT_SUBCASE( case2c, a_h, b_h ); + VECDOT_SUBCASE( case2c, a_d, b_d ); +} + +TEST_CASE("vecdot(case2d)" * doctest::test_suite("array::vecdot")) +{ + VECDOT_SUBCASE( case2d, a, b, None, keepdims ); + VECDOT_SUBCASE( case2d, a_a, b_a, None, keepdims ); + VECDOT_SUBCASE( case2d, a_f, b_f, None, keepdims ); + VECDOT_SUBCASE( case2d, a_h, b_h, None, keepdims ); + VECDOT_SUBCASE( case2d, a_d, b_d, None, keepdims ); +} + +TEST_CASE("vecdot(case2e)" * doctest::test_suite("array::vecdot")) +{ + VECDOT_SUBCASE( case2e, a, b ); + VECDOT_SUBCASE( case2e, a_a, b_a ); + VECDOT_SUBCASE( case2e, a_f, b_f ); + VECDOT_SUBCASE( case2e, a_h, b_h ); + VECDOT_SUBCASE( case2e, a_d, b_d ); +} + +TEST_CASE("vecdot(case2f)" * doctest::test_suite("array::vecdot")) +{ + VECDOT_SUBCASE( case2f, a, b, None, keepdims ); + VECDOT_SUBCASE( case2f, a_a, b_a, None, keepdims ); + VECDOT_SUBCASE( case2f, a_f, b_f, None, keepdims ); + VECDOT_SUBCASE( case2f, a_h, b_h, None, keepdims ); + VECDOT_SUBCASE( case2f, a_d, b_d, None, keepdims ); +} + +TEST_CASE("vecdot(case3a)" * doctest::test_suite("array::vecdot")) +{ + VECDOT_SUBCASE( case3a, a, b ); + VECDOT_SUBCASE( case3a, a_a, b_a ); + VECDOT_SUBCASE( case3a, a_f, b_f ); + VECDOT_SUBCASE( case3a, a_h, b_h ); + VECDOT_SUBCASE( case3a, a_d, b_d ); +} + +TEST_CASE("vecdot(case3b)" * doctest::test_suite("array::vecdot")) +{ + VECDOT_SUBCASE( case3b, a, b, None, keepdims ); + VECDOT_SUBCASE( case3b, a_a, b_a, None, keepdims ); + VECDOT_SUBCASE( case3b, a_f, b_f, None, keepdims ); + VECDOT_SUBCASE( case3b, a_h, b_h, None, keepdims ); + VECDOT_SUBCASE( case3b, a_d, b_d, None, keepdims ); +} + +TEST_CASE("vecdot(case3c)" * doctest::test_suite("array::vecdot")) +{ + VECDOT_SUBCASE( case3c, a, b ); + VECDOT_SUBCASE( case3c, a_a, b_a ); + VECDOT_SUBCASE( case3c, a_f, b_f ); + VECDOT_SUBCASE( case3c, a_h, b_h ); + VECDOT_SUBCASE( case3c, a_d, b_d ); +} diff --git a/tests/index/CMakeLists.txt b/tests/index/CMakeLists.txt index 2b703358f..421bbb42d 100644 --- a/tests/index/CMakeLists.txt +++ b/tests/index/CMakeLists.txt @@ -62,6 +62,7 @@ if (NMTOOLS_INDEX_TEST_ALL) src/filter.cpp src/free_axes.cpp src/gather.cpp + src/kron.cpp src/insert_index.cpp src/logical_not.cpp src/matmul.cpp @@ -88,6 +89,7 @@ if (NMTOOLS_INDEX_TEST_ALL) src/slice.cpp src/take.cpp src/take_along_axis.cpp + src/tensordot.cpp src/tile.cpp src/tril.cpp ) diff --git a/tests/index/src/kron.cpp b/tests/index/src/kron.cpp new file mode 100644 index 000000000..b16f1173a --- /dev/null +++ b/tests/index/src/kron.cpp @@ -0,0 +1,158 @@ +#include "nmtools/array/view/kron.hpp" +#include "nmtools/testing/data/index/kron.hpp" +#include "nmtools/testing/doctest.hpp" + +#define KRON_DST_TRANSPOSE_SUBCASE(case_name,...) \ +SUBCASE(#case_name) \ +{ \ + NMTOOLS_TESTING_USE_CASE( index, kron_dst_transpose, case_name ); \ + auto result = nmtools::index::kron_dst_transpose( __VA_ARGS__ ); \ + NMTOOLS_ASSERT_EQUAL( result, expect::result ); \ +} + +TEST_CASE("kron_dst_transpose(case1a)" * doctest::test_suite("index::kron_dst_transpose")) +{ + KRON_DST_TRANSPOSE_SUBCASE( case1a, lhs_dim, rhs_dim ); + KRON_DST_TRANSPOSE_SUBCASE( case1a, lhs_dim_ct, rhs_dim_ct ); + KRON_DST_TRANSPOSE_SUBCASE( case1a, lhs_dim_cl, rhs_dim_cl ); +} + +TEST_CASE("kron_dst_transpose(case1b)" * doctest::test_suite("index::kron_dst_transpose")) +{ + KRON_DST_TRANSPOSE_SUBCASE( case1b, lhs_dim, rhs_dim ); + KRON_DST_TRANSPOSE_SUBCASE( case1b, lhs_dim_ct, rhs_dim_ct ); + KRON_DST_TRANSPOSE_SUBCASE( case1b, lhs_dim_cl, rhs_dim_cl ); +} + +TEST_CASE("kron_dst_transpose(case1c)" * doctest::test_suite("index::kron_dst_transpose")) +{ + KRON_DST_TRANSPOSE_SUBCASE( case1c, lhs_dim, rhs_dim ); + KRON_DST_TRANSPOSE_SUBCASE( case1c, lhs_dim_ct, rhs_dim_ct ); + KRON_DST_TRANSPOSE_SUBCASE( case1c, lhs_dim_cl, rhs_dim_cl ); +} + +TEST_CASE("kron_dst_transpose(case1d)" * doctest::test_suite("index::kron_dst_transpose")) +{ + KRON_DST_TRANSPOSE_SUBCASE( case1d, lhs_dim, rhs_dim ); + KRON_DST_TRANSPOSE_SUBCASE( case1d, lhs_dim_ct, rhs_dim_ct ); + KRON_DST_TRANSPOSE_SUBCASE( case1d, lhs_dim_cl, rhs_dim_cl ); +} + +TEST_CASE("kron_dst_transpose(case1e)" * doctest::test_suite("index::kron_dst_transpose")) +{ + KRON_DST_TRANSPOSE_SUBCASE( case1e, lhs_dim, rhs_dim ); + KRON_DST_TRANSPOSE_SUBCASE( case1e, lhs_dim_ct, rhs_dim_ct ); + KRON_DST_TRANSPOSE_SUBCASE( case1e, lhs_dim_cl, rhs_dim_cl ); +} + +TEST_CASE("kron_dst_transpose(case2a)" * doctest::test_suite("index::kron_dst_transpose")) +{ + KRON_DST_TRANSPOSE_SUBCASE( case2a, lhs_dim, rhs_dim ); + KRON_DST_TRANSPOSE_SUBCASE( case2a, lhs_dim_ct, rhs_dim_ct ); + KRON_DST_TRANSPOSE_SUBCASE( case2a, lhs_dim_cl, rhs_dim_cl ); +} + +TEST_CASE("kron_dst_transpose(case2b)" * doctest::test_suite("index::kron_dst_transpose")) +{ + KRON_DST_TRANSPOSE_SUBCASE( case2b, lhs_dim, rhs_dim ); + KRON_DST_TRANSPOSE_SUBCASE( case2b, lhs_dim_ct, rhs_dim_ct ); + KRON_DST_TRANSPOSE_SUBCASE( case2b, lhs_dim_cl, rhs_dim_cl ); +} + +TEST_CASE("kron_dst_transpose(case2c)" * doctest::test_suite("index::kron_dst_transpose")) +{ + KRON_DST_TRANSPOSE_SUBCASE( case2c, lhs_dim, rhs_dim ); + KRON_DST_TRANSPOSE_SUBCASE( case2c, lhs_dim_ct, rhs_dim_ct ); + KRON_DST_TRANSPOSE_SUBCASE( case2c, lhs_dim_cl, rhs_dim_cl ); +} + +TEST_CASE("kron_dst_transpose(case2d)" * doctest::test_suite("index::kron_dst_transpose")) +{ + KRON_DST_TRANSPOSE_SUBCASE( case2d, lhs_dim, rhs_dim ); + KRON_DST_TRANSPOSE_SUBCASE( case2d, lhs_dim_ct, rhs_dim_ct ); + KRON_DST_TRANSPOSE_SUBCASE( case2d, lhs_dim_cl, rhs_dim_cl ); +} + +TEST_CASE("kron_dst_transpose(case2e)" * doctest::test_suite("index::kron_dst_transpose")) +{ + KRON_DST_TRANSPOSE_SUBCASE( case2e, lhs_dim, rhs_dim ); + KRON_DST_TRANSPOSE_SUBCASE( case2e, lhs_dim_ct, rhs_dim_ct ); + KRON_DST_TRANSPOSE_SUBCASE( case2e, lhs_dim_cl, rhs_dim_cl ); +} + +TEST_CASE("kron_dst_transpose(case2f)" * doctest::test_suite("index::kron_dst_transpose")) +{ + KRON_DST_TRANSPOSE_SUBCASE( case2f, lhs_dim, rhs_dim ); + KRON_DST_TRANSPOSE_SUBCASE( case2f, lhs_dim_ct, rhs_dim_ct ); + KRON_DST_TRANSPOSE_SUBCASE( case2f, lhs_dim_cl, rhs_dim_cl ); +} + +TEST_CASE("kron_dst_transpose(case3a)" * doctest::test_suite("index::kron_dst_transpose")) +{ + KRON_DST_TRANSPOSE_SUBCASE( case3a, lhs_dim, rhs_dim ); + KRON_DST_TRANSPOSE_SUBCASE( case3a, lhs_dim_ct, rhs_dim_ct ); + KRON_DST_TRANSPOSE_SUBCASE( case3a, lhs_dim_cl, rhs_dim_cl ); +} + +TEST_CASE("kron_dst_transpose(case3b)" * doctest::test_suite("index::kron_dst_transpose")) +{ + KRON_DST_TRANSPOSE_SUBCASE( case3b, lhs_dim, rhs_dim ); + KRON_DST_TRANSPOSE_SUBCASE( case3b, lhs_dim_ct, rhs_dim_ct ); + KRON_DST_TRANSPOSE_SUBCASE( case3b, lhs_dim_cl, rhs_dim_cl ); +} + +TEST_CASE("kron_dst_transpose(case3c)" * doctest::test_suite("index::kron_dst_transpose")) +{ + KRON_DST_TRANSPOSE_SUBCASE( case3c, lhs_dim, rhs_dim ); + KRON_DST_TRANSPOSE_SUBCASE( case3c, lhs_dim_ct, rhs_dim_ct ); + KRON_DST_TRANSPOSE_SUBCASE( case3c, lhs_dim_cl, rhs_dim_cl ); +} + +TEST_CASE("kron_dst_transpose(case3d)" * doctest::test_suite("index::kron_dst_transpose")) +{ + KRON_DST_TRANSPOSE_SUBCASE( case3d, lhs_dim, rhs_dim ); + KRON_DST_TRANSPOSE_SUBCASE( case3d, lhs_dim_ct, rhs_dim_ct ); + KRON_DST_TRANSPOSE_SUBCASE( case3d, lhs_dim_cl, rhs_dim_cl ); +} + +TEST_CASE("kron_dst_transpose(case4a)" * doctest::test_suite("index::kron_dst_transpose")) +{ + KRON_DST_TRANSPOSE_SUBCASE( case4a, lhs_dim, rhs_dim ); + KRON_DST_TRANSPOSE_SUBCASE( case4a, lhs_dim_ct, rhs_dim_ct ); + KRON_DST_TRANSPOSE_SUBCASE( case4a, lhs_dim_cl, rhs_dim_cl ); +} + +TEST_CASE("kron_dst_transpose(case4b)" * doctest::test_suite("index::kron_dst_transpose")) +{ + KRON_DST_TRANSPOSE_SUBCASE( case4b, lhs_dim, rhs_dim ); + KRON_DST_TRANSPOSE_SUBCASE( case4b, lhs_dim_ct, rhs_dim_ct ); + KRON_DST_TRANSPOSE_SUBCASE( case4b, lhs_dim_cl, rhs_dim_cl ); +} + +TEST_CASE("kron_dst_transpose(case4c)" * doctest::test_suite("index::kron_dst_transpose")) +{ + KRON_DST_TRANSPOSE_SUBCASE( case4c, lhs_dim, rhs_dim ); + KRON_DST_TRANSPOSE_SUBCASE( case4c, lhs_dim_ct, rhs_dim_ct ); + KRON_DST_TRANSPOSE_SUBCASE( case4c, lhs_dim_cl, rhs_dim_cl ); +} + +TEST_CASE("kron_dst_transpose(case4d)" * doctest::test_suite("index::kron_dst_transpose")) +{ + KRON_DST_TRANSPOSE_SUBCASE( case4d, lhs_dim, rhs_dim ); + KRON_DST_TRANSPOSE_SUBCASE( case4d, lhs_dim_ct, rhs_dim_ct ); + KRON_DST_TRANSPOSE_SUBCASE( case4d, lhs_dim_cl, rhs_dim_cl ); +} + +TEST_CASE("kron_dst_transpose(case5a)" * doctest::test_suite("index::kron_dst_transpose")) +{ + KRON_DST_TRANSPOSE_SUBCASE( case5a, lhs_dim, rhs_dim ); + KRON_DST_TRANSPOSE_SUBCASE( case5a, lhs_dim_ct, rhs_dim_ct ); + KRON_DST_TRANSPOSE_SUBCASE( case5a, lhs_dim_cl, rhs_dim_cl ); +} + +TEST_CASE("kron_dst_transpose(case6a)" * doctest::test_suite("index::kron_dst_transpose")) +{ + KRON_DST_TRANSPOSE_SUBCASE( case6a, lhs_dim, rhs_dim ); + KRON_DST_TRANSPOSE_SUBCASE( case6a, lhs_dim_ct, rhs_dim_ct ); + KRON_DST_TRANSPOSE_SUBCASE( case6a, lhs_dim_cl, rhs_dim_cl ); +} \ No newline at end of file diff --git a/tests/index/src/tensordot.cpp b/tests/index/src/tensordot.cpp new file mode 100644 index 000000000..d820b75dd --- /dev/null +++ b/tests/index/src/tensordot.cpp @@ -0,0 +1,339 @@ +#include "nmtools/array/view/tensordot.hpp" +#include "nmtools/testing/data/index/tensordot.hpp" +#include "nmtools/testing/doctest.hpp" + +#define TENSORDOT_LHS_TRANSPOSE_SUBCASE(case_name, ...) \ +SUBCASE(#case_name) \ +{ \ + NMTOOLS_TESTING_USE_CASE(index,tensordot_lhs_transpose,case_name) \ + auto result = nmtools::index::tensordot_lhs_transpose(__VA_ARGS__); \ + NMTOOLS_ASSERT_EQUAL( result, expect::result ); \ +} + +TEST_CASE("tensordot_lhs_transpose(case1a)" * doctest::test_suite("index::tensordot_lhs_transpose")) +{ + TENSORDOT_LHS_TRANSPOSE_SUBCASE( case1a, dim, axes ); + TENSORDOT_LHS_TRANSPOSE_SUBCASE( case1a, dim_ct, axes_ct ); +} + +TEST_CASE("tensordot_lhs_transpose(case2a)" * doctest::test_suite("index::tensordot_lhs_transpose")) +{ + TENSORDOT_LHS_TRANSPOSE_SUBCASE( case2a, dim, axes ); + TENSORDOT_LHS_TRANSPOSE_SUBCASE( case2a, dim_ct, axes_ct ); +} + +TEST_CASE("tensordot_lhs_transpose(case2b)" * doctest::test_suite("index::tensordot_lhs_transpose")) +{ + TENSORDOT_LHS_TRANSPOSE_SUBCASE( case2b, dim, axes ); + TENSORDOT_LHS_TRANSPOSE_SUBCASE( case2b, dim_ct, axes_ct ); +} + +TEST_CASE("tensordot_lhs_transpose(case2e)" * doctest::test_suite("index::tensordot_lhs_transpose")) +{ + TENSORDOT_LHS_TRANSPOSE_SUBCASE( case2e, dim, axes ); + TENSORDOT_LHS_TRANSPOSE_SUBCASE( case2e, dim, axes_a ); + TENSORDOT_LHS_TRANSPOSE_SUBCASE( case2e, dim, axes_f ); + TENSORDOT_LHS_TRANSPOSE_SUBCASE( case2e, dim, axes_h ); + TENSORDOT_LHS_TRANSPOSE_SUBCASE( case2e, dim, axes_v ); + TENSORDOT_LHS_TRANSPOSE_SUBCASE( case2e, dim_ct, axes_ct ); +} + +TEST_CASE("tensordot_lhs_transpose(case2h)" * doctest::test_suite("index::tensordot_lhs_transpose")) +{ + TENSORDOT_LHS_TRANSPOSE_SUBCASE( case2h, dim, axes ); + TENSORDOT_LHS_TRANSPOSE_SUBCASE( case2h, dim, axes_a ); + TENSORDOT_LHS_TRANSPOSE_SUBCASE( case2h, dim, axes_f ); + TENSORDOT_LHS_TRANSPOSE_SUBCASE( case2h, dim, axes_h ); + TENSORDOT_LHS_TRANSPOSE_SUBCASE( case2h, dim, axes_v ); + TENSORDOT_LHS_TRANSPOSE_SUBCASE( case2h, dim_ct, axes_ct ); +} + +TEST_CASE("tensordot_lhs_transpose(case3a)" * doctest::test_suite("index::tensordot_lhs_transpose")) +{ + TENSORDOT_LHS_TRANSPOSE_SUBCASE( case3a, dim, axes ); + TENSORDOT_LHS_TRANSPOSE_SUBCASE( case3a, dim_ct, axes_ct ); +} + +TEST_CASE("tensordot_lhs_transpose(case3b)" * doctest::test_suite("index::tensordot_lhs_transpose")) +{ + TENSORDOT_LHS_TRANSPOSE_SUBCASE( case3b, dim, axes ); + TENSORDOT_LHS_TRANSPOSE_SUBCASE( case3b, dim_ct, axes_ct ); +} + +TEST_CASE("tensordot_lhs_transpose(case3d)" * doctest::test_suite("index::tensordot_lhs_transpose")) +{ + TENSORDOT_LHS_TRANSPOSE_SUBCASE( case3d, dim, axes ); + TENSORDOT_LHS_TRANSPOSE_SUBCASE( case3d, dim_ct, axes_ct ); +} + +TEST_CASE("tensordot_lhs_transpose(case3e)" * doctest::test_suite("index::tensordot_lhs_transpose")) +{ + TENSORDOT_LHS_TRANSPOSE_SUBCASE( case3e, dim, axes ); + TENSORDOT_LHS_TRANSPOSE_SUBCASE( case3e, dim_ct, axes_ct ); +} + +#define TENSORDOT_RHS_TRANSPOSE_SUBCASE(case_name, ...) \ +SUBCASE(#case_name) \ +{ \ + NMTOOLS_TESTING_USE_CASE(index,tensordot_rhs_transpose,case_name); \ + auto result = nmtools::index::tensordot_rhs_transpose(__VA_ARGS__); \ + NMTOOLS_ASSERT_EQUAL( result, expect::result ); \ +} + +TEST_CASE("tensordot_rhs_transpose(case1a)" * doctest::test_suite("index::tensordot_lhs_transpose")) +{ + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case1a, dim, axes ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case1a, dim, axes_a ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case1a, dim, axes_f ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case1a, dim, axes_h ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case1a, dim, axes_v ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case1a, dim_ct, axes_ct ); +} + +TEST_CASE("tensordot_rhs_transpose(case2a)" * doctest::test_suite("index::tensordot_lhs_transpose")) +{ + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2a, dim, axes ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2a, dim, axes_a ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2a, dim, axes_f ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2a, dim, axes_h ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2a, dim, axes_v ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2a, dim_ct, axes_ct ); +} + +TEST_CASE("tensordot_rhs_transpose(case2b)" * doctest::test_suite("index::tensordot_lhs_transpose")) +{ + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2b, dim, axes ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2b, dim, axes_a ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2b, dim, axes_f ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2b, dim, axes_h ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2b, dim, axes_v ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2b, dim_ct, axes_ct ); +} + +TEST_CASE("tensordot_rhs_transpose(case2d)" * doctest::test_suite("index::tensordot_lhs_transpose")) +{ + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2d, dim, axes ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2d, dim, axes_a ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2d, dim, axes_f ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2d, dim, axes_h ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2d, dim, axes_v ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2d, dim_ct, axes_ct ); +} + +TEST_CASE("tensordot_rhs_transpose(case2e)" * doctest::test_suite("index::tensordot_lhs_transpose")) +{ + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2e, dim, axes ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2e, dim, axes_a ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2e, dim, axes_f ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2e, dim, axes_h ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2e, dim, axes_v ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2e, dim_ct, axes_ct ); +} + +TEST_CASE("tensordot_rhs_transpose(case2g)" * doctest::test_suite("index::tensordot_lhs_transpose")) +{ + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2g, dim, axes ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2g, dim, axes_a ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2g, dim, axes_f ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2g, dim, axes_h ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2g, dim, axes_v ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2g, dim_ct, axes_ct ); +} + +TEST_CASE("tensordot_rhs_transpose(case2i)" * doctest::test_suite("index::tensordot_lhs_transpose")) +{ + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2i, dim, axes ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2i, dim, axes_a ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2i, dim, axes_f ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2i, dim, axes_h ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2i, dim, axes_v ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2i, dim_ct, axes_ct ); +} + +TEST_CASE("tensordot_rhs_transpose(case2j)" * doctest::test_suite("index::tensordot_lhs_transpose")) +{ + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2j, dim, axes ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2j, dim, axes_a ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2j, dim, axes_f ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2j, dim, axes_h ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2j, dim, axes_v ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case2j, dim_ct, axes_ct ); +} + +TEST_CASE("tensordot_rhs_transpose(case3a)" * doctest::test_suite("index::tensordot_lhs_transpose")) +{ + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case3a, dim, axes ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case3a, dim, axes_a ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case3a, dim, axes_f ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case3a, dim, axes_h ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case3a, dim, axes_v ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case3a, dim_ct, axes_ct ); +} + +TEST_CASE("tensordot_rhs_transpose(case3c)" * doctest::test_suite("index::tensordot_lhs_transpose")) +{ + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case3c, dim, axes ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case3c, dim, axes_a ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case3c, dim, axes_f ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case3c, dim, axes_h ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case3c, dim, axes_v ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case3c, dim_ct, axes_ct ); +} + +TEST_CASE("tensordot_rhs_transpose(case3d)" * doctest::test_suite("index::tensordot_lhs_transpose")) +{ + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case3d, dim, axes ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case3d, dim, axes_a ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case3d, dim, axes_f ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case3d, dim, axes_h ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case3d, dim, axes_v ); + TENSORDOT_RHS_TRANSPOSE_SUBCASE( case3d, dim_ct, axes_ct ); +} + +#define TENSORDOT_LHS_RESHAPE_SUBCASE(case_name,...) \ +SUBCASE(#case_name) \ +{ \ + NMTOOLS_TESTING_USE_CASE( index, tensordot_lhs_reshape, case_name ); \ + auto result = nmtools::index::tensordot_lhs_reshape(__VA_ARGS__); \ + NMTOOLS_ASSERT_EQUAL( result, expect::result ); \ +} + +TEST_CASE("tensordot_lhs_reshape(case1a)" * doctest::test_suite("index::tensordot_lhs_reshape")) +{ + TENSORDOT_LHS_RESHAPE_SUBCASE( case1a, lhs_shape, rhs_shape, lhs_axes ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case1a, lhs_shape_a, rhs_shape_a, lhs_axes_a ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case1a, lhs_shape_f, rhs_shape_f, lhs_axes_f ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case1a, lhs_shape_h, rhs_shape_h, lhs_axes_h ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case1a, lhs_shape_v, rhs_shape_v, lhs_axes_v ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case1a, lhs_shape_ct, rhs_shape_ct, lhs_axes_ct ); +} + +TEST_CASE("tensordot_lhs_reshape(case2a)" * doctest::test_suite("index::tensordot_lhs_reshape")) +{ + TENSORDOT_LHS_RESHAPE_SUBCASE( case2a, lhs_shape, rhs_shape, lhs_axes ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2a, lhs_shape_a, rhs_shape_a, lhs_axes_a ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2a, lhs_shape_f, rhs_shape_f, lhs_axes_f ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2a, lhs_shape_h, rhs_shape_h, lhs_axes_h ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2a, lhs_shape_v, rhs_shape_v, lhs_axes_v ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2a, lhs_shape_ct, rhs_shape_ct, lhs_axes_ct ); +} + +TEST_CASE("tensordot_lhs_reshape(case2b)" * doctest::test_suite("index::tensordot_lhs_reshape")) +{ + TENSORDOT_LHS_RESHAPE_SUBCASE( case2b, lhs_shape, rhs_shape, lhs_axes ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2b, lhs_shape_a, rhs_shape_a, lhs_axes_a ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2b, lhs_shape_f, rhs_shape_f, lhs_axes_f ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2b, lhs_shape_h, rhs_shape_h, lhs_axes_h ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2b, lhs_shape_v, rhs_shape_v, lhs_axes_v ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2b, lhs_shape_ct, rhs_shape_ct, lhs_axes_ct ); +} + +TEST_CASE("tensordot_lhs_reshape(case2c)" * doctest::test_suite("index::tensordot_lhs_reshape")) +{ + TENSORDOT_LHS_RESHAPE_SUBCASE( case2c, lhs_shape, rhs_shape, lhs_axes ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2c, lhs_shape_a, rhs_shape_a, lhs_axes_a ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2c, lhs_shape_f, rhs_shape_f, lhs_axes_f ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2c, lhs_shape_h, rhs_shape_h, lhs_axes_h ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2c, lhs_shape_v, rhs_shape_v, lhs_axes_v ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2c, lhs_shape_ct, rhs_shape_ct, lhs_axes_ct ); +} + +TEST_CASE("tensordot_lhs_reshape(case2d)" * doctest::test_suite("index::tensordot_lhs_reshape")) +{ + TENSORDOT_LHS_RESHAPE_SUBCASE( case2d, lhs_shape, rhs_shape, lhs_axes ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2d, lhs_shape_a, rhs_shape_a, lhs_axes_a ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2d, lhs_shape_f, rhs_shape_f, lhs_axes_f ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2d, lhs_shape_h, rhs_shape_h, lhs_axes_h ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2d, lhs_shape_v, rhs_shape_v, lhs_axes_v ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2d, lhs_shape_ct, rhs_shape_ct, lhs_axes_ct ); +} + +TEST_CASE("tensordot_lhs_reshape(case2e)" * doctest::test_suite("index::tensordot_lhs_reshape")) +{ + TENSORDOT_LHS_RESHAPE_SUBCASE( case2e, lhs_shape, rhs_shape, lhs_axes ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2e, lhs_shape_a, rhs_shape_a, lhs_axes_a ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2e, lhs_shape_f, rhs_shape_f, lhs_axes_f ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2e, lhs_shape_h, rhs_shape_h, lhs_axes_h ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2e, lhs_shape_v, rhs_shape_v, lhs_axes_v ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2e, lhs_shape_ct, rhs_shape_ct, lhs_axes_ct ); +} + +TEST_CASE("tensordot_lhs_reshape(case2f)" * doctest::test_suite("index::tensordot_lhs_reshape")) +{ + TENSORDOT_LHS_RESHAPE_SUBCASE( case2f, lhs_shape, rhs_shape, lhs_axes ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2f, lhs_shape_a, rhs_shape_a, lhs_axes_a ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2f, lhs_shape_f, rhs_shape_f, lhs_axes_f ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2f, lhs_shape_h, rhs_shape_h, lhs_axes_h ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2f, lhs_shape_v, rhs_shape_v, lhs_axes_v ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2f, lhs_shape_ct, rhs_shape_ct, lhs_axes_ct ); +} + +TEST_CASE("tensordot_lhs_reshape(case2h)" * doctest::test_suite("index::tensordot_lhs_reshape")) +{ + TENSORDOT_LHS_RESHAPE_SUBCASE( case2h, lhs_shape, rhs_shape, lhs_axes ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2h, lhs_shape_a, rhs_shape_a, lhs_axes_a ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2h, lhs_shape_f, rhs_shape_f, lhs_axes_f ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2h, lhs_shape_h, rhs_shape_h, lhs_axes_h ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2h, lhs_shape_v, rhs_shape_v, lhs_axes_v ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2h, lhs_shape_ct, rhs_shape_ct, lhs_axes_ct ); +} + +TEST_CASE("tensordot_lhs_reshape(case2j)" * doctest::test_suite("index::tensordot_lhs_reshape")) +{ + TENSORDOT_LHS_RESHAPE_SUBCASE( case2j, lhs_shape, rhs_shape, lhs_axes ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2j, lhs_shape_a, rhs_shape_a, lhs_axes_a ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2j, lhs_shape_f, rhs_shape_f, lhs_axes_f ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2j, lhs_shape_h, rhs_shape_h, lhs_axes_h ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2j, lhs_shape_v, rhs_shape_v, lhs_axes_v ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case2j, lhs_shape_ct, rhs_shape_ct, lhs_axes_ct ); +} + +TEST_CASE("tensordot_lhs_reshape(case3a)" * doctest::test_suite("index::tensordot_lhs_reshape")) +{ + TENSORDOT_LHS_RESHAPE_SUBCASE( case3a, lhs_shape, rhs_shape, lhs_axes ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case3a, lhs_shape_a, rhs_shape_a, lhs_axes_a ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case3a, lhs_shape_f, rhs_shape_f, lhs_axes_f ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case3a, lhs_shape_h, rhs_shape_h, lhs_axes_h ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case3a, lhs_shape_v, rhs_shape_v, lhs_axes_v ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case3a, lhs_shape_ct, rhs_shape_ct, lhs_axes_ct ); +} + +TEST_CASE("tensordot_lhs_reshape(case3b)" * doctest::test_suite("index::tensordot_lhs_reshape")) +{ + TENSORDOT_LHS_RESHAPE_SUBCASE( case3b, lhs_shape, rhs_shape, lhs_axes ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case3b, lhs_shape_a, rhs_shape_a, lhs_axes_a ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case3b, lhs_shape_f, rhs_shape_f, lhs_axes_f ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case3b, lhs_shape_h, rhs_shape_h, lhs_axes_h ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case3b, lhs_shape_v, rhs_shape_v, lhs_axes_v ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case3b, lhs_shape_ct, rhs_shape_ct, lhs_axes_ct ); +} + +TEST_CASE("tensordot_lhs_reshape(case3c)" * doctest::test_suite("index::tensordot_lhs_reshape")) +{ + TENSORDOT_LHS_RESHAPE_SUBCASE( case3c, lhs_shape, rhs_shape, lhs_axes ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case3c, lhs_shape_a, rhs_shape_a, lhs_axes_a ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case3c, lhs_shape_f, rhs_shape_f, lhs_axes_f ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case3c, lhs_shape_h, rhs_shape_h, lhs_axes_h ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case3c, lhs_shape_v, rhs_shape_v, lhs_axes_v ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case3c, lhs_shape_ct, rhs_shape_ct, lhs_axes_ct ); +} + +TEST_CASE("tensordot_lhs_reshape(case3d)" * doctest::test_suite("index::tensordot_lhs_reshape")) +{ + TENSORDOT_LHS_RESHAPE_SUBCASE( case3d, lhs_shape, rhs_shape, lhs_axes ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case3d, lhs_shape_a, rhs_shape_a, lhs_axes_a ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case3d, lhs_shape_f, rhs_shape_f, lhs_axes_f ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case3d, lhs_shape_h, rhs_shape_h, lhs_axes_h ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case3d, lhs_shape_v, rhs_shape_v, lhs_axes_v ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case3d, lhs_shape_ct, rhs_shape_ct, lhs_axes_ct ); +} + +TEST_CASE("tensordot_lhs_reshape(case3e)" * doctest::test_suite("index::tensordot_lhs_reshape")) +{ + TENSORDOT_LHS_RESHAPE_SUBCASE( case3e, lhs_shape, rhs_shape, lhs_axes ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case3e, lhs_shape_a, rhs_shape_a, lhs_axes_a ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case3e, lhs_shape_f, rhs_shape_f, lhs_axes_f ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case3e, lhs_shape_h, rhs_shape_h, lhs_axes_h ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case3e, lhs_shape_v, rhs_shape_v, lhs_axes_v ); + TENSORDOT_LHS_RESHAPE_SUBCASE( case3e, lhs_shape_ct, rhs_shape_ct, lhs_axes_ct ); +} \ No newline at end of file diff --git a/tests/view/CMakeLists.txt b/tests/view/CMakeLists.txt index f0cf64b72..e069998c5 100644 --- a/tests/view/CMakeLists.txt +++ b/tests/view/CMakeLists.txt @@ -142,9 +142,12 @@ set(ARRAY_UFUNCS_1_TEST_SOURCES src/cumsum.cpp src/dot.cpp src/inner.cpp + src/kron.cpp src/outer.cpp src/prod.cpp src/sum.cpp + src/tensordot.cpp + src/vecdot.cpp src/vector_norm.cpp src/ufuncs/add.cpp src/ufuncs/amax.cpp diff --git a/tests/view/src/kron.cpp b/tests/view/src/kron.cpp new file mode 100644 index 000000000..d6b5b72e2 --- /dev/null +++ b/tests/view/src/kron.cpp @@ -0,0 +1,246 @@ +#include "nmtools/array/view/kron.hpp" +#include "nmtools/testing/data/array/kron.hpp" +#include "nmtools/testing/doctest.hpp" + +#define KRON_SUBCASE(case_name,...) \ +SUBCASE(#case_name) \ +{ \ + NMTOOLS_TESTING_USE_CASE(array,kron,case_name); \ + auto result = nmtools::view::kron(__VA_ARGS__); \ + NMTOOLS_ASSERT_EQUAL( nmtools::shape(result), nmtools::shape(expect::result) ); \ + NMTOOLS_ASSERT_CLOSE( result, expect::result ); \ +} + +TEST_CASE("kron(case1a)" * doctest::test_suite("view::kron")) +{ + KRON_SUBCASE( case1a, a, b ); + KRON_SUBCASE( case1a, a_a, b_a ); + KRON_SUBCASE( case1a, a_f, b_f ); + KRON_SUBCASE( case1a, a_h, b_h ); + KRON_SUBCASE( case1a, a_d, b_d ); +} + +TEST_CASE("kron(case1b)" * doctest::test_suite("view::kron")) +{ + KRON_SUBCASE( case1b, a, b ); + KRON_SUBCASE( case1b, a_a, b_a ); + KRON_SUBCASE( case1b, a_f, b_f ); + KRON_SUBCASE( case1b, a_h, b_h ); + KRON_SUBCASE( case1b, a_d, b_d ); +} + +TEST_CASE("kron(case1c)" * doctest::test_suite("view::kron")) +{ + KRON_SUBCASE( case1c, a, b ); + KRON_SUBCASE( case1c, a_a, b_a ); + KRON_SUBCASE( case1c, a_f, b_f ); + KRON_SUBCASE( case1c, a_h, b_h ); + KRON_SUBCASE( case1c, a_d, b_d ); +} + +TEST_CASE("kron(case1d)" * doctest::test_suite("view::kron")) +{ + KRON_SUBCASE( case1d, a, b ); + KRON_SUBCASE( case1d, a_a, b_a ); + KRON_SUBCASE( case1d, a_f, b_f ); + KRON_SUBCASE( case1d, a_h, b_h ); + KRON_SUBCASE( case1d, a_d, b_d ); +} + +TEST_CASE("kron(case1e)" * doctest::test_suite("view::kron")) +{ + KRON_SUBCASE( case1e, a, b ); + KRON_SUBCASE( case1e, a_a, b_a ); + KRON_SUBCASE( case1e, a_f, b_f ); + KRON_SUBCASE( case1e, a_h, b_h ); + KRON_SUBCASE( case1e, a_d, b_d ); +} + +TEST_CASE("kron(case2a)" * doctest::test_suite("view::kron")) +{ + KRON_SUBCASE( case2a, a, b ); + KRON_SUBCASE( case2a, a_a, b_a ); + KRON_SUBCASE( case2a, a_f, b_f ); + KRON_SUBCASE( case2a, a_h, b_h ); + KRON_SUBCASE( case2a, a_d, b_d ); +} + +TEST_CASE("kron(case2b)" * doctest::test_suite("view::kron")) +{ + KRON_SUBCASE( case2b, a, b ); + KRON_SUBCASE( case2b, a_a, b_a ); + KRON_SUBCASE( case2b, a_f, b_f ); + KRON_SUBCASE( case2b, a_h, b_h ); + KRON_SUBCASE( case2b, a_d, b_d ); +} + +TEST_CASE("kron(case2c)" * doctest::test_suite("view::kron")) +{ + KRON_SUBCASE( case2c, a, b ); + KRON_SUBCASE( case2c, a_a, b_a ); + KRON_SUBCASE( case2c, a_f, b_f ); + KRON_SUBCASE( case2c, a_h, b_h ); + KRON_SUBCASE( case2c, a_d, b_d ); +} + +TEST_CASE("kron(case2d)" * doctest::test_suite("view::kron")) +{ + KRON_SUBCASE( case2d, a, b ); + KRON_SUBCASE( case2d, a_a, b_a ); + KRON_SUBCASE( case2d, a_f, b_f ); + KRON_SUBCASE( case2d, a_h, b_h ); + KRON_SUBCASE( case2d, a_d, b_d ); +} + +TEST_CASE("kron(case2e)" * doctest::test_suite("view::kron")) +{ + KRON_SUBCASE( case2e, a, b ); + KRON_SUBCASE( case2e, a_a, b_a ); + KRON_SUBCASE( case2e, a_f, b_f ); + KRON_SUBCASE( case2e, a_h, b_h ); + KRON_SUBCASE( case2e, a_d, b_d ); +} + +TEST_CASE("kron(case2f)" * doctest::test_suite("view::kron")) +{ + KRON_SUBCASE( case2f, a, b ); + KRON_SUBCASE( case2f, a_a, b_a ); + KRON_SUBCASE( case2f, a_f, b_f ); + KRON_SUBCASE( case2f, a_h, b_h ); + KRON_SUBCASE( case2f, a_d, b_d ); +} + +TEST_CASE("kron(case2g)" * doctest::test_suite("view::kron")) +{ + KRON_SUBCASE( case2g, a, b ); + KRON_SUBCASE( case2g, a_a, b_a ); + KRON_SUBCASE( case2g, a_f, b_f ); + KRON_SUBCASE( case2g, a_h, b_h ); + KRON_SUBCASE( case2g, a_d, b_d ); +} + +TEST_CASE("kron(case2h)" * doctest::test_suite("view::kron")) +{ + KRON_SUBCASE( case2h, a, b ); + KRON_SUBCASE( case2h, a_a, b_a ); + KRON_SUBCASE( case2h, a_f, b_f ); + KRON_SUBCASE( case2h, a_h, b_h ); + KRON_SUBCASE( case2h, a_d, b_d ); +} + +TEST_CASE("kron(case3a)" * doctest::test_suite("view::kron")) +{ + KRON_SUBCASE( case3a, a, b ); + KRON_SUBCASE( case3a, a_a, b_a ); + KRON_SUBCASE( case3a, a_f, b_f ); + KRON_SUBCASE( case3a, a_h, b_h ); + KRON_SUBCASE( case3a, a_d, b_d ); +} + +TEST_CASE("kron(case3b)" * doctest::test_suite("view::kron")) +{ + KRON_SUBCASE( case3b, a, b ); + KRON_SUBCASE( case3b, a_a, b_a ); + KRON_SUBCASE( case3b, a_f, b_f ); + KRON_SUBCASE( case3b, a_h, b_h ); + KRON_SUBCASE( case3b, a_d, b_d ); +} + +TEST_CASE("kron(case3d)" * doctest::test_suite("view::kron")) +{ + KRON_SUBCASE( case3d, a, b ); + KRON_SUBCASE( case3d, a_a, b_a ); + KRON_SUBCASE( case3d, a_f, b_f ); + KRON_SUBCASE( case3d, a_h, b_h ); + KRON_SUBCASE( case3d, a_d, b_d ); +} + +TEST_CASE("kron(case3e)" * doctest::test_suite("view::kron")) +{ + KRON_SUBCASE( case3e, a, b ); + KRON_SUBCASE( case3e, a_a, b_a ); + KRON_SUBCASE( case3e, a_f, b_f ); + KRON_SUBCASE( case3e, a_h, b_h ); + KRON_SUBCASE( case3e, a_d, b_d ); +} + +TEST_CASE("kron(case3f)" * doctest::test_suite("view::kron")) +{ + KRON_SUBCASE( case3f, a, b ); + KRON_SUBCASE( case3f, a_a, b_a ); + KRON_SUBCASE( case3f, a_f, b_f ); + KRON_SUBCASE( case3f, a_h, b_h ); + KRON_SUBCASE( case3f, a_d, b_d ); +} + +TEST_CASE("kron(case3g)" * doctest::test_suite("view::kron")) +{ + KRON_SUBCASE( case3g, a, b ); + KRON_SUBCASE( case3g, a_a, b_a ); + KRON_SUBCASE( case3g, a_f, b_f ); + KRON_SUBCASE( case3g, a_h, b_h ); + KRON_SUBCASE( case3g, a_d, b_d ); +} + +TEST_CASE("kron(case3h)" * doctest::test_suite("view::kron")) +{ + KRON_SUBCASE( case3h, a, b ); + KRON_SUBCASE( case3h, a_a, b_a ); + KRON_SUBCASE( case3h, a_f, b_f ); + KRON_SUBCASE( case3h, a_h, b_h ); + KRON_SUBCASE( case3h, a_d, b_d ); +} + +TEST_CASE("kron(case4a)" * doctest::test_suite("view::kron")) +{ + KRON_SUBCASE( case4a, a, b ); + KRON_SUBCASE( case4a, a_a, b_a ); + KRON_SUBCASE( case4a, a_f, b_f ); + KRON_SUBCASE( case4a, a_h, b_h ); + KRON_SUBCASE( case4a, a_d, b_d ); +} + +TEST_CASE("kron(case4b)" * doctest::test_suite("view::kron")) +{ + KRON_SUBCASE( case4b, a, b ); + KRON_SUBCASE( case4b, a_a, b_a ); + KRON_SUBCASE( case4b, a_f, b_f ); + KRON_SUBCASE( case4b, a_h, b_h ); + KRON_SUBCASE( case4b, a_d, b_d ); +} + +TEST_CASE("kron(case4c)" * doctest::test_suite("view::kron")) +{ + KRON_SUBCASE( case4c, a, b ); + KRON_SUBCASE( case4c, a_a, b_a ); + KRON_SUBCASE( case4c, a_f, b_f ); + KRON_SUBCASE( case4c, a_h, b_h ); + KRON_SUBCASE( case4c, a_d, b_d ); +} + +TEST_CASE("kron(case4d)" * doctest::test_suite("view::kron")) +{ + KRON_SUBCASE( case4d, a, b ); + KRON_SUBCASE( case4d, a_a, b_a ); + KRON_SUBCASE( case4d, a_f, b_f ); + KRON_SUBCASE( case4d, a_h, b_h ); + KRON_SUBCASE( case4d, a_d, b_d ); +} + +TEST_CASE("kron(case5a)" * doctest::test_suite("view::kron")) +{ + KRON_SUBCASE( case5a, a, b ); + KRON_SUBCASE( case5a, a_a, b_a ); + KRON_SUBCASE( case5a, a_f, b_f ); + KRON_SUBCASE( case5a, a_h, b_h ); + KRON_SUBCASE( case5a, a_d, b_d ); +} + +TEST_CASE("kron(case6a)" * doctest::test_suite("view::kron")) +{ + KRON_SUBCASE( case6a, a, b ); + KRON_SUBCASE( case6a, a_a, b_a ); + KRON_SUBCASE( case6a, a_f, b_f ); + KRON_SUBCASE( case6a, a_h, b_h ); + KRON_SUBCASE( case6a, a_d, b_d ); +} \ No newline at end of file diff --git a/tests/view/src/tensordot.cpp b/tests/view/src/tensordot.cpp new file mode 100644 index 000000000..1fd51bb20 --- /dev/null +++ b/tests/view/src/tensordot.cpp @@ -0,0 +1,267 @@ +#include "nmtools/array/view/tensordot.hpp" +#include "nmtools/testing/data/array/tensordot.hpp" +#include "nmtools/testing/doctest.hpp" + +#define TENSORDOT_SUBCASE(case_name, ...) \ +SUBCASE(#case_name) \ +{ \ + NMTOOLS_TESTING_USE_CASE(array,tensordot,case_name); \ + auto result = nmtools::view::tensordot(__VA_ARGS__); \ + NMTOOLS_ASSERT_EQUAL( nmtools::shape(result), nmtools::shape(expect::result) ); \ + NMTOOLS_ASSERT_CLOSE( result, expect::result ); \ +} + +TEST_CASE("tensordot(case1a)" * doctest::test_suite("view::tensordot")) +{ + // TODO: fix + // TENSORDOT_SUBCASE( case1a, a, b, axes ); + TENSORDOT_SUBCASE( case1a, a, b, axes_ct ); + TENSORDOT_SUBCASE( case1a, a_a, b_a, axes_ct ); + TENSORDOT_SUBCASE( case1a, a_f, b_f, axes_ct ); + TENSORDOT_SUBCASE( case1a, a_h, b_h, axes_ct ); + // TENSORDOT_SUBCASE( case1a, a_d, b_d, axes_ct ); +} + +TEST_CASE("tensordot(case2a)" * doctest::test_suite("view::tensordot")) +{ + TENSORDOT_SUBCASE( case2a, a, b ); + TENSORDOT_SUBCASE( case2a, a_a, b_a ); + TENSORDOT_SUBCASE( case2a, a_f, b_f ); + TENSORDOT_SUBCASE( case2a, a_h, b_h ); + // TODO: fix + // TENSORDOT_SUBCASE( case2a, a_d, b_d ); +} + +TEST_CASE("tensordot(case2b)" * doctest::test_suite("view::tensordot")) +{ + TENSORDOT_SUBCASE( case2b, a, b ); + TENSORDOT_SUBCASE( case2b, a_a, b_a ); + TENSORDOT_SUBCASE( case2b, a_f, b_f ); + TENSORDOT_SUBCASE( case2b, a_h, b_h ); + TENSORDOT_SUBCASE( case2b, a_d, b_d ); +} + +TEST_CASE("tensordot(case2c)" * doctest::test_suite("view::tensordot")) +{ + TENSORDOT_SUBCASE( case2c, a, b ); + TENSORDOT_SUBCASE( case2c, a_a, b_a ); + TENSORDOT_SUBCASE( case2c, a_f, b_f ); + TENSORDOT_SUBCASE( case2c, a_h, b_h ); + TENSORDOT_SUBCASE( case2c, a_d, b_d ); +} + +TEST_CASE("tensordot(case2d)" * doctest::test_suite("view::tensordot")) +{ + TENSORDOT_SUBCASE( case2d, a, b ); + TENSORDOT_SUBCASE( case2d, a_a, b_a ); + TENSORDOT_SUBCASE( case2d, a_f, b_f ); + TENSORDOT_SUBCASE( case2d, a_h, b_h ); + TENSORDOT_SUBCASE( case2d, a_d, b_d ); +} + +TEST_CASE("tensordot(case2e)" * doctest::test_suite("view::tensordot")) +{ + TENSORDOT_SUBCASE( case2e, a, b, axes ); + TENSORDOT_SUBCASE( case2e, a_a, b_a, axes ); + TENSORDOT_SUBCASE( case2e, a_f, b_f, axes ); + TENSORDOT_SUBCASE( case2e, a_h, b_h, axes ); + TENSORDOT_SUBCASE( case2e, a_d, b_d, axes ); + + TENSORDOT_SUBCASE( case2e, a, b, axes_ct ); + TENSORDOT_SUBCASE( case2e, a_a, b_a, axes_ct ); + TENSORDOT_SUBCASE( case2e, a_f, b_f, axes_ct ); + TENSORDOT_SUBCASE( case2e, a_h, b_h, axes_ct ); + TENSORDOT_SUBCASE( case2e, a_d, b_d, axes_ct ); +} + +TEST_CASE("tensordot(case2f)" * doctest::test_suite("view::tensordot")) +{ + TENSORDOT_SUBCASE( case2f, a, b, axes ); + TENSORDOT_SUBCASE( case2f, a_a, b_a, axes ); + TENSORDOT_SUBCASE( case2f, a_f, b_f, axes ); + TENSORDOT_SUBCASE( case2f, a_h, b_h, axes ); + TENSORDOT_SUBCASE( case2f, a_d, b_d, axes ); + + TENSORDOT_SUBCASE( case2f, a, b, axes_ct ); + TENSORDOT_SUBCASE( case2f, a_a, b_a, axes_ct ); + TENSORDOT_SUBCASE( case2f, a_f, b_f, axes_ct ); + TENSORDOT_SUBCASE( case2f, a_h, b_h, axes_ct ); + TENSORDOT_SUBCASE( case2f, a_d, b_d, axes_ct ); +} + +TEST_CASE("tensordot(case2f)" * doctest::test_suite("view::tensordot")) +{ + TENSORDOT_SUBCASE( case2f, a, b, axes ); + TENSORDOT_SUBCASE( case2f, a_a, b_a, axes ); + TENSORDOT_SUBCASE( case2f, a_f, b_f, axes ); + TENSORDOT_SUBCASE( case2f, a_h, b_h, axes ); + TENSORDOT_SUBCASE( case2f, a_d, b_d, axes ); + + TENSORDOT_SUBCASE( case2f, a, b, axes_ct ); + TENSORDOT_SUBCASE( case2f, a_a, b_a, axes_ct ); + TENSORDOT_SUBCASE( case2f, a_f, b_f, axes_ct ); + TENSORDOT_SUBCASE( case2f, a_h, b_h, axes_ct ); + TENSORDOT_SUBCASE( case2f, a_d, b_d, axes_ct ); +} + +TEST_CASE("tensordot(case2g)" * doctest::test_suite("view::tensordot")) +{ + TENSORDOT_SUBCASE( case2g, a, b, axes ); + TENSORDOT_SUBCASE( case2g, a_a, b_a, axes ); + TENSORDOT_SUBCASE( case2g, a_f, b_f, axes ); + TENSORDOT_SUBCASE( case2g, a_h, b_h, axes ); + TENSORDOT_SUBCASE( case2g, a_d, b_d, axes ); + + TENSORDOT_SUBCASE( case2g, a, b, axes_ct ); + TENSORDOT_SUBCASE( case2g, a_a, b_a, axes_ct ); + TENSORDOT_SUBCASE( case2g, a_f, b_f, axes_ct ); + TENSORDOT_SUBCASE( case2g, a_h, b_h, axes_ct ); + TENSORDOT_SUBCASE( case2g, a_d, b_d, axes_ct ); + +} + +TEST_CASE("tensordot(case2h)" * doctest::test_suite("view::tensordot")) +{ + TENSORDOT_SUBCASE( case2h, a, b, axes ); + TENSORDOT_SUBCASE( case2h, a_a, b_a, axes ); + TENSORDOT_SUBCASE( case2h, a_f, b_f, axes ); + TENSORDOT_SUBCASE( case2h, a_h, b_h, axes ); + TENSORDOT_SUBCASE( case2h, a_d, b_d, axes ); + + TENSORDOT_SUBCASE( case2h, a, b, axes_ct ); + TENSORDOT_SUBCASE( case2h, a_a, b_a, axes_ct ); + TENSORDOT_SUBCASE( case2h, a_f, b_f, axes_ct ); + TENSORDOT_SUBCASE( case2h, a_h, b_h, axes_ct ); + TENSORDOT_SUBCASE( case2h, a_d, b_d, axes_ct ); +} + +TEST_CASE("tensordot(case2i)" * doctest::test_suite("view::tensordot")) +{ + TENSORDOT_SUBCASE( case2i, a, b, axes ); + TENSORDOT_SUBCASE( case2i, a_a, b_a, axes ); + TENSORDOT_SUBCASE( case2i, a_f, b_f, axes ); + TENSORDOT_SUBCASE( case2i, a_h, b_h, axes ); + TENSORDOT_SUBCASE( case2i, a_d, b_d, axes ); + + TENSORDOT_SUBCASE( case2i, a, b, axes_ct ); + TENSORDOT_SUBCASE( case2i, a_a, b_a, axes_ct ); + TENSORDOT_SUBCASE( case2i, a_f, b_f, axes_ct ); + TENSORDOT_SUBCASE( case2i, a_h, b_h, axes_ct ); + TENSORDOT_SUBCASE( case2i, a_d, b_d, axes_ct ); +} + +TEST_CASE("tensordot(case2j)" * doctest::test_suite("view::tensordot")) +{ + TENSORDOT_SUBCASE( case2j, a, b, axes ); + TENSORDOT_SUBCASE( case2j, a_a, b_a, axes ); + TENSORDOT_SUBCASE( case2j, a_f, b_f, axes ); + TENSORDOT_SUBCASE( case2j, a_h, b_h, axes ); + TENSORDOT_SUBCASE( case2j, a_d, b_d, axes ); + + TENSORDOT_SUBCASE( case2j, a, b, axes_ct ); + TENSORDOT_SUBCASE( case2j, a_a, b_a, axes_ct ); + TENSORDOT_SUBCASE( case2j, a_f, b_f, axes_ct ); + TENSORDOT_SUBCASE( case2j, a_h, b_h, axes_ct ); + TENSORDOT_SUBCASE( case2j, a_d, b_d, axes_ct ); +} + +TEST_CASE("tensordot(case3a)" * doctest::test_suite("view::tensordot")) +{ + TENSORDOT_SUBCASE( case3a, a, b ); + TENSORDOT_SUBCASE( case3a, a_a, b_a ); + TENSORDOT_SUBCASE( case3a, a_f, b_f ); + TENSORDOT_SUBCASE( case3a, a_h, b_h ); + TENSORDOT_SUBCASE( case3a, a_d, b_d ); +} + +TEST_CASE("tensordot(case3b)" * doctest::test_suite("view::tensordot")) +{ + TENSORDOT_SUBCASE( case3b, a, b, axes ); + TENSORDOT_SUBCASE( case3b, a_a, b_a, axes ); + TENSORDOT_SUBCASE( case3b, a_f, b_f, axes ); + TENSORDOT_SUBCASE( case3b, a_h, b_h, axes ); + TENSORDOT_SUBCASE( case3b, a_d, b_d, axes ); + + TENSORDOT_SUBCASE( case3b, a, b, axes_ct ); + TENSORDOT_SUBCASE( case3b, a_a, b_a, axes_ct ); + TENSORDOT_SUBCASE( case3b, a_f, b_f, axes_ct ); + TENSORDOT_SUBCASE( case3b, a_h, b_h, axes_ct ); + TENSORDOT_SUBCASE( case3b, a_d, b_d, axes_ct ); +} + +TEST_CASE("tensordot(case3c)" * doctest::test_suite("view::tensordot")) +{ + TENSORDOT_SUBCASE( case3c, a, b, axes ); + TENSORDOT_SUBCASE( case3c, a_a, b_a, axes ); + TENSORDOT_SUBCASE( case3c, a_f, b_f, axes ); + TENSORDOT_SUBCASE( case3c, a_h, b_h, axes ); + TENSORDOT_SUBCASE( case3c, a_d, b_d, axes ); + + TENSORDOT_SUBCASE( case3c, a, b, axes_ct ); + TENSORDOT_SUBCASE( case3c, a_a, b_a, axes_ct ); + TENSORDOT_SUBCASE( case3c, a_f, b_f, axes_ct ); + TENSORDOT_SUBCASE( case3c, a_h, b_h, axes_ct ); + TENSORDOT_SUBCASE( case3c, a_d, b_d, axes_ct ); +} + + +TEST_CASE("tensordot(case3d)" * doctest::test_suite("view::tensordot")) +{ + TENSORDOT_SUBCASE( case3d, a, b, axes ); + TENSORDOT_SUBCASE( case3d, a_a, b_a, axes ); + TENSORDOT_SUBCASE( case3d, a_f, b_f, axes ); + TENSORDOT_SUBCASE( case3d, a_h, b_h, axes ); + TENSORDOT_SUBCASE( case3d, a_d, b_d, axes ); + + TENSORDOT_SUBCASE( case3d, a, b, axes_ct ); + TENSORDOT_SUBCASE( case3d, a_a, b_a, axes_ct ); + TENSORDOT_SUBCASE( case3d, a_f, b_f, axes_ct ); + TENSORDOT_SUBCASE( case3d, a_h, b_h, axes_ct ); + TENSORDOT_SUBCASE( case3d, a_d, b_d, axes_ct ); +} + +TEST_CASE("tensordot(case3e)" * doctest::test_suite("view::tensordot")) +{ + // TODO: fix compilation + // TENSORDOT_SUBCASE( case3e, a, b, axes ); + // TENSORDOT_SUBCASE( case3e, a_a, b_a, axes ); + // TENSORDOT_SUBCASE( case3e, a_f, b_f, axes ); + // TENSORDOT_SUBCASE( case3e, a_h, b_h, axes ); + // TENSORDOT_SUBCASE( case3e, a_d, b_d, axes ); + + TENSORDOT_SUBCASE( case3e, a, b, axes_ct ); + TENSORDOT_SUBCASE( case3e, a_a, b_a, axes_ct ); + TENSORDOT_SUBCASE( case3e, a_f, b_f, axes_ct ); + TENSORDOT_SUBCASE( case3e, a_h, b_h, axes_ct ); + // TENSORDOT_SUBCASE( case3e, a_d, b_d, axes_ct ); +} + +TEST_CASE("tensordot(case4a)" * doctest::test_suite("view::tensordot")) +{ + TENSORDOT_SUBCASE( case4a, a, b, axes ); + TENSORDOT_SUBCASE( case4a, a_a, b_a, axes ); + TENSORDOT_SUBCASE( case4a, a_f, b_f, axes ); + TENSORDOT_SUBCASE( case4a, a_h, b_h, axes ); + TENSORDOT_SUBCASE( case4a, a_d, b_d, axes ); + + TENSORDOT_SUBCASE( case4a, a, b, axes_ct ); + TENSORDOT_SUBCASE( case4a, a_a, b_a, axes_ct ); + TENSORDOT_SUBCASE( case4a, a_f, b_f, axes_ct ); + TENSORDOT_SUBCASE( case4a, a_h, b_h, axes_ct ); + TENSORDOT_SUBCASE( case4a, a_d, b_d, axes_ct ); +} + +TEST_CASE("tensordot(case4b)" * doctest::test_suite("view::tensordot")) +{ + TENSORDOT_SUBCASE( case4b, a, b, axes ); + TENSORDOT_SUBCASE( case4b, a_a, b_a, axes ); + TENSORDOT_SUBCASE( case4b, a_f, b_f, axes ); + TENSORDOT_SUBCASE( case4b, a_h, b_h, axes ); + TENSORDOT_SUBCASE( case4b, a_d, b_d, axes ); + + TENSORDOT_SUBCASE( case4b, a, b, axes_ct ); + TENSORDOT_SUBCASE( case4b, a_a, b_a, axes_ct ); + TENSORDOT_SUBCASE( case4b, a_f, b_f, axes_ct ); + TENSORDOT_SUBCASE( case4b, a_h, b_h, axes_ct ); + TENSORDOT_SUBCASE( case4b, a_d, b_d, axes_ct ); +} \ No newline at end of file diff --git a/tests/view/src/vecdot.cpp b/tests/view/src/vecdot.cpp new file mode 100644 index 000000000..b9e991538 --- /dev/null +++ b/tests/view/src/vecdot.cpp @@ -0,0 +1,149 @@ +#include "nmtools/array/view/vecdot.hpp" +#include "nmtools/testing/data/array/vecdot.hpp" +#include "nmtools/testing/doctest.hpp" + +using nmtools::None; + +#define VECDOT_SUBCASE(case_name,...) \ +SUBCASE(#case_name) \ +{ \ + NMTOOLS_TESTING_USE_CASE(array,vecdot,case_name); \ + auto result = nmtools::view::vecdot(__VA_ARGS__); \ + NMTOOLS_ASSERT_EQUAL( nmtools::shape(result), nmtools::shape(expect::result) ); \ + NMTOOLS_ASSERT_CLOSE( result, expect::result ); \ +} + +TEST_CASE("vecdot(case1a)" * doctest::test_suite("view::vecdot")) +{ + VECDOT_SUBCASE( case1a, a, b ); + VECDOT_SUBCASE( case1a, a_a, b_a ); + VECDOT_SUBCASE( case1a, a_f, b_f ); + VECDOT_SUBCASE( case1a, a_h, b_h ); + // VECDOT_SUBCASE( case1a, a_d, b_d ); +} + +TEST_CASE("vecdot(case1b)" * doctest::test_suite("view::vecdot")) +{ + VECDOT_SUBCASE( case1b, a, b, None, keepdims ); + VECDOT_SUBCASE( case1b, a_a, b_a, None, keepdims ); + VECDOT_SUBCASE( case1b, a_f, b_f, None, keepdims ); + VECDOT_SUBCASE( case1b, a_h, b_h, None, keepdims ); + VECDOT_SUBCASE( case1b, a_d, b_d, None, keepdims ); +} + +TEST_CASE("vecdot(case1c)" * doctest::test_suite("view::vecdot")) +{ + VECDOT_SUBCASE( case1c, a, b ); + VECDOT_SUBCASE( case1c, a_a, b_a ); + VECDOT_SUBCASE( case1c, a_f, b_f ); + VECDOT_SUBCASE( case1c, a_h, b_h ); + // VECDOT_SUBCASE( case1c, a_d, b_d ); +} + +TEST_CASE("vecdot(case1d)" * doctest::test_suite("view::vecdot")) +{ + VECDOT_SUBCASE( case1d, a, b, None, keepdims ); + VECDOT_SUBCASE( case1d, a_a, b_a, None, keepdims ); + VECDOT_SUBCASE( case1d, a_f, b_f, None, keepdims ); + VECDOT_SUBCASE( case1d, a_h, b_h, None, keepdims ); + VECDOT_SUBCASE( case1d, a_d, b_d, None, keepdims ); +} + +TEST_CASE("vecdot(case1e)" * doctest::test_suite("view::vecdot")) +{ + VECDOT_SUBCASE( case1e, a, b ); + VECDOT_SUBCASE( case1e, a_a, b_a ); + VECDOT_SUBCASE( case1e, a_f, b_f ); + VECDOT_SUBCASE( case1e, a_h, b_h ); + // VECDOT_SUBCASE( case1e, a_d, b_d ); +} + +TEST_CASE("vecdot(case1f)" * doctest::test_suite("view::vecdot")) +{ + VECDOT_SUBCASE( case1f, a, b, None, keepdims ); + VECDOT_SUBCASE( case1f, a_a, b_a, None, keepdims ); + VECDOT_SUBCASE( case1f, a_f, b_f, None, keepdims ); + VECDOT_SUBCASE( case1f, a_h, b_h, None, keepdims ); + VECDOT_SUBCASE( case1f, a_d, b_d, None, keepdims ); +} + +TEST_CASE("vecdot(case2a)" * doctest::test_suite("view::vecdot")) +{ + VECDOT_SUBCASE( case2a, a, b ); + VECDOT_SUBCASE( case2a, a_a, b_a ); + VECDOT_SUBCASE( case2a, a_f, b_f ); + VECDOT_SUBCASE( case2a, a_h, b_h ); + VECDOT_SUBCASE( case2a, a_d, b_d ); +} + +TEST_CASE("vecdot(case2b)" * doctest::test_suite("view::vecdot")) +{ + VECDOT_SUBCASE( case2b, a, b, None, keepdims ); + VECDOT_SUBCASE( case2b, a_a, b_a, None, keepdims ); + VECDOT_SUBCASE( case2b, a_f, b_f, None, keepdims ); + VECDOT_SUBCASE( case2b, a_h, b_h, None, keepdims ); + VECDOT_SUBCASE( case2b, a_d, b_d, None, keepdims ); +} + +TEST_CASE("vecdot(case2c)" * doctest::test_suite("view::vecdot")) +{ + VECDOT_SUBCASE( case2c, a, b ); + VECDOT_SUBCASE( case2c, a_a, b_a ); + VECDOT_SUBCASE( case2c, a_f, b_f ); + VECDOT_SUBCASE( case2c, a_h, b_h ); + VECDOT_SUBCASE( case2c, a_d, b_d ); +} + +TEST_CASE("vecdot(case2d)" * doctest::test_suite("view::vecdot")) +{ + VECDOT_SUBCASE( case2d, a, b, None, keepdims ); + VECDOT_SUBCASE( case2d, a_a, b_a, None, keepdims ); + VECDOT_SUBCASE( case2d, a_f, b_f, None, keepdims ); + VECDOT_SUBCASE( case2d, a_h, b_h, None, keepdims ); + VECDOT_SUBCASE( case2d, a_d, b_d, None, keepdims ); +} + +TEST_CASE("vecdot(case2e)" * doctest::test_suite("view::vecdot")) +{ + VECDOT_SUBCASE( case2e, a, b ); + VECDOT_SUBCASE( case2e, a_a, b_a ); + VECDOT_SUBCASE( case2e, a_f, b_f ); + VECDOT_SUBCASE( case2e, a_h, b_h ); + VECDOT_SUBCASE( case2e, a_d, b_d ); +} + +TEST_CASE("vecdot(case2f)" * doctest::test_suite("view::vecdot")) +{ + VECDOT_SUBCASE( case2f, a, b, None, keepdims ); + VECDOT_SUBCASE( case2f, a_a, b_a, None, keepdims ); + VECDOT_SUBCASE( case2f, a_f, b_f, None, keepdims ); + VECDOT_SUBCASE( case2f, a_h, b_h, None, keepdims ); + VECDOT_SUBCASE( case2f, a_d, b_d, None, keepdims ); +} + +TEST_CASE("vecdot(case3a)" * doctest::test_suite("view::vecdot")) +{ + VECDOT_SUBCASE( case3a, a, b ); + VECDOT_SUBCASE( case3a, a_a, b_a ); + VECDOT_SUBCASE( case3a, a_f, b_f ); + VECDOT_SUBCASE( case3a, a_h, b_h ); + VECDOT_SUBCASE( case3a, a_d, b_d ); +} + +TEST_CASE("vecdot(case3b)" * doctest::test_suite("view::vecdot")) +{ + VECDOT_SUBCASE( case3b, a, b, None, keepdims ); + VECDOT_SUBCASE( case3b, a_a, b_a, None, keepdims ); + VECDOT_SUBCASE( case3b, a_f, b_f, None, keepdims ); + VECDOT_SUBCASE( case3b, a_h, b_h, None, keepdims ); + VECDOT_SUBCASE( case3b, a_d, b_d, None, keepdims ); +} + +TEST_CASE("vecdot(case3c)" * doctest::test_suite("view::vecdot")) +{ + VECDOT_SUBCASE( case3c, a, b ); + VECDOT_SUBCASE( case3c, a_a, b_a ); + VECDOT_SUBCASE( case3c, a_f, b_f ); + VECDOT_SUBCASE( case3c, a_h, b_h ); + VECDOT_SUBCASE( case3c, a_d, b_d ); +}