-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
1,099 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
#include "nmtools/array/array/mean.hpp" | ||
#include "nmtools/array/array/ufuncs/subtract.hpp" | ||
#include "nmtools/array/array/arange.hpp" | ||
#include "nmtools/array/array/reshape.hpp" | ||
#include "nmtools/array/eval/cuda.hpp" | ||
#include "nmtools/testing/doctest.hpp" | ||
#include "nmtools/testing/data/array/arange.hpp" | ||
|
||
namespace nm = nmtools; | ||
namespace na = nmtools::array; | ||
namespace ix = nmtools::index; | ||
namespace fn = nmtools::functional; | ||
namespace view = nm::view; | ||
|
||
namespace composition | ||
{ | ||
template <typename array_t, typename axis_t, typename dtype_t=nmtools::none_t> | ||
constexpr auto mean_subtract(const array_t& array, const axis_t& axis, dtype_t dtype=dtype_t{}) | ||
{ | ||
// must keep dimension to properly subtract | ||
auto a = view::mean(array,axis,dtype,nmtools::True); | ||
auto b = view::subtract(array,a); | ||
return b; | ||
} // mean_subtract | ||
} // composition | ||
|
||
TEST_CASE("mean_subtract(case1)" * doctest::test_suite("mean_subtract")) | ||
{ | ||
auto shape = nmtools_array{128}; | ||
auto numel = ix::product(shape); | ||
auto start = 0; | ||
auto stop = start+numel; | ||
auto step = 1; | ||
auto dtype = nm::float32; | ||
|
||
auto input = na::reshape(na::arange(start,stop,step,dtype),shape); | ||
auto axis = 0; | ||
auto array = composition::mean_subtract(input,axis,dtype); | ||
} | ||
|
||
TEST_CASE("mean" * doctest::test_suite("get_graph")) | ||
{ | ||
auto nodes = nmtools_tuple{fn::alias,fn::alias,fn::reduce_add,fn::divide}; | ||
/* operand list */ | ||
auto adjacency_list = nmtools_tuple{nm::None,nm::None,nmtools_array{0},nmtools_array{2,1}}; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
#include "nmtools/array/array/ufuncs/multiply.hpp" | ||
#include "nmtools/array/array/ufuncs/tanh.hpp" | ||
#include "nmtools/array/array/arange.hpp" | ||
#include "nmtools/array/functional/functor.hpp" | ||
#include "nmtools/array/functional/alias.hpp" | ||
#include "nmtools/array/functional/ufuncs/multiply.hpp" | ||
#include "nmtools/array/functional/ufuncs/tanh.hpp" | ||
#include "nmtools/array/functional/ufunc/ufunc.hpp" | ||
#include "nmtools/array/array/reshape.hpp" | ||
#include "nmtools/testing/doctest.hpp" | ||
#include "nmtools/testing/data/array/arange.hpp" | ||
|
||
namespace nm = nmtools; | ||
namespace na = nmtools::array; | ||
namespace ix = nmtools::index; | ||
namespace fn = nmtools::functional; | ||
namespace meta = nm::meta; | ||
namespace view = nm::view; | ||
|
||
using namespace nmtools::literals; | ||
|
||
TEST_CASE("multiply" * doctest::test_suite("functional::graph")) | ||
{ | ||
auto lhs_shape = nmtools_array{3,4}; | ||
auto lhs_buffer = na::arange(12); | ||
auto lhs_array = na::reshape(lhs_buffer,lhs_shape); | ||
|
||
auto rhs_shape = nmtools_array{4}; | ||
auto rhs_buffer = na::arange(4); | ||
auto rhs_array = na::reshape(rhs_buffer,rhs_shape); | ||
|
||
auto a = view::multiply(lhs_array,rhs_array); | ||
|
||
auto nodes = nmtools_tuple{ | ||
fn::alias | ||
, fn::alias | ||
, fn::multiply | ||
}; | ||
auto edges = nmtools_tuple{ | ||
nmtools_tuple<>{} | ||
, nmtools_tuple<>{} | ||
, nmtools_tuple{meta::ct_v<(int)-2>,meta::ct_v<(int)-1>} | ||
}; | ||
|
||
auto graph = fn::get_graph(a); | ||
|
||
NMTOOLS_ASSERT_EQUAL( meta::len_v<decltype(graph.nodes)>, meta::len_v<decltype(nodes)> ); | ||
NMTOOLS_ASSERT_EQUAL( meta::len_v<decltype(graph.edges)>, meta::len_v<decltype(edges)> ); | ||
|
||
{ | ||
constexpr auto M = meta::len_v<decltype(graph.nodes)>; | ||
constexpr auto N = meta::len_v<decltype(nodes)>; | ||
constexpr auto LEN = M > N ? N : M; | ||
meta::template_for<LEN>([&](auto i){ | ||
auto expect = nm::at(nodes,i); | ||
auto result = nm::at(graph.nodes,i); | ||
using expect_t = decltype(expect); | ||
using result_t = decltype(result); | ||
NMTOOLS_STATIC_CHECK_IS_SAME( result_t, expect_t ); | ||
}); | ||
} | ||
|
||
{ | ||
constexpr auto M = meta::len_v<decltype(graph.edges)>; | ||
constexpr auto N = meta::len_v<decltype(edges)>; | ||
constexpr auto LEN = M > N ? N : M; | ||
meta::template_for<LEN>([&](auto i){ | ||
auto expect = nm::at(edges,i); | ||
auto result = nm::at(graph.edges,i); | ||
using expect_t = decltype(expect); | ||
using result_t = decltype(result); | ||
NMTOOLS_STATIC_CHECK_IS_SAME( result_t, expect_t ); | ||
}); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
#include "nmtools/array/array/ufuncs/multiply.hpp" | ||
#include "nmtools/array/array/ufuncs/tanh.hpp" | ||
#include "nmtools/array/array/arange.hpp" | ||
#include "nmtools/array/functional/functor.hpp" | ||
#include "nmtools/array/functional/alias.hpp" | ||
#include "nmtools/array/functional/ufuncs/multiply.hpp" | ||
#include "nmtools/array/functional/ufuncs/add.hpp" | ||
#include "nmtools/array/functional/ufuncs/tanh.hpp" | ||
#include "nmtools/array/functional/ufunc/ufunc.hpp" | ||
#include "nmtools/array/array/reshape.hpp" | ||
#include "nmtools/testing/doctest.hpp" | ||
#include "nmtools/testing/data/array/arange.hpp" | ||
|
||
namespace nm = nmtools; | ||
namespace na = nmtools::array; | ||
namespace ix = nmtools::index; | ||
namespace fn = nmtools::functional; | ||
namespace meta = nm::meta; | ||
namespace view = nm::view; | ||
|
||
using namespace nmtools::literals; | ||
|
||
TEST_CASE("multiply_add_tanh" * doctest::test_suite("functional::graph")) | ||
{ | ||
auto lhs_shape = nmtools_array{3,4}; | ||
auto lhs_buffer = na::arange(12); | ||
auto lhs_array = na::reshape(lhs_buffer,lhs_shape); | ||
|
||
auto rhs_shape = nmtools_array{4}; | ||
auto rhs_buffer = na::arange(4); | ||
auto rhs_array = na::reshape(rhs_buffer,rhs_shape); | ||
|
||
auto a = view::multiply(lhs_array,rhs_array); | ||
auto b = view::add(lhs_array,a); | ||
auto c = view::tanh(b); | ||
|
||
// lhs_array duplicated because it is not aliased | ||
auto nodes = nmtools_tuple{ | ||
fn::alias // add lhs (lhs_array) | ||
, fn::alias // mul lhs (lhs_array) | ||
, fn::alias // mul rhs (rhs_array) | ||
, fn::multiply // add rhs | ||
, fn::add | ||
, fn::tanh | ||
}; | ||
auto edges = nmtools_tuple{ | ||
nmtools_tuple<>{} | ||
, nmtools_tuple<>{} | ||
, nmtools_tuple<>{} | ||
, nmtools_tuple{meta::ct_v<(int)-2>,meta::ct_v<(int)-1>} | ||
, nmtools_tuple{meta::ct_v<(int)-4>,meta::ct_v<(int)-1>} | ||
, nmtools_tuple{meta::ct_v<(int)-1>} | ||
}; | ||
|
||
auto graph = fn::get_graph(c); | ||
|
||
NMTOOLS_ASSERT_EQUAL( meta::len_v<decltype(graph.nodes)>, meta::len_v<decltype(nodes)> ); | ||
NMTOOLS_ASSERT_EQUAL( meta::len_v<decltype(graph.edges)>, meta::len_v<decltype(edges)> ); | ||
|
||
{ | ||
constexpr auto M = meta::len_v<decltype(graph.nodes)>; | ||
constexpr auto N = meta::len_v<decltype(nodes)>; | ||
constexpr auto LEN = M > N ? N : M; | ||
meta::template_for<LEN>([&](auto i){ | ||
auto expect = nm::at(nodes,i); | ||
auto result = nm::at(graph.nodes,i); | ||
using expect_t = decltype(expect); | ||
using result_t = decltype(result); | ||
NMTOOLS_STATIC_CHECK_IS_SAME( result_t, expect_t ); | ||
}); | ||
} | ||
|
||
{ | ||
constexpr auto M = meta::len_v<decltype(graph.edges)>; | ||
constexpr auto N = meta::len_v<decltype(edges)>; | ||
constexpr auto LEN = M > N ? N : M; | ||
meta::template_for<LEN>([&](auto i){ | ||
auto expect = nm::at(edges,i); | ||
auto result = nm::at(graph.edges,i); | ||
using expect_t = decltype(expect); | ||
using result_t = decltype(result); | ||
NMTOOLS_STATIC_CHECK_IS_SAME( result_t, expect_t ); | ||
}); | ||
} | ||
} | ||
|
||
TEST_CASE("multiply_add_tanh" * doctest::test_suite("functional::graph")) | ||
{ | ||
auto lhs_shape = nmtools_array{3,4}; | ||
auto lhs_buffer = na::arange(12); | ||
auto lhs_array = na::reshape(lhs_buffer,lhs_shape); | ||
|
||
auto rhs_shape = nmtools_array{4}; | ||
auto rhs_buffer = na::arange(4); | ||
auto rhs_array = na::reshape(rhs_buffer,rhs_shape); | ||
|
||
auto a = view::multiply(lhs_array,rhs_array); | ||
auto b = view::add(a,rhs_array); | ||
auto c = view::tanh(b); | ||
|
||
// lhs_array duplicated because it is not aliased | ||
auto nodes = nmtools_tuple{ | ||
fn::alias // mul lhs (lhs_array) | ||
, fn::alias // mul rhs (rhs_array) | ||
, fn::multiply // add lhs | ||
, fn::alias // add rhs | ||
, fn::add | ||
, fn::tanh | ||
}; | ||
auto edges = nmtools_tuple{ | ||
nmtools_tuple<>{} | ||
, nmtools_tuple<>{} | ||
, nmtools_tuple{meta::ct_v<(int)-2>,meta::ct_v<(int)-1>} | ||
, nmtools_tuple<>{} | ||
, nmtools_tuple{meta::ct_v<(int)-2>,meta::ct_v<(int)-1>} | ||
, nmtools_tuple{meta::ct_v<(int)-1>} | ||
}; | ||
|
||
auto graph = fn::get_graph(c); | ||
|
||
NMTOOLS_ASSERT_EQUAL( meta::len_v<decltype(graph.nodes)>, meta::len_v<decltype(nodes)> ); | ||
NMTOOLS_ASSERT_EQUAL( meta::len_v<decltype(graph.edges)>, meta::len_v<decltype(edges)> ); | ||
|
||
{ | ||
constexpr auto M = meta::len_v<decltype(graph.nodes)>; | ||
constexpr auto N = meta::len_v<decltype(nodes)>; | ||
constexpr auto LEN = M > N ? N : M; | ||
meta::template_for<LEN>([&](auto i){ | ||
auto expect = nm::at(nodes,i); | ||
auto result = nm::at(graph.nodes,i); | ||
using expect_t = decltype(expect); | ||
using result_t = decltype(result); | ||
NMTOOLS_STATIC_CHECK_IS_SAME( result_t, expect_t ); | ||
}); | ||
} | ||
|
||
{ | ||
constexpr auto M = meta::len_v<decltype(graph.edges)>; | ||
constexpr auto N = meta::len_v<decltype(edges)>; | ||
constexpr auto LEN = M > N ? N : M; | ||
meta::template_for<LEN>([&](auto i){ | ||
auto expect = nm::at(edges,i); | ||
auto result = nm::at(graph.edges,i); | ||
using expect_t = decltype(expect); | ||
using result_t = decltype(result); | ||
NMTOOLS_STATIC_CHECK_IS_SAME( result_t, expect_t ); | ||
}); | ||
} | ||
} |
Oops, something went wrong.