Skip to content

Commit

Permalink
add compile-time graph tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alifahrri committed Jan 20, 2024
1 parent e547269 commit 4902e7d
Show file tree
Hide file tree
Showing 14 changed files with 1,099 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tests/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ set(NMTOOLS_CUDA_TEST_SOURCES ${NMTOOLS_CUDA_TEST_SOURCES}
array/prod.cpp
array/cumprod.cpp
array/mean.cpp

# composition/mean_subtract.cpp
)

## TODO: support nvcc compilation
Expand Down
46 changes: 46 additions & 0 deletions tests/cuda/composition/mean_subtract.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#include "nmtools/array/array/mean.hpp"
#include "nmtools/array/array/ufuncs/subtract.hpp"
#include "nmtools/array/array/arange.hpp"
#include "nmtools/array/array/reshape.hpp"
#include "nmtools/array/eval/cuda.hpp"
#include "nmtools/testing/doctest.hpp"
#include "nmtools/testing/data/array/arange.hpp"

namespace nm = nmtools;
namespace na = nmtools::array;
namespace ix = nmtools::index;
namespace fn = nmtools::functional;
namespace view = nm::view;

namespace composition
{
template <typename array_t, typename axis_t, typename dtype_t=nmtools::none_t>
constexpr auto mean_subtract(const array_t& array, const axis_t& axis, dtype_t dtype=dtype_t{})
{
// must keep dimension to properly subtract
auto a = view::mean(array,axis,dtype,nmtools::True);
auto b = view::subtract(array,a);
return b;
} // mean_subtract
} // composition

TEST_CASE("mean_subtract(case1)" * doctest::test_suite("mean_subtract"))
{
auto shape = nmtools_array{128};
auto numel = ix::product(shape);
auto start = 0;
auto stop = start+numel;
auto step = 1;
auto dtype = nm::float32;

auto input = na::reshape(na::arange(start,stop,step,dtype),shape);
auto axis = 0;
auto array = composition::mean_subtract(input,axis,dtype);
}

TEST_CASE("mean" * doctest::test_suite("get_graph"))
{
auto nodes = nmtools_tuple{fn::alias,fn::alias,fn::reduce_add,fn::divide};
/* operand list */
auto adjacency_list = nmtools_tuple{nm::None,nm::None,nmtools_array{0},nmtools_array{2,1}};
}
12 changes: 12 additions & 0 deletions tests/functional/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,18 @@ if (NMTOOLS_FUNCTIONAL_TEST_ALL)
src/transpose.cpp
src/var.cpp
src/where.cpp
src/graph/tanh.cpp
src/graph/multiply.cpp
src/graph/reduce_add_tanh.cpp
src/graph/multiply_tanh.cpp
src/graph/multiply_add_tanh.cpp
src/graph/multiply_add_tanh_add.cpp
src/graph/reduce_add_divide.cpp
src/graph/var.cpp
src/graph/stddev.cpp
src/graph/softmax.cpp

src/misc/ct_map.cpp
)
endif()

Expand Down
75 changes: 75 additions & 0 deletions tests/functional/src/graph/multiply.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#include "nmtools/array/array/ufuncs/multiply.hpp"
#include "nmtools/array/array/ufuncs/tanh.hpp"
#include "nmtools/array/array/arange.hpp"
#include "nmtools/array/functional/functor.hpp"
#include "nmtools/array/functional/alias.hpp"
#include "nmtools/array/functional/ufuncs/multiply.hpp"
#include "nmtools/array/functional/ufuncs/tanh.hpp"
#include "nmtools/array/functional/ufunc/ufunc.hpp"
#include "nmtools/array/array/reshape.hpp"
#include "nmtools/testing/doctest.hpp"
#include "nmtools/testing/data/array/arange.hpp"

namespace nm = nmtools;
namespace na = nmtools::array;
namespace ix = nmtools::index;
namespace fn = nmtools::functional;
namespace meta = nm::meta;
namespace view = nm::view;

using namespace nmtools::literals;

TEST_CASE("multiply" * doctest::test_suite("functional::graph"))
{
auto lhs_shape = nmtools_array{3,4};
auto lhs_buffer = na::arange(12);
auto lhs_array = na::reshape(lhs_buffer,lhs_shape);

auto rhs_shape = nmtools_array{4};
auto rhs_buffer = na::arange(4);
auto rhs_array = na::reshape(rhs_buffer,rhs_shape);

auto a = view::multiply(lhs_array,rhs_array);

auto nodes = nmtools_tuple{
fn::alias
, fn::alias
, fn::multiply
};
auto edges = nmtools_tuple{
nmtools_tuple<>{}
, nmtools_tuple<>{}
, nmtools_tuple{meta::ct_v<(int)-2>,meta::ct_v<(int)-1>}
};

auto graph = fn::get_graph(a);

NMTOOLS_ASSERT_EQUAL( meta::len_v<decltype(graph.nodes)>, meta::len_v<decltype(nodes)> );
NMTOOLS_ASSERT_EQUAL( meta::len_v<decltype(graph.edges)>, meta::len_v<decltype(edges)> );

{
constexpr auto M = meta::len_v<decltype(graph.nodes)>;
constexpr auto N = meta::len_v<decltype(nodes)>;
constexpr auto LEN = M > N ? N : M;
meta::template_for<LEN>([&](auto i){
auto expect = nm::at(nodes,i);
auto result = nm::at(graph.nodes,i);
using expect_t = decltype(expect);
using result_t = decltype(result);
NMTOOLS_STATIC_CHECK_IS_SAME( result_t, expect_t );
});
}

{
constexpr auto M = meta::len_v<decltype(graph.edges)>;
constexpr auto N = meta::len_v<decltype(edges)>;
constexpr auto LEN = M > N ? N : M;
meta::template_for<LEN>([&](auto i){
auto expect = nm::at(edges,i);
auto result = nm::at(graph.edges,i);
using expect_t = decltype(expect);
using result_t = decltype(result);
NMTOOLS_STATIC_CHECK_IS_SAME( result_t, expect_t );
});
}
}
149 changes: 149 additions & 0 deletions tests/functional/src/graph/multiply_add_tanh.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
#include "nmtools/array/array/ufuncs/multiply.hpp"
#include "nmtools/array/array/ufuncs/tanh.hpp"
#include "nmtools/array/array/arange.hpp"
#include "nmtools/array/functional/functor.hpp"
#include "nmtools/array/functional/alias.hpp"
#include "nmtools/array/functional/ufuncs/multiply.hpp"
#include "nmtools/array/functional/ufuncs/add.hpp"
#include "nmtools/array/functional/ufuncs/tanh.hpp"
#include "nmtools/array/functional/ufunc/ufunc.hpp"
#include "nmtools/array/array/reshape.hpp"
#include "nmtools/testing/doctest.hpp"
#include "nmtools/testing/data/array/arange.hpp"

namespace nm = nmtools;
namespace na = nmtools::array;
namespace ix = nmtools::index;
namespace fn = nmtools::functional;
namespace meta = nm::meta;
namespace view = nm::view;

using namespace nmtools::literals;

TEST_CASE("multiply_add_tanh" * doctest::test_suite("functional::graph"))
{
auto lhs_shape = nmtools_array{3,4};
auto lhs_buffer = na::arange(12);
auto lhs_array = na::reshape(lhs_buffer,lhs_shape);

auto rhs_shape = nmtools_array{4};
auto rhs_buffer = na::arange(4);
auto rhs_array = na::reshape(rhs_buffer,rhs_shape);

auto a = view::multiply(lhs_array,rhs_array);
auto b = view::add(lhs_array,a);
auto c = view::tanh(b);

// lhs_array duplicated because it is not aliased
auto nodes = nmtools_tuple{
fn::alias // add lhs (lhs_array)
, fn::alias // mul lhs (lhs_array)
, fn::alias // mul rhs (rhs_array)
, fn::multiply // add rhs
, fn::add
, fn::tanh
};
auto edges = nmtools_tuple{
nmtools_tuple<>{}
, nmtools_tuple<>{}
, nmtools_tuple<>{}
, nmtools_tuple{meta::ct_v<(int)-2>,meta::ct_v<(int)-1>}
, nmtools_tuple{meta::ct_v<(int)-4>,meta::ct_v<(int)-1>}
, nmtools_tuple{meta::ct_v<(int)-1>}
};

auto graph = fn::get_graph(c);

NMTOOLS_ASSERT_EQUAL( meta::len_v<decltype(graph.nodes)>, meta::len_v<decltype(nodes)> );
NMTOOLS_ASSERT_EQUAL( meta::len_v<decltype(graph.edges)>, meta::len_v<decltype(edges)> );

{
constexpr auto M = meta::len_v<decltype(graph.nodes)>;
constexpr auto N = meta::len_v<decltype(nodes)>;
constexpr auto LEN = M > N ? N : M;
meta::template_for<LEN>([&](auto i){
auto expect = nm::at(nodes,i);
auto result = nm::at(graph.nodes,i);
using expect_t = decltype(expect);
using result_t = decltype(result);
NMTOOLS_STATIC_CHECK_IS_SAME( result_t, expect_t );
});
}

{
constexpr auto M = meta::len_v<decltype(graph.edges)>;
constexpr auto N = meta::len_v<decltype(edges)>;
constexpr auto LEN = M > N ? N : M;
meta::template_for<LEN>([&](auto i){
auto expect = nm::at(edges,i);
auto result = nm::at(graph.edges,i);
using expect_t = decltype(expect);
using result_t = decltype(result);
NMTOOLS_STATIC_CHECK_IS_SAME( result_t, expect_t );
});
}
}

TEST_CASE("multiply_add_tanh" * doctest::test_suite("functional::graph"))
{
auto lhs_shape = nmtools_array{3,4};
auto lhs_buffer = na::arange(12);
auto lhs_array = na::reshape(lhs_buffer,lhs_shape);

auto rhs_shape = nmtools_array{4};
auto rhs_buffer = na::arange(4);
auto rhs_array = na::reshape(rhs_buffer,rhs_shape);

auto a = view::multiply(lhs_array,rhs_array);
auto b = view::add(a,rhs_array);
auto c = view::tanh(b);

// lhs_array duplicated because it is not aliased
auto nodes = nmtools_tuple{
fn::alias // mul lhs (lhs_array)
, fn::alias // mul rhs (rhs_array)
, fn::multiply // add lhs
, fn::alias // add rhs
, fn::add
, fn::tanh
};
auto edges = nmtools_tuple{
nmtools_tuple<>{}
, nmtools_tuple<>{}
, nmtools_tuple{meta::ct_v<(int)-2>,meta::ct_v<(int)-1>}
, nmtools_tuple<>{}
, nmtools_tuple{meta::ct_v<(int)-2>,meta::ct_v<(int)-1>}
, nmtools_tuple{meta::ct_v<(int)-1>}
};

auto graph = fn::get_graph(c);

NMTOOLS_ASSERT_EQUAL( meta::len_v<decltype(graph.nodes)>, meta::len_v<decltype(nodes)> );
NMTOOLS_ASSERT_EQUAL( meta::len_v<decltype(graph.edges)>, meta::len_v<decltype(edges)> );

{
constexpr auto M = meta::len_v<decltype(graph.nodes)>;
constexpr auto N = meta::len_v<decltype(nodes)>;
constexpr auto LEN = M > N ? N : M;
meta::template_for<LEN>([&](auto i){
auto expect = nm::at(nodes,i);
auto result = nm::at(graph.nodes,i);
using expect_t = decltype(expect);
using result_t = decltype(result);
NMTOOLS_STATIC_CHECK_IS_SAME( result_t, expect_t );
});
}

{
constexpr auto M = meta::len_v<decltype(graph.edges)>;
constexpr auto N = meta::len_v<decltype(edges)>;
constexpr auto LEN = M > N ? N : M;
meta::template_for<LEN>([&](auto i){
auto expect = nm::at(edges,i);
auto result = nm::at(graph.edges,i);
using expect_t = decltype(expect);
using result_t = decltype(result);
NMTOOLS_STATIC_CHECK_IS_SAME( result_t, expect_t );
});
}
}
Loading

0 comments on commit 4902e7d

Please sign in to comment.